creddy/src-tauri/src/state.rs

378 lines
12 KiB
Rust
Raw Normal View History

2022-12-19 16:20:46 -08:00
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
2023-05-02 11:33:18 -07:00
use std::time::{
Duration,
SystemTime,
UNIX_EPOCH
};
2023-05-02 11:33:18 -07:00
use aws_smithy_types::date_time::{
DateTime as AwsDateTime,
Format as AwsDateTimeFormat,
};
use serde::{Serialize, Deserialize};
use tokio::sync::oneshot::Sender;
2022-12-19 16:20:46 -08:00
use tokio::time::sleep;
2023-04-28 14:33:04 -07:00
use sqlx::SqlitePool;
2022-12-03 21:47:09 -08:00
use sodiumoxide::crypto::{
pwhash,
pwhash::Salt,
secretbox,
secretbox::{Nonce, Key}
};
use tauri::async_runtime as runtime;
2022-12-19 16:20:46 -08:00
use tauri::Manager;
2023-05-02 11:33:18 -07:00
use serde::Serializer;
use crate::{config, config::AppConfig};
use crate::ipc;
2022-12-20 16:11:49 -08:00
use crate::clientinfo::Client;
2022-11-28 16:16:33 -08:00
use crate::errors::*;
2023-04-28 14:33:04 -07:00
use crate::server::Server;
2023-05-02 11:33:18 -07:00
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct BaseCredentials {
access_key_id: String,
secret_access_key: String,
}
2023-05-02 11:33:18 -07:00
#[derive(Clone, Debug, Serialize)]
#[serde(rename_all = "PascalCase")]
pub struct SessionCredentials {
access_key_id: String,
secret_access_key: String,
token: String,
2023-05-02 11:33:18 -07:00
#[serde(serialize_with = "serialize_expiration")]
expiration: AwsDateTime,
}
2023-05-02 11:33:18 -07:00
impl SessionCredentials {
fn is_expired(&self) -> bool {
let current_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
.as_secs();
2023-05-02 11:33:18 -07:00
let expire_ts = self.expiration.secs();
let remaining = expire_ts - (current_ts as i64);
remaining < 60
}
}
#[derive(Clone, Debug)]
2022-12-03 21:47:09 -08:00
pub struct LockedCredentials {
access_key_id: String,
secret_key_enc: Vec<u8>,
salt: Salt,
nonce: Nonce,
}
2023-05-02 11:33:18 -07:00
#[derive(Clone, Debug)]
2022-12-03 21:47:09 -08:00
pub enum Session {
Unlocked{
base: BaseCredentials,
session: SessionCredentials,
},
2022-12-03 21:47:09 -08:00
Locked(LockedCredentials),
Empty,
}
#[derive(Debug)]
pub struct AppState {
pub config: RwLock<AppConfig>,
2022-12-03 21:47:09 -08:00
pub session: RwLock<Session>,
pub request_count: RwLock<u64>,
pub open_requests: RwLock<HashMap<u64, Sender<ipc::Approval>>>,
2022-12-20 16:11:49 -08:00
pub bans: RwLock<std::collections::HashSet<Option<Client>>>,
2023-04-28 14:33:04 -07:00
server: RwLock<Server>,
pool: sqlx::SqlitePool,
}
impl AppState {
2023-04-29 10:01:45 -07:00
pub fn new(config: AppConfig, session: Session, server: Server, pool: SqlitePool) -> AppState {
AppState {
config: RwLock::new(config),
session: RwLock::new(session),
request_count: RwLock::new(0),
open_requests: RwLock::new(HashMap::new()),
bans: RwLock::new(HashSet::new()),
server: RwLock::new(server),
pool,
}
2022-12-02 22:59:13 -08:00
}
2023-04-28 14:33:04 -07:00
pub async fn load_creds(pool: &SqlitePool) -> Result<Session, SetupError> {
2023-04-25 22:10:14 -07:00
let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc")
2022-12-03 21:47:09 -08:00
.fetch_optional(pool)
2022-12-02 22:59:13 -08:00
.await?;
2022-12-03 21:47:09 -08:00
let row = match res {
Some(r) => r,
None => {return Ok(Session::Empty);}
};
2022-12-02 22:59:13 -08:00
2022-12-03 21:47:09 -08:00
let salt_buf: [u8; 32] = row.salt
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let nonce_buf: [u8; 24] = row.nonce
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let creds = LockedCredentials {
access_key_id: row.access_key_id,
secret_key_enc: row.secret_key_enc,
salt: Salt(salt_buf),
nonce: Nonce(nonce_buf),
};
Ok(Session::Locked(creds))
}
pub async fn save_creds(&self, creds: BaseCredentials, passphrase: &str) -> Result<(), UnlockError> {
let BaseCredentials {access_key_id, secret_access_key} = creds;
2023-04-25 22:10:14 -07:00
// do this first so that if it fails we don't save bad credentials
self.new_session(&access_key_id, &secret_access_key).await?;
2023-04-25 22:10:14 -07:00
2022-12-03 21:47:09 -08:00
let salt = pwhash::gen_salt();
let mut key_buf = [0; secretbox::KEYBYTES];
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap();
let key = Key(key_buf);
// not sure we need both salt AND nonce given that we generate a
// fresh salt every time we encrypt, but better safe than sorry
let nonce = secretbox::gen_nonce();
let secret_key_enc = secretbox::seal(secret_access_key.as_bytes(), &nonce, &key);
2022-12-14 14:52:16 -08:00
2022-12-19 15:26:44 -08:00
sqlx::query(
2023-04-25 22:10:14 -07:00
"INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at)
VALUES (?, ?, ?, ?, strftime('%s'))"
2022-12-19 15:26:44 -08:00
)
.bind(&access_key_id)
2022-12-19 15:26:44 -08:00
.bind(&secret_key_enc)
.bind(&salt.0[0..])
.bind(&nonce.0[0..])
.execute(&self.pool)
.await?;
2022-12-02 22:59:13 -08:00
Ok(())
}
2023-04-27 14:24:08 -07:00
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
2023-04-28 14:33:04 -07:00
{
let orig_config = self.config.read().unwrap();
if new_config.start_on_login != orig_config.start_on_login {
config::set_auto_launch(new_config.start_on_login)?;
}
if new_config.listen_addr != orig_config.listen_addr
|| new_config.listen_port != orig_config.listen_port
{
let mut sv = self.server.write().unwrap();
sv.rebind(new_config.listen_addr, new_config.listen_port)?;
}
}
2023-04-27 14:24:08 -07:00
2023-04-28 14:33:04 -07:00
new_config.save(&self.pool).await?;
2023-04-27 14:24:08 -07:00
let mut live_config = self.config.write().unwrap();
*live_config = new_config;
2023-04-26 15:49:08 -07:00
Ok(())
}
2022-11-28 16:16:33 -08:00
pub fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
let count = {
2022-11-29 16:13:09 -08:00
let mut c = self.request_count.write().unwrap();
*c += 1;
c
};
2022-11-29 16:13:09 -08:00
let mut open_requests = self.open_requests.write().unwrap();
2022-12-03 21:47:09 -08:00
open_requests.insert(*count, chan); // `count` is the request id
2022-11-28 16:16:33 -08:00
*count
}
2022-12-20 16:11:49 -08:00
pub fn unregister_request(&self, id: u64) {
let mut open_requests = self.open_requests.write().unwrap();
open_requests.remove(&id);
}
2022-12-21 16:04:12 -08:00
pub fn req_count(&self) -> usize {
let open_requests = self.open_requests.read().unwrap();
open_requests.len()
}
2023-05-02 11:33:18 -07:00
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
self.renew_session_if_expired().await?;
let mut open_requests = self.open_requests.write().unwrap();
2022-11-28 16:16:33 -08:00
let chan = open_requests
.remove(&response.id)
.ok_or(SendResponseError::NotFound)
?;
2022-11-28 16:16:33 -08:00
chan.send(response.approval)
.map_err(|_e| SendResponseError::Abandoned)
}
2022-12-20 16:11:49 -08:00
pub fn add_ban(&self, client: Option<Client>, app: tauri::AppHandle) {
2022-12-19 16:20:46 -08:00
let mut bans = self.bans.write().unwrap();
2022-12-20 16:11:49 -08:00
bans.insert(client.clone());
2022-12-19 16:20:46 -08:00
runtime::spawn(async move {
sleep(Duration::from_secs(5)).await;
let state = app.state::<AppState>();
let mut bans = state.bans.write().unwrap();
2022-12-20 16:11:49 -08:00
bans.remove(&client);
2022-12-19 16:20:46 -08:00
});
}
2022-12-20 16:11:49 -08:00
pub fn is_banned(&self, client: &Option<Client>) -> bool {
self.bans.read().unwrap().contains(&client)
2022-12-19 16:20:46 -08:00
}
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
let (access_key_id, secret_access_key) = {
// do this all in a block so that we aren't holding a lock across an await
2022-12-19 15:26:44 -08:00
let session = self.session.read().unwrap();
let locked = match *session {
Session::Empty => {return Err(UnlockError::NoCredentials);},
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
2022-12-19 15:26:44 -08:00
Session::Locked(ref c) => c,
};
let mut key_buf = [0; secretbox::KEYBYTES];
// pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &locked.salt).unwrap();
let decrypted = secretbox::open(&locked.secret_key_enc, &locked.nonce, &Key(key_buf))
.map_err(|_e| UnlockError::BadPassphrase)?;
let secret_str = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
(locked.access_key_id.clone(), secret_str)
2022-12-03 21:47:09 -08:00
};
let session_creds = self.new_session(&access_key_id, &secret_access_key).await?;
let mut app_session = self.session.write().unwrap();
*app_session = Session::Unlocked {
base: BaseCredentials {access_key_id, secret_access_key},
session: session_creds
};
2022-12-19 15:26:44 -08:00
2022-12-03 21:47:09 -08:00
Ok(())
2022-11-29 16:13:09 -08:00
}
// pub fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
// let session = self.session.read().unwrap();
// match *session {
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
// Session::Locked(_) => Err(GetCredentialsError::Locked),
// Session::Empty => Err(GetCredentialsError::Empty),
// }
// }
pub fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
2022-12-03 21:47:09 -08:00
let session = self.session.read().unwrap();
match *session {
Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()),
2022-12-03 21:47:09 -08:00
Session::Locked(_) => Err(GetCredentialsError::Locked),
Session::Empty => Err(GetCredentialsError::Empty),
}
2022-11-28 16:16:33 -08:00
}
2022-12-19 15:26:44 -08:00
async fn new_session(&self, key_id: &str, secret_key: &str) -> Result<SessionCredentials, GetSessionError> {
2022-12-19 15:26:44 -08:00
let creds = aws_sdk_sts::Credentials::new(
key_id,
secret_key,
None, // token
None, // expiration
"creddy", // "provider name" apparently
);
let config = aws_config::from_env()
.credentials_provider(creds)
.load()
.await;
let client = aws_sdk_sts::Client::new(&config);
let resp = client.get_session_token()
.duration_seconds(43_200)
.send()
.await?;
2023-05-02 11:33:18 -07:00
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
2022-12-19 15:26:44 -08:00
let access_key_id = aws_session.access_key_id()
2023-05-02 11:33:18 -07:00
.ok_or(GetSessionError::EmptyResponse)?
2022-12-19 15:26:44 -08:00
.to_string();
let secret_access_key = aws_session.secret_access_key()
2023-05-02 11:33:18 -07:00
.ok_or(GetSessionError::EmptyResponse)?
2022-12-19 15:26:44 -08:00
.to_string();
let token = aws_session.session_token()
2023-05-02 11:33:18 -07:00
.ok_or(GetSessionError::EmptyResponse)?
2022-12-19 15:26:44 -08:00
.to_string();
let expiration = aws_session.expiration()
2023-05-02 11:33:18 -07:00
.ok_or(GetSessionError::EmptyResponse)?
.clone();
2022-12-19 15:26:44 -08:00
let session_creds = SessionCredentials {
2022-12-19 15:26:44 -08:00
access_key_id,
secret_access_key,
token,
expiration,
};
2023-04-29 10:01:45 -07:00
#[cfg(debug_assertions)]
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
2022-12-19 15:26:44 -08:00
Ok(session_creds)
2022-12-19 15:26:44 -08:00
}
2023-05-02 11:33:18 -07:00
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
let base = {
let session = self.session.read().unwrap();
match *session {
Session::Unlocked{ref base, ..} => base.clone(),
_ => unreachable!(),
}
};
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
match *self.session.write().unwrap() {
Session::Unlocked{ref mut session, ..} => {
if !session.is_expired() {
return Ok(false);
}
*session = new_session;
Ok(true)
},
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty),
}
// match *self.session.write().unwrap() {
// Session::Unlocked{ref base, ref mut session} => {
// if !session.is_expired() {
// return Ok(false);
// }
// let new_session = self.new_session(
// &base.access_key_id,
// &base.secret_access_key
// ).await?;
// *session = new_session;
// Ok(true)
// },
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
// }
}
}
fn serialize_expiration<S>(exp: &AwsDateTime, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
// this only fails if the d/t is out of range, which it can't be for this format
let time_str = exp.fmt(AwsDateTimeFormat::DateTime).unwrap();
serializer.serialize_str(&time_str)
}