diff --git a/src-tauri/src/cli.rs b/src-tauri/src/cli.rs index d60a56c..9dec340 100644 --- a/src-tauri/src/cli.rs +++ b/src-tauri/src/cli.rs @@ -98,7 +98,7 @@ pub fn exec(args: &ArgMatches) -> Result<(), CliError> { let name: OsString = cmd_name.into(); Err(ExecError::NotFound(name).into()) } - e => Err(ExecError::ExecutionFailed(e).into()), + _ => Err(ExecError::ExecutionFailed(e).into()), } } diff --git a/src-tauri/src/server.rs b/src-tauri/src/server.rs index 85dc094..241afcf 100644 --- a/src-tauri/src/server.rs +++ b/src-tauri/src/server.rs @@ -10,7 +10,7 @@ use tokio::net::{ TcpStream, }; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::oneshot; +use tokio::sync::oneshot::{self, Sender, Receiver}; use tokio::time::sleep; use tauri::{AppHandle, Manager}; @@ -23,24 +23,55 @@ use crate::ipc::{Request, Approval}; use crate::state::AppState; +#[derive(Debug)] +pub struct RequestWaiter { + pub rehide_after: bool, + pub sender: Option>, +} + +impl RequestWaiter { + pub fn notify(&mut self, approval: Approval) -> Result<(), SendResponseError> { + let chan = self.sender + .take() + .ok_or(SendResponseError::Fulfilled)?; + + chan.send(approval) + .map_err(|_| SendResponseError::Abandoned) + } +} + + struct Handler { request_id: u64, stream: TcpStream, - receiver: Option>, + rehide_after: bool, + receiver: Option>, app: AppHandle, } impl Handler { - async fn new(stream: TcpStream, app: AppHandle) -> Self { + async fn new(stream: TcpStream, app: AppHandle) -> Result { let state = app.state::(); + + // determine whether we should re-hide the window after handling this request + let is_currently_visible = app.get_window("main") + .ok_or(HandlerError::NoMainWindow)? + .is_visible()?; + let rehide_after = state.current_rehide_status() + .await + .unwrap_or(!is_currently_visible); + let (chan_send, chan_recv) = oneshot::channel(); - let request_id = state.register_request(chan_send).await; - Handler { + let waiter = RequestWaiter {rehide_after, sender: Some(chan_send)}; + let request_id = state.register_request(waiter).await; + let handler = Handler { request_id, stream, + rehide_after, receiver: Some(chan_recv), app - } + }; + Ok(handler) } async fn handle(mut self) { @@ -62,7 +93,7 @@ impl Handler { let req = Request {id: self.request_id, clients, base}; self.app.emit_all("credentials-request", &req)?; - let starting_visibility = self.show_window()?; + self.show_window()?; match self.wait_for_response().await? { Approval::Approved => { @@ -94,9 +125,11 @@ impl Handler { }; sleep(delay).await; - if !starting_visibility && state.req_count().await == 0 { - let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?; - window.hide()?; + if self.rehide_after && state.req_count().await == 1 { + self.app + .get_window("main") + .ok_or(HandlerError::NoMainWindow)? + .hide()?; } Ok(()) @@ -143,15 +176,14 @@ impl Handler { false } - fn show_window(&self) -> Result { + fn show_window(&self) -> Result<(), HandlerError> { let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?; - let starting_visibility = window.is_visible()?; - if !starting_visibility { + if !window.is_visible()? { window.unminimize()?; window.show()?; } window.set_focus()?; - Ok(starting_visibility) + Ok(()) } async fn wait_for_response(&mut self) -> Result { @@ -231,12 +263,12 @@ impl Server { loop { match listener.accept().await { Ok((stream, _)) => { - let handler = Handler::new(stream, app_handle.app_handle()).await; - rt::spawn(handler.handle()); + match Handler::new(stream, app_handle.app_handle()).await { + Ok(handler) => { rt::spawn(handler.handle()); } + Err(e) => { eprintln!("Error handling request: {e}"); } + } }, - Err(e) => { - eprintln!("Error accepting connection: {e}"); - } + Err(e) => { eprintln!("Error accepting connection: {e}"); } } } } diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index 8d1f3d7..90038cd 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -2,7 +2,6 @@ use std::collections::{HashMap, HashSet}; use std::time::Duration; use tokio::{ - sync::oneshot::Sender, sync::RwLock, time::sleep, }; @@ -20,7 +19,7 @@ use crate::{config, config::AppConfig}; use crate::ipc::{self, Approval}; use crate::clientinfo::Client; use crate::errors::*; -use crate::server::Server; +use crate::server::{Server, RequestWaiter}; #[derive(Debug)] @@ -28,7 +27,7 @@ pub struct AppState { pub config: RwLock, pub session: RwLock, pub request_count: RwLock, - pub open_requests: RwLock>>>, + pub waiting_requests: RwLock>, pub pending_terminal_request: RwLock, pub bans: RwLock>>, server: RwLock, @@ -41,7 +40,7 @@ impl AppState { config: RwLock::new(config), session: RwLock::new(session), request_count: RwLock::new(0), - open_requests: RwLock::new(HashMap::new()), + waiting_requests: RwLock::new(HashMap::new()), pending_terminal_request: RwLock::new(false), bans: RwLock::new(HashSet::new()), server: RwLock::new(server), @@ -84,26 +83,33 @@ impl AppState { Ok(()) } - pub async fn register_request(&self, chan: Sender) -> u64 { + pub async fn register_request(&self, waiter: RequestWaiter) -> u64 { let count = { let mut c = self.request_count.write().await; *c += 1; c }; - let mut open_requests = self.open_requests.write().await; - open_requests.insert(*count, Some(chan)); // `count` is the request id + let mut waiting_requests = self.waiting_requests.write().await; + waiting_requests.insert(*count, waiter); // `count` is the request id *count } pub async fn unregister_request(&self, id: u64) { - let mut open_requests = self.open_requests.write().await; - open_requests.remove(&id); + let mut waiting_requests = self.waiting_requests.write().await; + waiting_requests.remove(&id); } pub async fn req_count(&self) -> usize { - let open_requests = self.open_requests.read().await; - open_requests.len() + let waiting_requests = self.waiting_requests.read().await; + waiting_requests.len() + } + + pub async fn current_rehide_status(&self) -> Option { + // since all requests that are pending at a given time should have the same + // value for rehide_after, it doesn't matter which one we use + let waiting_requests = self.waiting_requests.read().await; + waiting_requests.iter().next().map(|(_id, w)| w.rehide_after) } pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { @@ -112,14 +118,11 @@ impl AppState { session.renew_if_expired().await?; } - let mut open_requests = self.open_requests.write().await; - let req = open_requests + let mut waiting_requests = self.waiting_requests.write().await; + waiting_requests .get_mut(&response.id) - .ok_or(SendResponseError::NotFound)?; - - let chan = req.take().ok_or(SendResponseError::Fulfilled)?; - chan.send(response.approval) - .map_err(|_e| SendResponseError::Abandoned) + .ok_or(SendResponseError::NotFound)? + .notify(response.approval) } pub async fn add_ban(&self, client: Option) {