still in progress

This commit is contained in:
2024-06-25 15:19:29 -04:00
parent 9928996fab
commit 8c668e51a6
12 changed files with 1620 additions and 1138 deletions

View File

@@ -1,6 +1,7 @@
use std::fmt::{self, Formatter};
use std::time::{SystemTime, UNIX_EPOCH};
use aws_config::BehaviorVersion;
use aws_smithy_types::date_time::{DateTime, Format};
use chacha20poly1305::XNonce;
use serde::{
@@ -10,14 +11,21 @@ use serde::{
Deserializer,
};
use serde::de::{self, Visitor};
use sqlx::SqlitePool;
use sqlx::{
SqlitePool,
types::Uuid,
};
use sqlx::error::{
Error as SqlxError,
};
use tokio_stream::StreamExt;
use super::{Crypto, PersistentCredential};
use super::{Credential, Crypto, SaveCredential, PersistentCredential};
use crate::errors::*;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct AwsBaseCredential {
#[serde(default = "default_credentials_version")]
@@ -26,6 +34,7 @@ pub struct AwsBaseCredential {
pub secret_access_key: String,
}
impl AwsBaseCredential {
pub fn new(access_key_id: String, secret_access_key: String) -> Self {
Self {version: 1, access_key_id, secret_access_key}
@@ -33,54 +42,89 @@ impl AwsBaseCredential {
}
impl PersistentCredential for AwsBaseCredential {
async fn save(&self, crypto: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
async fn save(&self, id: &Uuid, name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
let (nonce, ciphertext) = crypto.encrypt(self.secret_access_key.as_bytes())?;
let nonce_bytes = &nonce.as_slice();
sqlx::query!(
"INSERT INTO aws_credentials (
name,
let res = sqlx::query!(
"INSERT INTO credentials (id, name, type, created_at)
VALUES (?, ?, 'aws', strftime('%s'))
ON CONFLICT(id) DO UPDATE SET
name = excluded.name,
type = excluded.type,
created_at = excluded.created_at;
INSERT OR REPLACE INTO aws_credentials (
id,
access_key_id,
secret_key_enc,
nonce,
created_at
nonce
)
VALUES ('default', ?, ?, ?, strftime('%s'))
ON CONFLICT DO UPDATE SET
access_key_id = excluded.access_key_id,
secret_key_enc = excluded.secret_key_enc,
nonce = excluded.nonce,
created_at = excluded.created_at",
VALUES (?, ?, ?, ?);",
id,
name,
id, // for the second query
self.access_key_id,
ciphertext,
nonce_bytes,
).execute(pool).await?;
).execute(pool).await;
Ok(())
match res {
Err(SqlxError::Database(e)) if e.code().as_deref() == Some("2067") => Err(SaveCredentialsError::Duplicate),
Err(e) => Err(SaveCredentialsError::DbError(e)),
Ok(_) => Ok(())
}
}
async fn load(crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
let row = sqlx::query!("SELECT * FROM aws_credentials WHERE name = 'default'")
.fetch_optional(pool)
async fn load(name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
let row = sqlx::query!(
"SELECT c.name, a.access_key_id, a.secret_key_enc, a.nonce
FROM credentials c JOIN aws_credentials a ON a.id = c.id
WHERE c.name = ?",
name
).fetch_optional(pool)
.await?
.ok_or(LoadCredentialsError::NoCredentials)?;
// note: switch to try_from eventually
let nonce = XNonce::clone_from_slice(&row.nonce);
let secret_key_bytes = crypto.decrypt(&nonce, &row.secret_key_enc)?;
let secret_key = String::from_utf8(secret_key_bytes)
.map_err(|_| LoadCredentialsError::InvalidData)?;
let creds = Self {
version: 1,
access_key_id: row.access_key_id,
secret_access_key: secret_key,
};
Ok(AwsBaseCredential::new(row.access_key_id, secret_key))
}
async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<SaveCredential>, LoadCredentialsError> {
let mut rows = sqlx::query!(
"SELECT c.id, c.name, a.access_key_id, a.secret_key_enc, a.nonce
FROM credentials c JOIN aws_credentials a ON a.id = c.id"
).fetch(pool);
let mut creds = Vec::new();
while let Some(row) = rows.try_next().await? {
let nonce = XNonce::clone_from_slice(&row.nonce);
let secret_key_bytes = crypto.decrypt(&nonce, &row.secret_key_enc)?;
let secret_key = String::from_utf8(secret_key_bytes)
.map_err(|_| LoadCredentialsError::InvalidData)?;
let aws = AwsBaseCredential::new(row.access_key_id, secret_key);
let id = Uuid::from_slice(&row.id)
.map_err(|_| LoadCredentialsError::InvalidData)?;
let cred = SaveCredential {
id,
name: row.name,
credential: Credential::AwsBase(aws),
};
creds.push(cred);
}
Ok(creds)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct AwsSessionCredential {
#[serde(default = "default_credentials_version")]
@@ -95,14 +139,14 @@ pub struct AwsSessionCredential {
impl AwsSessionCredential {
pub async fn from_base(base: &AwsBaseCredential) -> Result<Self, GetSessionError> {
let req_creds = aws_sdk_sts::Credentials::new(
let req_creds = aws_sdk_sts::config::Credentials::new(
&base.access_key_id,
&base.secret_access_key,
None, // token
None, //expiration
"Creddy", // "provider name" apparently
);
let config = aws_config::from_env()
let config = aws_config::defaults(BehaviorVersion::latest())
.credentials_provider(req_creds)
.load()
.await;
@@ -113,27 +157,14 @@ impl AwsSessionCredential {
.send()
.await?;
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
let access_key_id = aws_session.access_key_id()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let session_token = aws_session.session_token()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let expiration = aws_session.expiration()
.ok_or(GetSessionError::EmptyResponse)?
.clone();
let aws_session = resp.credentials.ok_or(GetSessionError::EmptyResponse)?;
let session_creds = AwsSessionCredential {
version: 1,
access_key_id,
secret_access_key,
session_token,
expiration,
access_key_id: aws_session.access_key_id,
secret_access_key: aws_session.secret_access_key,
session_token: aws_session.session_token,
expiration: aws_session.expiration,
};
#[cfg(debug_assertions)]
@@ -187,3 +218,128 @@ where S: Serializer
let time_str = exp.fmt(Format::DateTime).unwrap();
serializer.serialize_str(&time_str)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_creds() -> AwsBaseCredential {
AwsBaseCredential::new(
"AKIAIOSFODNN7EXAMPLE".into(),
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".into(),
)
}
fn test_creds_2() -> AwsBaseCredential {
AwsBaseCredential::new(
"AKIAIOSFODNN7EXAMPL2".into(),
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKE2".into(),
)
}
fn test_uuid() -> Uuid {
Uuid::try_parse("00000000-0000-0000-0000-000000000000").unwrap()
}
fn test_uuid_2() -> Uuid {
Uuid::try_parse("ffffffff-ffff-ffff-ffff-ffffffffffff").unwrap()
}
fn test_uuid_random() -> Uuid {
let bytes = Crypto::salt();
Uuid::from_slice(&bytes[..16]).unwrap()
}
#[sqlx::test]
async fn test_save(pool: SqlitePool) {
let crypt = Crypto::random();
test_creds().save(&test_uuid_random(), "test", &crypt, &pool).await
.expect("Failed to save AWS credentials");
}
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_overwrite(pool: SqlitePool) {
let crypt = Crypto::fixed();
let creds = test_creds_2();
// overwite original creds with different test data
creds.save(&test_uuid(), "test", &crypt, &pool).await
.expect("Failed to update AWS credentials");
// make sure update went through
let loaded = AwsBaseCredential::load("test", &crypt, &pool).await.unwrap();
assert_eq!(creds, loaded);
}
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_duplicate_name(pool: SqlitePool) {
let crypt = Crypto::random();
let id = test_uuid_random();
let resp = test_creds().save(&id, "test", &crypt, &pool).await;
if !matches!(resp, Err(SaveCredentialsError::Duplicate)) {
panic!("Attempt to create duplicate entry returned {resp:?}")
}
}
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_load(pool: SqlitePool) {
let crypt = Crypto::fixed();
let loaded = AwsBaseCredential::load("test", &crypt, &pool).await.unwrap();
assert_eq!(test_creds(), loaded);
}
#[sqlx::test]
async fn test_save_load(pool: SqlitePool) {
let crypt = Crypto::random();
let creds = test_creds();
creds.save(&test_uuid_random(), "test", &crypt, &pool).await.unwrap();
let loaded = AwsBaseCredential::load("test", &crypt, &pool).await.unwrap();
assert_eq!(creds, loaded);
}
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_list(pool: SqlitePool) {
let crypt = Crypto::fixed();
let list = AwsBaseCredential::list(&crypt, &pool).await
.expect("Failed to list AWS credentials");
let first = SaveCredential {
id: test_uuid(),
name: "test".into(),
credential: Credential::AwsBase(test_creds()),
};
assert_eq!(&first, &list[0]);
let second = SaveCredential {
id: test_uuid_2(),
name: "test2".into(),
credential: Credential::AwsBase(test_creds_2()),
};
assert_eq!(&second, &list[1]);
}
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_rekey(pool: SqlitePool) {
let old_crypt = Crypto::fixed();
let orig = AwsBaseCredential::list(&old_crypt, &pool).await.unwrap();
let new_crypt = Crypto::random();
AwsBaseCredential::rekey(&old_crypt, &new_crypt, &pool).await
.expect("Failed to re-key AWS credentials");
let rekeyed = AwsBaseCredential::list(&new_crypt, &pool).await.unwrap();
for (before, after) in orig.iter().zip(rekeyed.iter()) {
assert_eq!(before, after);
}
}
}

View File

@@ -0,0 +1,19 @@
INSERT INTO credentials (id, name, type, created_at)
VALUES
(X'00000000000000000000000000000000', 'test', 'aws', strftime('%s')),
(X'ffffffffffffffffffffffffffffffff', 'test2', 'aws', strftime('%s'));
INSERT INTO aws_credentials (id, access_key_id, secret_key_enc, nonce)
VALUES
(
X'00000000000000000000000000000000',
'AKIAIOSFODNN7EXAMPLE',
X'B09ACDADD07E295A3FD9146D8D3672FA5C2518BFB15CF039E68820C42EFD3BC3BE3156ACF438C2C957EC113EF8617DBC71790EAFE39B3DE8',
X'DB777F2C6315DC0E12ADF322256E69D09D7FB586AAE614A6'
),
(
X'ffffffffffffffffffffffffffffffff',
'AKIAIOSFODNN7EXAMPL2',
X'ED6125FF40EF6F61929DF5FFD7141CD2B5A302A51C20152156477F8CC77980C614AB1B212AC06983F3CED35C4F3C54D4EE38964859930FBF',
X'962396B78DAA98DFDCC0AC0C9B7D688EC121F5759EBA790A'
);

View File

@@ -1,4 +1,4 @@
use std::fmt::{Debug, Formatter};
use std::fmt::{self, Debug, Formatter};
use argon2::{
Argon2,
@@ -17,8 +17,15 @@ use chacha20poly1305::{
generic_array::GenericArray,
},
};
use serde::Deserialize;
use serde::{
Serialize,
Deserialize,
Serializer,
Deserializer,
};
use serde::de::{self, Visitor};
use sqlx::SqlitePool;
use sqlx::types::Uuid;
use crate::errors::*;
use crate::kv;
@@ -27,11 +34,73 @@ mod aws;
pub use aws::{AwsBaseCredential, AwsSessionCredential};
pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
async fn load(crypt: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError>;
async fn save(&self, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError>;
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub enum Credential {
AwsBase(AwsBaseCredential),
AwsSession(AwsSessionCredential),
}
// we need a special type for listing structs because
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct SaveCredential {
#[serde(serialize_with = "serialize_uuid")]
#[serde(deserialize_with = "deserialize_uuid")]
id: Uuid, // UUID so it can be generated on the frontend
name: String, // user-facing identifier so it can be changed
credential: Credential,
}
impl SaveCredential {
pub async fn save(&self, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
let cred = match &self.credential {
Credential::AwsBase(b) => b,
Credential::AwsSession(_) => return Err(SaveCredentialsError::NotPersistent),
};
cred.save(&self.id, &self.name, crypt, pool).await
}
}
fn serialize_uuid<S: Serializer>(u: &Uuid, s: S) -> Result<S::Ok, S::Error> {
let mut buf = Vec::new();
s.serialize_str(u.as_hyphenated().encode_lower(&mut buf))
}
struct UuidVisitor;
impl<'de> Visitor<'de> for UuidVisitor {
type Value = Uuid;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "a hyphenated UUID")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Uuid, E> {
Uuid::try_parse(v)
.map_err(|_| E::custom(format!("Could not interpret string as UUID: {v}")))
}
}
fn deserialize_uuid<'de, D: Deserializer<'de>>(ds: D) -> Result<Uuid, D::Error> {
ds.deserialize_str(UuidVisitor)
}
pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
async fn load(name: &str, crypt: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError>;
async fn list(crypt: &Crypto, pool: &SqlitePool) -> Result<Vec<SaveCredential>, LoadCredentialsError>;
async fn save(&self, id: &Uuid, name: &str, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError>;
async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
for cred in Self::list(old, pool).await? {
cred.save(new, pool).await?;
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub enum AppSession {
Unlocked {
@@ -89,14 +158,14 @@ impl AppSession {
match self {
Self::Unlocked {salt, crypto} => {
let (nonce, blob) = crypto.encrypt(b"correct horse battery staple")?;
kv::save(pool, "salt", salt).await?;
kv::save(pool, "verify_nonce", &nonce.as_slice()).await?;
kv::save(pool, "verify_blob", &blob).await?;
kv::save_bytes(pool, "salt", salt).await?;
kv::save_bytes(pool, "verify_nonce", &nonce.as_slice()).await?;
kv::save_bytes(pool, "verify_blob", &blob).await?;
},
Self::Locked {salt, verify_nonce, verify_blob} => {
kv::save(pool, "salt", salt).await?;
kv::save(pool, "verify_nonce", &verify_nonce.as_slice()).await?;
kv::save(pool, "verify_blob", verify_blob).await?;
kv::save_bytes(pool, "salt", salt).await?;
kv::save_bytes(pool, "verify_nonce", &verify_nonce.as_slice()).await?;
kv::save_bytes(pool, "verify_blob", verify_blob).await?;
},
// "saving" an empty session just means doing nothing
Self::Empty => (),
@@ -187,6 +256,25 @@ impl Crypto {
Ok(Crypto { cipher })
}
#[cfg(test)]
pub fn random() -> Crypto {
// salt and key are the same length, so we can just use this
let key = Crypto::salt();
let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key));
Crypto { cipher }
}
#[cfg(test)]
pub fn fixed() -> Crypto {
let key = [
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
];
let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key));
Crypto { cipher }
}
fn salt() -> [u8; 32] {
let mut salt = [0; 32];
OsRng.fill_bytes(&mut salt);
@@ -210,3 +298,16 @@ impl Debug for Crypto {
write!(f, "Crypto {{ [...] }}")
}
}
// #[cfg(test)]
// mod tests {
// use super::*;
// #[sqlx::test(fixtures("uuid_test"))]
// async fn save_uuid(pool: SqlitePool) {
// let u = Uuid::try_parse("7140b90c-bfbd-4394-9008-01b94f94ecf8").unwrap();
// sqlx::query!("INSERT INTO uuids (uuid) VALUES (?)", u).execute(pool).unwrap();
// panic!("done, go check db");
// }
// }