From 61d674199f276f5cb8ce0893d649135f80796da9 Mon Sep 17 00:00:00 2001 From: Joseph Montanaro Date: Thu, 22 Dec 2022 16:36:32 -0800 Subject: [PATCH] store config in database, macro for state access --- src-tauri/Cargo.lock | 5 +- src-tauri/Cargo.toml | 1 + .../migrations/20221201002355_initial.sql | 4 +- src-tauri/src/clientinfo.rs | 10 +-- src-tauri/src/config.rs | 67 ++++++++++++++++--- src-tauri/src/errors.rs | 4 +- src-tauri/src/ipc.rs | 14 +++- src-tauri/src/main.rs | 43 +++++++++++- src-tauri/src/server.rs | 16 +++-- src-tauri/src/state.rs | 18 +++-- 10 files changed, 144 insertions(+), 38 deletions(-) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index c9e8dcb..040500e 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -73,6 +73,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "netstat2", + "once_cell", "serde", "serde_json", "sodiumoxide", @@ -2337,9 +2338,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.13.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a6dbe30758c9f83eb00cbea4ac95966305f5a7772f3f42ebfc7fc7eddbd8e1" +checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" [[package]] name = "open" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 63fb592..cd41769 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -28,6 +28,7 @@ aws-sdk-sts = "0.22.0" aws-smithy-types = "0.52.0" aws-config = "0.52.0" thiserror = "1.0.38" +once_cell = "1.16.0" [features] # by default Tauri runs in production mode diff --git a/src-tauri/migrations/20221201002355_initial.sql b/src-tauri/migrations/20221201002355_initial.sql index c0a870f..86de386 100644 --- a/src-tauri/migrations/20221201002355_initial.sql +++ b/src-tauri/migrations/20221201002355_initial.sql @@ -7,8 +7,8 @@ CREATE TABLE credentials ( ); CREATE TABLE config ( - name TEXT, - data TEXT + name TEXT NOT NULL, + data TEXT NOT NULL ); CREATE TABLE clients ( diff --git a/src-tauri/src/clientinfo.rs b/src-tauri/src/clientinfo.rs index 7fdf44f..10c5028 100644 --- a/src-tauri/src/clientinfo.rs +++ b/src-tauri/src/clientinfo.rs @@ -3,9 +3,10 @@ use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt}; use serde::{Serialize, Deserialize}; use crate::errors::*; +use crate::get_state; -#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] pub struct Client { pub pid: u32, pub exe: String, @@ -18,6 +19,7 @@ fn get_associated_pids(local_port: u16) -> Result, netstat2::error::Err ProtocolFlags::TCP )?; + get_state!(config as app_config); for item in sockets_iter { let sock_info = item?; let proto_info = match sock_info.protocol_socket_info { @@ -26,9 +28,9 @@ fn get_associated_pids(local_port: u16) -> Result, netstat2::error::Err }; if proto_info.local_port == local_port - && proto_info.remote_port == 19_923 - && proto_info.local_addr == std::net::Ipv4Addr::LOCALHOST - && proto_info.remote_addr == std::net::Ipv4Addr::LOCALHOST + && proto_info.remote_port == app_config.listen_port + && proto_info.local_addr == app_config.listen_addr + && proto_info.remote_addr == app_config.listen_addr { return Ok(sock_info.associated_pids) } diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index 5920f1d..54882d1 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -1,32 +1,79 @@ use std::net::Ipv4Addr; use std::path::PathBuf; +use serde::{Serialize, Deserialize}; +use sqlx::SqlitePool; +use crate::errors::*; + + +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct AppConfig { pub db_path: PathBuf, pub listen_addr: Ipv4Addr, pub listen_port: u16, + pub rehide_ms: u64, } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DbAppConfig { + listen_addr: Option, + listen_port: Option, + rehide_ms: Option, +} + + impl Default for AppConfig { fn default() -> Self { - let listen_port = if cfg!(debug_assertions) { - 12_345 - } - else { - 19_923 - }; - AppConfig { db_path: get_or_create_db_path(), listen_addr: Ipv4Addr::LOCALHOST, - listen_port, + listen_port: listen_port(), + rehide_ms: 1000, } } } -fn get_or_create_db_path() -> PathBuf { +impl From for AppConfig { + fn from(db_config: DbAppConfig) -> Self { + AppConfig { + db_path: get_or_create_db_path(), + listen_addr: db_config.listen_addr.unwrap_or(Ipv4Addr::LOCALHOST), + listen_port: db_config.listen_port.unwrap_or_else(|| listen_port()), + rehide_ms: db_config.rehide_ms.unwrap_or(1000), + } + } +} + + +pub async fn load(pool: &SqlitePool) -> Result { + let res = sqlx::query!("SELECT * from config where name = 'main'") + .fetch_optional(pool) + .await?; + + let row = match res { + Some(row) => row, + None => return Ok(AppConfig::default()), + }; + + let db_config: DbAppConfig = serde_json::from_str(&row.data)?; + Ok(AppConfig::from(db_config)) +} + + +fn listen_port() -> u16 { + if cfg!(debug_assertions) { + 12_345 + } + else { + 19_923 + } +} + + +pub fn get_or_create_db_path() -> PathBuf { if cfg!(debug_assertions) { return PathBuf::from("./creddy.db"); } @@ -41,4 +88,4 @@ fn get_or_create_db_path() -> PathBuf { parent.push("creddy.db"); parent -} \ No newline at end of file +} diff --git a/src-tauri/src/errors.rs b/src-tauri/src/errors.rs index 05ae03f..e883a96 100644 --- a/src-tauri/src/errors.rs +++ b/src-tauri/src/errors.rs @@ -18,7 +18,9 @@ pub enum SetupError { #[error("Error from database: {0}")] DbError(#[from] SqlxError), #[error("Error running migrations: {0}")] - MigrationError(#[from] MigrateError) + MigrationError(#[from] MigrateError), + #[error("Error parsing configuration from database")] + ConfigParseError(#[from] serde_json::Error), } diff --git a/src-tauri/src/ipc.rs b/src-tauri/src/ipc.rs index bfbfacc..fa3d165 100644 --- a/src-tauri/src/ipc.rs +++ b/src-tauri/src/ipc.rs @@ -1,25 +1,26 @@ use serde::{Serialize, Deserialize}; use tauri::State; +use crate::config::AppConfig; use crate::clientinfo::Client; use crate::state::{AppState, Session, Credentials}; -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Request { pub id: u64, pub clients: Vec>, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct RequestResponse { pub id: u64, pub approval: Approval, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub enum Approval { Approved, Denied, @@ -62,3 +63,10 @@ pub async fn save_credentials( .await .map_err(|e| {eprintln!("{e:?}"); e.to_string()}) } + + +#[tauri::command] +pub fn get_config(app_state: State<'_, AppState>) -> AppConfig { + let config = app_state.config.read().unwrap(); + config.clone() +} diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 21e5bde..3f783c5 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -3,7 +3,8 @@ windows_subsystem = "windows" )] -use tauri::Manager; +use tauri::{AppHandle, Manager}; +use once_cell::sync::OnceCell; mod config; mod errors; @@ -16,6 +17,8 @@ mod tray; use state::AppState; +pub static APP: OnceCell = OnceCell::new(); + fn main() { let initial_state = match state::AppState::new() { Ok(state) => state, @@ -31,8 +34,10 @@ fn main() { ipc::respond, ipc::get_session_status, ipc::save_credentials, + ipc::get_config, ]) .setup(|app| { + APP.set(app.handle()).unwrap(); let state = app.state::(); let config = state.config.read().unwrap(); let addr = std::net::SocketAddrV4::new(config.listen_addr, config.listen_port); @@ -51,4 +56,38 @@ fn main() { } _ => () }) -} \ No newline at end of file +} + + +macro_rules! get_state { + ($prop:ident as $name:ident) => { + use tauri::Manager; + let app = crate::APP.get().unwrap(); // as long as the app is running, this is fine + let state = app.state::(); + let $name = state.$prop.read().unwrap(); // only panics if another thread has already panicked + }; + (config.$prop:ident as $name:ident) => { + use tauri::Manager; + let app = crate::APP.get().unwrap(); + let state = app.state::(); + let config = state.config.read().unwrap(); + let $name = config.$prop; + }; + + (mut $prop:ident as $name:ident) => { + use tauri::Manager; + let app = crate::APP.get().unwrap(); + let state = app.state::(); + let $name = state.$prop.write().unwrap(); + }; + (mut config.$prop:ident as $name:ident) => { + use tauri::Manager; + let app = crate::APP.get().unwrap(); + let state = app.state::(); + let config = state.config.write().unwrap(); + let $name = config.$prop; + } +} + + +pub(crate) use get_state; diff --git a/src-tauri/src/server.rs b/src-tauri/src/server.rs index ad98660..04d695a 100644 --- a/src-tauri/src/server.rs +++ b/src-tauri/src/server.rs @@ -52,7 +52,7 @@ impl Handler { let req = Request {id: self.request_id, clients}; self.app.emit_all("credentials-request", &req)?; - let starting_visibility = self.ensure_visible()?; + let starting_visibility = self.show_window()?; match self.wait_for_response().await? { Approval::Approved => self.send_credentials().await?, @@ -68,11 +68,13 @@ impl Handler { // and b) there are no other pending requests let state = self.app.state::(); if !starting_visibility && state.req_count() == 0 { - let handle = self.app.app_handle(); - tauri::async_runtime::spawn(async move { - sleep(Duration::from_secs(3)).await; - let _ = handle.get_window("main").map(|w| w.hide()); - }); + let delay = { + let config = state.config.read().unwrap(); + Duration::from_millis(config.rehide_ms) + }; + sleep(delay).await; + let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; + window.hide()?; } Ok(()) @@ -104,7 +106,7 @@ impl Handler { clients.iter().any(|c| state.is_banned(c)) } - fn ensure_visible(&self) -> Result { + fn show_window(&self) -> Result { let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; let starting_visibility = window.is_visible()?; if !starting_visibility { diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index e5da31c..455a427 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -15,13 +15,13 @@ use sodiumoxide::crypto::{ use tauri::async_runtime as runtime; use tauri::Manager; -use crate::config::AppConfig; +use crate::{config, config::AppConfig}; use crate::ipc; use crate::clientinfo::Client; use crate::errors::*; -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum Credentials { #[serde(rename_all = "PascalCase")] @@ -39,6 +39,7 @@ pub enum Credentials { } +#[derive(Debug)] pub struct LockedCredentials { access_key_id: String, secret_key_enc: Vec, @@ -47,6 +48,7 @@ pub struct LockedCredentials { } +#[derive(Debug)] pub enum Session { Unlocked(Credentials), Locked(LockedCredentials), @@ -54,6 +56,7 @@ pub enum Session { } +#[derive(Debug)] pub struct AppState { pub config: RwLock, pub session: RwLock, @@ -65,16 +68,15 @@ pub struct AppState { impl AppState { pub fn new() -> Result { - let conf = AppConfig::default(); - let conn_opts = SqliteConnectOptions::new() - .filename(&conf.db_path) + .filename(config::get_or_create_db_path()) .create_if_missing(true); let pool_opts = SqlitePoolOptions::new(); let pool: SqlitePool = runtime::block_on(pool_opts.connect_with(conn_opts))?; runtime::block_on(sqlx::migrate!().run(&pool))?; let creds = runtime::block_on(Self::load_creds(&pool))?; + let conf = runtime::block_on(config::load(&pool))?; let state = AppState { config: RwLock::new(conf), @@ -197,7 +199,7 @@ impl AppState { pub async fn decrypt(&self, passphrase: &str) -> Result<(), UnlockError> { let (key_id, secret) = { - // do this all in a block so rustc doesn't complain about holding a lock across an await + // do this all in a block so that we aren't holding a lock across an await let session = self.session.read().unwrap(); let locked = match *session { Session::Empty => {return Err(UnlockError::NoCredentials);}, @@ -272,7 +274,9 @@ impl AppState { expiration, }; - println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap()); + if cfg!(debug_assertions) { + println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap()); + } *app_session = Session::Unlocked(session_creds);