diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index ead2ab2..e6edd0a 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -66,7 +66,7 @@ impl AppConfig { } -pub fn set_auto_launch(enable: bool) -> Result<(), SetupError> { +pub fn set_auto_launch(is_configured: bool) -> Result<(), SetupError> { let path_buf = std::env::current_exe() .map_err(|e| auto_launch::Error::Io(e))?; let path = path_buf @@ -75,13 +75,14 @@ pub fn set_auto_launch(enable: bool) -> Result<(), SetupError> { let auto = AutoLaunchBuilder::new() .set_app_name("Creddy") .set_app_path(&path) - .build()?; + .build().expect("Failed to build"); - if enable { - auto.enable()?; + let is_enabled = auto.is_enabled()?; + if is_configured && !is_enabled { + auto.enable().expect("Failed to enable"); } - else { - auto.disable()?; + else if !is_configured && is_enabled { + auto.disable().expect("Failed to disable"); } Ok(()) diff --git a/src-tauri/src/errors.rs b/src-tauri/src/errors.rs index 2c1df9e..12c8734 100644 --- a/src-tauri/src/errors.rs +++ b/src-tauri/src/errors.rs @@ -89,6 +89,8 @@ pub enum SetupError { ConfigParseError(#[from] serde_json::Error), #[error("Failed to set up start-on-login: {0}")] AutoLaunchError(#[from] auto_launch::Error), + #[error("Failed to start listener: {0}")] + ServerSetupError(#[from] std::io::Error), } diff --git a/src-tauri/src/ipc.rs b/src-tauri/src/ipc.rs index 555f0c2..dbbb613 100644 --- a/src-tauri/src/ipc.rs +++ b/src-tauri/src/ipc.rs @@ -73,5 +73,6 @@ pub fn get_config(app_state: State<'_, AppState>) -> AppConfig { pub async fn save_config(config: AppConfig, app_state: State<'_, AppState>) -> Result<(), String> { app_state.update_config(config) .await - .map_err(|e| format!("Error saving config to database: {e}")) + .map_err(|e| format!("Error saving config: {e}"))?; + Ok(()) } diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 2cd076f..5ccf387 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -2,8 +2,19 @@ all(not(debug_assertions), target_os = "windows"), windows_subsystem = "windows" )] +use std::error::Error; -use tauri::{AppHandle, Manager, async_runtime as rt}; +use sqlx::{ + SqlitePool, + sqlite::SqlitePoolOptions, + sqlite::SqliteConnectOptions, +}; +use tauri::{ + App, + AppHandle, + Manager, + async_runtime as rt, +}; use once_cell::sync::OnceCell; mod config; @@ -14,20 +25,44 @@ mod state; mod server; mod tray; -use crate::errors::*; +use config::AppConfig; +use server::Server; +use errors::*; use state::AppState; pub static APP: OnceCell = OnceCell::new(); -fn main() { - let initial_state = match rt::block_on(AppState::load()) { - Ok(state) => state, - Err(e) => {eprintln!("{}", e); return;} - }; +async fn setup(app: &mut App) -> Result<(), Box> { + APP.set(app.handle()).unwrap(); + + let conn_opts = SqliteConnectOptions::new() + .filename(config::get_or_create_db_path()) + .create_if_missing(true); + let pool_opts = SqlitePoolOptions::new(); + let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?; + sqlx::migrate!().run(&pool).await?; + + let conf = AppConfig::load(&pool).await?; + let session = AppState::load_creds(&pool).await?; + let srv = Server::new(conf.listen_addr, conf.listen_port, app.handle()).await?; + + config::set_auto_launch(conf.start_on_login)?; + if !conf.start_minimized { + app.get_window("main") + .ok_or(RequestError::NoMainWindow)? + .show()?; + } + + let state = AppState::new(conf, session, srv, pool); + app.manage(state); + Ok(()) +} + + +fn main() { tauri::Builder::default() - .manage(initial_state) .system_tray(tray::create()) .on_system_tray_event(tray::handle_event) .invoke_handler(tauri::generate_handler![ @@ -38,22 +73,7 @@ fn main() { ipc::get_config, ipc::save_config, ]) - .setup(|app| { - APP.set(app.handle()).unwrap(); - let state = app.state::(); - let config = state.config.read().unwrap(); - config::set_auto_launch(config.start_on_login)?; - - let addr = std::net::SocketAddrV4::new(config.listen_addr, config.listen_port); - tauri::async_runtime::spawn(server::serve(addr, app.handle())); - - if !config.start_minimized { - app.get_window("main") - .ok_or(RequestError::NoMainWindow)? - .show()?; - } - Ok(()) - }) + .setup(|app| rt::block_on(setup(app))) .build(tauri::generate_context!()) .expect("error while running tauri application") .run(|app, run_event| match run_event { diff --git a/src-tauri/src/server.rs b/src-tauri/src/server.rs index 5bfd6b8..1e10e22 100644 --- a/src-tauri/src/server.rs +++ b/src-tauri/src/server.rs @@ -1,12 +1,22 @@ use core::time::Duration; use std::io; -use std::net::{SocketAddr, SocketAddrV4}; -use tokio::net::{TcpListener, TcpStream}; +use std::net::{ + Ipv4Addr, + SocketAddr, + SocketAddrV4, + TcpListener as StdTcpListener, +}; +use tokio::net::{ + TcpListener, + TcpStream, +}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::oneshot; use tokio::time::sleep; use tauri::{AppHandle, Manager}; +use tauri::async_runtime as rt; +use tauri::async_runtime::JoinHandle; use crate::{clientinfo, clientinfo::Client}; use crate::errors::*; @@ -159,17 +169,55 @@ impl Handler { } -pub async fn serve(addr: SocketAddrV4, app_handle: AppHandle) -> io::Result<()> { - let listener = TcpListener::bind(&addr).await?; - println!("Listening on {addr}"); - loop { - match listener.accept().await { - Ok((stream, _)) => { - let handler = Handler::new(stream, app_handle.app_handle()); - tauri::async_runtime::spawn(handler.handle()); - }, - Err(e) => { - eprintln!("Error accepting connection: {e}"); +#[derive(Debug)] +pub struct Server { + addr: Ipv4Addr, + port: u16, + app_handle: AppHandle, + task: JoinHandle<()>, +} + + +impl Server { + pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result { + // construct the listener before passing it to the task so that we know if it fails + let sock_addr = SocketAddrV4::new(addr, port); + let listener = TcpListener::bind(&sock_addr).await?; + let task = rt::spawn(Self::serve(listener, app_handle.app_handle())); + + Ok(Server { addr, port, app_handle, task}) + } + + // this is blocking because it's too much of a paint to juggle mutexes otherwise + pub fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> { + if addr == self.addr && port == self.port { + return Ok(()) + } + + let sock_addr = SocketAddrV4::new(addr, port); + let std_listener = StdTcpListener::bind(&sock_addr)?; + std_listener.set_nonblocking(true)?; + let async_listener = TcpListener::from_std(std_listener)?; + let new_task = rt::spawn(Self::serve(async_listener, self.app_handle.app_handle())); + self.task.abort(); + + self.addr = addr; + self.port = port; + self.task = new_task; + + Ok(()) + } + + async fn serve(listener: TcpListener, app_handle: AppHandle) { + loop { + match listener.accept().await { + Ok((stream, _)) => { + let handler = Handler::new(stream, app_handle.app_handle()); + rt::spawn(handler.handle()); + }, + Err(e) => { + eprintln!("Error accepting connection: {e}"); + } } } } diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index dcc8ba9..aa063df 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -5,7 +5,7 @@ use std::sync::RwLock; use serde::{Serialize, Deserialize}; use tokio::sync::oneshot::Sender; use tokio::time::sleep; -use sqlx::{SqlitePool, sqlite::SqlitePoolOptions, sqlite::SqliteConnectOptions}; +use sqlx::SqlitePool; use sodiumoxide::crypto::{ pwhash, pwhash::Salt, @@ -19,6 +19,7 @@ use crate::{config, config::AppConfig}; use crate::ipc; use crate::clientinfo::Client; use crate::errors::*; +use crate::server::Server; #[derive(Debug, Serialize, Deserialize)] @@ -63,34 +64,50 @@ pub struct AppState { pub request_count: RwLock, pub open_requests: RwLock>>, pub bans: RwLock>>, - pool: SqlitePool, + server: RwLock, + pool: sqlx::SqlitePool, } impl AppState { - pub async fn load() -> Result { - let conn_opts = SqliteConnectOptions::new() - .filename(config::get_or_create_db_path()) - .create_if_missing(true); - let pool_opts = SqlitePoolOptions::new(); - - let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?; - sqlx::migrate!().run(&pool).await?; - let creds = Self::load_creds(&pool).await?; - let conf = AppConfig::load(&pool).await?; - - let state = AppState { - config: RwLock::new(conf), - session: RwLock::new(creds), - request_count: RwLock::new(0), - open_requests: RwLock::new(HashMap::new()), - bans: RwLock::new(HashSet::new()), - pool, - }; - - Ok(state) +pub fn new(config: AppConfig, session: Session, server: Server, pool: SqlitePool) -> AppState { + AppState { + config: RwLock::new(config), + session: RwLock::new(session), + request_count: RwLock::new(0), + open_requests: RwLock::new(HashMap::new()), + bans: RwLock::new(HashSet::new()), + server: RwLock::new(server), + pool, } +} - async fn load_creds(pool: &SqlitePool) -> Result { + // pub async fn load(app_handle: AppHandle) -> Result { + // let conn_opts = SqliteConnectOptions::new() + // .filename(config::get_or_create_db_path()) + // .create_if_missing(true); + // let pool_opts = SqlitePoolOptions::new(); + + // let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?; + // sqlx::migrate!().run(&pool).await?; + + // let creds = Self::load_creds(&pool).await?; + // let conf = AppConfig::load(&pool).await?; + // let server = Server::new(conf.listen_addr, conf.listen_port, app_handle)?; + + // let state = AppState { + // config: RwLock::new(conf), + // session: RwLock::new(creds), + // request_count: RwLock::new(0), + // open_requests: RwLock::new(HashMap::new()), + // bans: RwLock::new(HashSet::new()), + // server: RwLock::new(server), + // pool, + // }; + + // Ok(state) + // } + + pub async fn load_creds(pool: &SqlitePool) -> Result { let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc") .fetch_optional(pool) .await?; @@ -151,14 +168,22 @@ impl AppState { } pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> { - new_config.save(&self.pool).await?; - - let mut live_config = self.config.write().unwrap(); - if new_config.start_on_login != live_config.start_on_login { - config::set_auto_launch(new_config.start_on_login)?; + { + let orig_config = self.config.read().unwrap(); + if new_config.start_on_login != orig_config.start_on_login { + config::set_auto_launch(new_config.start_on_login)?; + } + if new_config.listen_addr != orig_config.listen_addr + || new_config.listen_port != orig_config.listen_port + { + let mut sv = self.server.write().unwrap(); + sv.rebind(new_config.listen_addr, new_config.listen_port)?; + } } - *live_config = new_config; + new_config.save(&self.pool).await?; + let mut live_config = self.config.write().unwrap(); + *live_config = new_config; Ok(()) }