creddy/src-tauri/src/state.rs

286 lines
9.1 KiB
Rust

use core::time::Duration;
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use serde::{Serialize, Deserialize};
use tokio::sync::oneshot::Sender;
use tokio::time::sleep;
use sqlx::{SqlitePool, sqlite::SqlitePoolOptions, sqlite::SqliteConnectOptions};
use sodiumoxide::crypto::{
pwhash,
pwhash::Salt,
secretbox,
secretbox::{Nonce, Key}
};
use tauri::async_runtime as runtime;
use tauri::Manager;
use crate::{config, config::AppConfig};
use crate::ipc;
use crate::clientinfo::Client;
use crate::errors::*;
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Credentials {
#[serde(rename_all = "PascalCase")]
LongLived {
access_key_id: String,
secret_access_key: String,
},
#[serde(rename_all = "PascalCase")]
ShortLived {
access_key_id: String,
secret_access_key: String,
token: String,
expiration: String,
},
}
#[derive(Debug)]
pub struct LockedCredentials {
access_key_id: String,
secret_key_enc: Vec<u8>,
salt: Salt,
nonce: Nonce,
}
#[derive(Debug)]
pub enum Session {
Unlocked(Credentials),
Locked(LockedCredentials),
Empty,
}
#[derive(Debug)]
pub struct AppState {
pub config: RwLock<AppConfig>,
pub session: RwLock<Session>,
pub request_count: RwLock<u64>,
pub open_requests: RwLock<HashMap<u64, Sender<ipc::Approval>>>,
pub bans: RwLock<std::collections::HashSet<Option<Client>>>,
pool: SqlitePool,
}
impl AppState {
pub async fn load() -> Result<Self, SetupError> {
let conn_opts = SqliteConnectOptions::new()
.filename(config::get_or_create_db_path())
.create_if_missing(true);
let pool_opts = SqlitePoolOptions::new();
let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?;
sqlx::migrate!().run(&pool).await?;
let creds = Self::load_creds(&pool).await?;
let conf = config::load(&pool).await?;
let state = AppState {
config: RwLock::new(conf),
session: RwLock::new(creds),
request_count: RwLock::new(0),
open_requests: RwLock::new(HashMap::new()),
bans: RwLock::new(HashSet::new()),
pool,
};
Ok(state)
}
async fn load_creds(pool: &SqlitePool) -> Result<Session, SetupError> {
let res = sqlx::query!("SELECT * FROM credentials")
.fetch_optional(pool)
.await?;
let row = match res {
Some(r) => r,
None => {return Ok(Session::Empty);}
};
let salt_buf: [u8; 32] = row.salt
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let nonce_buf: [u8; 24] = row.nonce
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let creds = LockedCredentials {
access_key_id: row.access_key_id,
secret_key_enc: row.secret_key_enc,
salt: Salt(salt_buf),
nonce: Nonce(nonce_buf),
};
Ok(Session::Locked(creds))
}
pub async fn save_creds(&self, creds: Credentials, passphrase: &str) -> Result<(), UnlockError> {
let (key_id, secret_key) = match creds {
Credentials::LongLived {access_key_id, secret_access_key} => {
(access_key_id, secret_access_key)
},
_ => unreachable!(),
};
let salt = pwhash::gen_salt();
let mut key_buf = [0; secretbox::KEYBYTES];
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap();
let key = Key(key_buf);
// not sure we need both salt AND nonce given that we generate a
// fresh salt every time we encrypt, but better safe than sorry
let nonce = secretbox::gen_nonce();
let secret_key_enc = secretbox::seal(secret_key.as_bytes(), &nonce, &key);
sqlx::query(
"INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce)
VALUES (?, ?, ?, ?)"
)
.bind(&key_id)
.bind(&secret_key_enc)
.bind(&salt.0[0..])
.bind(&nonce.0[0..])
.execute(&self.pool)
.await?;
self.new_session(&key_id, &secret_key).await?;
Ok(())
}
pub fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
let count = {
let mut c = self.request_count.write().unwrap();
*c += 1;
c
};
let mut open_requests = self.open_requests.write().unwrap();
open_requests.insert(*count, chan); // `count` is the request id
*count
}
pub fn unregister_request(&self, id: u64) {
let mut open_requests = self.open_requests.write().unwrap();
open_requests.remove(&id);
}
pub fn req_count(&self) -> usize {
let open_requests = self.open_requests.read().unwrap();
open_requests.len()
}
pub fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
let mut open_requests = self.open_requests.write().unwrap();
let chan = open_requests
.remove(&response.id)
.ok_or(SendResponseError::NotFound)
?;
chan.send(response.approval)
.map_err(|_e| SendResponseError::Abandoned)
}
pub fn add_ban(&self, client: Option<Client>, app: tauri::AppHandle) {
let mut bans = self.bans.write().unwrap();
bans.insert(client.clone());
runtime::spawn(async move {
sleep(Duration::from_secs(5)).await;
let state = app.state::<AppState>();
let mut bans = state.bans.write().unwrap();
bans.remove(&client);
});
}
pub fn is_banned(&self, client: &Option<Client>) -> bool {
self.bans.read().unwrap().contains(&client)
}
pub async fn decrypt(&self, passphrase: &str) -> Result<(), UnlockError> {
let (key_id, secret) = {
// do this all in a block so that we aren't holding a lock across an await
let session = self.session.read().unwrap();
let locked = match *session {
Session::Empty => {return Err(UnlockError::NoCredentials);},
Session::Unlocked(_) => {return Err(UnlockError::NotLocked);},
Session::Locked(ref c) => c,
};
let mut key_buf = [0; secretbox::KEYBYTES];
// pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &locked.salt).unwrap();
let decrypted = secretbox::open(&locked.secret_key_enc, &locked.nonce, &Key(key_buf))
.map_err(|_e| UnlockError::BadPassphrase)?;
let secret_str = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
(locked.access_key_id.clone(), secret_str)
};
self.new_session(&key_id, &secret).await?;
Ok(())
}
pub fn get_creds_serialized(&self) -> Result<String, GetCredentialsError> {
let session = self.session.read().unwrap();
match *session {
Session::Unlocked(ref creds) => Ok(serde_json::to_string(creds).unwrap()),
Session::Locked(_) => Err(GetCredentialsError::Locked),
Session::Empty => Err(GetCredentialsError::Empty),
}
}
async fn new_session(&self, key_id: &str, secret_key: &str) -> Result<(), GetSessionError> {
let creds = aws_sdk_sts::Credentials::new(
key_id,
secret_key,
None, // token
None, // expiration
"creddy", // "provider name" apparently
);
let config = aws_config::from_env()
.credentials_provider(creds)
.load()
.await;
let client = aws_sdk_sts::Client::new(&config);
let resp = client.get_session_token()
.duration_seconds(43_200)
.send()
.await?;
let aws_session = resp.credentials().ok_or(GetSessionError::NoCredentials)?;
let access_key_id = aws_session.access_key_id()
.ok_or(GetSessionError::NoCredentials)?
.to_string();
let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::NoCredentials)?
.to_string();
let token = aws_session.session_token()
.ok_or(GetSessionError::NoCredentials)?
.to_string();
let expiration = aws_session.expiration()
.ok_or(GetSessionError::NoCredentials)?
.fmt(aws_smithy_types::date_time::Format::DateTime)
.unwrap(); // only fails if the d/t is out of range, which it can't be for this format
let mut app_session = self.session.write().unwrap();
let session_creds = Credentials::ShortLived {
access_key_id,
secret_access_key,
token,
expiration,
};
if cfg!(debug_assertions) {
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
}
*app_session = Session::Unlocked(session_creds);
Ok(())
}
}