creddy/src-tauri/src/state.rs

241 lines
7.4 KiB
Rust

use std::collections::HashMap;
use std::time::Duration;
use tokio::{
sync::RwLock,
sync::oneshot::{self, Sender},
};
use sqlx::SqlitePool;
use tauri::{
Manager,
async_runtime as rt,
};
use crate::credentials::{
Session,
BaseCredentials,
SessionCredentials,
};
use crate::{config, config::AppConfig};
use crate::ipc::{self, Approval};
use crate::errors::*;
use crate::shortcuts;
#[derive(Debug)]
struct Visibility {
leases: usize,
original: Option<bool>,
}
impl Visibility {
fn new() -> Self {
Visibility { leases: 0, original: None }
}
fn acquire(&mut self, delay_ms: u64) -> Result<VisibilityLease, WindowError> {
let app = crate::app::APP.get().unwrap();
let window = app.get_window("main")
.ok_or(WindowError::NoMainWindow)?;
self.leases += 1;
if self.original.is_none() {
let is_visible = window.is_visible()?;
self.original = Some(is_visible);
if !is_visible {
window.show()?;
}
}
window.set_focus()?;
let (tx, rx) = oneshot::channel();
let lease = VisibilityLease { notify: tx };
let delay = Duration::from_millis(delay_ms);
let handle = app.app_handle();
rt::spawn(async move {
// We don't care if it's an error; lease being dropped should be handled identically
let _ = rx.await;
tokio::time::sleep(delay).await;
// we can't use `self` here because we would have to move it into the async block
let state = handle.state::<AppState>();
let mut visibility = state.visibility.write().await;
visibility.leases -= 1;
if visibility.leases == 0 {
if let Some(false) = visibility.original {
window.hide().error_print();
}
visibility.original = None;
}
});
Ok(lease)
}
}
pub struct VisibilityLease {
notify: Sender<()>,
}
impl VisibilityLease {
pub fn release(self) {
rt::spawn(async move {
if let Err(_) = self.notify.send(()) {
eprintln!("Error releasing visibility lease")
}
});
}
}
#[derive(Debug)]
pub struct AppState {
pub config: RwLock<AppConfig>,
pub session: RwLock<Session>,
pub request_count: RwLock<u64>,
pub waiting_requests: RwLock<HashMap<u64, Sender<Approval>>>,
pub pending_terminal_request: RwLock<bool>,
// setup_errors is never modified and so doesn't need to be wrapped in RwLock
pub setup_errors: Vec<String>,
pool: sqlx::SqlitePool,
visibility: RwLock<Visibility>,
}
impl AppState {
pub fn new(
config: AppConfig,
session: Session,
pool: SqlitePool,
setup_errors: Vec<String>,
) -> AppState {
AppState {
config: RwLock::new(config),
session: RwLock::new(session),
request_count: RwLock::new(0),
waiting_requests: RwLock::new(HashMap::new()),
pending_terminal_request: RwLock::new(false),
setup_errors,
pool,
visibility: RwLock::new(Visibility::new()),
}
}
pub async fn new_creds(&self, base_creds: BaseCredentials, passphrase: &str) -> Result<(), UnlockError> {
let locked = base_creds.encrypt(passphrase)?;
// do this first so that if it fails we don't save bad credentials
self.new_session(base_creds).await?;
locked.save(&self.pool).await?;
Ok(())
}
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
let mut live_config = self.config.write().await;
// update autostart if necessary
if new_config.start_on_login != live_config.start_on_login {
config::set_auto_launch(new_config.start_on_login)?;
}
// re-register hotkeys if necessary
if new_config.hotkeys.show_window != live_config.hotkeys.show_window
|| new_config.hotkeys.launch_terminal != live_config.hotkeys.launch_terminal
{
shortcuts::register_hotkeys(&new_config.hotkeys)?;
}
new_config.save(&self.pool).await?;
*live_config = new_config;
Ok(())
}
pub async fn register_request(&self, sender: Sender<Approval>) -> u64 {
let count = {
let mut c = self.request_count.write().await;
*c += 1;
c
};
let mut waiting_requests = self.waiting_requests.write().await;
waiting_requests.insert(*count, sender); // `count` is the request id
*count
}
pub async fn unregister_request(&self, id: u64) {
let mut waiting_requests = self.waiting_requests.write().await;
waiting_requests.remove(&id);
}
pub async fn acquire_visibility_lease(&self, delay: u64) -> Result<VisibilityLease, WindowError> {
let mut visibility = self.visibility.write().await;
visibility.acquire(delay)
}
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
if let Approval::Approved = response.approval {
let mut session = self.session.write().await;
session.renew_if_expired().await?;
}
let mut waiting_requests = self.waiting_requests.write().await;
waiting_requests
.remove(&response.id)
.ok_or(SendResponseError::NotFound)?
.send(response.approval)
.map_err(|_| SendResponseError::Abandoned)
}
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
let base_creds = match *self.session.read().await {
Session::Empty => {return Err(UnlockError::NoCredentials);},
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
Session::Locked(ref locked) => locked.decrypt(passphrase)?,
};
// Read lock is dropped here, so this doesn't deadlock
self.new_session(base_creds).await?;
Ok(())
}
pub async fn is_unlocked(&self) -> bool {
let session = self.session.read().await;
matches!(*session, Session::Unlocked{..})
}
pub async fn base_creds_cloned(&self) -> Result<BaseCredentials, GetCredentialsError> {
let app_session = self.session.read().await;
let (base, _session) = app_session.try_get()?;
Ok(base.clone())
}
pub async fn session_creds_cloned(&self) -> Result<SessionCredentials, GetCredentialsError> {
let app_session = self.session.read().await;
let (_bsae, session) = app_session.try_get()?;
Ok(session.clone())
}
async fn new_session(&self, base: BaseCredentials) -> Result<(), GetSessionError> {
let session = SessionCredentials::from_base(&base).await?;
let mut app_session = self.session.write().await;
*app_session = Session::Unlocked {base, session};
Ok(())
}
pub async fn register_terminal_request(&self) -> Result<(), ()> {
let mut req = self.pending_terminal_request.write().await;
if *req {
// if a request is already pending, we can't register a new one
Err(())
}
else {
*req = true;
Ok(())
}
}
pub async fn unregister_terminal_request(&self) {
let mut req = self.pending_terminal_request.write().await;
*req = false;
}
}