276 lines
8.6 KiB
Rust
276 lines
8.6 KiB
Rust
use core::time::Duration;
|
|
use std::io;
|
|
use std::net::{
|
|
Ipv4Addr,
|
|
SocketAddr,
|
|
SocketAddrV4,
|
|
};
|
|
use tokio::net::{
|
|
TcpListener,
|
|
TcpStream,
|
|
};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::sync::oneshot::{self, Sender, Receiver};
|
|
use tokio::time::sleep;
|
|
|
|
use tauri::{AppHandle, Manager};
|
|
use tauri::async_runtime as rt;
|
|
use tauri::async_runtime::JoinHandle;
|
|
|
|
use crate::{clientinfo, clientinfo::Client};
|
|
use crate::errors::*;
|
|
use crate::ipc::{Request, Approval};
|
|
use crate::state::AppState;
|
|
|
|
|
|
#[derive(Debug)]
|
|
pub struct RequestWaiter {
|
|
pub rehide_after: bool,
|
|
pub sender: Option<Sender<Approval>>,
|
|
}
|
|
|
|
impl RequestWaiter {
|
|
pub fn notify(&mut self, approval: Approval) -> Result<(), SendResponseError> {
|
|
let chan = self.sender
|
|
.take()
|
|
.ok_or(SendResponseError::Fulfilled)?;
|
|
|
|
chan.send(approval)
|
|
.map_err(|_| SendResponseError::Abandoned)
|
|
}
|
|
}
|
|
|
|
|
|
struct Handler {
|
|
request_id: u64,
|
|
stream: TcpStream,
|
|
rehide_after: bool,
|
|
receiver: Option<Receiver<Approval>>,
|
|
app: AppHandle,
|
|
}
|
|
|
|
impl Handler {
|
|
async fn new(stream: TcpStream, app: AppHandle) -> Result<Self, HandlerError> {
|
|
let state = app.state::<AppState>();
|
|
|
|
// determine whether we should re-hide the window after handling this request
|
|
let is_currently_visible = app.get_window("main")
|
|
.ok_or(HandlerError::NoMainWindow)?
|
|
.is_visible()?;
|
|
let rehide_after = state.current_rehide_status()
|
|
.await
|
|
.unwrap_or(!is_currently_visible);
|
|
|
|
let (chan_send, chan_recv) = oneshot::channel();
|
|
let waiter = RequestWaiter {rehide_after, sender: Some(chan_send)};
|
|
let request_id = state.register_request(waiter).await;
|
|
let handler = Handler {
|
|
request_id,
|
|
stream,
|
|
rehide_after,
|
|
receiver: Some(chan_recv),
|
|
app
|
|
};
|
|
Ok(handler)
|
|
}
|
|
|
|
async fn handle(mut self) {
|
|
if let Err(e) = self.try_handle().await {
|
|
eprintln!("{e}");
|
|
}
|
|
let state = self.app.state::<AppState>();
|
|
state.unregister_request(self.request_id).await;
|
|
}
|
|
|
|
async fn try_handle(&mut self) -> Result<(), HandlerError> {
|
|
let req_path = self.recv_request().await?;
|
|
let clients = self.get_clients().await?;
|
|
if self.includes_banned(&clients).await {
|
|
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
|
|
return Ok(())
|
|
}
|
|
let base = req_path == b"/creddy/base-credentials";
|
|
|
|
let req = Request {id: self.request_id, clients, base};
|
|
self.app.emit_all("credentials-request", &req)?;
|
|
self.show_window()?;
|
|
|
|
match self.wait_for_response().await? {
|
|
Approval::Approved => {
|
|
let state = self.app.state::<AppState>();
|
|
let creds = if base {
|
|
state.serialize_base_creds().await?
|
|
}
|
|
else {
|
|
state.serialize_session_creds().await?
|
|
};
|
|
self.send_body(creds.as_bytes()).await?;
|
|
},
|
|
Approval::Denied => {
|
|
let state = self.app.state::<AppState>();
|
|
for client in req.clients {
|
|
state.add_ban(client).await;
|
|
}
|
|
self.send_body(b"Denied!").await?;
|
|
self.stream.shutdown().await?;
|
|
}
|
|
}
|
|
|
|
// only hide the window if a) it was hidden to start with
|
|
// and b) there are no other pending requests
|
|
let state = self.app.state::<AppState>();
|
|
let delay = {
|
|
let config = state.config.read().await;
|
|
Duration::from_millis(config.rehide_ms)
|
|
};
|
|
sleep(delay).await;
|
|
|
|
if self.rehide_after && state.req_count().await == 1 {
|
|
self.app
|
|
.get_window("main")
|
|
.ok_or(HandlerError::NoMainWindow)?
|
|
.hide()?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn recv_request(&mut self) -> Result<Vec<u8>, HandlerError> {
|
|
let mut buf = vec![0; 8192]; // it's what tokio's BufReader uses
|
|
let mut n = 0;
|
|
loop {
|
|
n += self.stream.read(&mut buf[n..]).await?;
|
|
if n >= 4 && &buf[(n - 4)..n] == b"\r\n\r\n" {break;}
|
|
if n == buf.len() {return Err(HandlerError::RequestTooLarge);}
|
|
}
|
|
|
|
let path = buf.split(|&c| &[c] == b" ")
|
|
.skip(1)
|
|
.next()
|
|
.ok_or(HandlerError::BadRequest(buf.clone()))?;
|
|
|
|
#[cfg(debug_assertions)] {
|
|
println!("Path: {}", std::str::from_utf8(&path).unwrap());
|
|
println!("{}", std::str::from_utf8(&buf).unwrap());
|
|
}
|
|
|
|
Ok(path.into())
|
|
}
|
|
|
|
async fn get_clients(&self) -> Result<Vec<Option<Client>>, HandlerError> {
|
|
let peer_addr = match self.stream.peer_addr()? {
|
|
SocketAddr::V4(addr) => addr,
|
|
_ => unreachable!(), // we only listen on IPv4
|
|
};
|
|
let clients = clientinfo::get_clients(peer_addr.port()).await?;
|
|
Ok(clients)
|
|
}
|
|
|
|
async fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
|
|
let state = self.app.state::<AppState>();
|
|
for client in clients {
|
|
if state.is_banned(client).await {
|
|
return true;
|
|
}
|
|
}
|
|
false
|
|
}
|
|
|
|
fn show_window(&self) -> Result<(), HandlerError> {
|
|
let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?;
|
|
if !window.is_visible()? {
|
|
window.unminimize()?;
|
|
window.show()?;
|
|
}
|
|
window.set_focus()?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn wait_for_response(&mut self) -> Result<Approval, HandlerError> {
|
|
self.stream.write(b"HTTP/1.0 200 OK\r\n").await?;
|
|
self.stream.write(b"Content-Type: application/json\r\n").await?;
|
|
self.stream.write(b"X-Creddy-delaying-tactic: ").await?;
|
|
|
|
#[allow(unreachable_code)] // seems necessary for type inference
|
|
let stall = async {
|
|
let delay = std::time::Duration::from_secs(1);
|
|
loop {
|
|
tokio::time::sleep(delay).await;
|
|
self.stream.write(b"x").await?;
|
|
}
|
|
Ok(Approval::Denied)
|
|
};
|
|
|
|
// this is the only place we even read this field, so it's safe to unwrap
|
|
let receiver = self.receiver.take().unwrap();
|
|
tokio::select!{
|
|
r = receiver => Ok(r.unwrap()), // only panics if the sender is dropped without sending, which shouldn't be possible
|
|
e = stall => e,
|
|
}
|
|
}
|
|
|
|
async fn send_body(&mut self, body: &[u8]) -> Result<(), HandlerError> {
|
|
self.stream.write(b"\r\nContent-Length: ").await?;
|
|
self.stream.write(body.len().to_string().as_bytes()).await?;
|
|
self.stream.write(b"\r\n\r\n").await?;
|
|
self.stream.write(body).await?;
|
|
self.stream.shutdown().await?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Debug)]
|
|
pub struct Server {
|
|
addr: Ipv4Addr,
|
|
port: u16,
|
|
app_handle: AppHandle,
|
|
task: JoinHandle<()>,
|
|
}
|
|
|
|
|
|
impl Server {
|
|
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
|
|
let task = Self::start_server(addr, port, app_handle.app_handle()).await?;
|
|
Ok(Server { addr, port, app_handle, task})
|
|
}
|
|
|
|
pub async fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
|
|
if addr == self.addr && port == self.port {
|
|
return Ok(())
|
|
}
|
|
|
|
let new_task = Self::start_server(addr, port, self.app_handle.app_handle()).await?;
|
|
self.task.abort();
|
|
|
|
self.addr = addr;
|
|
self.port = port;
|
|
self.task = new_task;
|
|
Ok(())
|
|
}
|
|
|
|
// construct the listener before spawning the task so that we can return early if it fails
|
|
async fn start_server(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<JoinHandle<()>> {
|
|
let sock_addr = SocketAddrV4::new(addr, port);
|
|
let listener = TcpListener::bind(&sock_addr).await?;
|
|
let task = rt::spawn(
|
|
Self::serve(listener, app_handle.app_handle())
|
|
);
|
|
Ok(task)
|
|
}
|
|
|
|
async fn serve(listener: TcpListener, app_handle: AppHandle) {
|
|
loop {
|
|
match listener.accept().await {
|
|
Ok((stream, _)) => {
|
|
match Handler::new(stream, app_handle.app_handle()).await {
|
|
Ok(handler) => { rt::spawn(handler.handle()); }
|
|
Err(e) => { eprintln!("Error handling request: {e}"); }
|
|
}
|
|
},
|
|
Err(e) => { eprintln!("Error accepting connection: {e}"); }
|
|
}
|
|
}
|
|
}
|
|
}
|