store config in database, macro for state access

This commit is contained in:
Joseph Montanaro 2022-12-22 16:36:32 -08:00
parent 398916fe10
commit 61d674199f
10 changed files with 144 additions and 38 deletions

5
src-tauri/Cargo.lock generated
View File

@ -73,6 +73,7 @@ dependencies = [
"aws-smithy-types",
"aws-types",
"netstat2",
"once_cell",
"serde",
"serde_json",
"sodiumoxide",
@ -2337,9 +2338,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.13.0"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18a6dbe30758c9f83eb00cbea4ac95966305f5a7772f3f42ebfc7fc7eddbd8e1"
checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
[[package]]
name = "open"

View File

@ -28,6 +28,7 @@ aws-sdk-sts = "0.22.0"
aws-smithy-types = "0.52.0"
aws-config = "0.52.0"
thiserror = "1.0.38"
once_cell = "1.16.0"
[features]
# by default Tauri runs in production mode

View File

@ -7,8 +7,8 @@ CREATE TABLE credentials (
);
CREATE TABLE config (
name TEXT,
data TEXT
name TEXT NOT NULL,
data TEXT NOT NULL
);
CREATE TABLE clients (

View File

@ -3,9 +3,10 @@ use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
use serde::{Serialize, Deserialize};
use crate::errors::*;
use crate::get_state;
#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Client {
pub pid: u32,
pub exe: String,
@ -18,6 +19,7 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
ProtocolFlags::TCP
)?;
get_state!(config as app_config);
for item in sockets_iter {
let sock_info = item?;
let proto_info = match sock_info.protocol_socket_info {
@ -26,9 +28,9 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
};
if proto_info.local_port == local_port
&& proto_info.remote_port == 19_923
&& proto_info.local_addr == std::net::Ipv4Addr::LOCALHOST
&& proto_info.remote_addr == std::net::Ipv4Addr::LOCALHOST
&& proto_info.remote_port == app_config.listen_port
&& proto_info.local_addr == app_config.listen_addr
&& proto_info.remote_addr == app_config.listen_addr
{
return Ok(sock_info.associated_pids)
}

View File

@ -1,32 +1,79 @@
use std::net::Ipv4Addr;
use std::path::PathBuf;
use serde::{Serialize, Deserialize};
use sqlx::SqlitePool;
use crate::errors::*;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AppConfig {
pub db_path: PathBuf,
pub listen_addr: Ipv4Addr,
pub listen_port: u16,
pub rehide_ms: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DbAppConfig {
listen_addr: Option<Ipv4Addr>,
listen_port: Option<u16>,
rehide_ms: Option<u64>,
}
impl Default for AppConfig {
fn default() -> Self {
let listen_port = if cfg!(debug_assertions) {
12_345
}
else {
19_923
};
AppConfig {
db_path: get_or_create_db_path(),
listen_addr: Ipv4Addr::LOCALHOST,
listen_port,
listen_port: listen_port(),
rehide_ms: 1000,
}
}
}
fn get_or_create_db_path() -> PathBuf {
impl From<DbAppConfig> for AppConfig {
fn from(db_config: DbAppConfig) -> Self {
AppConfig {
db_path: get_or_create_db_path(),
listen_addr: db_config.listen_addr.unwrap_or(Ipv4Addr::LOCALHOST),
listen_port: db_config.listen_port.unwrap_or_else(|| listen_port()),
rehide_ms: db_config.rehide_ms.unwrap_or(1000),
}
}
}
pub async fn load(pool: &SqlitePool) -> Result<AppConfig, SetupError> {
let res = sqlx::query!("SELECT * from config where name = 'main'")
.fetch_optional(pool)
.await?;
let row = match res {
Some(row) => row,
None => return Ok(AppConfig::default()),
};
let db_config: DbAppConfig = serde_json::from_str(&row.data)?;
Ok(AppConfig::from(db_config))
}
fn listen_port() -> u16 {
if cfg!(debug_assertions) {
12_345
}
else {
19_923
}
}
pub fn get_or_create_db_path() -> PathBuf {
if cfg!(debug_assertions) {
return PathBuf::from("./creddy.db");
}

View File

@ -18,7 +18,9 @@ pub enum SetupError {
#[error("Error from database: {0}")]
DbError(#[from] SqlxError),
#[error("Error running migrations: {0}")]
MigrationError(#[from] MigrateError)
MigrationError(#[from] MigrateError),
#[error("Error parsing configuration from database")]
ConfigParseError(#[from] serde_json::Error),
}

View File

@ -1,25 +1,26 @@
use serde::{Serialize, Deserialize};
use tauri::State;
use crate::config::AppConfig;
use crate::clientinfo::Client;
use crate::state::{AppState, Session, Credentials};
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Request {
pub id: u64,
pub clients: Vec<Option<Client>>,
}
#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct RequestResponse {
pub id: u64,
pub approval: Approval,
}
#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub enum Approval {
Approved,
Denied,
@ -62,3 +63,10 @@ pub async fn save_credentials(
.await
.map_err(|e| {eprintln!("{e:?}"); e.to_string()})
}
#[tauri::command]
pub fn get_config(app_state: State<'_, AppState>) -> AppConfig {
let config = app_state.config.read().unwrap();
config.clone()
}

View File

@ -3,7 +3,8 @@
windows_subsystem = "windows"
)]
use tauri::Manager;
use tauri::{AppHandle, Manager};
use once_cell::sync::OnceCell;
mod config;
mod errors;
@ -16,6 +17,8 @@ mod tray;
use state::AppState;
pub static APP: OnceCell<AppHandle> = OnceCell::new();
fn main() {
let initial_state = match state::AppState::new() {
Ok(state) => state,
@ -31,8 +34,10 @@ fn main() {
ipc::respond,
ipc::get_session_status,
ipc::save_credentials,
ipc::get_config,
])
.setup(|app| {
APP.set(app.handle()).unwrap();
let state = app.state::<AppState>();
let config = state.config.read().unwrap();
let addr = std::net::SocketAddrV4::new(config.listen_addr, config.listen_port);
@ -52,3 +57,37 @@ fn main() {
_ => ()
})
}
macro_rules! get_state {
($prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap(); // as long as the app is running, this is fine
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.read().unwrap(); // only panics if another thread has already panicked
};
(config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.read().unwrap();
let $name = config.$prop;
};
(mut $prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.write().unwrap();
};
(mut config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.write().unwrap();
let $name = config.$prop;
}
}
pub(crate) use get_state;

View File

@ -52,7 +52,7 @@ impl Handler {
let req = Request {id: self.request_id, clients};
self.app.emit_all("credentials-request", &req)?;
let starting_visibility = self.ensure_visible()?;
let starting_visibility = self.show_window()?;
match self.wait_for_response().await? {
Approval::Approved => self.send_credentials().await?,
@ -68,11 +68,13 @@ impl Handler {
// and b) there are no other pending requests
let state = self.app.state::<AppState>();
if !starting_visibility && state.req_count() == 0 {
let handle = self.app.app_handle();
tauri::async_runtime::spawn(async move {
sleep(Duration::from_secs(3)).await;
let _ = handle.get_window("main").map(|w| w.hide());
});
let delay = {
let config = state.config.read().unwrap();
Duration::from_millis(config.rehide_ms)
};
sleep(delay).await;
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
window.hide()?;
}
Ok(())
@ -104,7 +106,7 @@ impl Handler {
clients.iter().any(|c| state.is_banned(c))
}
fn ensure_visible(&self) -> Result<bool, RequestError> {
fn show_window(&self) -> Result<bool, RequestError> {
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
let starting_visibility = window.is_visible()?;
if !starting_visibility {

View File

@ -15,13 +15,13 @@ use sodiumoxide::crypto::{
use tauri::async_runtime as runtime;
use tauri::Manager;
use crate::config::AppConfig;
use crate::{config, config::AppConfig};
use crate::ipc;
use crate::clientinfo::Client;
use crate::errors::*;
#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Credentials {
#[serde(rename_all = "PascalCase")]
@ -39,6 +39,7 @@ pub enum Credentials {
}
#[derive(Debug)]
pub struct LockedCredentials {
access_key_id: String,
secret_key_enc: Vec<u8>,
@ -47,6 +48,7 @@ pub struct LockedCredentials {
}
#[derive(Debug)]
pub enum Session {
Unlocked(Credentials),
Locked(LockedCredentials),
@ -54,6 +56,7 @@ pub enum Session {
}
#[derive(Debug)]
pub struct AppState {
pub config: RwLock<AppConfig>,
pub session: RwLock<Session>,
@ -65,16 +68,15 @@ pub struct AppState {
impl AppState {
pub fn new() -> Result<Self, SetupError> {
let conf = AppConfig::default();
let conn_opts = SqliteConnectOptions::new()
.filename(&conf.db_path)
.filename(config::get_or_create_db_path())
.create_if_missing(true);
let pool_opts = SqlitePoolOptions::new();
let pool: SqlitePool = runtime::block_on(pool_opts.connect_with(conn_opts))?;
runtime::block_on(sqlx::migrate!().run(&pool))?;
let creds = runtime::block_on(Self::load_creds(&pool))?;
let conf = runtime::block_on(config::load(&pool))?;
let state = AppState {
config: RwLock::new(conf),
@ -197,7 +199,7 @@ impl AppState {
pub async fn decrypt(&self, passphrase: &str) -> Result<(), UnlockError> {
let (key_id, secret) = {
// do this all in a block so rustc doesn't complain about holding a lock across an await
// do this all in a block so that we aren't holding a lock across an await
let session = self.session.read().unwrap();
let locked = match *session {
Session::Empty => {return Err(UnlockError::NoCredentials);},
@ -272,7 +274,9 @@ impl AppState {
expiration,
};
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
if cfg!(debug_assertions) {
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
}
*app_session = Session::Unlocked(session_creds);