start working on generalizing credential logic

This commit is contained in:
2024-06-16 07:08:10 -04:00
parent 0491cb5790
commit d0a2532c27
16 changed files with 1192 additions and 54 deletions

View File

@@ -0,0 +1,174 @@
use serde::{
Serialize,
Deserialize,
Serializer,
Deserializer,
};
use sqlx::SqlitePool;
use super::{Crypto, PersistentCredential};
use crate::errors::*;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct AwsBaseCredential {
#[serde(default = "default_credentials_version")]
pub version: usize,
pub access_key_id: String,
pub secret_access_key: String,
}
impl AwsBaseCredential {
pub fn new(access_key_id: String, secret_access_key: String) -> Self {
Self {version: 1, access_key_id, secret_access_key}
}
}
impl PersistentCredential for AwsBaseCredential {
pub async fn save(&self, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
let (nonce, ciphertext) = crypto.encrypt(self.secret_access_key.as_bytes())?;
sqlx::query!(
"INSERT INTO aws_credentials (
name,
key_id,
secret_key_enc,
nonce,
updated_at
)
VALUES ('main', ?, ?, ? strftime('%s'))
ON CONFLICT DO UPDATE SET
key_id = excluded.key_id,
secret_key_enc = excluded.secret_key_enc,
nonce = excluded.nonce
updated_at = excluded.updated_at",
self.access_key_id,
ciphertext,
nonce,
).execute(pool).await?;
Ok(())
}
pub async fn load(crypt: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
let row = sqlx::query!("SELECT * FROM aws_credentials WHERE name = 'main'")
.fetch_optional(pool)
.await?
.ok_or(LoadCredentialsError::NoCredentials);
let secret_key = crypto.decrypt(&row.nonce, &row.secret_key_enc)?;
let creds = Self {
version: 1,
access_key_id: row.key_id,
secret_access_key: secret_key,
};
Ok(creds)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct AwsSessionCredential {
#[serde(default = "default_credentials_version")]
pub version: usize,
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: String,
#[serde(serialize_with = "serialize_expiration")]
#[serde(deserialize_with = "deserialize_expiration")]
pub expiration: DateTime,
}
impl AwsSessionCredential {
pub async fn from_base(base: &BaseCredentials) -> Result<Self, GetSessionError> {
let req_creds = aws_sdk_sts::Credentials::new(
&base.access_key_id,
&base.secret_access_key,
None, // token
None, //expiration
"Creddy", // "provider name" apparently
);
let config = aws_config::from_env()
.credentials_provider(req_creds)
.load()
.await;
let client = aws_sdk_sts::Client::new(&config);
let resp = client.get_session_token()
.duration_seconds(43_200)
.send()
.await?;
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
let access_key_id = aws_session.access_key_id()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let session_token = aws_session.session_token()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let expiration = aws_session.expiration()
.ok_or(GetSessionError::EmptyResponse)?
.clone();
let session_creds = SessionCredentials {
version: 1,
access_key_id,
secret_access_key,
session_token,
expiration,
};
#[cfg(debug_assertions)]
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
Ok(session_creds)
}
pub fn is_expired(&self) -> bool {
let current_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
.as_secs();
let expire_ts = self.expiration.secs();
let remaining = expire_ts - (current_ts as i64);
remaining < 60
}
}
struct DateTimeVisitor;
impl<'de> Visitor<'de> for DateTimeVisitor {
type Value = DateTime;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<DateTime, E> {
DateTime::from_str(v, Format::DateTime)
.map_err(|_| E::custom(format!("Invalid date/time: {v}")))
}
}
fn deserialize_expiration<'de, D>(deserializer: D) -> Result<DateTime, D::Error>
where D: Deserializer<'de>
{
deserializer.deserialize_str(DateTimeVisitor)
}
fn serialize_expiration<S>(exp: &DateTime, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
// this only fails if the d/t is out of range, which it can't be for this format
let time_str = exp.fmt(Format::DateTime).unwrap();
serializer.serialize_str(&time_str)
}

View File

@@ -0,0 +1,193 @@
use argon2::{
Argon2,
Algorithm,
Version,
ParamsBuilder,
password_hash::rand_core::{RngCore, OsRng},
};
use chacha20poly1305::{
XChaCha20Poly1305,
XNonce,
aead::{
Aead,
AeadCore,
KeyInit,
Error as AeadError,
generic_array::GenericArray,
},
};
use serde::{Serialize, Deserialize};
use sqlx::{FromRow, SqlitePool};
use crate::kv;
mod aws;
pub use aws::{AwsBaseCredential, AwsSessionCredential};
pub enum CredentialKind {
AwsBase,
AwsSession,
}
pub trait PersistentCredential {
async fn load(crypt: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError>;
async fn save(&self, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError>;
}
#[derive(Debug, Clone)]
pub enum AppSession {
Unlocked {
salt: [u8; 32],
crypto: Crypto,
},
Locked {
salt: [u8; 32],
verify_nonce: XNonce,
verify_blob: Vec<u8>
},
Empty,
}
impl AppSession {
pub fn new(passphrase: &str) -> Result<Self, CryptoError> {
let salt = Crypto::salt();
let crypto = Crypto::new(passphrase, &salt);
Ok(Self::Unlocked {salt, crypto})
}
pub fn unlock(self, passphrase: &str) -> Result<Self, UnlockError> {
let (salt, nonce, blob) = match self {
Self::Empty => return Err(UnlockError::NoCredentials),
Self::Unlocked => return Err(UnlockError::NotLocked),
Self::Locked {salt, verify_nonce, verify_blob} => (salt, verify_nonce, verify_blob),
};
let crypto = Crypto::new(passphrase, salt)
.map_err(|e| CryptoError::Argon2(e))?;
// if passphrase is incorrect, this will fail
let verify = crypto.decrypt(&nonce, &blob)?;
Ok(Self::Unlocked{crypto, salt})
}
pub async fn load(pool: &SqlitePool) -> Result<Self, LoadKvError> {
match kv::load_bytes_multi!(pool, "salt", "verify_nonce", "verify_blob").await? {
Some((salt, verify_nonce, verify_blob)) => {
Ok(Self::Locked {salt, verify_nonce, verify_blob}),
},
None => Ok(Self::Empty),
}
}
pub async fn save(&self, pool: &SqlitePool) -> Result<(), LockError> {
let (salt, nonce, blob) = match self {
Self::Unlocked {salt, crypto} => {
let (nonce, blob) = crypto.encrypt(b"correct horse battery staple")
.map_err(|e| CryptoError::Aead(e))?;
(salt, nonce, blob)
},
Self::Locked {salt, verify_nonce, verify_blob} => (salt, verify_nonce, verify_blob),
// "saving" an empty session just means doing nothing
Self::Empty => return Ok(()),
};
kv::save(pool, "salt", salt).await?;
kv::save(pool, "verify_nonce", nonce).await?;
kv::save(pool, "verify_blob", blob).await?;
Ok(())
}
pub fn try_encrypt(&self, data: &[u8]) -> Result<(XNonce, Vec<u8>), CryptoError> {
let crypto = match self {
Self::Empty => Err(GetCredentialsError::Empty),
Self::Locked => Err(GetCredentialsError::Locked),
Self::Unlocked {crypto, ..} => crypto,
}?;
let res = crypto.encrypt(data)?;
Ok(res)
}
pub fn try_decrypt(&self, nonce: XNonce, data: &[u8]) -> Result<Vec<u8>, CryptoError> {
let crypto = match self {
Self::Empty => Err(GetCredentialsError::Empty),
Self::Locked => Err(GetCredentialsError::Locked),
Self::Unlocked {crypto, ..} => crypto,
}?;
let res = crypto.decrypt(nonce, data)?;
Ok(res)
}
}
pub struct Crypto {
cipher: XChaCha20Poly1305,
}
impl Crypto {
/// Argon2 params rationale:
///
/// m_cost is measured in KiB, so 128 * 1024 gives us 128MiB.
/// This should roughly double the memory usage of the application
/// while deriving the key.
///
/// p_cost is irrelevant since (at present) there isn't any parallelism
/// implemented, so we leave it at 1.
///
/// With the above m_cost, t_cost = 8 results in about 800ms to derive
/// a key on my (somewhat older) CPU. This is probably overkill, but
/// given that it should only have to happen ~once a day for most
/// usage, it should be acceptable.
#[cfg(not(debug_assertions))]
const MEM_COST: u32 = 128 * 1024;
#[cfg(not(debug_assertions))]
const TIME_COST: u32 = 8;
/// But since this takes a million years without optimizations,
/// we turn it way down in debug builds.
#[cfg(debug_assertions)]
const MEM_COST: u32 = 48 * 1024;
#[cfg(debug_assertions)]
const TIME_COST: u32 = 1;
fn new(passphrase: &str, salt: &[u8]) -> argon2::Result<Crypto> {
let params = ParamsBuilder::new()
.m_cost(Self::MEM_COST)
.p_cost(1)
.t_cost(Self::TIME_COST)
.build()
.unwrap(); // only errors if the given params are invalid
let hasher = Argon2::new(
Algorithm::Argon2id,
Version::V0x13,
params,
);
let mut key = [0; 32];
hasher.hash_password_into(passphrase.as_bytes(), &salt, &mut key)?;
let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key));
Ok(Crypto { cipher })
}
fn salt() -> [u8; 32] {
let mut salt = [0; 32];
OsRng.fill_bytes(&mut salt);
salt
}
fn encrypt(&self, data: &[u8]) -> Result<(XNonce, Vec<u8>), AeadError> {
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let ciphertext = self.cipher.encrypt(&nonce, data)?;
Ok((nonce, ciphertext))
}
fn decrypt(&self, nonce: &XNonce, data: &[u8]) -> Result<Vec<u8>, AeadError> {
self.cipher.decrypt(nonce, data)
}
}