almost finish refactoring PersistentCredential trait

This commit is contained in:
2024-06-26 15:01:07 -04:00
parent 37b44ddb2e
commit ce7d75f15a
21 changed files with 1287 additions and 459 deletions

View File

@@ -1,31 +1,7 @@
use std::fmt::{self, Debug, Formatter};
use std::fmt::Formatter;
use argon2::{
Argon2,
Algorithm,
Version,
ParamsBuilder,
password_hash::rand_core::{RngCore, OsRng},
};
use chacha20poly1305::{
XChaCha20Poly1305,
XNonce,
aead::{
Aead,
AeadCore,
KeyInit,
generic_array::GenericArray,
},
};
use serde::{
Serialize,
Deserialize,
Serializer,
Deserializer,
};
use serde::de::{self, Visitor};
use serde::{Serialize, Deserialize};
use sqlx::{
Error as SqlxError,
FromRow,
Sqlite,
SqlitePool,
@@ -35,11 +11,19 @@ use sqlx::{
};
use crate::errors::*;
use crate::kv;
mod aws;
pub use aws::{AwsBaseCredential, AwsSessionCredential};
mod record;
pub use record::CredentialRecord;
mod session;
pub use session::AppSession;
mod crypto;
pub use crypto::Crypto;
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
@@ -49,99 +33,6 @@ pub enum Credential {
}
// we need a special type for listing structs because
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct CredentialRecord {
#[serde(serialize_with = "serialize_uuid")]
#[serde(deserialize_with = "deserialize_uuid")]
id: Uuid, // UUID so it can be generated on the frontend
name: String, // user-facing identifier so it can be changed
is_default: bool,
credential: Credential,
}
impl CredentialRecord {
pub async fn save(&self, crypto: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
let type_name = match &self.credential {
Credential::AwsBase(_) => AwsBaseCredential::type_name(),
_ => return Err(SaveCredentialsError::NotPersistent),
};
// if the credential being saved is default, make sure it's the only default of its type
let mut txn = pool.begin().await?;
if self.is_default {
sqlx::query!(
"UPDATE credentials SET is_default = 0 WHERE type = ?",
type_name
).execute(&mut *txn).await?;
}
// save to parent credentials table
let res = sqlx::query!(
"INSERT INTO credentials (id, name, type, is_default)
VALUES (?, ?, ?, ?)
ON CONFLICT DO UPDATE SET
name = excluded.name,
type = excluded.type,
is_default = excluded.is_default",
self.id, self.name, type_name, self.is_default
).execute(&mut *txn).await;
// if id is unique, but name is not, we will get an error
// (if id is not unique, this becomes an upsert due to ON CONFLICT clause)
match res {
Err(SqlxError::Database(e)) if e.is_unique_violation() => Err(SaveCredentialsError::Duplicate),
Err(e) => Err(SaveCredentialsError::DbError(e)),
Ok(_) => Ok(())
}?;
// save credential details to child table
match &self.credential {
Credential::AwsBase(b) => b.save_details(&self.id, crypto, &mut txn).await,
_ => Err(SaveCredentialsError::NotPersistent),
}?;
// make it real
txn.commit().await?;
Ok(())
}
#[allow(unused_variables)]
pub async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<Self>, LoadCredentialsError> {
todo!()
}
#[allow(unused_variables)]
pub async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
todo!()
}
}
fn serialize_uuid<S: Serializer>(u: &Uuid, s: S) -> Result<S::Ok, S::Error> {
let mut buf = Uuid::encode_buffer();
s.serialize_str(u.as_hyphenated().encode_lower(&mut buf))
}
struct UuidVisitor;
impl<'de> Visitor<'de> for UuidVisitor {
type Value = Uuid;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "a hyphenated UUID")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Uuid, E> {
Uuid::try_parse(v)
.map_err(|_| E::custom(format!("Could not interpret string as UUID: {v}")))
}
}
fn deserialize_uuid<'de, D: Deserializer<'de>>(ds: D) -> Result<Uuid, D::Error> {
ds.deserialize_str(UuidVisitor)
}
pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
type Row: Send + Unpin + for<'r> FromRow<'r, SqliteRow>;
@@ -195,250 +86,10 @@ pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
Self::from_row(row, crypto)
}
}
#[derive(Clone, Debug)]
pub enum AppSession {
Unlocked {
salt: [u8; 32],
crypto: Crypto,
},
Locked {
salt: [u8; 32],
verify_nonce: XNonce,
verify_blob: Vec<u8>
},
Empty,
}
impl AppSession {
pub fn new(passphrase: &str) -> Result<Self, CryptoError> {
let salt = Crypto::salt();
let crypto = Crypto::new(passphrase, &salt)?;
Ok(Self::Unlocked {salt, crypto})
}
pub fn unlock(&mut self, passphrase: &str) -> Result<(), UnlockError> {
let (salt, nonce, blob) = match self {
Self::Empty => return Err(UnlockError::NoCredentials),
Self::Unlocked {..} => return Err(UnlockError::NotLocked),
Self::Locked {salt, verify_nonce, verify_blob} => (salt, verify_nonce, verify_blob),
};
let crypto = Crypto::new(passphrase, salt)
.map_err(|e| CryptoError::Argon2(e))?;
// if passphrase is incorrect, this will fail
let _verify = crypto.decrypt(&nonce, &blob)?;
*self = Self::Unlocked {crypto, salt: *salt};
Ok(())
}
pub async fn load(pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
match kv::load_bytes_multi!(pool, "salt", "verify_nonce", "verify_blob").await? {
Some((salt, nonce, blob)) => {
Ok(Self::Locked {
salt: salt.try_into().map_err(|_| LoadCredentialsError::InvalidData)?,
// note: replace this with try_from at some point
verify_nonce: XNonce::clone_from_slice(&nonce),
verify_blob: blob,
})
},
None => Ok(Self::Empty),
}
}
pub async fn save(&self, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
match self {
Self::Unlocked {salt, crypto} => {
let (nonce, blob) = crypto.encrypt(b"correct horse battery staple")?;
kv::save_bytes(pool, "salt", salt).await?;
kv::save_bytes(pool, "verify_nonce", &nonce.as_slice()).await?;
kv::save_bytes(pool, "verify_blob", &blob).await?;
},
Self::Locked {salt, verify_nonce, verify_blob} => {
kv::save_bytes(pool, "salt", salt).await?;
kv::save_bytes(pool, "verify_nonce", &verify_nonce.as_slice()).await?;
kv::save_bytes(pool, "verify_blob", verify_blob).await?;
},
// "saving" an empty session just means doing nothing
Self::Empty => (),
};
Ok(())
}
pub fn try_get_crypto(&self) -> Result<&Crypto, GetCredentialsError> {
match self {
Self::Empty => Err(GetCredentialsError::Empty),
Self::Locked {..} => Err(GetCredentialsError::Locked),
Self::Unlocked {crypto, ..} => Ok(crypto),
}
}
pub fn try_encrypt(&self, data: &[u8]) -> Result<(XNonce, Vec<u8>), GetCredentialsError> {
let crypto = match self {
Self::Empty => return Err(GetCredentialsError::Empty),
Self::Locked {..} => return Err(GetCredentialsError::Locked),
Self::Unlocked {crypto, ..} => crypto,
};
let res = crypto.encrypt(data)?;
Ok(res)
}
pub fn try_decrypt(&self, nonce: XNonce, data: &[u8]) -> Result<Vec<u8>, GetCredentialsError> {
let crypto = match self {
Self::Empty => return Err(GetCredentialsError::Empty),
Self::Locked {..} => return Err(GetCredentialsError::Locked),
Self::Unlocked {crypto, ..} => crypto,
};
let res = crypto.decrypt(&nonce, data)?;
Ok(res)
}
}
#[derive(Clone)]
pub struct Crypto {
cipher: XChaCha20Poly1305,
}
impl Crypto {
/// Argon2 params rationale:
///
/// m_cost is measured in KiB, so 128 * 1024 gives us 128MiB.
/// This should roughly double the memory usage of the application
/// while deriving the key.
///
/// p_cost is irrelevant since (at present) there isn't any parallelism
/// implemented, so we leave it at 1.
///
/// With the above m_cost, t_cost = 8 results in about 800ms to derive
/// a key on my (somewhat older) CPU. This is probably overkill, but
/// given that it should only have to happen ~once a day for most
/// usage, it should be acceptable.
#[cfg(not(debug_assertions))]
const MEM_COST: u32 = 128 * 1024;
#[cfg(not(debug_assertions))]
const TIME_COST: u32 = 8;
/// But since this takes a million years without optimizations,
/// we turn it way down in debug builds.
#[cfg(debug_assertions)]
const MEM_COST: u32 = 48 * 1024;
#[cfg(debug_assertions)]
const TIME_COST: u32 = 1;
fn new(passphrase: &str, salt: &[u8]) -> argon2::Result<Crypto> {
let params = ParamsBuilder::new()
.m_cost(Self::MEM_COST)
.p_cost(1)
.t_cost(Self::TIME_COST)
.build()
.unwrap(); // only errors if the given params are invalid
let hasher = Argon2::new(
Algorithm::Argon2id,
Version::V0x13,
params,
);
let mut key = [0; 32];
hasher.hash_password_into(passphrase.as_bytes(), &salt, &mut key)?;
let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key));
Ok(Crypto { cipher })
}
#[cfg(test)]
pub fn random() -> Crypto {
// salt and key are the same length, so we can just use this
let key = Crypto::salt();
let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key));
Crypto { cipher }
}
#[cfg(test)]
pub fn fixed() -> Crypto {
let key = [
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
];
let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key));
Crypto { cipher }
}
fn salt() -> [u8; 32] {
let mut salt = [0; 32];
OsRng.fill_bytes(&mut salt);
salt
}
fn encrypt(&self, data: &[u8]) -> Result<(XNonce, Vec<u8>), CryptoError> {
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let ciphertext = self.cipher.encrypt(&nonce, data)?;
Ok((nonce, ciphertext))
}
fn decrypt(&self, nonce: &XNonce, data: &[u8]) -> Result<Vec<u8>, CryptoError> {
let plaintext = self.cipher.decrypt(nonce, data)?;
Ok(plaintext)
}
}
impl Debug for Crypto {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "Crypto {{ [...] }}")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
struct UuidWrapper {
#[serde(serialize_with = "serialize_uuid")]
#[serde(deserialize_with = "deserialize_uuid")]
id: Uuid,
}
#[test]
fn test_serialize_uuid() {
let u = UuidWrapper {
id: Uuid::try_parse("693f84d2-4c1b-41e5-8483-cbe178324e04").unwrap()
};
let computed = serde_json::to_string(&u).unwrap();
assert_eq!(
"{\"id\":\"693f84d2-4c1b-41e5-8483-cbe178324e04\"}",
&computed,
);
}
#[test]
fn test_deserialize_uuid() {
let s = "{\"id\":\"045bd359-8630-4b76-9b7d-e4a86ed2222c\"}";
let computed = serde_json::from_str(s).unwrap();
let expected = UuidWrapper {
id: Uuid::try_parse("045bd359-8630-4b76-9b7d-e4a86ed2222c").unwrap(),
};
assert_eq!(expected, computed);
}
#[test]
fn test_serialize_deserialize_uuid() {
let buf = Crypto::salt();
let expected = UuidWrapper{
id: Uuid::from_slice(&buf[..16]).unwrap()
};
let serialized = serde_json::to_string(&expected).unwrap();
let computed = serde_json::from_str(&serialized).unwrap();
assert_eq!(expected, computed)
async fn list(pool: &SqlitePool) -> Result<Vec<Self::Row>, LoadCredentialsError> {
let q = format!("SELECT * FROM {}", Self::table_name());
let rows: Vec<Self::Row> = sqlx::query_as(&q).fetch_all(pool).await?;
Ok(rows)
}
}