Inherit rehide flag from existing request if present
This commit is contained in:
		| @@ -98,7 +98,7 @@ pub fn exec(args: &ArgMatches) -> Result<(), CliError> { | ||||
|                 let name: OsString = cmd_name.into(); | ||||
|                 Err(ExecError::NotFound(name).into()) | ||||
|             } | ||||
|             e => Err(ExecError::ExecutionFailed(e).into()), | ||||
|             _ => Err(ExecError::ExecutionFailed(e).into()), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -10,7 +10,7 @@ use tokio::net::{ | ||||
|     TcpStream, | ||||
| }; | ||||
| use tokio::io::{AsyncReadExt, AsyncWriteExt}; | ||||
| use tokio::sync::oneshot; | ||||
| use tokio::sync::oneshot::{self, Sender, Receiver}; | ||||
| use tokio::time::sleep; | ||||
|  | ||||
| use tauri::{AppHandle, Manager}; | ||||
| @@ -23,24 +23,55 @@ use crate::ipc::{Request, Approval}; | ||||
| use crate::state::AppState; | ||||
|  | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct RequestWaiter { | ||||
|     pub rehide_after: bool, | ||||
|     pub sender: Option<Sender<Approval>>, | ||||
| } | ||||
|  | ||||
| impl RequestWaiter { | ||||
|     pub fn notify(&mut self, approval: Approval) -> Result<(), SendResponseError> { | ||||
|         let chan = self.sender | ||||
|             .take() | ||||
|             .ok_or(SendResponseError::Fulfilled)?; | ||||
|          | ||||
|         chan.send(approval) | ||||
|             .map_err(|_| SendResponseError::Abandoned) | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| struct Handler { | ||||
|     request_id: u64, | ||||
|     stream: TcpStream, | ||||
|     receiver: Option<oneshot::Receiver<Approval>>, | ||||
|     rehide_after: bool, | ||||
|     receiver: Option<Receiver<Approval>>, | ||||
|     app: AppHandle, | ||||
| } | ||||
|  | ||||
| impl Handler { | ||||
|     async fn new(stream: TcpStream, app: AppHandle) -> Self { | ||||
|     async fn new(stream: TcpStream, app: AppHandle) -> Result<Self, HandlerError> { | ||||
|         let state = app.state::<AppState>(); | ||||
|  | ||||
|         // determine whether we should re-hide the window after handling this request | ||||
|         let is_currently_visible = app.get_window("main") | ||||
|             .ok_or(HandlerError::NoMainWindow)? | ||||
|             .is_visible()?; | ||||
|         let rehide_after = state.current_rehide_status() | ||||
|             .await | ||||
|             .unwrap_or(!is_currently_visible); | ||||
|  | ||||
|         let (chan_send, chan_recv) = oneshot::channel(); | ||||
|         let request_id = state.register_request(chan_send).await; | ||||
|         Handler {  | ||||
|         let waiter = RequestWaiter {rehide_after, sender: Some(chan_send)}; | ||||
|         let request_id = state.register_request(waiter).await; | ||||
|         let  handler = Handler {  | ||||
|             request_id, | ||||
|             stream, | ||||
|             rehide_after, | ||||
|             receiver: Some(chan_recv), | ||||
|             app | ||||
|         } | ||||
|         }; | ||||
|         Ok(handler) | ||||
|     } | ||||
|  | ||||
|     async fn handle(mut self) { | ||||
| @@ -62,7 +93,7 @@ impl Handler { | ||||
|          | ||||
|         let req = Request {id: self.request_id, clients, base}; | ||||
|         self.app.emit_all("credentials-request", &req)?; | ||||
|         let starting_visibility = self.show_window()?; | ||||
|         self.show_window()?; | ||||
|  | ||||
|         match self.wait_for_response().await? { | ||||
|             Approval::Approved => { | ||||
| @@ -94,9 +125,11 @@ impl Handler { | ||||
|         }; | ||||
|         sleep(delay).await; | ||||
|  | ||||
|         if !starting_visibility && state.req_count().await == 0 { | ||||
|             let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?; | ||||
|             window.hide()?; | ||||
|         if self.rehide_after && state.req_count().await == 1 { | ||||
|             self.app | ||||
|                 .get_window("main") | ||||
|                 .ok_or(HandlerError::NoMainWindow)? | ||||
|                 .hide()?; | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
| @@ -143,15 +176,14 @@ impl Handler { | ||||
|         false | ||||
|     } | ||||
|  | ||||
|     fn show_window(&self) -> Result<bool, HandlerError> { | ||||
|     fn show_window(&self) -> Result<(), HandlerError> { | ||||
|         let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?; | ||||
|         let starting_visibility = window.is_visible()?; | ||||
|         if !starting_visibility { | ||||
|         if !window.is_visible()? { | ||||
|             window.unminimize()?; | ||||
|             window.show()?; | ||||
|         } | ||||
|         window.set_focus()?; | ||||
|         Ok(starting_visibility) | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     async fn wait_for_response(&mut self) -> Result<Approval, HandlerError> { | ||||
| @@ -231,12 +263,12 @@ impl Server { | ||||
|         loop { | ||||
|             match listener.accept().await { | ||||
|                 Ok((stream, _)) => { | ||||
|                     let handler = Handler::new(stream, app_handle.app_handle()).await; | ||||
|                     rt::spawn(handler.handle()); | ||||
|                 }, | ||||
|                 Err(e) => { | ||||
|                     eprintln!("Error accepting connection: {e}"); | ||||
|                     match Handler::new(stream, app_handle.app_handle()).await { | ||||
|                         Ok(handler) => { rt::spawn(handler.handle()); } | ||||
|                         Err(e) => { eprintln!("Error handling request: {e}"); } | ||||
|                     } | ||||
|                 }, | ||||
|                 Err(e) => { eprintln!("Error accepting connection: {e}"); } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -2,7 +2,6 @@ use std::collections::{HashMap, HashSet}; | ||||
| use std::time::Duration; | ||||
|  | ||||
| use tokio::{ | ||||
|     sync::oneshot::Sender, | ||||
|     sync::RwLock, | ||||
|     time::sleep, | ||||
| }; | ||||
| @@ -20,7 +19,7 @@ use crate::{config, config::AppConfig}; | ||||
| use crate::ipc::{self, Approval}; | ||||
| use crate::clientinfo::Client; | ||||
| use crate::errors::*; | ||||
| use crate::server::Server; | ||||
| use crate::server::{Server, RequestWaiter}; | ||||
|  | ||||
|  | ||||
| #[derive(Debug)] | ||||
| @@ -28,7 +27,7 @@ pub struct AppState { | ||||
|     pub config: RwLock<AppConfig>, | ||||
|     pub session: RwLock<Session>, | ||||
|     pub request_count: RwLock<u64>, | ||||
|     pub open_requests: RwLock<HashMap<u64, Option<Sender<ipc::Approval>>>>, | ||||
|     pub waiting_requests: RwLock<HashMap<u64, RequestWaiter>>, | ||||
|     pub pending_terminal_request: RwLock<bool>, | ||||
|     pub bans: RwLock<std::collections::HashSet<Option<Client>>>, | ||||
|     server: RwLock<Server>, | ||||
| @@ -41,7 +40,7 @@ impl AppState { | ||||
|             config: RwLock::new(config), | ||||
|             session: RwLock::new(session), | ||||
|             request_count: RwLock::new(0), | ||||
|             open_requests: RwLock::new(HashMap::new()), | ||||
|             waiting_requests: RwLock::new(HashMap::new()), | ||||
|             pending_terminal_request: RwLock::new(false), | ||||
|             bans: RwLock::new(HashSet::new()), | ||||
|             server: RwLock::new(server), | ||||
| @@ -84,26 +83,33 @@ impl AppState { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     pub async fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 { | ||||
|     pub async fn register_request(&self, waiter: RequestWaiter) -> u64 { | ||||
|         let count = { | ||||
|             let mut c = self.request_count.write().await; | ||||
|             *c += 1; | ||||
|             c | ||||
|         }; | ||||
|  | ||||
|         let mut open_requests = self.open_requests.write().await; | ||||
|         open_requests.insert(*count, Some(chan)); // `count` is the request id | ||||
|         let mut waiting_requests = self.waiting_requests.write().await; | ||||
|         waiting_requests.insert(*count, waiter); // `count` is the request id | ||||
|         *count | ||||
|     } | ||||
|  | ||||
|     pub async fn unregister_request(&self, id: u64) { | ||||
|         let mut open_requests = self.open_requests.write().await; | ||||
|         open_requests.remove(&id); | ||||
|         let mut waiting_requests = self.waiting_requests.write().await; | ||||
|         waiting_requests.remove(&id); | ||||
|     } | ||||
|  | ||||
|     pub async fn req_count(&self) -> usize { | ||||
|         let open_requests = self.open_requests.read().await; | ||||
|         open_requests.len() | ||||
|         let waiting_requests = self.waiting_requests.read().await; | ||||
|         waiting_requests.len() | ||||
|     } | ||||
|  | ||||
|     pub async fn current_rehide_status(&self) -> Option<bool> { | ||||
|         // since all requests that are pending at a given time should have the same | ||||
|         // value for rehide_after, it doesn't matter which one we use | ||||
|         let waiting_requests = self.waiting_requests.read().await; | ||||
|         waiting_requests.iter().next().map(|(_id, w)| w.rehide_after) | ||||
|     } | ||||
|  | ||||
|     pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { | ||||
| @@ -112,14 +118,11 @@ impl AppState { | ||||
|             session.renew_if_expired().await?; | ||||
|         } | ||||
|  | ||||
|         let mut open_requests = self.open_requests.write().await; | ||||
|         let req = open_requests | ||||
|         let mut waiting_requests = self.waiting_requests.write().await; | ||||
|         waiting_requests | ||||
|             .get_mut(&response.id) | ||||
|             .ok_or(SendResponseError::NotFound)?; | ||||
|  | ||||
|         let chan = req.take().ok_or(SendResponseError::Fulfilled)?; | ||||
|         chan.send(response.approval) | ||||
|             .map_err(|_e| SendResponseError::Abandoned) | ||||
|             .ok_or(SendResponseError::NotFound)? | ||||
|             .notify(response.approval) | ||||
|     } | ||||
|  | ||||
|     pub async fn add_ban(&self, client: Option<Client>) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user