completely reorganize http server

This commit is contained in:
Joseph Montanaro
2022-12-20 16:11:49 -08:00
parent 80b92ebe69
commit 414379b74e
4 changed files with 153 additions and 99 deletions

View File

@ -1,28 +1,148 @@
use std::io;
use std::net::SocketAddrV4;
use std::net::{SocketAddr, SocketAddrV4};
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::oneshot;
use tauri::{AppHandle, Manager};
use crate::clientinfo;
use crate::errors::RequestError;
use crate::{clientinfo, clientinfo::Client};
use crate::errors::*;
use crate::ipc::{Request, Approval};
use crate::state::AppState;
struct Handler {
request_id: u64,
stream: TcpStream,
receiver: Option<oneshot::Receiver<Approval>>,
app: AppHandle,
}
impl Handler {
fn new(stream: TcpStream, app: AppHandle) -> Self {
let state = app.state::<AppState>();
let (chan_send, chan_recv) = oneshot::channel();
let request_id = state.register_request(chan_send);
Handler {
request_id,
stream,
receiver: Some(chan_recv),
app
}
}
async fn handle(mut self) {
if let Err(e) = self.try_handle().await {
eprintln!("{e}");
}
let state = self.app.state::<AppState>();
state.unregister_request(self.request_id);
}
async fn try_handle(&mut self) -> Result<(), RequestError> {
let _ = self.recv_request().await?;
let clients = self.get_clients()?;
if self.includes_banned(&clients) {
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(())
}
let req = Request {id: self.request_id, clients};
self.notify_frontend(&req).await?;
match self.wait_for_response().await? {
Approval::Approved => self.send_credentials().await?,
Approval::Denied => {
let state = self.app.state::<AppState>();
for client in req.clients {
state.add_ban(client, self.app.clone());
}
}
}
Ok(())
}
async fn recv_request(&mut self) -> Result<Vec<u8>, RequestError> {
let mut buf = vec![0; 8192]; // it's what tokio's BufReader uses
let mut n = 0;
loop {
n += self.stream.read(&mut buf[n..]).await?;
if n >= 4 && &buf[(n - 4)..n] == b"\r\n\r\n" {break;}
if n == buf.len() {return Err(RequestError::RequestTooLarge);}
}
println!("{}", std::str::from_utf8(&buf).unwrap());
Ok(buf)
}
fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
let peer_addr = match self.stream.peer_addr()? {
SocketAddr::V4(addr) => addr,
_ => unreachable!(), // we only listen on IPv4
};
let clients = clientinfo::get_clients(peer_addr.port())?;
Ok(clients)
}
fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
let state = self.app.state::<AppState>();
clients.iter().any(|c| state.is_banned(c))
}
async fn notify_frontend(&self, req: &Request) -> Result<(), RequestError> {
self.app.emit_all("credentials-request", req)?;
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
window.unminimize()?;
window.set_focus()?;
Ok(())
}
async fn wait_for_response(&mut self) -> Result<Approval, RequestError> {
self.stream.write(b"HTTP/1.0 200 OK\r\n").await?;
self.stream.write(b"Content-Type: application/json\r\n").await?;
self.stream.write(b"X-Creddy-delaying-tactic: ").await?;
#[allow(unreachable_code)] // seems necessary for type inference
let stall = async {
let delay = std::time::Duration::from_secs(1);
loop {
tokio::time::sleep(delay).await;
self.stream.write(b"x").await?;
}
Ok(Approval::Denied)
};
// this is the only place we even read this field, so it's safe to unwrap
let receiver = self.receiver.take().unwrap();
tokio::select!{
r = receiver => Ok(r.unwrap()), // only panics if the sender is dropped without sending, which shouldn't be possible
e = stall => e,
}
}
async fn send_credentials(&mut self) -> Result<(), RequestError> {
let state = self.app.state::<AppState>();
let creds = state.get_creds_serialized()?;
self.stream.write(b"\r\nContent-Length: ").await?;
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
self.stream.write(b"\r\n\r\n").await?;
self.stream.write(creds.as_bytes()).await?;
self.stream.write(b"\r\n\r\n").await?;
Ok(())
}
}
pub async fn serve(addr: SocketAddrV4, app_handle: AppHandle) -> io::Result<()> {
let listener = TcpListener::bind(&addr).await?;
println!("Listening on {addr}");
loop {
let new_handle = app_handle.app_handle();
match listener.accept().await {
Ok((stream, _)) => {
tokio::spawn(async {
if let Err(e) = handle(stream, new_handle).await {
eprintln!("{e}");
}
});
let handler = Handler::new(stream, app_handle.app_handle());
tauri::async_runtime::spawn(handler.handle());
},
Err(e) => {
eprintln!("Error accepting connection: {e}");
@ -30,79 +150,3 @@ pub async fn serve(addr: SocketAddrV4, app_handle: AppHandle) -> io::Result<()>
}
}
}
// it doesn't really return Approval, we just need to placate the compiler
async fn stall(stream: &mut TcpStream) -> Result<Approval, tokio::io::Error> {
let delay = std::time::Duration::from_secs(1);
loop {
tokio::time::sleep(delay).await;
stream.write(b"x").await?;
}
}
async fn handle(mut stream: TcpStream, app_handle: AppHandle) -> Result<(), RequestError> {
let (chan_send, chan_recv) = oneshot::channel();
let app_state = app_handle.state::<crate::state::AppState>();
let request_id = app_state.register_request(chan_send);
let peer_addr = match stream.peer_addr()? {
std::net::SocketAddr::V4(addr) => addr,
_ => unreachable!(), // we only listen on IPv4
};
let clients = clientinfo::get_clients(peer_addr.port())?;
let req = Request {id: request_id, clients};
if req.clients.iter().any(|c| app_state.is_banned(c.pid)) {
stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(())
}
app_handle.emit_all("credentials-request", &req)?;
let window = app_handle.get_window("main").ok_or(RequestError::NoMainWindow)?;
window.unminimize()?;
// window.show()?;
window.set_focus()?;
let mut buf = [0; 8192]; // it's what tokio's BufReader uses
let mut n = 0;
loop {
n += stream.read(&mut buf[n..]).await?;
if &buf[(n - 4)..n] == b"\r\n\r\n" {break;}
if n == buf.len() {return Err(RequestError::RequestTooLarge);}
}
println!("{}", std::str::from_utf8(&buf).unwrap());
stream.write(b"HTTP/1.0 200 OK\r\n").await?;
stream.write(b"Content-Type: application/json\r\n").await?;
stream.write(b"X-Creddy-delaying-tactic: ").await?;
let approval = tokio::select!{
e = stall(&mut stream) => e?, // this will never return Ok, just Err if it can't write to the stream
r = chan_recv => r.unwrap(), // only panics if the sender is dropped without sending, which shouldn't happen
};
if matches!(approval, Approval::Denied) {
// because we own the stream, it gets closed when we return.
// Unfortunately we've already signaled 200 OK, there's no way around this -
// we have to write the status code first thing, and we have to assume that the user
// might need more time than that gives us (especially if entering the passphrase).
// Fortunately most AWS libs automatically retry if the request dies uncompleted, allowing
// us to respond with a proper error status.
for client in req.clients {
app_state.add_ban(client.pid, app_handle.clone());
}
return Ok(());
}
let creds = app_state.get_creds_serialized()?;
stream.write(b"\r\nContent-Length: ").await?;
stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
stream.write(b"\r\n\r\n").await?;
stream.write(creds.as_bytes()).await?;
stream.write(b"\r\n\r\n").await?;
Ok(())
}