use std::collections::HashMap; use std::collections::hash_map::Entry; use std::time::Duration; use time::OffsetDateTime; use tokio::{ sync::{RwLock, RwLockReadGuard}, sync::oneshot::{self, Sender}, }; use ssh_agent_lib::proto::message::Identity; use sqlx::SqlitePool; use sqlx::types::Uuid; use tauri::{ Manager, async_runtime as rt, }; use crate::app; use crate::credentials::{ AppSession, AwsSessionCredential, SshKey, }; use crate::{config, config::AppConfig}; use crate::credentials::{ AwsBaseCredential, Credential, CredentialRecord, PersistentCredential }; use crate::ipc::{self, RequestResponse}; use crate::errors::*; use crate::shortcuts; #[derive(Debug)] struct Visibility { leases: usize, original: Option, } impl Visibility { fn new() -> Self { Visibility { leases: 0, original: None } } fn acquire(&mut self, delay_ms: u64) -> Result { let app = crate::app::APP.get().unwrap(); let window = app.get_webview_window("main") .ok_or(WindowError::NoMainWindow)?; self.leases += 1; // `original` represents the visibility of the window before any leases were acquired // None means we don't know, Some(false) means it was previously hidden, // Some(true) means it was previously visible let is_visible = window.is_visible()?; if self.original.is_none() { self.original = Some(is_visible); } let state = app.state::(); if is_visible && state.desktop_is_gnome { // Gnome has a really annoying "focus-stealing prevention" behavior means we // can't just pop up when the window is already visible, so to work around it // we hide and then immediately unhide the window window.hide()?; } app::show_main_window(&app)?; window.set_focus()?; let (tx, rx) = oneshot::channel(); let lease = VisibilityLease { notify: tx }; let delay = Duration::from_millis(delay_ms); rt::spawn(async move { // We don't care if it's an error; lease being dropped should be handled identically let _ = rx.await; tokio::time::sleep(delay).await; // we can't use `self` here because we would have to move it into the async block let state = app.state::(); let mut visibility = state.visibility.write().await; visibility.leases -= 1; if visibility.leases == 0 { if let Some(false) = visibility.original { app::hide_main_window(app).error_print(); } visibility.original = None; } }); Ok(lease) } } pub struct VisibilityLease { notify: Sender<()>, } impl VisibilityLease { pub fn release(self) { rt::spawn(async move { if let Err(_) = self.notify.send(()) { eprintln!("Error releasing visibility lease") } }); } } #[derive(Debug)] pub struct AppState { pub config: RwLock, pub app_session: RwLock, // session cache is keyed on id rather than name because names can change pub aws_sessions: RwLock>, pub last_activity: RwLock, pub request_count: RwLock, pub waiting_requests: RwLock>>, pub pending_terminal_request: RwLock, // these are never modified and so don't need to be wrapped in RwLocks pub setup_errors: Vec, pub desktop_is_gnome: bool, pool: sqlx::SqlitePool, visibility: RwLock, } impl AppState { pub fn new( config: AppConfig, app_session: AppSession, pool: SqlitePool, setup_errors: Vec, desktop_is_gnome: bool, ) -> AppState { AppState { config: RwLock::new(config), app_session: RwLock::new(app_session), aws_sessions: RwLock::new(HashMap::new()), last_activity: RwLock::new(OffsetDateTime::now_utc()), request_count: RwLock::new(0), waiting_requests: RwLock::new(HashMap::new()), pending_terminal_request: RwLock::new(false), setup_errors, desktop_is_gnome, pool, visibility: RwLock::new(Visibility::new()), } } pub async fn save_credential(&self, record: CredentialRecord) -> Result<(), SaveCredentialsError> { let session = self.app_session.read().await; let crypto = session.try_get_crypto()?; record.save(crypto, &self.pool).await } pub async fn delete_credential(&self, id: &Uuid) -> Result<(), SaveCredentialsError> { sqlx::query!("DELETE FROM credentials WHERE id = ?", id) .execute(&self.pool) .await?; Ok(()) } pub async fn list_credentials(&self) -> Result, GetCredentialsError> { let session = self.app_session.read().await; let crypto = session.try_get_crypto()?; let list = CredentialRecord::list(crypto, &self.pool).await?; Ok(list) } pub async fn list_ssh_identities(&self) -> Result, GetCredentialsError> { Ok(SshKey::list_identities(&self.pool).await?) } pub async fn set_passphrase(&self, passphrase: &str) -> Result<(), SaveCredentialsError> { let mut cur_session = self.app_session.write().await; if let AppSession::Locked {..} = *cur_session { return Err(SaveCredentialsError::Locked); } let new_session = AppSession::new(passphrase)?; if let AppSession::Unlocked {salt: _, ref crypto} = *cur_session { CredentialRecord::rekey( crypto, new_session.try_get_crypto().expect("AppSession::new() should always return Unlocked"), &self.pool, ).await?; } new_session.save(&self.pool).await?; *cur_session = new_session; Ok(()) } pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> { let mut live_config = self.config.write().await; // update autostart if necessary if new_config.start_on_login != live_config.start_on_login { config::set_auto_launch(new_config.start_on_login)?; } // re-register hotkeys if necessary if new_config.hotkeys.show_window != live_config.hotkeys.show_window || new_config.hotkeys.launch_terminal != live_config.hotkeys.launch_terminal { shortcuts::register_hotkeys(&new_config.hotkeys)?; } new_config.save(&self.pool).await?; *live_config = new_config; Ok(()) } pub async fn register_request(&self, sender: Sender) -> u64 { let count = { let mut c = self.request_count.write().await; *c += 1; c }; let mut waiting_requests = self.waiting_requests.write().await; waiting_requests.insert(*count, sender); // `count` is the request id *count } pub async fn unregister_request(&self, id: u64) { let mut waiting_requests = self.waiting_requests.write().await; waiting_requests.remove(&id); } pub async fn acquire_visibility_lease(&self, delay: u64) -> Result { let mut visibility = self.visibility.write().await; visibility.acquire(delay) } pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { let mut waiting_requests = self.waiting_requests.write().await; waiting_requests .remove(&response.id) .ok_or(SendResponseError::NotFound)? .send(response) .map_err(|_| SendResponseError::Abandoned) } pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> { let mut session = self.app_session.write().await; session.unlock(passphrase) } pub async fn lock(&self) -> Result<(), LockError> { let mut session = self.app_session.write().await; match *session { AppSession::Empty => Err(LockError::NotUnlocked), AppSession::Locked{..} => Err(LockError::NotUnlocked), AppSession::Unlocked{..} => { *session = AppSession::load(&self.pool).await?; let app_handle = app::APP.get().unwrap(); app_handle.emit("locked", None::)?; Ok(()) } } } pub async fn reset_session(&self) -> Result<(), SaveCredentialsError> { let mut session = self.app_session.write().await; session.reset(&self.pool).await?; sqlx::query!("DELETE FROM credentials").execute(&self.pool).await?; Ok(()) } pub async fn get_aws_base(&self, name: Option) -> Result { let app_session = self.app_session.read().await; let crypto = app_session.try_get_crypto()?; let creds = match name { Some(n) => AwsBaseCredential::load_by_name(&n, crypto, &self.pool).await?, None => AwsBaseCredential::load_default(crypto, &self.pool).await?, }; Ok(creds) } pub async fn get_aws_session(&self, name: Option) -> Result, GetCredentialsError> { let app_session = self.app_session.read().await; let crypto = app_session.try_get_crypto()?; let record = match name { Some(n) => CredentialRecord::load_by_name(&n, crypto, &self.pool).await?, None => CredentialRecord::load_default("aws", crypto, &self.pool).await?, }; let base = match &record.credential { Credential::AwsBase(b) => Ok(b), _ => Err(LoadCredentialsError::NoCredentials) }?; { let mut aws_sessions = self.aws_sessions.write().await; match aws_sessions.entry(record.id) { Entry::Vacant(e) => { e.insert(AwsSessionCredential::from_base(&base).await?); }, Entry::Occupied(mut e) if e.get().is_expired() => { *(e.get_mut()) = AwsSessionCredential::from_base(&base).await?; }, _ => () } } // we know the unwrap is safe, because we just made sure of it let s = RwLockReadGuard::map(self.aws_sessions.read().await, |map| map.get(&record.id).unwrap()); Ok(s) } pub async fn ssh_name_from_pubkey(&self, pubkey: &[u8]) -> Result { let k = SshKey::name_from_pubkey(pubkey, &self.pool).await?; Ok(k) } pub async fn sshkey_by_name(&self, name: &str) -> Result { let app_session = self.app_session.read().await; let crypto = app_session.try_get_crypto()?; let k = SshKey::load_by_name(name, crypto, &self.pool).await?; Ok(k) } pub async fn signal_activity(&self) { let mut last_activity = self.last_activity.write().await; *last_activity = OffsetDateTime::now_utc(); } pub async fn should_auto_lock(&self) -> bool { let config = self.config.read().await; if !config.auto_lock || self.is_locked().await { return false; } let last_activity = self.last_activity.read().await; let elapsed = OffsetDateTime::now_utc() - *last_activity; elapsed >= config.lock_after } pub async fn is_locked(&self) -> bool { let session = self.app_session.read().await; matches!(*session, AppSession::Locked {..}) } pub async fn register_terminal_request(&self) -> Result<(), ()> { let mut req = self.pending_terminal_request.write().await; if *req { // if a request is already pending, we can't register a new one Err(()) } else { *req = true; Ok(()) } } pub async fn unregister_terminal_request(&self) { let mut req = self.pending_terminal_request.write().await; *req = false; } } #[cfg(test)] mod tests { use super::*; use crate::credentials::Crypto; use sqlx::types::Uuid; fn test_state(pool: SqlitePool) -> AppState { let salt = [0u8; 32]; let crypto = Crypto::fixed(); AppState::new( AppConfig::default(), AppSession::Unlocked { salt, crypto }, pool, vec![], false, ) } #[sqlx::test(fixtures("./credentials/fixtures/aws_credentials.sql"))] fn test_delete_credential(pool: SqlitePool) { let state = test_state(pool); let id = Uuid::try_parse("00000000-0000-0000-0000-000000000000").unwrap(); state.delete_credential(&id).await.unwrap(); // ensure delete-cascade went through correctly let res = AwsBaseCredential::load(&id, &Crypto::fixed(), &state.pool).await; assert!(matches!(res, Err(LoadCredentialsError::NoCredentials))); } }