use std::collections::{HashMap, HashSet}; use std::time::{ Duration, SystemTime, UNIX_EPOCH }; use aws_smithy_types::date_time::{ DateTime as AwsDateTime, Format as AwsDateTimeFormat, }; use serde::{Serialize, Deserialize}; use tokio::{ sync::oneshot::Sender, sync::RwLock, time::sleep, }; use sqlx::SqlitePool; use sodiumoxide::crypto::{ pwhash, pwhash::Salt, secretbox, secretbox::{Nonce, Key} }; use tauri::async_runtime as runtime; use tauri::Manager; use serde::Serializer; use crate::{config, config::AppConfig}; use crate::ipc; use crate::clientinfo::Client; use crate::errors::*; use crate::server::Server; #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct BaseCredentials { access_key_id: String, secret_access_key: String, } #[derive(Clone, Debug, Serialize)] #[serde(rename_all = "PascalCase")] pub struct SessionCredentials { access_key_id: String, secret_access_key: String, token: String, #[serde(serialize_with = "serialize_expiration")] expiration: AwsDateTime, } impl SessionCredentials { fn is_expired(&self) -> bool { let current_ts = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() // doesn't panic because UNIX_EPOCH won't be later than now() .as_secs(); let expire_ts = self.expiration.secs(); let remaining = expire_ts - (current_ts as i64); remaining < 60 } } #[derive(Clone, Debug)] pub struct LockedCredentials { access_key_id: String, secret_key_enc: Vec, salt: Salt, nonce: Nonce, } #[derive(Clone, Debug)] pub enum Session { Unlocked{ base: BaseCredentials, session: SessionCredentials, }, Locked(LockedCredentials), Empty, } #[derive(Debug)] pub struct AppState { pub config: RwLock, pub session: RwLock, pub request_count: RwLock, pub open_requests: RwLock>>, pub bans: RwLock>>, server: RwLock, pool: sqlx::SqlitePool, } impl AppState { pub fn new(config: AppConfig, session: Session, server: Server, pool: SqlitePool) -> AppState { AppState { config: RwLock::new(config), session: RwLock::new(session), request_count: RwLock::new(0), open_requests: RwLock::new(HashMap::new()), bans: RwLock::new(HashSet::new()), server: RwLock::new(server), pool, } } pub async fn load_creds(pool: &SqlitePool) -> Result { let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc") .fetch_optional(pool) .await?; let row = match res { Some(r) => r, None => {return Ok(Session::Empty);} }; let salt_buf: [u8; 32] = row.salt .try_into() .map_err(|_e| SetupError::InvalidRecord)?; let nonce_buf: [u8; 24] = row.nonce .try_into() .map_err(|_e| SetupError::InvalidRecord)?; let creds = LockedCredentials { access_key_id: row.access_key_id, secret_key_enc: row.secret_key_enc, salt: Salt(salt_buf), nonce: Nonce(nonce_buf), }; Ok(Session::Locked(creds)) } pub async fn save_creds(&self, creds: BaseCredentials, passphrase: &str) -> Result<(), UnlockError> { let BaseCredentials {access_key_id, secret_access_key} = creds; // do this first so that if it fails we don't save bad credentials self.new_session(&access_key_id, &secret_access_key).await?; let salt = pwhash::gen_salt(); let mut key_buf = [0; secretbox::KEYBYTES]; pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap(); let key = Key(key_buf); // not sure we need both salt AND nonce given that we generate a // fresh salt every time we encrypt, but better safe than sorry let nonce = secretbox::gen_nonce(); let secret_key_enc = secretbox::seal(secret_access_key.as_bytes(), &nonce, &key); sqlx::query( "INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at) VALUES (?, ?, ?, ?, strftime('%s'))" ) .bind(&access_key_id) .bind(&secret_key_enc) .bind(&salt.0[0..]) .bind(&nonce.0[0..]) .execute(&self.pool) .await?; Ok(()) } pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> { let mut live_config = self.config.write().await; if new_config.start_on_login != live_config.start_on_login { config::set_auto_launch(new_config.start_on_login)?; } if new_config.listen_addr != live_config.listen_addr || new_config.listen_port != live_config.listen_port { let mut sv = self.server.write().await; sv.rebind(new_config.listen_addr, new_config.listen_port).await?; } new_config.save(&self.pool).await?; *live_config = new_config; Ok(()) } pub async fn register_request(&self, chan: Sender) -> u64 { let count = { let mut c = self.request_count.write().await; *c += 1; c }; let mut open_requests = self.open_requests.write().await; open_requests.insert(*count, chan); // `count` is the request id *count } pub async fn unregister_request(&self, id: u64) { let mut open_requests = self.open_requests.write().await; open_requests.remove(&id); } pub async fn req_count(&self) -> usize { let open_requests = self.open_requests.read().await; open_requests.len() } pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { self.renew_session_if_expired().await?; let mut open_requests = self.open_requests.write().await; let chan = open_requests .remove(&response.id) .ok_or(SendResponseError::NotFound) ?; chan.send(response.approval) .map_err(|_e| SendResponseError::Abandoned) } pub async fn add_ban(&self, client: Option) { let mut bans = self.bans.write().await; bans.insert(client.clone()); runtime::spawn(async move { sleep(Duration::from_secs(5)).await; let app = crate::APP.get().unwrap(); let state = app.state::(); let mut bans = state.bans.write().await; bans.remove(&client); }); } pub async fn is_banned(&self, client: &Option) -> bool { self.bans.read().await.contains(&client) } pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> { let mut session = self.session.write().await; let LockedCredentials { access_key_id, secret_key_enc, salt, nonce } = match *session { Session::Empty => {return Err(UnlockError::NoCredentials);}, Session::Unlocked{..} => {return Err(UnlockError::NotLocked);}, Session::Locked(ref c) => c, }; let mut key_buf = [0; secretbox::KEYBYTES]; // pretty sure this only fails if we're out of memory pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap(); let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf)) .map_err(|_e| UnlockError::BadPassphrase)?; let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?; let session_creds = self.new_session(access_key_id, &secret_access_key).await?; *session = Session::Unlocked { base: BaseCredentials { access_key_id: access_key_id.clone(), secret_access_key, }, session: session_creds }; Ok(()) } // pub async fn serialize_base_creds(&self) -> Result { // let session = self.session.read().await; // match *session { // Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()), // Session::Locked(_) => Err(GetCredentialsError::Locked), // Session::Empty => Err(GetCredentialsError::Empty), // } // } pub async fn serialize_session_creds(&self) -> Result { let session = self.session.read().await; match *session { Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()), Session::Locked(_) => Err(GetCredentialsError::Locked), Session::Empty => Err(GetCredentialsError::Empty), } } async fn new_session(&self, key_id: &str, secret_key: &str) -> Result { let creds = aws_sdk_sts::Credentials::new( key_id, secret_key, None, // token None, // expiration "creddy", // "provider name" apparently ); let config = aws_config::from_env() .credentials_provider(creds) .load() .await; let client = aws_sdk_sts::Client::new(&config); let resp = client.get_session_token() .duration_seconds(43_200) .send() .await?; let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?; let access_key_id = aws_session.access_key_id() .ok_or(GetSessionError::EmptyResponse)? .to_string(); let secret_access_key = aws_session.secret_access_key() .ok_or(GetSessionError::EmptyResponse)? .to_string(); let token = aws_session.session_token() .ok_or(GetSessionError::EmptyResponse)? .to_string(); let expiration = aws_session.expiration() .ok_or(GetSessionError::EmptyResponse)? .clone(); let session_creds = SessionCredentials { access_key_id, secret_access_key, token, expiration, }; #[cfg(debug_assertions)] println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap()); Ok(session_creds) } pub async fn renew_session_if_expired(&self) -> Result { match *self.session.write().await { Session::Unlocked{ref base, ref mut session} => { if !session.is_expired() { return Ok(false); } let new_session = self.new_session( &base.access_key_id, &base.secret_access_key ).await?; *session = new_session; Ok(true) }, Session::Locked(_) => Err(GetSessionError::CredentialsLocked), Session::Empty => Err(GetSessionError::CredentialsEmpty), } } } fn serialize_expiration(exp: &AwsDateTime, serializer: S) -> Result where S: Serializer { // this only fails if the d/t is out of range, which it can't be for this format let time_str = exp.fmt(AwsDateTimeFormat::DateTime).unwrap(); serializer.serialize_str(&time_str) }