switch to tokio RwLock instead of std

This commit is contained in:
Joseph Montanaro 2023-05-02 15:24:35 -07:00
parent 96bbc2dbc2
commit ddf865d0b4
5 changed files with 121 additions and 157 deletions

View File

@ -1,9 +1,13 @@
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
use tauri::Manager;
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
use serde::{Serialize, Deserialize};
use crate::errors::*;
use crate::get_state;
use crate::{
errors::*,
config::AppConfig,
state::AppState,
};
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
@ -13,13 +17,18 @@ pub struct Client {
}
fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
let state = crate::APP.get().unwrap().state::<AppState>();
let AppConfig {
listen_addr: app_listen_addr,
listen_port: app_listen_port,
..
} = *state.config.read().await;
let sockets_iter = netstat2::iterate_sockets_info(
AddressFamilyFlags::IPV4,
ProtocolFlags::TCP
)?;
get_state!(config as app_config);
for item in sockets_iter {
let sock_info = item?;
let proto_info = match sock_info.protocol_socket_info {
@ -28,9 +37,9 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
};
if proto_info.local_port == local_port
&& proto_info.remote_port == app_config.listen_port
&& proto_info.local_addr == app_config.listen_addr
&& proto_info.remote_addr == app_config.listen_addr
&& proto_info.remote_port == app_listen_port
&& proto_info.local_addr == app_listen_addr
&& proto_info.remote_addr == app_listen_addr
{
return Ok(sock_info.associated_pids)
}
@ -40,10 +49,10 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
// Theoretically, on some systems, multiple processes can share a socket
pub fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
let mut clients = Vec::new();
let mut sys = System::new();
for p in get_associated_pids(local_port)? {
for p in get_associated_pids(local_port).await? {
let pid = Pid::from_u32(p);
sys.refresh_process(pid);
let proc = sys.process(pid)

View File

@ -41,13 +41,14 @@ pub async fn unlock(passphrase: String, app_state: State<'_, AppState>) -> Resul
#[tauri::command]
pub fn get_session_status(app_state: State<'_, AppState>) -> String {
let session = app_state.session.read().unwrap();
match *session {
pub async fn get_session_status(app_state: State<'_, AppState>) -> Result<String, ()> {
let session = app_state.session.read().await;
let status = match *session {
Session::Locked(_) => "locked".into(),
Session::Unlocked{..} => "unlocked".into(),
Session::Empty => "empty".into()
}
};
Ok(status)
}
@ -62,9 +63,9 @@ pub async fn save_credentials(
#[tauri::command]
pub fn get_config(app_state: State<'_, AppState>) -> AppConfig {
let config = app_state.config.read().unwrap();
config.clone()
pub async fn get_config(app_state: State<'_, AppState>) -> Result<AppConfig, ()> {
let config = app_state.config.read().await;
Ok(config.clone())
}

View File

@ -97,37 +97,3 @@ fn run() -> tauri::Result<()> {
fn main() {
run().error_popup("Creddy failed to start");
}
macro_rules! get_state {
($prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap(); // as long as the app is running, this is fine
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.read().unwrap(); // only panics if another thread has already panicked
};
(config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.read().unwrap();
let $name = config.$prop;
};
(mut $prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.write().unwrap();
};
(mut config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.write().unwrap();
let $name = config.$prop;
}
}
pub(crate) use get_state;

View File

@ -4,7 +4,6 @@ use std::net::{
Ipv4Addr,
SocketAddr,
SocketAddrV4,
TcpListener as StdTcpListener,
};
use tokio::net::{
TcpListener,
@ -32,10 +31,10 @@ struct Handler {
}
impl Handler {
fn new(stream: TcpStream, app: AppHandle) -> Self {
async fn new(stream: TcpStream, app: AppHandle) -> Self {
let state = app.state::<AppState>();
let (chan_send, chan_recv) = oneshot::channel();
let request_id = state.register_request(chan_send);
let request_id = state.register_request(chan_send).await;
Handler {
request_id,
stream,
@ -49,13 +48,13 @@ impl Handler {
eprintln!("{e}");
}
let state = self.app.state::<AppState>();
state.unregister_request(self.request_id);
state.unregister_request(self.request_id).await;
}
async fn try_handle(&mut self) -> Result<(), RequestError> {
let _ = self.recv_request().await?;
let clients = self.get_clients()?;
if self.includes_banned(&clients) {
let clients = self.get_clients().await?;
if self.includes_banned(&clients).await {
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(())
}
@ -69,7 +68,7 @@ impl Handler {
Approval::Denied => {
let state = self.app.state::<AppState>();
for client in req.clients {
state.add_ban(client, self.app.clone());
state.add_ban(client).await;
}
}
}
@ -78,12 +77,12 @@ impl Handler {
// and b) there are no other pending requests
let state = self.app.state::<AppState>();
let delay = {
let config = state.config.read().unwrap();
let config = state.config.read().await;
Duration::from_millis(config.rehide_ms)
};
sleep(delay).await;
if !starting_visibility && state.req_count() == 0 {
if !starting_visibility && state.req_count().await == 0 {
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
window.hide()?;
}
@ -107,18 +106,23 @@ impl Handler {
Ok(buf)
}
fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
async fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
let peer_addr = match self.stream.peer_addr()? {
SocketAddr::V4(addr) => addr,
_ => unreachable!(), // we only listen on IPv4
};
let clients = clientinfo::get_clients(peer_addr.port())?;
let clients = clientinfo::get_clients(peer_addr.port()).await?;
Ok(clients)
}
fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
async fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
let state = self.app.state::<AppState>();
clients.iter().any(|c| state.is_banned(c))
for client in clients {
if state.is_banned(client).await {
return true;
}
}
false
}
fn show_window(&self) -> Result<bool, RequestError> {
@ -157,7 +161,7 @@ impl Handler {
async fn send_credentials(&mut self) -> Result<(), RequestError> {
let state = self.app.state::<AppState>();
let creds = state.serialize_session_creds()?;
let creds = state.serialize_session_creds().await?;
self.stream.write(b"\r\nContent-Length: ").await?;
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
@ -180,39 +184,39 @@ pub struct Server {
impl Server {
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
// construct the listener before passing it to the task so that we know if it fails
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(Self::serve(listener, app_handle.app_handle()));
let task = Self::start_server(addr, port, app_handle.app_handle()).await?;
Ok(Server { addr, port, app_handle, task})
}
// this is blocking because it's too much of a paint to juggle mutexes otherwise
pub fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
pub async fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
if addr == self.addr && port == self.port {
return Ok(())
}
let sock_addr = SocketAddrV4::new(addr, port);
let std_listener = StdTcpListener::bind(&sock_addr)?;
std_listener.set_nonblocking(true)?;
let async_listener = TcpListener::from_std(std_listener)?;
let new_task = rt::spawn(Self::serve(async_listener, self.app_handle.app_handle()));
let new_task = Self::start_server(addr, port, self.app_handle.app_handle()).await?;
self.task.abort();
self.addr = addr;
self.port = port;
self.task = new_task;
Ok(())
}
// construct the listener before spawning the task so that we can return early if it fails
async fn start_server(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<JoinHandle<()>> {
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(
Self::serve(listener, app_handle.app_handle())
);
Ok(task)
}
async fn serve(listener: TcpListener, app_handle: AppHandle) {
loop {
match listener.accept().await {
Ok((stream, _)) => {
let handler = Handler::new(stream, app_handle.app_handle());
let handler = Handler::new(stream, app_handle.app_handle()).await;
rt::spawn(handler.handle());
},
Err(e) => {

View File

@ -1,5 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use std::time::{
Duration,
SystemTime,
@ -12,8 +11,11 @@ use aws_smithy_types::date_time::{
Format as AwsDateTimeFormat,
};
use serde::{Serialize, Deserialize};
use tokio::sync::oneshot::Sender;
use tokio::time::sleep;
use tokio::{
sync::oneshot::Sender,
sync::RwLock,
time::sleep,
};
use sqlx::SqlitePool;
use sodiumoxide::crypto::{
pwhash,
@ -163,51 +165,49 @@ impl AppState {
}
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
{
let orig_config = self.config.read().unwrap();
if new_config.start_on_login != orig_config.start_on_login {
let mut live_config = self.config.write().await;
if new_config.start_on_login != live_config.start_on_login {
config::set_auto_launch(new_config.start_on_login)?;
}
if new_config.listen_addr != orig_config.listen_addr
|| new_config.listen_port != orig_config.listen_port
if new_config.listen_addr != live_config.listen_addr
|| new_config.listen_port != live_config.listen_port
{
let mut sv = self.server.write().unwrap();
sv.rebind(new_config.listen_addr, new_config.listen_port)?;
}
let mut sv = self.server.write().await;
sv.rebind(new_config.listen_addr, new_config.listen_port).await?;
}
new_config.save(&self.pool).await?;
let mut live_config = self.config.write().unwrap();
*live_config = new_config;
Ok(())
}
pub fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
pub async fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
let count = {
let mut c = self.request_count.write().unwrap();
let mut c = self.request_count.write().await;
*c += 1;
c
};
let mut open_requests = self.open_requests.write().unwrap();
let mut open_requests = self.open_requests.write().await;
open_requests.insert(*count, chan); // `count` is the request id
*count
}
pub fn unregister_request(&self, id: u64) {
let mut open_requests = self.open_requests.write().unwrap();
pub async fn unregister_request(&self, id: u64) {
let mut open_requests = self.open_requests.write().await;
open_requests.remove(&id);
}
pub fn req_count(&self) -> usize {
let open_requests = self.open_requests.read().unwrap();
pub async fn req_count(&self) -> usize {
let open_requests = self.open_requests.read().await;
open_requests.len()
}
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
self.renew_session_if_expired().await?;
let mut open_requests = self.open_requests.write().unwrap();
let mut open_requests = self.open_requests.write().await;
let chan = open_requests
.remove(&response.id)
.ok_or(SendResponseError::NotFound)
@ -217,27 +217,31 @@ impl AppState {
.map_err(|_e| SendResponseError::Abandoned)
}
pub fn add_ban(&self, client: Option<Client>, app: tauri::AppHandle) {
let mut bans = self.bans.write().unwrap();
pub async fn add_ban(&self, client: Option<Client>) {
let mut bans = self.bans.write().await;
bans.insert(client.clone());
runtime::spawn(async move {
sleep(Duration::from_secs(5)).await;
let app = crate::APP.get().unwrap();
let state = app.state::<AppState>();
let mut bans = state.bans.write().unwrap();
let mut bans = state.bans.write().await;
bans.remove(&client);
});
}
pub fn is_banned(&self, client: &Option<Client>) -> bool {
self.bans.read().unwrap().contains(&client)
pub async fn is_banned(&self, client: &Option<Client>) -> bool {
self.bans.read().await.contains(&client)
}
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
let (access_key_id, secret_access_key) = {
// do this all in a block so that we aren't holding a lock across an await
let session = self.session.read().unwrap();
let locked = match *session {
let mut session = self.session.write().await;
let LockedCredentials {
access_key_id,
secret_key_enc,
salt,
nonce
} = match *session {
Session::Empty => {return Err(UnlockError::NoCredentials);},
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
Session::Locked(ref c) => c,
@ -245,26 +249,26 @@ impl AppState {
let mut key_buf = [0; secretbox::KEYBYTES];
// pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &locked.salt).unwrap();
let decrypted = secretbox::open(&locked.secret_key_enc, &locked.nonce, &Key(key_buf))
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap();
let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf))
.map_err(|_e| UnlockError::BadPassphrase)?;
let secret_str = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
(locked.access_key_id.clone(), secret_str)
};
let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
let session_creds = self.new_session(&access_key_id, &secret_access_key).await?;
let mut app_session = self.session.write().unwrap();
*app_session = Session::Unlocked {
base: BaseCredentials {access_key_id, secret_access_key},
let session_creds = self.new_session(access_key_id, &secret_access_key).await?;
*session = Session::Unlocked {
base: BaseCredentials {
access_key_id: access_key_id.clone(),
secret_access_key,
},
session: session_creds
};
Ok(())
}
// pub fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
// let session = self.session.read().unwrap();
// pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
// let session = self.session.read().await;
// match *session {
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
// Session::Locked(_) => Err(GetCredentialsError::Locked),
@ -272,8 +276,8 @@ impl AppState {
// }
// }
pub fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
let session = self.session.read().unwrap();
pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
let session = self.session.read().await;
match *session {
Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()),
Session::Locked(_) => Err(GetCredentialsError::Locked),
@ -329,41 +333,21 @@ impl AppState {
}
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
let base = {
let session = self.session.read().unwrap();
match *session {
Session::Unlocked{ref base, ..} => base.clone(),
_ => unreachable!(),
}
};
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
match *self.session.write().unwrap() {
Session::Unlocked{ref mut session, ..} => {
match *self.session.write().await {
Session::Unlocked{ref base, ref mut session} => {
if !session.is_expired() {
return Ok(false);
}
let new_session = self.new_session(
&base.access_key_id,
&base.secret_access_key
).await?;
*session = new_session;
Ok(true)
},
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty),
}
// match *self.session.write().unwrap() {
// Session::Unlocked{ref base, ref mut session} => {
// if !session.is_expired() {
// return Ok(false);
// }
// let new_session = self.new_session(
// &base.access_key_id,
// &base.secret_access_key
// ).await?;
// *session = new_session;
// Ok(true)
// },
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
// }
}
}