use std::collections::HashMap; use std::sync::RwLock; use serde::{Serialize, Deserialize}; use tokio::sync::oneshot::Sender; use sqlx::{SqlitePool, sqlite::SqlitePoolOptions, sqlite::SqliteConnectOptions}; use crate::ipc; use crate::errors::*; #[derive(Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] #[serde(untagged)] pub enum Credentials { LongLived { access_key_id: String, secret_access_key: String, }, ShortLived { access_key_id: String, secret_access_key: String, token: String, expiration: String, }, } #[derive(Serialize, Deserialize)] pub enum SessionStatus { Unlocked, Locked, Empty, } pub struct AppState { status: RwLock, credentials: RwLock>, request_count: RwLock, open_requests: RwLock>>, pool: SqlitePool, } impl AppState { pub fn new(status: SessionStatus, creds: Option) -> Result { let conn_opts = SqliteConnectOptions::new() .filename("creddy.db") .create_if_missing(true); let pool_opts = SqlitePoolOptions::new(); let pool: SqlitePool = tauri::async_runtime::block_on(pool_opts.connect_with(conn_opts))?; tauri::async_runtime::block_on(sqlx::migrate!().run(&pool))?; let state = AppState { status: RwLock::new(status), credentials: RwLock::new(creds), request_count: RwLock::new(0), open_requests: RwLock::new(HashMap::new()), pool, }; Ok(state) } async fn _load_from_db(&self) -> Result<(), sqlx::error::Error> { let row: (i32,) = sqlx::query_as("SELECT COUNT(*) FROM credentials") .fetch_one(&self.pool) .await?; let mut status = self.status.write().unwrap(); if row.0 > 0 { *status = SessionStatus::Locked; } else { *status = SessionStatus::Empty; } Ok(()) } pub fn register_request(&self, chan: Sender) -> u64 { let count = { let mut c = self.request_count.write().unwrap(); *c += 1; c }; let mut open_requests = self.open_requests.write().unwrap(); open_requests.insert(*count, chan); *count } pub fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { let mut open_requests = self.open_requests.write().unwrap(); let chan = open_requests .remove(&response.id) .ok_or(SendResponseError::NotFound) ?; chan.send(response.approval) .map_err(|_e| SendResponseError::Abandoned) } pub fn set_creds(&self, new_creds: Credentials) { let mut current_creds = self.credentials.write().unwrap(); *current_creds = Some(new_creds); let mut status = self.status.write().unwrap(); *status = SessionStatus::Unlocked; } pub fn get_creds_serialized(&self) -> String { let creds_option = self.credentials.read().unwrap(); // fix this at some point let creds = creds_option.as_ref().unwrap(); serde_json::to_string(creds).unwrap() } }