switch to tokio RwLock instead of std
This commit is contained in:
parent
96bbc2dbc2
commit
ddf865d0b4
@ -1,9 +1,13 @@
|
||||
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
|
||||
use tauri::Manager;
|
||||
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::errors::*;
|
||||
use crate::get_state;
|
||||
use crate::{
|
||||
errors::*,
|
||||
config::AppConfig,
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
@ -13,13 +17,18 @@ pub struct Client {
|
||||
}
|
||||
|
||||
|
||||
fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
|
||||
async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
|
||||
let state = crate::APP.get().unwrap().state::<AppState>();
|
||||
let AppConfig {
|
||||
listen_addr: app_listen_addr,
|
||||
listen_port: app_listen_port,
|
||||
..
|
||||
} = *state.config.read().await;
|
||||
|
||||
let sockets_iter = netstat2::iterate_sockets_info(
|
||||
AddressFamilyFlags::IPV4,
|
||||
ProtocolFlags::TCP
|
||||
)?;
|
||||
|
||||
get_state!(config as app_config);
|
||||
for item in sockets_iter {
|
||||
let sock_info = item?;
|
||||
let proto_info = match sock_info.protocol_socket_info {
|
||||
@ -28,9 +37,9 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
|
||||
};
|
||||
|
||||
if proto_info.local_port == local_port
|
||||
&& proto_info.remote_port == app_config.listen_port
|
||||
&& proto_info.local_addr == app_config.listen_addr
|
||||
&& proto_info.remote_addr == app_config.listen_addr
|
||||
&& proto_info.remote_port == app_listen_port
|
||||
&& proto_info.local_addr == app_listen_addr
|
||||
&& proto_info.remote_addr == app_listen_addr
|
||||
{
|
||||
return Ok(sock_info.associated_pids)
|
||||
}
|
||||
@ -40,10 +49,10 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
|
||||
|
||||
|
||||
// Theoretically, on some systems, multiple processes can share a socket
|
||||
pub fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
|
||||
pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
|
||||
let mut clients = Vec::new();
|
||||
let mut sys = System::new();
|
||||
for p in get_associated_pids(local_port)? {
|
||||
for p in get_associated_pids(local_port).await? {
|
||||
let pid = Pid::from_u32(p);
|
||||
sys.refresh_process(pid);
|
||||
let proc = sys.process(pid)
|
||||
|
@ -41,13 +41,14 @@ pub async fn unlock(passphrase: String, app_state: State<'_, AppState>) -> Resul
|
||||
|
||||
|
||||
#[tauri::command]
|
||||
pub fn get_session_status(app_state: State<'_, AppState>) -> String {
|
||||
let session = app_state.session.read().unwrap();
|
||||
match *session {
|
||||
pub async fn get_session_status(app_state: State<'_, AppState>) -> Result<String, ()> {
|
||||
let session = app_state.session.read().await;
|
||||
let status = match *session {
|
||||
Session::Locked(_) => "locked".into(),
|
||||
Session::Unlocked{..} => "unlocked".into(),
|
||||
Session::Empty => "empty".into()
|
||||
}
|
||||
};
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
|
||||
@ -62,9 +63,9 @@ pub async fn save_credentials(
|
||||
|
||||
|
||||
#[tauri::command]
|
||||
pub fn get_config(app_state: State<'_, AppState>) -> AppConfig {
|
||||
let config = app_state.config.read().unwrap();
|
||||
config.clone()
|
||||
pub async fn get_config(app_state: State<'_, AppState>) -> Result<AppConfig, ()> {
|
||||
let config = app_state.config.read().await;
|
||||
Ok(config.clone())
|
||||
}
|
||||
|
||||
|
||||
|
@ -97,37 +97,3 @@ fn run() -> tauri::Result<()> {
|
||||
fn main() {
|
||||
run().error_popup("Creddy failed to start");
|
||||
}
|
||||
|
||||
|
||||
macro_rules! get_state {
|
||||
($prop:ident as $name:ident) => {
|
||||
use tauri::Manager;
|
||||
let app = crate::APP.get().unwrap(); // as long as the app is running, this is fine
|
||||
let state = app.state::<crate::state::AppState>();
|
||||
let $name = state.$prop.read().unwrap(); // only panics if another thread has already panicked
|
||||
};
|
||||
(config.$prop:ident as $name:ident) => {
|
||||
use tauri::Manager;
|
||||
let app = crate::APP.get().unwrap();
|
||||
let state = app.state::<crate::state::AppState>();
|
||||
let config = state.config.read().unwrap();
|
||||
let $name = config.$prop;
|
||||
};
|
||||
|
||||
(mut $prop:ident as $name:ident) => {
|
||||
use tauri::Manager;
|
||||
let app = crate::APP.get().unwrap();
|
||||
let state = app.state::<crate::state::AppState>();
|
||||
let $name = state.$prop.write().unwrap();
|
||||
};
|
||||
(mut config.$prop:ident as $name:ident) => {
|
||||
use tauri::Manager;
|
||||
let app = crate::APP.get().unwrap();
|
||||
let state = app.state::<crate::state::AppState>();
|
||||
let config = state.config.write().unwrap();
|
||||
let $name = config.$prop;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub(crate) use get_state;
|
||||
|
@ -4,7 +4,6 @@ use std::net::{
|
||||
Ipv4Addr,
|
||||
SocketAddr,
|
||||
SocketAddrV4,
|
||||
TcpListener as StdTcpListener,
|
||||
};
|
||||
use tokio::net::{
|
||||
TcpListener,
|
||||
@ -32,10 +31,10 @@ struct Handler {
|
||||
}
|
||||
|
||||
impl Handler {
|
||||
fn new(stream: TcpStream, app: AppHandle) -> Self {
|
||||
async fn new(stream: TcpStream, app: AppHandle) -> Self {
|
||||
let state = app.state::<AppState>();
|
||||
let (chan_send, chan_recv) = oneshot::channel();
|
||||
let request_id = state.register_request(chan_send);
|
||||
let request_id = state.register_request(chan_send).await;
|
||||
Handler {
|
||||
request_id,
|
||||
stream,
|
||||
@ -49,13 +48,13 @@ impl Handler {
|
||||
eprintln!("{e}");
|
||||
}
|
||||
let state = self.app.state::<AppState>();
|
||||
state.unregister_request(self.request_id);
|
||||
state.unregister_request(self.request_id).await;
|
||||
}
|
||||
|
||||
async fn try_handle(&mut self) -> Result<(), RequestError> {
|
||||
let _ = self.recv_request().await?;
|
||||
let clients = self.get_clients()?;
|
||||
if self.includes_banned(&clients) {
|
||||
let clients = self.get_clients().await?;
|
||||
if self.includes_banned(&clients).await {
|
||||
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
|
||||
return Ok(())
|
||||
}
|
||||
@ -69,7 +68,7 @@ impl Handler {
|
||||
Approval::Denied => {
|
||||
let state = self.app.state::<AppState>();
|
||||
for client in req.clients {
|
||||
state.add_ban(client, self.app.clone());
|
||||
state.add_ban(client).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -78,12 +77,12 @@ impl Handler {
|
||||
// and b) there are no other pending requests
|
||||
let state = self.app.state::<AppState>();
|
||||
let delay = {
|
||||
let config = state.config.read().unwrap();
|
||||
let config = state.config.read().await;
|
||||
Duration::from_millis(config.rehide_ms)
|
||||
};
|
||||
sleep(delay).await;
|
||||
|
||||
if !starting_visibility && state.req_count() == 0 {
|
||||
if !starting_visibility && state.req_count().await == 0 {
|
||||
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
|
||||
window.hide()?;
|
||||
}
|
||||
@ -107,18 +106,23 @@ impl Handler {
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
|
||||
async fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
|
||||
let peer_addr = match self.stream.peer_addr()? {
|
||||
SocketAddr::V4(addr) => addr,
|
||||
_ => unreachable!(), // we only listen on IPv4
|
||||
};
|
||||
let clients = clientinfo::get_clients(peer_addr.port())?;
|
||||
let clients = clientinfo::get_clients(peer_addr.port()).await?;
|
||||
Ok(clients)
|
||||
}
|
||||
|
||||
fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
|
||||
async fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
|
||||
let state = self.app.state::<AppState>();
|
||||
clients.iter().any(|c| state.is_banned(c))
|
||||
for client in clients {
|
||||
if state.is_banned(client).await {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn show_window(&self) -> Result<bool, RequestError> {
|
||||
@ -157,7 +161,7 @@ impl Handler {
|
||||
|
||||
async fn send_credentials(&mut self) -> Result<(), RequestError> {
|
||||
let state = self.app.state::<AppState>();
|
||||
let creds = state.serialize_session_creds()?;
|
||||
let creds = state.serialize_session_creds().await?;
|
||||
|
||||
self.stream.write(b"\r\nContent-Length: ").await?;
|
||||
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
|
||||
@ -180,39 +184,39 @@ pub struct Server {
|
||||
|
||||
impl Server {
|
||||
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
|
||||
// construct the listener before passing it to the task so that we know if it fails
|
||||
let sock_addr = SocketAddrV4::new(addr, port);
|
||||
let listener = TcpListener::bind(&sock_addr).await?;
|
||||
let task = rt::spawn(Self::serve(listener, app_handle.app_handle()));
|
||||
|
||||
let task = Self::start_server(addr, port, app_handle.app_handle()).await?;
|
||||
Ok(Server { addr, port, app_handle, task})
|
||||
}
|
||||
|
||||
// this is blocking because it's too much of a paint to juggle mutexes otherwise
|
||||
pub fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
|
||||
pub async fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
|
||||
if addr == self.addr && port == self.port {
|
||||
return Ok(())
|
||||
}
|
||||
|
||||
let sock_addr = SocketAddrV4::new(addr, port);
|
||||
let std_listener = StdTcpListener::bind(&sock_addr)?;
|
||||
std_listener.set_nonblocking(true)?;
|
||||
let async_listener = TcpListener::from_std(std_listener)?;
|
||||
let new_task = rt::spawn(Self::serve(async_listener, self.app_handle.app_handle()));
|
||||
let new_task = Self::start_server(addr, port, self.app_handle.app_handle()).await?;
|
||||
self.task.abort();
|
||||
|
||||
self.addr = addr;
|
||||
self.port = port;
|
||||
self.task = new_task;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// construct the listener before spawning the task so that we can return early if it fails
|
||||
async fn start_server(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<JoinHandle<()>> {
|
||||
let sock_addr = SocketAddrV4::new(addr, port);
|
||||
let listener = TcpListener::bind(&sock_addr).await?;
|
||||
let task = rt::spawn(
|
||||
Self::serve(listener, app_handle.app_handle())
|
||||
);
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
async fn serve(listener: TcpListener, app_handle: AppHandle) {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, _)) => {
|
||||
let handler = Handler::new(stream, app_handle.app_handle());
|
||||
let handler = Handler::new(stream, app_handle.app_handle()).await;
|
||||
rt::spawn(handler.handle());
|
||||
},
|
||||
Err(e) => {
|
||||
|
@ -1,5 +1,4 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::RwLock;
|
||||
use std::time::{
|
||||
Duration,
|
||||
SystemTime,
|
||||
@ -12,8 +11,11 @@ use aws_smithy_types::date_time::{
|
||||
Format as AwsDateTimeFormat,
|
||||
};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::sleep;
|
||||
use tokio::{
|
||||
sync::oneshot::Sender,
|
||||
sync::RwLock,
|
||||
time::sleep,
|
||||
};
|
||||
use sqlx::SqlitePool;
|
||||
use sodiumoxide::crypto::{
|
||||
pwhash,
|
||||
@ -163,51 +165,49 @@ impl AppState {
|
||||
}
|
||||
|
||||
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
|
||||
{
|
||||
let orig_config = self.config.read().unwrap();
|
||||
if new_config.start_on_login != orig_config.start_on_login {
|
||||
let mut live_config = self.config.write().await;
|
||||
|
||||
if new_config.start_on_login != live_config.start_on_login {
|
||||
config::set_auto_launch(new_config.start_on_login)?;
|
||||
}
|
||||
if new_config.listen_addr != orig_config.listen_addr
|
||||
|| new_config.listen_port != orig_config.listen_port
|
||||
if new_config.listen_addr != live_config.listen_addr
|
||||
|| new_config.listen_port != live_config.listen_port
|
||||
{
|
||||
let mut sv = self.server.write().unwrap();
|
||||
sv.rebind(new_config.listen_addr, new_config.listen_port)?;
|
||||
}
|
||||
let mut sv = self.server.write().await;
|
||||
sv.rebind(new_config.listen_addr, new_config.listen_port).await?;
|
||||
}
|
||||
|
||||
new_config.save(&self.pool).await?;
|
||||
let mut live_config = self.config.write().unwrap();
|
||||
*live_config = new_config;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
|
||||
pub async fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
|
||||
let count = {
|
||||
let mut c = self.request_count.write().unwrap();
|
||||
let mut c = self.request_count.write().await;
|
||||
*c += 1;
|
||||
c
|
||||
};
|
||||
|
||||
let mut open_requests = self.open_requests.write().unwrap();
|
||||
let mut open_requests = self.open_requests.write().await;
|
||||
open_requests.insert(*count, chan); // `count` is the request id
|
||||
*count
|
||||
}
|
||||
|
||||
pub fn unregister_request(&self, id: u64) {
|
||||
let mut open_requests = self.open_requests.write().unwrap();
|
||||
pub async fn unregister_request(&self, id: u64) {
|
||||
let mut open_requests = self.open_requests.write().await;
|
||||
open_requests.remove(&id);
|
||||
}
|
||||
|
||||
pub fn req_count(&self) -> usize {
|
||||
let open_requests = self.open_requests.read().unwrap();
|
||||
pub async fn req_count(&self) -> usize {
|
||||
let open_requests = self.open_requests.read().await;
|
||||
open_requests.len()
|
||||
}
|
||||
|
||||
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
|
||||
self.renew_session_if_expired().await?;
|
||||
|
||||
let mut open_requests = self.open_requests.write().unwrap();
|
||||
let mut open_requests = self.open_requests.write().await;
|
||||
let chan = open_requests
|
||||
.remove(&response.id)
|
||||
.ok_or(SendResponseError::NotFound)
|
||||
@ -217,27 +217,31 @@ impl AppState {
|
||||
.map_err(|_e| SendResponseError::Abandoned)
|
||||
}
|
||||
|
||||
pub fn add_ban(&self, client: Option<Client>, app: tauri::AppHandle) {
|
||||
let mut bans = self.bans.write().unwrap();
|
||||
pub async fn add_ban(&self, client: Option<Client>) {
|
||||
let mut bans = self.bans.write().await;
|
||||
bans.insert(client.clone());
|
||||
|
||||
runtime::spawn(async move {
|
||||
sleep(Duration::from_secs(5)).await;
|
||||
let app = crate::APP.get().unwrap();
|
||||
let state = app.state::<AppState>();
|
||||
let mut bans = state.bans.write().unwrap();
|
||||
let mut bans = state.bans.write().await;
|
||||
bans.remove(&client);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn is_banned(&self, client: &Option<Client>) -> bool {
|
||||
self.bans.read().unwrap().contains(&client)
|
||||
pub async fn is_banned(&self, client: &Option<Client>) -> bool {
|
||||
self.bans.read().await.contains(&client)
|
||||
}
|
||||
|
||||
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
|
||||
let (access_key_id, secret_access_key) = {
|
||||
// do this all in a block so that we aren't holding a lock across an await
|
||||
let session = self.session.read().unwrap();
|
||||
let locked = match *session {
|
||||
let mut session = self.session.write().await;
|
||||
let LockedCredentials {
|
||||
access_key_id,
|
||||
secret_key_enc,
|
||||
salt,
|
||||
nonce
|
||||
} = match *session {
|
||||
Session::Empty => {return Err(UnlockError::NoCredentials);},
|
||||
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
|
||||
Session::Locked(ref c) => c,
|
||||
@ -245,26 +249,26 @@ impl AppState {
|
||||
|
||||
let mut key_buf = [0; secretbox::KEYBYTES];
|
||||
// pretty sure this only fails if we're out of memory
|
||||
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &locked.salt).unwrap();
|
||||
let decrypted = secretbox::open(&locked.secret_key_enc, &locked.nonce, &Key(key_buf))
|
||||
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap();
|
||||
let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf))
|
||||
.map_err(|_e| UnlockError::BadPassphrase)?;
|
||||
|
||||
let secret_str = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
|
||||
(locked.access_key_id.clone(), secret_str)
|
||||
};
|
||||
let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
|
||||
|
||||
let session_creds = self.new_session(&access_key_id, &secret_access_key).await?;
|
||||
let mut app_session = self.session.write().unwrap();
|
||||
*app_session = Session::Unlocked {
|
||||
base: BaseCredentials {access_key_id, secret_access_key},
|
||||
let session_creds = self.new_session(access_key_id, &secret_access_key).await?;
|
||||
*session = Session::Unlocked {
|
||||
base: BaseCredentials {
|
||||
access_key_id: access_key_id.clone(),
|
||||
secret_access_key,
|
||||
},
|
||||
session: session_creds
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// pub fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
|
||||
// let session = self.session.read().unwrap();
|
||||
// pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
|
||||
// let session = self.session.read().await;
|
||||
// match *session {
|
||||
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
|
||||
// Session::Locked(_) => Err(GetCredentialsError::Locked),
|
||||
@ -272,8 +276,8 @@ impl AppState {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
|
||||
let session = self.session.read().unwrap();
|
||||
pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
|
||||
let session = self.session.read().await;
|
||||
match *session {
|
||||
Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()),
|
||||
Session::Locked(_) => Err(GetCredentialsError::Locked),
|
||||
@ -329,41 +333,21 @@ impl AppState {
|
||||
}
|
||||
|
||||
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
|
||||
let base = {
|
||||
let session = self.session.read().unwrap();
|
||||
match *session {
|
||||
Session::Unlocked{ref base, ..} => base.clone(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
|
||||
match *self.session.write().unwrap() {
|
||||
Session::Unlocked{ref mut session, ..} => {
|
||||
match *self.session.write().await {
|
||||
Session::Unlocked{ref base, ref mut session} => {
|
||||
if !session.is_expired() {
|
||||
return Ok(false);
|
||||
}
|
||||
let new_session = self.new_session(
|
||||
&base.access_key_id,
|
||||
&base.secret_access_key
|
||||
).await?;
|
||||
*session = new_session;
|
||||
Ok(true)
|
||||
},
|
||||
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
||||
Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
||||
}
|
||||
|
||||
// match *self.session.write().unwrap() {
|
||||
// Session::Unlocked{ref base, ref mut session} => {
|
||||
// if !session.is_expired() {
|
||||
// return Ok(false);
|
||||
// }
|
||||
// let new_session = self.new_session(
|
||||
// &base.access_key_id,
|
||||
// &base.secret_access_key
|
||||
// ).await?;
|
||||
// *session = new_session;
|
||||
// Ok(true)
|
||||
// },
|
||||
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
||||
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user