use std::fmt::{self, Formatter}; use std::time::{SystemTime, UNIX_EPOCH}; use aws_smithy_types::date_time::{DateTime, Format}; use argon2::{ Argon2, Algorithm, Version, ParamsBuilder, password_hash::rand_core::{RngCore, OsRng}, }; use chacha20poly1305::{ XChaCha20Poly1305, XNonce, aead::{ Aead, AeadCore, KeyInit, Error as AeadError, generic_array::GenericArray, }, }; use serde::{ Serialize, Deserialize, Serializer, Deserializer, }; use serde::de::{self, Visitor}; use sqlx::SqlitePool; use crate::errors::*; #[derive(Clone, Debug)] pub enum Session { Unlocked{ base: BaseCredentials, session: SessionCredentials, }, Locked(LockedCredentials), Empty, } impl Session { pub async fn load(pool: &SqlitePool) -> Result { let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc") .fetch_optional(pool) .await?; let row = match res { Some(r) => r, None => {return Ok(Session::Empty);} }; let salt: [u8; 32] = row.salt .try_into() .map_err(|_e| SetupError::InvalidRecord)?; let nonce = XNonce::from_exact_iter(row.nonce.into_iter()) .ok_or(SetupError::InvalidRecord)?; let creds = LockedCredentials { access_key_id: row.access_key_id, secret_key_enc: row.secret_key_enc, salt, nonce, }; Ok(Session::Locked(creds)) } pub async fn renew_if_expired(&mut self) -> Result { match self { Session::Unlocked{ref base, ref mut session} => { if !session.is_expired() { return Ok(false); } *session = SessionCredentials::from_base(base).await?; Ok(true) }, Session::Locked(_) => Err(GetSessionError::CredentialsLocked), Session::Empty => Err(GetSessionError::CredentialsEmpty), } } pub fn try_get( &self ) -> Result<(&BaseCredentials, &SessionCredentials), GetCredentialsError> { match self { Self::Empty => Err(GetCredentialsError::Empty), Self::Locked(_) => Err(GetCredentialsError::Locked), Self::Unlocked{ ref base, ref session } => Ok((base, session)) } } } #[derive(Clone, Debug)] pub struct LockedCredentials { pub access_key_id: String, pub secret_key_enc: Vec, pub salt: [u8; 32], pub nonce: XNonce, } impl LockedCredentials { pub async fn save(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> { sqlx::query( "INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at) VALUES (?, ?, ?, ?, strftime('%s'))" ) .bind(&self.access_key_id) .bind(&self.secret_key_enc) .bind(&self.salt[..]) .bind(&self.nonce[..]) .execute(pool) .await?; Ok(()) } pub fn decrypt(&self, passphrase: &str) -> Result { let crypto = Crypto::new(passphrase, &self.salt) .map_err(|e| CryptoError::Argon2(e))?; let decrypted = crypto.decrypt(&self.nonce, &self.secret_key_enc) .map_err(|e| CryptoError::Aead(e))?; let secret_access_key = String::from_utf8(decrypted) .map_err(|_| UnlockError::InvalidUtf8)?; let creds = BaseCredentials { access_key_id: self.access_key_id.clone(), secret_access_key, }; Ok(creds) } } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct BaseCredentials { pub access_key_id: String, pub secret_access_key: String, } impl BaseCredentials { pub fn encrypt(&self, passphrase: &str) -> Result { let salt = Crypto::salt(); let crypto = Crypto::new(passphrase, &salt)?; let (nonce, secret_key_enc) = crypto.encrypt(self.secret_access_key.as_bytes())?; let locked = LockedCredentials { access_key_id: self.access_key_id.clone(), secret_key_enc, salt, nonce, }; Ok(locked) } } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct SessionCredentials { pub version: usize, pub access_key_id: String, pub secret_access_key: String, pub session_token: String, #[serde(serialize_with = "serialize_expiration")] #[serde(deserialize_with = "deserialize_expiration")] pub expiration: DateTime, } impl SessionCredentials { pub async fn from_base(base: &BaseCredentials) -> Result { let req_creds = aws_sdk_sts::Credentials::new( &base.access_key_id, &base.secret_access_key, None, // token None, //expiration "Creddy", // "provider name" apparently ); let config = aws_config::from_env() .credentials_provider(req_creds) .load() .await; let client = aws_sdk_sts::Client::new(&config); let resp = client.get_session_token() .duration_seconds(43_200) .send() .await?; let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?; let access_key_id = aws_session.access_key_id() .ok_or(GetSessionError::EmptyResponse)? .to_string(); let secret_access_key = aws_session.secret_access_key() .ok_or(GetSessionError::EmptyResponse)? .to_string(); let session_token = aws_session.session_token() .ok_or(GetSessionError::EmptyResponse)? .to_string(); let expiration = aws_session.expiration() .ok_or(GetSessionError::EmptyResponse)? .clone(); let session_creds = SessionCredentials { version: 1, access_key_id, secret_access_key, session_token, expiration, }; #[cfg(debug_assertions)] println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap()); Ok(session_creds) } pub fn is_expired(&self) -> bool { let current_ts = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() // doesn't panic because UNIX_EPOCH won't be later than now() .as_secs(); let expire_ts = self.expiration.secs(); let remaining = expire_ts - (current_ts as i64); remaining < 60 } } #[derive(Debug, Serialize, Deserialize)] pub enum Credentials { Base(BaseCredentials), Session(SessionCredentials), } fn serialize_expiration(exp: &DateTime, serializer: S) -> Result where S: Serializer { // this only fails if the d/t is out of range, which it can't be for this format let time_str = exp.fmt(Format::DateTime).unwrap(); serializer.serialize_str(&time_str) } struct DateTimeVisitor; impl<'de> Visitor<'de> for DateTimeVisitor { type Value = DateTime; fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"") } fn visit_str(self, v: &str) -> Result { DateTime::from_str(v, Format::DateTime) .map_err(|_| E::custom(format!("Invalid date/time: {v}"))) } } fn deserialize_expiration<'de, D>(deserializer: D) -> Result where D: Deserializer<'de> { deserializer.deserialize_str(DateTimeVisitor) } struct Crypto { cipher: XChaCha20Poly1305, } impl Crypto { /// Argon2 params rationale: /// /// m_cost is measured in KiB, so 128 * 1024 gives us 128MiB. /// This should roughly double the memory usage of the application /// while deriving the key. /// /// p_cost is irrelevant since (at present) there isn't any parallelism /// implemented, so we leave it at 1. /// /// With the above m_cost, t_cost = 8 results in about 800ms to derive /// a key on my (somewhat older) CPU. This is probably overkill, but /// given that it should only have to happen ~once a day for most /// usage, it should be acceptable. #[cfg(not(debug_assertions))] const MEM_COST: u32 = 128 * 1024; #[cfg(not(debug_assertions))] const TIME_COST: u32 = 8; /// But since this takes a million years without optimizations, /// we turn it way down in debug builds. #[cfg(debug_assertions)] const MEM_COST: u32 = 48 * 1024; #[cfg(debug_assertions)] const TIME_COST: u32 = 1; fn new(passphrase: &str, salt: &[u8]) -> argon2::Result { let params = ParamsBuilder::new() .m_cost(Self::MEM_COST) .p_cost(1) .t_cost(Self::TIME_COST) .build() .unwrap(); // only errors if the given params are invalid let hasher = Argon2::new( Algorithm::Argon2id, Version::V0x13, params, ); let mut key = [0; 32]; hasher.hash_password_into(passphrase.as_bytes(), &salt, &mut key)?; let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key)); Ok(Crypto { cipher }) } fn salt() -> [u8; 32] { let mut salt = [0; 32]; OsRng.fill_bytes(&mut salt); salt } fn encrypt(&self, data: &[u8]) -> Result<(XNonce, Vec), AeadError> { let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng); let ciphertext = self.cipher.encrypt(&nonce, data)?; Ok((nonce, ciphertext)) } fn decrypt(&self, nonce: &XNonce, data: &[u8]) -> Result, AeadError> { self.cipher.decrypt(nonce, data) } }