switch to tokio RwLock instead of std

This commit is contained in:
Joseph Montanaro 2023-05-02 15:24:35 -07:00
parent 96bbc2dbc2
commit ddf865d0b4
5 changed files with 121 additions and 157 deletions

View File

@ -1,9 +1,13 @@
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo}; use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
use tauri::Manager;
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt}; use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::errors::*; use crate::{
use crate::get_state; errors::*,
config::AppConfig,
state::AppState,
};
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
@ -13,13 +17,18 @@ pub struct Client {
} }
fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> { async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
let state = crate::APP.get().unwrap().state::<AppState>();
let AppConfig {
listen_addr: app_listen_addr,
listen_port: app_listen_port,
..
} = *state.config.read().await;
let sockets_iter = netstat2::iterate_sockets_info( let sockets_iter = netstat2::iterate_sockets_info(
AddressFamilyFlags::IPV4, AddressFamilyFlags::IPV4,
ProtocolFlags::TCP ProtocolFlags::TCP
)?; )?;
get_state!(config as app_config);
for item in sockets_iter { for item in sockets_iter {
let sock_info = item?; let sock_info = item?;
let proto_info = match sock_info.protocol_socket_info { let proto_info = match sock_info.protocol_socket_info {
@ -28,9 +37,9 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
}; };
if proto_info.local_port == local_port if proto_info.local_port == local_port
&& proto_info.remote_port == app_config.listen_port && proto_info.remote_port == app_listen_port
&& proto_info.local_addr == app_config.listen_addr && proto_info.local_addr == app_listen_addr
&& proto_info.remote_addr == app_config.listen_addr && proto_info.remote_addr == app_listen_addr
{ {
return Ok(sock_info.associated_pids) return Ok(sock_info.associated_pids)
} }
@ -40,10 +49,10 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
// Theoretically, on some systems, multiple processes can share a socket // Theoretically, on some systems, multiple processes can share a socket
pub fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> { pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
let mut clients = Vec::new(); let mut clients = Vec::new();
let mut sys = System::new(); let mut sys = System::new();
for p in get_associated_pids(local_port)? { for p in get_associated_pids(local_port).await? {
let pid = Pid::from_u32(p); let pid = Pid::from_u32(p);
sys.refresh_process(pid); sys.refresh_process(pid);
let proc = sys.process(pid) let proc = sys.process(pid)

View File

@ -41,13 +41,14 @@ pub async fn unlock(passphrase: String, app_state: State<'_, AppState>) -> Resul
#[tauri::command] #[tauri::command]
pub fn get_session_status(app_state: State<'_, AppState>) -> String { pub async fn get_session_status(app_state: State<'_, AppState>) -> Result<String, ()> {
let session = app_state.session.read().unwrap(); let session = app_state.session.read().await;
match *session { let status = match *session {
Session::Locked(_) => "locked".into(), Session::Locked(_) => "locked".into(),
Session::Unlocked{..} => "unlocked".into(), Session::Unlocked{..} => "unlocked".into(),
Session::Empty => "empty".into() Session::Empty => "empty".into()
} };
Ok(status)
} }
@ -62,9 +63,9 @@ pub async fn save_credentials(
#[tauri::command] #[tauri::command]
pub fn get_config(app_state: State<'_, AppState>) -> AppConfig { pub async fn get_config(app_state: State<'_, AppState>) -> Result<AppConfig, ()> {
let config = app_state.config.read().unwrap(); let config = app_state.config.read().await;
config.clone() Ok(config.clone())
} }

View File

@ -97,37 +97,3 @@ fn run() -> tauri::Result<()> {
fn main() { fn main() {
run().error_popup("Creddy failed to start"); run().error_popup("Creddy failed to start");
} }
macro_rules! get_state {
($prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap(); // as long as the app is running, this is fine
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.read().unwrap(); // only panics if another thread has already panicked
};
(config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.read().unwrap();
let $name = config.$prop;
};
(mut $prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.write().unwrap();
};
(mut config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.write().unwrap();
let $name = config.$prop;
}
}
pub(crate) use get_state;

View File

@ -4,7 +4,6 @@ use std::net::{
Ipv4Addr, Ipv4Addr,
SocketAddr, SocketAddr,
SocketAddrV4, SocketAddrV4,
TcpListener as StdTcpListener,
}; };
use tokio::net::{ use tokio::net::{
TcpListener, TcpListener,
@ -32,10 +31,10 @@ struct Handler {
} }
impl Handler { impl Handler {
fn new(stream: TcpStream, app: AppHandle) -> Self { async fn new(stream: TcpStream, app: AppHandle) -> Self {
let state = app.state::<AppState>(); let state = app.state::<AppState>();
let (chan_send, chan_recv) = oneshot::channel(); let (chan_send, chan_recv) = oneshot::channel();
let request_id = state.register_request(chan_send); let request_id = state.register_request(chan_send).await;
Handler { Handler {
request_id, request_id,
stream, stream,
@ -49,13 +48,13 @@ impl Handler {
eprintln!("{e}"); eprintln!("{e}");
} }
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
state.unregister_request(self.request_id); state.unregister_request(self.request_id).await;
} }
async fn try_handle(&mut self) -> Result<(), RequestError> { async fn try_handle(&mut self) -> Result<(), RequestError> {
let _ = self.recv_request().await?; let _ = self.recv_request().await?;
let clients = self.get_clients()?; let clients = self.get_clients().await?;
if self.includes_banned(&clients) { if self.includes_banned(&clients).await {
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(()) return Ok(())
} }
@ -69,7 +68,7 @@ impl Handler {
Approval::Denied => { Approval::Denied => {
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
for client in req.clients { for client in req.clients {
state.add_ban(client, self.app.clone()); state.add_ban(client).await;
} }
} }
} }
@ -78,12 +77,12 @@ impl Handler {
// and b) there are no other pending requests // and b) there are no other pending requests
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
let delay = { let delay = {
let config = state.config.read().unwrap(); let config = state.config.read().await;
Duration::from_millis(config.rehide_ms) Duration::from_millis(config.rehide_ms)
}; };
sleep(delay).await; sleep(delay).await;
if !starting_visibility && state.req_count() == 0 { if !starting_visibility && state.req_count().await == 0 {
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
window.hide()?; window.hide()?;
} }
@ -107,18 +106,23 @@ impl Handler {
Ok(buf) Ok(buf)
} }
fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> { async fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
let peer_addr = match self.stream.peer_addr()? { let peer_addr = match self.stream.peer_addr()? {
SocketAddr::V4(addr) => addr, SocketAddr::V4(addr) => addr,
_ => unreachable!(), // we only listen on IPv4 _ => unreachable!(), // we only listen on IPv4
}; };
let clients = clientinfo::get_clients(peer_addr.port())?; let clients = clientinfo::get_clients(peer_addr.port()).await?;
Ok(clients) Ok(clients)
} }
fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool { async fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
clients.iter().any(|c| state.is_banned(c)) for client in clients {
if state.is_banned(client).await {
return true;
}
}
false
} }
fn show_window(&self) -> Result<bool, RequestError> { fn show_window(&self) -> Result<bool, RequestError> {
@ -157,7 +161,7 @@ impl Handler {
async fn send_credentials(&mut self) -> Result<(), RequestError> { async fn send_credentials(&mut self) -> Result<(), RequestError> {
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
let creds = state.serialize_session_creds()?; let creds = state.serialize_session_creds().await?;
self.stream.write(b"\r\nContent-Length: ").await?; self.stream.write(b"\r\nContent-Length: ").await?;
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?; self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
@ -180,39 +184,39 @@ pub struct Server {
impl Server { impl Server {
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> { pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
// construct the listener before passing it to the task so that we know if it fails let task = Self::start_server(addr, port, app_handle.app_handle()).await?;
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(Self::serve(listener, app_handle.app_handle()));
Ok(Server { addr, port, app_handle, task}) Ok(Server { addr, port, app_handle, task})
} }
// this is blocking because it's too much of a paint to juggle mutexes otherwise pub async fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
pub fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
if addr == self.addr && port == self.port { if addr == self.addr && port == self.port {
return Ok(()) return Ok(())
} }
let sock_addr = SocketAddrV4::new(addr, port); let new_task = Self::start_server(addr, port, self.app_handle.app_handle()).await?;
let std_listener = StdTcpListener::bind(&sock_addr)?;
std_listener.set_nonblocking(true)?;
let async_listener = TcpListener::from_std(std_listener)?;
let new_task = rt::spawn(Self::serve(async_listener, self.app_handle.app_handle()));
self.task.abort(); self.task.abort();
self.addr = addr; self.addr = addr;
self.port = port; self.port = port;
self.task = new_task; self.task = new_task;
Ok(()) Ok(())
} }
// construct the listener before spawning the task so that we can return early if it fails
async fn start_server(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<JoinHandle<()>> {
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(
Self::serve(listener, app_handle.app_handle())
);
Ok(task)
}
async fn serve(listener: TcpListener, app_handle: AppHandle) { async fn serve(listener: TcpListener, app_handle: AppHandle) {
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
let handler = Handler::new(stream, app_handle.app_handle()); let handler = Handler::new(stream, app_handle.app_handle()).await;
rt::spawn(handler.handle()); rt::spawn(handler.handle());
}, },
Err(e) => { Err(e) => {

View File

@ -1,5 +1,4 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use std::time::{ use std::time::{
Duration, Duration,
SystemTime, SystemTime,
@ -12,8 +11,11 @@ use aws_smithy_types::date_time::{
Format as AwsDateTimeFormat, Format as AwsDateTimeFormat,
}; };
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use tokio::sync::oneshot::Sender; use tokio::{
use tokio::time::sleep; sync::oneshot::Sender,
sync::RwLock,
time::sleep,
};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use sodiumoxide::crypto::{ use sodiumoxide::crypto::{
pwhash, pwhash,
@ -163,51 +165,49 @@ impl AppState {
} }
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> { pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
let mut live_config = self.config.write().await;
if new_config.start_on_login != live_config.start_on_login {
config::set_auto_launch(new_config.start_on_login)?;
}
if new_config.listen_addr != live_config.listen_addr
|| new_config.listen_port != live_config.listen_port
{ {
let orig_config = self.config.read().unwrap(); let mut sv = self.server.write().await;
if new_config.start_on_login != orig_config.start_on_login { sv.rebind(new_config.listen_addr, new_config.listen_port).await?;
config::set_auto_launch(new_config.start_on_login)?;
}
if new_config.listen_addr != orig_config.listen_addr
|| new_config.listen_port != orig_config.listen_port
{
let mut sv = self.server.write().unwrap();
sv.rebind(new_config.listen_addr, new_config.listen_port)?;
}
} }
new_config.save(&self.pool).await?; new_config.save(&self.pool).await?;
let mut live_config = self.config.write().unwrap();
*live_config = new_config; *live_config = new_config;
Ok(()) Ok(())
} }
pub fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 { pub async fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
let count = { let count = {
let mut c = self.request_count.write().unwrap(); let mut c = self.request_count.write().await;
*c += 1; *c += 1;
c c
}; };
let mut open_requests = self.open_requests.write().unwrap(); let mut open_requests = self.open_requests.write().await;
open_requests.insert(*count, chan); // `count` is the request id open_requests.insert(*count, chan); // `count` is the request id
*count *count
} }
pub fn unregister_request(&self, id: u64) { pub async fn unregister_request(&self, id: u64) {
let mut open_requests = self.open_requests.write().unwrap(); let mut open_requests = self.open_requests.write().await;
open_requests.remove(&id); open_requests.remove(&id);
} }
pub fn req_count(&self) -> usize { pub async fn req_count(&self) -> usize {
let open_requests = self.open_requests.read().unwrap(); let open_requests = self.open_requests.read().await;
open_requests.len() open_requests.len()
} }
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
self.renew_session_if_expired().await?; self.renew_session_if_expired().await?;
let mut open_requests = self.open_requests.write().unwrap(); let mut open_requests = self.open_requests.write().await;
let chan = open_requests let chan = open_requests
.remove(&response.id) .remove(&response.id)
.ok_or(SendResponseError::NotFound) .ok_or(SendResponseError::NotFound)
@ -217,54 +217,58 @@ impl AppState {
.map_err(|_e| SendResponseError::Abandoned) .map_err(|_e| SendResponseError::Abandoned)
} }
pub fn add_ban(&self, client: Option<Client>, app: tauri::AppHandle) { pub async fn add_ban(&self, client: Option<Client>) {
let mut bans = self.bans.write().unwrap(); let mut bans = self.bans.write().await;
bans.insert(client.clone()); bans.insert(client.clone());
runtime::spawn(async move { runtime::spawn(async move {
sleep(Duration::from_secs(5)).await; sleep(Duration::from_secs(5)).await;
let app = crate::APP.get().unwrap();
let state = app.state::<AppState>(); let state = app.state::<AppState>();
let mut bans = state.bans.write().unwrap(); let mut bans = state.bans.write().await;
bans.remove(&client); bans.remove(&client);
}); });
} }
pub fn is_banned(&self, client: &Option<Client>) -> bool { pub async fn is_banned(&self, client: &Option<Client>) -> bool {
self.bans.read().unwrap().contains(&client) self.bans.read().await.contains(&client)
} }
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> { pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
let (access_key_id, secret_access_key) = { let mut session = self.session.write().await;
// do this all in a block so that we aren't holding a lock across an await let LockedCredentials {
let session = self.session.read().unwrap(); access_key_id,
let locked = match *session { secret_key_enc,
Session::Empty => {return Err(UnlockError::NoCredentials);}, salt,
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);}, nonce
Session::Locked(ref c) => c, } = match *session {
}; Session::Empty => {return Err(UnlockError::NoCredentials);},
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
let mut key_buf = [0; secretbox::KEYBYTES]; Session::Locked(ref c) => c,
// pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &locked.salt).unwrap();
let decrypted = secretbox::open(&locked.secret_key_enc, &locked.nonce, &Key(key_buf))
.map_err(|_e| UnlockError::BadPassphrase)?;
let secret_str = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
(locked.access_key_id.clone(), secret_str)
}; };
let session_creds = self.new_session(&access_key_id, &secret_access_key).await?; let mut key_buf = [0; secretbox::KEYBYTES];
let mut app_session = self.session.write().unwrap(); // pretty sure this only fails if we're out of memory
*app_session = Session::Unlocked { pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap();
base: BaseCredentials {access_key_id, secret_access_key}, let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf))
.map_err(|_e| UnlockError::BadPassphrase)?;
let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
let session_creds = self.new_session(access_key_id, &secret_access_key).await?;
*session = Session::Unlocked {
base: BaseCredentials {
access_key_id: access_key_id.clone(),
secret_access_key,
},
session: session_creds session: session_creds
}; };
Ok(()) Ok(())
} }
// pub fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> { // pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
// let session = self.session.read().unwrap(); // let session = self.session.read().await;
// match *session { // match *session {
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()), // Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
// Session::Locked(_) => Err(GetCredentialsError::Locked), // Session::Locked(_) => Err(GetCredentialsError::Locked),
@ -272,8 +276,8 @@ impl AppState {
// } // }
// } // }
pub fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> { pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
let session = self.session.read().unwrap(); let session = self.session.read().await;
match *session { match *session {
Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()), Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()),
Session::Locked(_) => Err(GetCredentialsError::Locked), Session::Locked(_) => Err(GetCredentialsError::Locked),
@ -329,41 +333,21 @@ impl AppState {
} }
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> { pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
let base = { match *self.session.write().await {
let session = self.session.read().unwrap(); Session::Unlocked{ref base, ref mut session} => {
match *session {
Session::Unlocked{ref base, ..} => base.clone(),
_ => unreachable!(),
}
};
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
match *self.session.write().unwrap() {
Session::Unlocked{ref mut session, ..} => {
if !session.is_expired() { if !session.is_expired() {
return Ok(false); return Ok(false);
} }
let new_session = self.new_session(
&base.access_key_id,
&base.secret_access_key
).await?;
*session = new_session; *session = new_session;
Ok(true) Ok(true)
}, },
Session::Locked(_) => Err(GetSessionError::CredentialsLocked), Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty), Session::Empty => Err(GetSessionError::CredentialsEmpty),
} }
// match *self.session.write().unwrap() {
// Session::Unlocked{ref base, ref mut session} => {
// if !session.is_expired() {
// return Ok(false);
// }
// let new_session = self.new_session(
// &base.access_key_id,
// &base.secret_access_key
// ).await?;
// *session = new_session;
// Ok(true)
// },
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
// }
} }
} }