start refactoring for default credentials
This commit is contained in:
@@ -24,8 +24,15 @@ use serde::{
|
||||
Deserializer,
|
||||
};
|
||||
use serde::de::{self, Visitor};
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::types::Uuid;
|
||||
use sqlx::{
|
||||
Error as SqlxError,
|
||||
FromRow,
|
||||
Sqlite,
|
||||
SqlitePool,
|
||||
sqlite::SqliteRow,
|
||||
Transaction,
|
||||
types::Uuid,
|
||||
};
|
||||
|
||||
use crate::errors::*;
|
||||
use crate::kv;
|
||||
@@ -35,6 +42,7 @@ pub use aws::{AwsBaseCredential, AwsSessionCredential};
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum Credential {
|
||||
AwsBase(AwsBaseCredential),
|
||||
AwsSession(AwsSessionCredential),
|
||||
@@ -43,27 +51,74 @@ pub enum Credential {
|
||||
|
||||
// we need a special type for listing structs because
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SaveCredential {
|
||||
pub struct CredentialRecord {
|
||||
#[serde(serialize_with = "serialize_uuid")]
|
||||
#[serde(deserialize_with = "deserialize_uuid")]
|
||||
id: Uuid, // UUID so it can be generated on the frontend
|
||||
name: String, // user-facing identifier so it can be changed
|
||||
is_default: bool,
|
||||
credential: Credential,
|
||||
}
|
||||
|
||||
impl SaveCredential {
|
||||
pub async fn save(&self, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
|
||||
let cred = match &self.credential {
|
||||
Credential::AwsBase(b) => b,
|
||||
Credential::AwsSession(_) => return Err(SaveCredentialsError::NotPersistent),
|
||||
impl CredentialRecord {
|
||||
pub async fn save(&self, crypto: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
|
||||
let type_name = match &self.credential {
|
||||
Credential::AwsBase(_) => AwsBaseCredential::type_name(),
|
||||
_ => return Err(SaveCredentialsError::NotPersistent),
|
||||
};
|
||||
|
||||
cred.save(&self.id, &self.name, crypt, pool).await
|
||||
// if the credential being saved is default, make sure it's the only default of its type
|
||||
let mut txn = pool.begin().await?;
|
||||
if self.is_default {
|
||||
sqlx::query!(
|
||||
"UPDATE credentials SET is_default = 0 WHERE type = ?",
|
||||
type_name
|
||||
).execute(&mut *txn).await?;
|
||||
}
|
||||
|
||||
// save to parent credentials table
|
||||
let res = sqlx::query!(
|
||||
"INSERT INTO credentials (id, name, type, is_default)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT DO UPDATE SET
|
||||
name = excluded.name,
|
||||
type = excluded.type,
|
||||
is_default = excluded.is_default",
|
||||
self.id, self.name, type_name, self.is_default
|
||||
).execute(&mut *txn).await;
|
||||
|
||||
// if id is unique, but name is not, we will get an error
|
||||
// (if id is not unique, this becomes an upsert due to ON CONFLICT clause)
|
||||
match res {
|
||||
Err(SqlxError::Database(e)) if e.is_unique_violation() => Err(SaveCredentialsError::Duplicate),
|
||||
Err(e) => Err(SaveCredentialsError::DbError(e)),
|
||||
Ok(_) => Ok(())
|
||||
}?;
|
||||
|
||||
// save credential details to child table
|
||||
match &self.credential {
|
||||
Credential::AwsBase(b) => b.save_details(&self.id, crypto, &mut txn).await,
|
||||
_ => Err(SaveCredentialsError::NotPersistent),
|
||||
}?;
|
||||
|
||||
// make it real
|
||||
txn.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
pub async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<Self>, LoadCredentialsError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
pub async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_uuid<S: Serializer>(u: &Uuid, s: S) -> Result<S::Ok, S::Error> {
|
||||
let mut buf = Vec::new();
|
||||
let mut buf = Uuid::encode_buffer();
|
||||
s.serialize_str(u.as_hyphenated().encode_lower(&mut buf))
|
||||
}
|
||||
|
||||
@@ -88,15 +143,57 @@ fn deserialize_uuid<'de, D: Deserializer<'de>>(ds: D) -> Result<Uuid, D::Error>
|
||||
|
||||
|
||||
pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
|
||||
async fn load(name: &str, crypt: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError>;
|
||||
async fn list(crypt: &Crypto, pool: &SqlitePool) -> Result<Vec<SaveCredential>, LoadCredentialsError>;
|
||||
async fn save(&self, id: &Uuid, name: &str, crypt: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError>;
|
||||
type Row: Send + Unpin + for<'r> FromRow<'r, SqliteRow>;
|
||||
|
||||
async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
|
||||
for cred in Self::list(old, pool).await? {
|
||||
cred.save(new, pool).await?;
|
||||
}
|
||||
Ok(())
|
||||
fn type_name() -> &'static str;
|
||||
fn from_row(row: Self::Row, crypto: &Crypto) -> Result<Self, LoadCredentialsError>;
|
||||
// save_details needs to be implemented per-type because we don't know the number of parameters in advance
|
||||
async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError>;
|
||||
|
||||
fn table_name() -> String {
|
||||
format!("{}_credentials", Self::type_name())
|
||||
}
|
||||
|
||||
async fn load(id: &Uuid, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
|
||||
let q = format!("SELECT * FROM {} WHERE id = ?", Self::table_name());
|
||||
let row: Self::Row = sqlx::query_as(&q)
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.ok_or(LoadCredentialsError::NoCredentials)?;
|
||||
|
||||
Self::from_row(row, crypto)
|
||||
}
|
||||
|
||||
async fn load_by_name(name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
|
||||
let q = format!(
|
||||
"SELECT * FROM {} WHERE id = (SELECT id FROM credentials WHERE name = ?)",
|
||||
Self::table_name(),
|
||||
);
|
||||
let row: Self::Row = sqlx::query_as(&q)
|
||||
.bind(name)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.ok_or(LoadCredentialsError::NoCredentials)?;
|
||||
|
||||
Self::from_row(row, crypto)
|
||||
}
|
||||
|
||||
async fn load_default(crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
|
||||
let q = format!(
|
||||
"SELECT details.*
|
||||
FROM {} details
|
||||
JOIN credentials c
|
||||
ON c.id = details.id
|
||||
AND c.is_default = 1",
|
||||
Self::table_name(),
|
||||
);
|
||||
let row: Self::Row = sqlx::query_as(&q)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.ok_or(LoadCredentialsError::NoCredentials)?;
|
||||
|
||||
Self::from_row(row, crypto)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,14 +397,48 @@ impl Debug for Crypto {
|
||||
}
|
||||
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// #[sqlx::test(fixtures("uuid_test"))]
|
||||
// async fn save_uuid(pool: SqlitePool) {
|
||||
// let u = Uuid::try_parse("7140b90c-bfbd-4394-9008-01b94f94ecf8").unwrap();
|
||||
// sqlx::query!("INSERT INTO uuids (uuid) VALUES (?)", u).execute(pool).unwrap();
|
||||
// panic!("done, go check db");
|
||||
// }
|
||||
// }
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
struct UuidWrapper {
|
||||
#[serde(serialize_with = "serialize_uuid")]
|
||||
#[serde(deserialize_with = "deserialize_uuid")]
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_uuid() {
|
||||
let u = UuidWrapper {
|
||||
id: Uuid::try_parse("693f84d2-4c1b-41e5-8483-cbe178324e04").unwrap()
|
||||
};
|
||||
let computed = serde_json::to_string(&u).unwrap();
|
||||
assert_eq!(
|
||||
"{\"id\":\"693f84d2-4c1b-41e5-8483-cbe178324e04\"}",
|
||||
&computed,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_uuid() {
|
||||
let s = "{\"id\":\"045bd359-8630-4b76-9b7d-e4a86ed2222c\"}";
|
||||
let computed = serde_json::from_str(s).unwrap();
|
||||
let expected = UuidWrapper {
|
||||
id: Uuid::try_parse("045bd359-8630-4b76-9b7d-e4a86ed2222c").unwrap(),
|
||||
};
|
||||
assert_eq!(expected, computed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize_uuid() {
|
||||
let buf = Crypto::salt();
|
||||
let expected = UuidWrapper{
|
||||
id: Uuid::from_slice(&buf[..16]).unwrap()
|
||||
};
|
||||
let serialized = serde_json::to_string(&expected).unwrap();
|
||||
let computed = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(expected, computed)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user