Compare commits

..

No commits in common. "e8b8dc29763dc790ff74d9164b6d3b259d44420d" and "96bbc2dbc22bc948571112981dbe210257869edc" have entirely different histories.

6 changed files with 800 additions and 782 deletions

1304
src-tauri/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,9 @@
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo}; use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
use tauri::Manager;
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt}; use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::{ use crate::errors::*;
errors::*, use crate::get_state;
config::AppConfig,
state::AppState,
};
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
@ -17,18 +13,13 @@ pub struct Client {
} }
async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> { fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
let state = crate::APP.get().unwrap().state::<AppState>();
let AppConfig {
listen_addr: app_listen_addr,
listen_port: app_listen_port,
..
} = *state.config.read().await;
let sockets_iter = netstat2::iterate_sockets_info( let sockets_iter = netstat2::iterate_sockets_info(
AddressFamilyFlags::IPV4, AddressFamilyFlags::IPV4,
ProtocolFlags::TCP ProtocolFlags::TCP
)?; )?;
get_state!(config as app_config);
for item in sockets_iter { for item in sockets_iter {
let sock_info = item?; let sock_info = item?;
let proto_info = match sock_info.protocol_socket_info { let proto_info = match sock_info.protocol_socket_info {
@ -37,9 +28,9 @@ async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::erro
}; };
if proto_info.local_port == local_port if proto_info.local_port == local_port
&& proto_info.remote_port == app_listen_port && proto_info.remote_port == app_config.listen_port
&& proto_info.local_addr == app_listen_addr && proto_info.local_addr == app_config.listen_addr
&& proto_info.remote_addr == app_listen_addr && proto_info.remote_addr == app_config.listen_addr
{ {
return Ok(sock_info.associated_pids) return Ok(sock_info.associated_pids)
} }
@ -49,10 +40,10 @@ async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::erro
// Theoretically, on some systems, multiple processes can share a socket // Theoretically, on some systems, multiple processes can share a socket
pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> { pub fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
let mut clients = Vec::new(); let mut clients = Vec::new();
let mut sys = System::new(); let mut sys = System::new();
for p in get_associated_pids(local_port).await? { for p in get_associated_pids(local_port)? {
let pid = Pid::from_u32(p); let pid = Pid::from_u32(p);
sys.refresh_process(pid); sys.refresh_process(pid);
let proc = sys.process(pid) let proc = sys.process(pid)

View File

@ -41,14 +41,13 @@ pub async fn unlock(passphrase: String, app_state: State<'_, AppState>) -> Resul
#[tauri::command] #[tauri::command]
pub async fn get_session_status(app_state: State<'_, AppState>) -> Result<String, ()> { pub fn get_session_status(app_state: State<'_, AppState>) -> String {
let session = app_state.session.read().await; let session = app_state.session.read().unwrap();
let status = match *session { match *session {
Session::Locked(_) => "locked".into(), Session::Locked(_) => "locked".into(),
Session::Unlocked{..} => "unlocked".into(), Session::Unlocked{..} => "unlocked".into(),
Session::Empty => "empty".into() Session::Empty => "empty".into()
}; }
Ok(status)
} }
@ -63,9 +62,9 @@ pub async fn save_credentials(
#[tauri::command] #[tauri::command]
pub async fn get_config(app_state: State<'_, AppState>) -> Result<AppConfig, ()> { pub fn get_config(app_state: State<'_, AppState>) -> AppConfig {
let config = app_state.config.read().await; let config = app_state.config.read().unwrap();
Ok(config.clone()) config.clone()
} }

View File

@ -97,3 +97,37 @@ fn run() -> tauri::Result<()> {
fn main() { fn main() {
run().error_popup("Creddy failed to start"); run().error_popup("Creddy failed to start");
} }
macro_rules! get_state {
($prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap(); // as long as the app is running, this is fine
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.read().unwrap(); // only panics if another thread has already panicked
};
(config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.read().unwrap();
let $name = config.$prop;
};
(mut $prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let $name = state.$prop.write().unwrap();
};
(mut config.$prop:ident as $name:ident) => {
use tauri::Manager;
let app = crate::APP.get().unwrap();
let state = app.state::<crate::state::AppState>();
let config = state.config.write().unwrap();
let $name = config.$prop;
}
}
pub(crate) use get_state;

View File

@ -4,6 +4,7 @@ use std::net::{
Ipv4Addr, Ipv4Addr,
SocketAddr, SocketAddr,
SocketAddrV4, SocketAddrV4,
TcpListener as StdTcpListener,
}; };
use tokio::net::{ use tokio::net::{
TcpListener, TcpListener,
@ -31,10 +32,10 @@ struct Handler {
} }
impl Handler { impl Handler {
async fn new(stream: TcpStream, app: AppHandle) -> Self { fn new(stream: TcpStream, app: AppHandle) -> Self {
let state = app.state::<AppState>(); let state = app.state::<AppState>();
let (chan_send, chan_recv) = oneshot::channel(); let (chan_send, chan_recv) = oneshot::channel();
let request_id = state.register_request(chan_send).await; let request_id = state.register_request(chan_send);
Handler { Handler {
request_id, request_id,
stream, stream,
@ -48,13 +49,13 @@ impl Handler {
eprintln!("{e}"); eprintln!("{e}");
} }
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
state.unregister_request(self.request_id).await; state.unregister_request(self.request_id);
} }
async fn try_handle(&mut self) -> Result<(), RequestError> { async fn try_handle(&mut self) -> Result<(), RequestError> {
let _ = self.recv_request().await?; let _ = self.recv_request().await?;
let clients = self.get_clients().await?; let clients = self.get_clients()?;
if self.includes_banned(&clients).await { if self.includes_banned(&clients) {
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(()) return Ok(())
} }
@ -68,7 +69,7 @@ impl Handler {
Approval::Denied => { Approval::Denied => {
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
for client in req.clients { for client in req.clients {
state.add_ban(client).await; state.add_ban(client, self.app.clone());
} }
} }
} }
@ -77,12 +78,12 @@ impl Handler {
// and b) there are no other pending requests // and b) there are no other pending requests
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
let delay = { let delay = {
let config = state.config.read().await; let config = state.config.read().unwrap();
Duration::from_millis(config.rehide_ms) Duration::from_millis(config.rehide_ms)
}; };
sleep(delay).await; sleep(delay).await;
if !starting_visibility && state.req_count().await == 0 { if !starting_visibility && state.req_count() == 0 {
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
window.hide()?; window.hide()?;
} }
@ -106,23 +107,18 @@ impl Handler {
Ok(buf) Ok(buf)
} }
async fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> { fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
let peer_addr = match self.stream.peer_addr()? { let peer_addr = match self.stream.peer_addr()? {
SocketAddr::V4(addr) => addr, SocketAddr::V4(addr) => addr,
_ => unreachable!(), // we only listen on IPv4 _ => unreachable!(), // we only listen on IPv4
}; };
let clients = clientinfo::get_clients(peer_addr.port()).await?; let clients = clientinfo::get_clients(peer_addr.port())?;
Ok(clients) Ok(clients)
} }
async fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool { fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
for client in clients { clients.iter().any(|c| state.is_banned(c))
if state.is_banned(client).await {
return true;
}
}
false
} }
fn show_window(&self) -> Result<bool, RequestError> { fn show_window(&self) -> Result<bool, RequestError> {
@ -161,7 +157,7 @@ impl Handler {
async fn send_credentials(&mut self) -> Result<(), RequestError> { async fn send_credentials(&mut self) -> Result<(), RequestError> {
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
let creds = state.serialize_session_creds().await?; let creds = state.serialize_session_creds()?;
self.stream.write(b"\r\nContent-Length: ").await?; self.stream.write(b"\r\nContent-Length: ").await?;
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?; self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
@ -184,39 +180,39 @@ pub struct Server {
impl Server { impl Server {
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> { pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
let task = Self::start_server(addr, port, app_handle.app_handle()).await?; // construct the listener before passing it to the task so that we know if it fails
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(Self::serve(listener, app_handle.app_handle()));
Ok(Server { addr, port, app_handle, task}) Ok(Server { addr, port, app_handle, task})
} }
pub async fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> { // this is blocking because it's too much of a paint to juggle mutexes otherwise
pub fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
if addr == self.addr && port == self.port { if addr == self.addr && port == self.port {
return Ok(()) return Ok(())
} }
let new_task = Self::start_server(addr, port, self.app_handle.app_handle()).await?; let sock_addr = SocketAddrV4::new(addr, port);
let std_listener = StdTcpListener::bind(&sock_addr)?;
std_listener.set_nonblocking(true)?;
let async_listener = TcpListener::from_std(std_listener)?;
let new_task = rt::spawn(Self::serve(async_listener, self.app_handle.app_handle()));
self.task.abort(); self.task.abort();
self.addr = addr; self.addr = addr;
self.port = port; self.port = port;
self.task = new_task; self.task = new_task;
Ok(())
}
// construct the listener before spawning the task so that we can return early if it fails Ok(())
async fn start_server(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<JoinHandle<()>> {
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(
Self::serve(listener, app_handle.app_handle())
);
Ok(task)
} }
async fn serve(listener: TcpListener, app_handle: AppHandle) { async fn serve(listener: TcpListener, app_handle: AppHandle) {
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
let handler = Handler::new(stream, app_handle.app_handle()).await; let handler = Handler::new(stream, app_handle.app_handle());
rt::spawn(handler.handle()); rt::spawn(handler.handle());
}, },
Err(e) => { Err(e) => {

View File

@ -1,4 +1,5 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use std::time::{ use std::time::{
Duration, Duration,
SystemTime, SystemTime,
@ -11,11 +12,8 @@ use aws_smithy_types::date_time::{
Format as AwsDateTimeFormat, Format as AwsDateTimeFormat,
}; };
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use tokio::{ use tokio::sync::oneshot::Sender;
sync::oneshot::Sender, use tokio::time::sleep;
sync::RwLock,
time::sleep,
};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use sodiumoxide::crypto::{ use sodiumoxide::crypto::{
pwhash, pwhash,
@ -165,49 +163,51 @@ impl AppState {
} }
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> { pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
let mut live_config = self.config.write().await; {
let orig_config = self.config.read().unwrap();
if new_config.start_on_login != live_config.start_on_login { if new_config.start_on_login != orig_config.start_on_login {
config::set_auto_launch(new_config.start_on_login)?; config::set_auto_launch(new_config.start_on_login)?;
} }
if new_config.listen_addr != live_config.listen_addr if new_config.listen_addr != orig_config.listen_addr
|| new_config.listen_port != live_config.listen_port || new_config.listen_port != orig_config.listen_port
{ {
let mut sv = self.server.write().await; let mut sv = self.server.write().unwrap();
sv.rebind(new_config.listen_addr, new_config.listen_port).await?; sv.rebind(new_config.listen_addr, new_config.listen_port)?;
}
} }
new_config.save(&self.pool).await?; new_config.save(&self.pool).await?;
let mut live_config = self.config.write().unwrap();
*live_config = new_config; *live_config = new_config;
Ok(()) Ok(())
} }
pub async fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 { pub fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
let count = { let count = {
let mut c = self.request_count.write().await; let mut c = self.request_count.write().unwrap();
*c += 1; *c += 1;
c c
}; };
let mut open_requests = self.open_requests.write().await; let mut open_requests = self.open_requests.write().unwrap();
open_requests.insert(*count, chan); // `count` is the request id open_requests.insert(*count, chan); // `count` is the request id
*count *count
} }
pub async fn unregister_request(&self, id: u64) { pub fn unregister_request(&self, id: u64) {
let mut open_requests = self.open_requests.write().await; let mut open_requests = self.open_requests.write().unwrap();
open_requests.remove(&id); open_requests.remove(&id);
} }
pub async fn req_count(&self) -> usize { pub fn req_count(&self) -> usize {
let open_requests = self.open_requests.read().await; let open_requests = self.open_requests.read().unwrap();
open_requests.len() open_requests.len()
} }
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
self.renew_session_if_expired().await?; self.renew_session_if_expired().await?;
let mut open_requests = self.open_requests.write().await; let mut open_requests = self.open_requests.write().unwrap();
let chan = open_requests let chan = open_requests
.remove(&response.id) .remove(&response.id)
.ok_or(SendResponseError::NotFound) .ok_or(SendResponseError::NotFound)
@ -217,31 +217,27 @@ impl AppState {
.map_err(|_e| SendResponseError::Abandoned) .map_err(|_e| SendResponseError::Abandoned)
} }
pub async fn add_ban(&self, client: Option<Client>) { pub fn add_ban(&self, client: Option<Client>, app: tauri::AppHandle) {
let mut bans = self.bans.write().await; let mut bans = self.bans.write().unwrap();
bans.insert(client.clone()); bans.insert(client.clone());
runtime::spawn(async move { runtime::spawn(async move {
sleep(Duration::from_secs(5)).await; sleep(Duration::from_secs(5)).await;
let app = crate::APP.get().unwrap();
let state = app.state::<AppState>(); let state = app.state::<AppState>();
let mut bans = state.bans.write().await; let mut bans = state.bans.write().unwrap();
bans.remove(&client); bans.remove(&client);
}); });
} }
pub async fn is_banned(&self, client: &Option<Client>) -> bool { pub fn is_banned(&self, client: &Option<Client>) -> bool {
self.bans.read().await.contains(&client) self.bans.read().unwrap().contains(&client)
} }
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> { pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
let mut session = self.session.write().await; let (access_key_id, secret_access_key) = {
let LockedCredentials { // do this all in a block so that we aren't holding a lock across an await
access_key_id, let session = self.session.read().unwrap();
secret_key_enc, let locked = match *session {
salt,
nonce
} = match *session {
Session::Empty => {return Err(UnlockError::NoCredentials);}, Session::Empty => {return Err(UnlockError::NoCredentials);},
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);}, Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
Session::Locked(ref c) => c, Session::Locked(ref c) => c,
@ -249,26 +245,26 @@ impl AppState {
let mut key_buf = [0; secretbox::KEYBYTES]; let mut key_buf = [0; secretbox::KEYBYTES];
// pretty sure this only fails if we're out of memory // pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap(); pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &locked.salt).unwrap();
let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf)) let decrypted = secretbox::open(&locked.secret_key_enc, &locked.nonce, &Key(key_buf))
.map_err(|_e| UnlockError::BadPassphrase)?; .map_err(|_e| UnlockError::BadPassphrase)?;
let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?; let secret_str = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
(locked.access_key_id.clone(), secret_str)
};
let session_creds = self.new_session(access_key_id, &secret_access_key).await?; let session_creds = self.new_session(&access_key_id, &secret_access_key).await?;
*session = Session::Unlocked { let mut app_session = self.session.write().unwrap();
base: BaseCredentials { *app_session = Session::Unlocked {
access_key_id: access_key_id.clone(), base: BaseCredentials {access_key_id, secret_access_key},
secret_access_key,
},
session: session_creds session: session_creds
}; };
Ok(()) Ok(())
} }
// pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> { // pub fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
// let session = self.session.read().await; // let session = self.session.read().unwrap();
// match *session { // match *session {
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()), // Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
// Session::Locked(_) => Err(GetCredentialsError::Locked), // Session::Locked(_) => Err(GetCredentialsError::Locked),
@ -276,8 +272,8 @@ impl AppState {
// } // }
// } // }
pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> { pub fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
let session = self.session.read().await; let session = self.session.read().unwrap();
match *session { match *session {
Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()), Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()),
Session::Locked(_) => Err(GetCredentialsError::Locked), Session::Locked(_) => Err(GetCredentialsError::Locked),
@ -333,21 +329,41 @@ impl AppState {
} }
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> { pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
match *self.session.write().await { let base = {
Session::Unlocked{ref base, ref mut session} => { let session = self.session.read().unwrap();
match *session {
Session::Unlocked{ref base, ..} => base.clone(),
_ => unreachable!(),
}
};
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
match *self.session.write().unwrap() {
Session::Unlocked{ref mut session, ..} => {
if !session.is_expired() { if !session.is_expired() {
return Ok(false); return Ok(false);
} }
let new_session = self.new_session(
&base.access_key_id,
&base.secret_access_key
).await?;
*session = new_session; *session = new_session;
Ok(true) Ok(true)
}, },
Session::Locked(_) => Err(GetSessionError::CredentialsLocked), Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty), Session::Empty => Err(GetSessionError::CredentialsEmpty),
} }
// match *self.session.write().unwrap() {
// Session::Unlocked{ref base, ref mut session} => {
// if !session.is_expired() {
// return Ok(false);
// }
// let new_session = self.new_session(
// &base.access_key_id,
// &base.secret_access_key
// ).await?;
// *session = new_session;
// Ok(true)
// },
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
// }
} }
} }