diff --git a/src-tauri/src/clientinfo.rs b/src-tauri/src/clientinfo.rs index 2285dc9..d05dd97 100644 --- a/src-tauri/src/clientinfo.rs +++ b/src-tauri/src/clientinfo.rs @@ -1,8 +1,15 @@ use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo}; use sysinfo::{System, SystemExt, Pid, ProcessExt}; +use serde::{Serialize, Deserialize}; use crate::errors::*; -use crate::ipc::Client; + + +#[derive(Clone, Serialize, Deserialize)] +pub struct Client { + pub pid: u32, + pub exe: String, +} fn get_associated_pids(local_port: u16) -> Result, netstat2::error::Error> { diff --git a/src-tauri/src/ipc.rs b/src-tauri/src/ipc.rs index 5052b4a..14c6aa1 100644 --- a/src-tauri/src/ipc.rs +++ b/src-tauri/src/ipc.rs @@ -1,16 +1,10 @@ use serde::{Serialize, Deserialize}; use tauri::State; +use crate::clientinfo::Client; use crate::state::{AppState, Session, Credentials}; -#[derive(Clone, Serialize, Deserialize)] -pub struct Client { - pub pid: u32, - pub exe: String, -} - - #[derive(Clone, Serialize, Deserialize)] pub struct Request { pub id: u64, diff --git a/src-tauri/src/server.rs b/src-tauri/src/server.rs index fccdd29..3669f37 100644 --- a/src-tauri/src/server.rs +++ b/src-tauri/src/server.rs @@ -54,7 +54,12 @@ async fn handle(mut stream: TcpStream, app_handle: AppHandle) -> Result<(), Requ let clients = clientinfo::get_clients(peer_addr.port())?; let req = Request {id: request_id, clients}; - app_handle.emit_all("credentials-request", req)?; + if req.clients.iter().any(|c| app_state.is_banned(c.pid)) { + stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; + return Ok(()) + } + + app_handle.emit_all("credentials-request", &req)?; let window = app_handle.get_window("main").ok_or(RequestError::NoMainWindow)?; window.show()?; window.set_focus()?; @@ -85,6 +90,9 @@ async fn handle(mut stream: TcpStream, app_handle: AppHandle) -> Result<(), Requ // might need more time than that gives us (especially if entering the passphrase). // Fortunately most AWS libs automatically retry if the request dies uncompleted, allowing // us to respond with a proper error status. + for client in req.clients { + app_state.add_ban(client.pid, app_handle.clone()); + } return Ok(()); } diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index 14dc344..7e1a42a 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -1,8 +1,10 @@ -use std::collections::HashMap; +use core::time::Duration; +use std::collections::{HashMap, HashSet}; use std::sync::RwLock; use serde::{Serialize, Deserialize}; use tokio::sync::oneshot::Sender; +use tokio::time::sleep; use sqlx::{SqlitePool, sqlite::SqlitePoolOptions, sqlite::SqliteConnectOptions}; use sodiumoxide::crypto::{ pwhash, @@ -11,6 +13,7 @@ use sodiumoxide::crypto::{ secretbox::{Nonce, Key} }; use tauri::async_runtime as runtime; +use tauri::Manager; use crate::ipc; use crate::errors::*; @@ -53,6 +56,7 @@ pub struct AppState { pub session: RwLock, pub request_count: RwLock, pub open_requests: RwLock>>, + pub bans: RwLock>, pool: SqlitePool, } @@ -71,6 +75,7 @@ impl AppState { session: RwLock::new(creds), request_count: RwLock::new(0), open_requests: RwLock::new(HashMap::new()), + bans: RwLock::new(HashSet::new()), pool, }; @@ -158,6 +163,22 @@ impl AppState { .map_err(|_e| SendResponseError::Abandoned) } + pub fn add_ban(&self, pid: u32, app: tauri::AppHandle) { + let mut bans = self.bans.write().unwrap(); + bans.insert(pid); + + runtime::spawn(async move { + sleep(Duration::from_secs(5)).await; + let state = app.state::(); + let mut bans = state.bans.write().unwrap(); + bans.remove(&pid); + }); + } + + pub fn is_banned(&self, pid: u32) -> bool { + self.bans.read().unwrap().contains(&pid) + } + pub async fn decrypt(&self, passphrase: &str) -> Result<(), UnlockError> { let (key_id, secret) = { // do this all in a block so rustc doesn't complain about holding a lock across an await