From 414379b74eb2969f2b31666684ba056138fe66fc Mon Sep 17 00:00:00 2001 From: Joseph Montanaro Date: Tue, 20 Dec 2022 16:11:49 -0800 Subject: [PATCH] completely reorganize http server --- src-tauri/src/clientinfo.rs | 18 +-- src-tauri/src/ipc.rs | 2 +- src-tauri/src/server.rs | 214 ++++++++++++++++++++++-------------- src-tauri/src/state.rs | 18 ++- 4 files changed, 153 insertions(+), 99 deletions(-) diff --git a/src-tauri/src/clientinfo.rs b/src-tauri/src/clientinfo.rs index 7e34eb9..c0ae19f 100644 --- a/src-tauri/src/clientinfo.rs +++ b/src-tauri/src/clientinfo.rs @@ -5,7 +5,7 @@ use serde::{Serialize, Deserialize}; use crate::errors::*; -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub struct Client { pub pid: u32, pub exe: String, @@ -13,12 +13,12 @@ pub struct Client { fn get_associated_pids(local_port: u16) -> Result, netstat2::error::Error> { - let mut it = netstat2::iterate_sockets_info( + let sockets_iter = netstat2::iterate_sockets_info( AddressFamilyFlags::IPV4, ProtocolFlags::TCP )?; - for (i, item) in it.enumerate() { + for item in sockets_iter { let sock_info = item?; let proto_info = match sock_info.protocol_socket_info { ProtocolSocketInfo::Tcp(tcp_info) => tcp_info, @@ -37,9 +37,8 @@ fn get_associated_pids(local_port: u16) -> Result, netstat2::error::Err } -// Theoretically, on some systems, multiple processes can share a socket. We have to -// account for this even though 99% of the time there will be only one. -pub fn get_clients(local_port: u16) -> Result, ClientInfoError> { +// Theoretically, on some systems, multiple processes can share a socket +pub fn get_clients(local_port: u16) -> Result>, ClientInfoError> { let mut clients = Vec::new(); let mut sys = System::new(); for p in get_associated_pids(local_port)? { @@ -52,7 +51,12 @@ pub fn get_clients(local_port: u16) -> Result, ClientInfoError> { pid: p, exe: proc.exe().to_string_lossy().into_owned(), }; - clients.push(client); + clients.push(Some(client)); } + + if clients.is_empty() { + clients.push(None); + } + Ok(clients) } diff --git a/src-tauri/src/ipc.rs b/src-tauri/src/ipc.rs index 14c6aa1..cfa284f 100644 --- a/src-tauri/src/ipc.rs +++ b/src-tauri/src/ipc.rs @@ -8,7 +8,7 @@ use crate::state::{AppState, Session, Credentials}; #[derive(Clone, Serialize, Deserialize)] pub struct Request { pub id: u64, - pub clients: Vec, + pub clients: Vec>, } diff --git a/src-tauri/src/server.rs b/src-tauri/src/server.rs index 57e7bf4..1cb45d7 100644 --- a/src-tauri/src/server.rs +++ b/src-tauri/src/server.rs @@ -1,28 +1,148 @@ use std::io; -use std::net::SocketAddrV4; +use std::net::{SocketAddr, SocketAddrV4}; use tokio::net::{TcpListener, TcpStream}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::oneshot; use tauri::{AppHandle, Manager}; -use crate::clientinfo; -use crate::errors::RequestError; +use crate::{clientinfo, clientinfo::Client}; +use crate::errors::*; use crate::ipc::{Request, Approval}; +use crate::state::AppState; + + +struct Handler { + request_id: u64, + stream: TcpStream, + receiver: Option>, + app: AppHandle, +} + +impl Handler { + fn new(stream: TcpStream, app: AppHandle) -> Self { + let state = app.state::(); + let (chan_send, chan_recv) = oneshot::channel(); + let request_id = state.register_request(chan_send); + Handler { + request_id, + stream, + receiver: Some(chan_recv), + app + } + } + + async fn handle(mut self) { + if let Err(e) = self.try_handle().await { + eprintln!("{e}"); + } + let state = self.app.state::(); + state.unregister_request(self.request_id); + } + + async fn try_handle(&mut self) -> Result<(), RequestError> { + let _ = self.recv_request().await?; + let clients = self.get_clients()?; + if self.includes_banned(&clients) { + self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; + return Ok(()) + } + + let req = Request {id: self.request_id, clients}; + self.notify_frontend(&req).await?; + + match self.wait_for_response().await? { + Approval::Approved => self.send_credentials().await?, + Approval::Denied => { + let state = self.app.state::(); + for client in req.clients { + state.add_ban(client, self.app.clone()); + } + } + } + + Ok(()) + } + + async fn recv_request(&mut self) -> Result, RequestError> { + let mut buf = vec![0; 8192]; // it's what tokio's BufReader uses + let mut n = 0; + loop { + n += self.stream.read(&mut buf[n..]).await?; + if n >= 4 && &buf[(n - 4)..n] == b"\r\n\r\n" {break;} + if n == buf.len() {return Err(RequestError::RequestTooLarge);} + } + println!("{}", std::str::from_utf8(&buf).unwrap()); + Ok(buf) + } + + fn get_clients(&self) -> Result>, RequestError> { + let peer_addr = match self.stream.peer_addr()? { + SocketAddr::V4(addr) => addr, + _ => unreachable!(), // we only listen on IPv4 + }; + let clients = clientinfo::get_clients(peer_addr.port())?; + Ok(clients) + } + + fn includes_banned(&self, clients: &Vec>) -> bool { + let state = self.app.state::(); + clients.iter().any(|c| state.is_banned(c)) + } + + async fn notify_frontend(&self, req: &Request) -> Result<(), RequestError> { + self.app.emit_all("credentials-request", req)?; + let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; + window.unminimize()?; + window.set_focus()?; + Ok(()) + } + + async fn wait_for_response(&mut self) -> Result { + self.stream.write(b"HTTP/1.0 200 OK\r\n").await?; + self.stream.write(b"Content-Type: application/json\r\n").await?; + self.stream.write(b"X-Creddy-delaying-tactic: ").await?; + + #[allow(unreachable_code)] // seems necessary for type inference + let stall = async { + let delay = std::time::Duration::from_secs(1); + loop { + tokio::time::sleep(delay).await; + self.stream.write(b"x").await?; + } + Ok(Approval::Denied) + }; + + // this is the only place we even read this field, so it's safe to unwrap + let receiver = self.receiver.take().unwrap(); + tokio::select!{ + r = receiver => Ok(r.unwrap()), // only panics if the sender is dropped without sending, which shouldn't be possible + e = stall => e, + } + } + + async fn send_credentials(&mut self) -> Result<(), RequestError> { + let state = self.app.state::(); + let creds = state.get_creds_serialized()?; + + self.stream.write(b"\r\nContent-Length: ").await?; + self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?; + self.stream.write(b"\r\n\r\n").await?; + self.stream.write(creds.as_bytes()).await?; + self.stream.write(b"\r\n\r\n").await?; + Ok(()) + } +} pub async fn serve(addr: SocketAddrV4, app_handle: AppHandle) -> io::Result<()> { let listener = TcpListener::bind(&addr).await?; println!("Listening on {addr}"); loop { - let new_handle = app_handle.app_handle(); match listener.accept().await { Ok((stream, _)) => { - tokio::spawn(async { - if let Err(e) = handle(stream, new_handle).await { - eprintln!("{e}"); - } - }); + let handler = Handler::new(stream, app_handle.app_handle()); + tauri::async_runtime::spawn(handler.handle()); }, Err(e) => { eprintln!("Error accepting connection: {e}"); @@ -30,79 +150,3 @@ pub async fn serve(addr: SocketAddrV4, app_handle: AppHandle) -> io::Result<()> } } } - - -// it doesn't really return Approval, we just need to placate the compiler -async fn stall(stream: &mut TcpStream) -> Result { - let delay = std::time::Duration::from_secs(1); - loop { - tokio::time::sleep(delay).await; - stream.write(b"x").await?; - } -} - - -async fn handle(mut stream: TcpStream, app_handle: AppHandle) -> Result<(), RequestError> { - let (chan_send, chan_recv) = oneshot::channel(); - let app_state = app_handle.state::(); - let request_id = app_state.register_request(chan_send); - - let peer_addr = match stream.peer_addr()? { - std::net::SocketAddr::V4(addr) => addr, - _ => unreachable!(), // we only listen on IPv4 - }; - let clients = clientinfo::get_clients(peer_addr.port())?; - - let req = Request {id: request_id, clients}; - if req.clients.iter().any(|c| app_state.is_banned(c.pid)) { - stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; - return Ok(()) - } - - app_handle.emit_all("credentials-request", &req)?; - let window = app_handle.get_window("main").ok_or(RequestError::NoMainWindow)?; - window.unminimize()?; - // window.show()?; - window.set_focus()?; - - let mut buf = [0; 8192]; // it's what tokio's BufReader uses - let mut n = 0; - loop { - n += stream.read(&mut buf[n..]).await?; - if &buf[(n - 4)..n] == b"\r\n\r\n" {break;} - if n == buf.len() {return Err(RequestError::RequestTooLarge);} - } - - println!("{}", std::str::from_utf8(&buf).unwrap()); - - stream.write(b"HTTP/1.0 200 OK\r\n").await?; - stream.write(b"Content-Type: application/json\r\n").await?; - stream.write(b"X-Creddy-delaying-tactic: ").await?; - - let approval = tokio::select!{ - e = stall(&mut stream) => e?, // this will never return Ok, just Err if it can't write to the stream - r = chan_recv => r.unwrap(), // only panics if the sender is dropped without sending, which shouldn't happen - }; - - if matches!(approval, Approval::Denied) { - // because we own the stream, it gets closed when we return. - // Unfortunately we've already signaled 200 OK, there's no way around this - - // we have to write the status code first thing, and we have to assume that the user - // might need more time than that gives us (especially if entering the passphrase). - // Fortunately most AWS libs automatically retry if the request dies uncompleted, allowing - // us to respond with a proper error status. - for client in req.clients { - app_state.add_ban(client.pid, app_handle.clone()); - } - return Ok(()); - } - - let creds = app_state.get_creds_serialized()?; - - stream.write(b"\r\nContent-Length: ").await?; - stream.write(creds.as_bytes().len().to_string().as_bytes()).await?; - stream.write(b"\r\n\r\n").await?; - stream.write(creds.as_bytes()).await?; - stream.write(b"\r\n\r\n").await?; - Ok(()) -} diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index 7e1a42a..a9f6ac7 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -16,6 +16,7 @@ use tauri::async_runtime as runtime; use tauri::Manager; use crate::ipc; +use crate::clientinfo::Client; use crate::errors::*; @@ -56,7 +57,7 @@ pub struct AppState { pub session: RwLock, pub request_count: RwLock, pub open_requests: RwLock>>, - pub bans: RwLock>, + pub bans: RwLock>>, pool: SqlitePool, } @@ -152,6 +153,11 @@ impl AppState { *count } + pub fn unregister_request(&self, id: u64) { + let mut open_requests = self.open_requests.write().unwrap(); + open_requests.remove(&id); + } + pub fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { let mut open_requests = self.open_requests.write().unwrap(); let chan = open_requests @@ -163,20 +169,20 @@ impl AppState { .map_err(|_e| SendResponseError::Abandoned) } - pub fn add_ban(&self, pid: u32, app: tauri::AppHandle) { + pub fn add_ban(&self, client: Option, app: tauri::AppHandle) { let mut bans = self.bans.write().unwrap(); - bans.insert(pid); + bans.insert(client.clone()); runtime::spawn(async move { sleep(Duration::from_secs(5)).await; let state = app.state::(); let mut bans = state.bans.write().unwrap(); - bans.remove(&pid); + bans.remove(&client); }); } - pub fn is_banned(&self, pid: u32) -> bool { - self.bans.read().unwrap().contains(&pid) + pub fn is_banned(&self, client: &Option) -> bool { + self.bans.read().unwrap().contains(&client) } pub async fn decrypt(&self, passphrase: &str) -> Result<(), UnlockError> {