start refactoring for default credentials

This commit is contained in:
2024-06-26 11:10:50 -04:00
parent 8c668e51a6
commit 37b44ddb2e
21 changed files with 708 additions and 632 deletions

View File

@@ -24,8 +24,15 @@ use serde::{
Deserializer,
};
use serde::de::{self, Visitor};
use sqlx::SqlitePool;
use sqlx::types::Uuid;
use sqlx::{
Error as SqlxError,
FromRow,
Sqlite,
SqlitePool,
sqlite::SqliteRow,
Transaction,
types::Uuid,
};
use crate::errors::*;
use crate::kv;
@@ -35,6 +42,7 @@ pub use aws::{AwsBaseCredential, AwsSessionCredential};
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Credential {
AwsBase(AwsBaseCredential),
AwsSession(AwsSessionCredential),
@@ -43,27 +51,74 @@ pub enum Credential {
// we need a special type for listing structs because
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct SaveCredential {
pub struct CredentialRecord {
#[serde(serialize_with = "serialize_uuid")]
#[serde(deserialize_with = "deserialize_uuid")]
id: Uuid, // UUID so it can be generated on the frontend
name: String, // user-facing identifier so it can be changed
is_default: bool,
credential: Credential,
}
impl SaveCredential {
pub async fn save(&self, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
let cred = match &self.credential {
Credential::AwsBase(b) => b,
Credential::AwsSession(_) => return Err(SaveCredentialsError::NotPersistent),
impl CredentialRecord {
pub async fn save(&self, crypto: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
let type_name = match &self.credential {
Credential::AwsBase(_) => AwsBaseCredential::type_name(),
_ => return Err(SaveCredentialsError::NotPersistent),
};
cred.save(&self.id, &self.name, crypt, pool).await
// if the credential being saved is default, make sure it's the only default of its type
let mut txn = pool.begin().await?;
if self.is_default {
sqlx::query!(
"UPDATE credentials SET is_default = 0 WHERE type = ?",
type_name
).execute(&mut *txn).await?;
}
// save to parent credentials table
let res = sqlx::query!(
"INSERT INTO credentials (id, name, type, is_default)
VALUES (?, ?, ?, ?)
ON CONFLICT DO UPDATE SET
name = excluded.name,
type = excluded.type,
is_default = excluded.is_default",
self.id, self.name, type_name, self.is_default
).execute(&mut *txn).await;
// if id is unique, but name is not, we will get an error
// (if id is not unique, this becomes an upsert due to ON CONFLICT clause)
match res {
Err(SqlxError::Database(e)) if e.is_unique_violation() => Err(SaveCredentialsError::Duplicate),
Err(e) => Err(SaveCredentialsError::DbError(e)),
Ok(_) => Ok(())
}?;
// save credential details to child table
match &self.credential {
Credential::AwsBase(b) => b.save_details(&self.id, crypto, &mut txn).await,
_ => Err(SaveCredentialsError::NotPersistent),
}?;
// make it real
txn.commit().await?;
Ok(())
}
#[allow(unused_variables)]
pub async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<Self>, LoadCredentialsError> {
todo!()
}
#[allow(unused_variables)]
pub async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
todo!()
}
}
fn serialize_uuid<S: Serializer>(u: &Uuid, s: S) -> Result<S::Ok, S::Error> {
let mut buf = Vec::new();
let mut buf = Uuid::encode_buffer();
s.serialize_str(u.as_hyphenated().encode_lower(&mut buf))
}
@@ -88,15 +143,57 @@ fn deserialize_uuid<'de, D: Deserializer<'de>>(ds: D) -> Result<Uuid, D::Error>
pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
async fn load(name: &str, crypt: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError>;
async fn list(crypt: &Crypto, pool: &SqlitePool) -> Result<Vec<SaveCredential>, LoadCredentialsError>;
async fn save(&self, id: &Uuid, name: &str, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError>;
type Row: Send + Unpin + for<'r> FromRow<'r, SqliteRow>;
async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
for cred in Self::list(old, pool).await? {
cred.save(new, pool).await?;
}
Ok(())
fn type_name() -> &'static str;
fn from_row(row: Self::Row, crypto: &Crypto) -> Result<Self, LoadCredentialsError>;
// save_details needs to be implemented per-type because we don't know the number of parameters in advance
async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError>;
fn table_name() -> String {
format!("{}_credentials", Self::type_name())
}
async fn load(id: &Uuid, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
let q = format!("SELECT * FROM {} WHERE id = ?", Self::table_name());
let row: Self::Row = sqlx::query_as(&q)
.bind(id)
.fetch_optional(pool)
.await?
.ok_or(LoadCredentialsError::NoCredentials)?;
Self::from_row(row, crypto)
}
async fn load_by_name(name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
let q = format!(
"SELECT * FROM {} WHERE id = (SELECT id FROM credentials WHERE name = ?)",
Self::table_name(),
);
let row: Self::Row = sqlx::query_as(&q)
.bind(name)
.fetch_optional(pool)
.await?
.ok_or(LoadCredentialsError::NoCredentials)?;
Self::from_row(row, crypto)
}
async fn load_default(crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
let q = format!(
"SELECT details.*
FROM {} details
JOIN credentials c
ON c.id = details.id
AND c.is_default = 1",
Self::table_name(),
);
let row: Self::Row = sqlx::query_as(&q)
.fetch_optional(pool)
.await?
.ok_or(LoadCredentialsError::NoCredentials)?;
Self::from_row(row, crypto)
}
}
@@ -300,14 +397,48 @@ impl Debug for Crypto {
}
// #[cfg(test)]
// mod tests {
// use super::*;
#[cfg(test)]
mod tests {
use super::*;
// #[sqlx::test(fixtures("uuid_test"))]
// async fn save_uuid(pool: SqlitePool) {
// let u = Uuid::try_parse("7140b90c-bfbd-4394-9008-01b94f94ecf8").unwrap();
// sqlx::query!("INSERT INTO uuids (uuid) VALUES (?)", u).execute(pool).unwrap();
// panic!("done, go check db");
// }
// }
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
struct UuidWrapper {
#[serde(serialize_with = "serialize_uuid")]
#[serde(deserialize_with = "deserialize_uuid")]
id: Uuid,
}
#[test]
fn test_serialize_uuid() {
let u = UuidWrapper {
id: Uuid::try_parse("693f84d2-4c1b-41e5-8483-cbe178324e04").unwrap()
};
let computed = serde_json::to_string(&u).unwrap();
assert_eq!(
"{\"id\":\"693f84d2-4c1b-41e5-8483-cbe178324e04\"}",
&computed,
);
}
#[test]
fn test_deserialize_uuid() {
let s = "{\"id\":\"045bd359-8630-4b76-9b7d-e4a86ed2222c\"}";
let computed = serde_json::from_str(s).unwrap();
let expected = UuidWrapper {
id: Uuid::try_parse("045bd359-8630-4b76-9b7d-e4a86ed2222c").unwrap(),
};
assert_eq!(expected, computed);
}
#[test]
fn test_serialize_deserialize_uuid() {
let buf = Crypto::salt();
let expected = UuidWrapper{
id: Uuid::from_slice(&buf[..16]).unwrap()
};
let serialized = serde_json::to_string(&expected).unwrap();
let computed = serde_json::from_str(&serialized).unwrap();
assert_eq!(expected, computed)
}
}