use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use serde::{ Serialize, Deserialize, Serializer, Deserializer, }; use serde::de::{self, Visitor}; use sqlx::{ Error as SqlxError, FromRow, SqlitePool, types::Uuid, }; use tokio_stream::StreamExt; use crate::errors::*; use super::{ AwsBaseCredential, Credential, Crypto, PersistentCredential, SshKey, }; #[derive(Debug, Clone, FromRow)] #[allow(dead_code)] struct CredentialRow { id: Uuid, name: String, credential_type: String, is_default: bool, created_at: i64, } #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] pub struct CredentialRecord { #[serde(serialize_with = "serialize_uuid")] #[serde(deserialize_with = "deserialize_uuid")] pub id: Uuid, // UUID so it can be generated on the frontend pub name: String, // user-facing identifier so it can be changed pub is_default: bool, pub credential: Credential, } impl CredentialRecord { pub async fn save(&self, crypto: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> { let type_name = match &self.credential { Credential::AwsBase(_) => AwsBaseCredential::type_name(), Credential::Ssh(_) => SshKey::type_name(), _ => return Err(SaveCredentialsError::NotPersistent), }; // if the credential being saved is default, make sure it's the only default of its type let mut txn = pool.begin().await?; if self.is_default { sqlx::query!( "UPDATE credentials SET is_default = 0 WHERE credential_type = ?", type_name ).execute(&mut *txn).await?; } // save to parent credentials table let res = sqlx::query!( "INSERT INTO credentials (id, name, credential_type, is_default, created_at) VALUES (?, ?, ?, ?, strftime('%s')) ON CONFLICT(id) DO UPDATE SET name = excluded.name, credential_type = excluded.credential_type, is_default = excluded.is_default", self.id, self.name, type_name, self.is_default ).execute(&mut *txn).await; // if id is unique, but name is not, we will get an error // (if id is not unique, this becomes an upsert due to ON CONFLICT clause) match res { Err(SqlxError::Database(e)) if e.is_unique_violation() => Err(SaveCredentialsError::Duplicate), Err(e) => Err(SaveCredentialsError::DbError(e)), Ok(_) => Ok(()) }?; // save credential details to child table match &self.credential { Credential::AwsBase(b) => b.save_details(&self.id, crypto, &mut txn).await, Credential::Ssh(s) => s.save_details(&self.id, crypto, &mut txn).await, _ => Err(SaveCredentialsError::NotPersistent), }?; // make it real txn.commit().await?; Ok(()) } fn from_parts(row: CredentialRow, credential: Credential) -> Self { CredentialRecord { id: row.id, name: row.name, is_default: row.is_default, credential, } } async fn load_credential(row: CredentialRow, crypto: &Crypto, pool: &SqlitePool) -> Result { let credential = match row.credential_type.as_str() { "aws" => Credential::AwsBase(AwsBaseCredential::load(&row.id, crypto, pool).await?), _ => return Err(LoadCredentialsError::InvalidData), }; Ok(Self::from_parts(row, credential)) } #[cfg(test)] pub async fn load(id: &Uuid, crypto: &Crypto, pool: &SqlitePool) -> Result { let row: CredentialRow = sqlx::query_as("SELECT * FROM credentials WHERE id = ?") .bind(id) .fetch_optional(pool) .await? .ok_or(LoadCredentialsError::NoCredentials)?; Self::load_credential(row, crypto, pool).await } pub async fn load_by_name(name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result { let row: CredentialRow = sqlx::query_as("SELECT * FROM credentials WHERE name = ?") .bind(name) .fetch_optional(pool) .await? .ok_or(LoadCredentialsError::NoCredentials)?; Self::load_credential(row, crypto, pool).await } pub async fn load_default(credential_type: &str, crypto: &Crypto, pool: &SqlitePool) -> Result { let row: CredentialRow = sqlx::query_as( "SELECT * FROM credentials WHERE credential_type = ? AND is_default = 1" ).bind(credential_type) .fetch_optional(pool) .await? .ok_or(LoadCredentialsError::NoCredentials)?; Self::load_credential(row, crypto, pool).await } pub async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result, LoadCredentialsError> { let mut parent_rows = sqlx::query_as::<_, CredentialRow>( "SELECT * FROM credentials" ).fetch(pool); let mut parent_map = HashMap::new(); while let Some(row) = parent_rows.try_next().await? { parent_map.insert(row.id, row); } let mut records = Vec::with_capacity(parent_map.len()); for (id, credential) in AwsBaseCredential::list(crypto, pool).await? { let parent = parent_map.remove(&id) .ok_or(LoadCredentialsError::InvalidData)?; records.push(Self::from_parts(parent, credential)); } for (id, credential) in SshKey::list(crypto, pool).await? { let parent = parent_map.remove(&id) .ok_or(LoadCredentialsError::InvalidData)?; records.push(Self::from_parts(parent, credential)); } Ok(records) } pub async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> { for record in Self::list(old, pool).await? { record.save(new, pool).await?; } Ok(()) } } fn serialize_uuid(u: &Uuid, s: S) -> Result { let mut buf = Uuid::encode_buffer(); s.serialize_str(u.as_hyphenated().encode_lower(&mut buf)) } struct UuidVisitor; impl<'de> Visitor<'de> for UuidVisitor { type Value = Uuid; fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { write!(formatter, "a hyphenated UUID") } fn visit_str(self, v: &str) -> Result { Uuid::try_parse(v) .map_err(|_| E::custom(format!("Could not interpret string as UUID: {v}"))) } } fn deserialize_uuid<'de, D: Deserializer<'de>>(ds: D) -> Result { ds.deserialize_str(UuidVisitor) } #[cfg(test)] mod tests { use super::*; use sqlx::types::uuid::uuid; fn aws_record() -> CredentialRecord { let id = uuid!("00000000-0000-0000-0000-000000000000"); let aws = AwsBaseCredential::new( "AKIAIOSFODNN7EXAMPLE".into(), "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".into(), ); CredentialRecord { id, name: "test".into(), is_default: true, credential: Credential::AwsBase(aws), } } fn aws_record_2() -> CredentialRecord { let id = uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff"); let aws = AwsBaseCredential::new( "AKIAIOSFODNN7EXAMPL2".into(), "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKE2".into(), ); CredentialRecord { id, name: "test2".into(), is_default: false, credential: Credential::AwsBase(aws), } } fn random_uuid() -> Uuid { let bytes = Crypto::salt(); Uuid::from_slice(&bytes[..16]).unwrap() } #[sqlx::test(fixtures("aws_credentials"))] async fn test_load_aws(pool: SqlitePool) { let crypt = Crypto::fixed(); let id = uuid!("00000000-0000-0000-0000-000000000000"); let loaded = CredentialRecord::load(&id, &crypt, &pool).await .expect("Failed to load record"); assert_eq!(aws_record(), loaded); } #[sqlx::test(fixtures("aws_credentials"))] async fn test_load_aws_default(pool: SqlitePool) { let crypt = Crypto::fixed(); let loaded = CredentialRecord::load_default("aws", &crypt, &pool).await .expect("Failed to load record"); assert_eq!(aws_record(), loaded); } #[sqlx::test] async fn test_save_aws(pool: SqlitePool) { let crypt = Crypto::random(); let mut record = aws_record(); record.id = random_uuid(); aws_record().save(&crypt, &pool).await .expect("Failed to save record"); } #[sqlx::test] async fn test_save_load_aws(pool: SqlitePool) { let crypt = Crypto::random(); let mut record = aws_record(); record.id = random_uuid(); record.save(&crypt, &pool).await .expect("Failed to save record"); let loaded = CredentialRecord::load(&record.id, &crypt, &pool).await .expect("Failed to load record"); assert_eq!(record, loaded); } #[sqlx::test] async fn test_overwrite_aws(pool: SqlitePool) { let crypt = Crypto::fixed(); let original = aws_record(); original.save(&crypt, &pool).await .expect("Failed to save first record"); let mut updated = aws_record_2(); updated.id = original.id; updated.save(&crypt, &pool).await .expect("Failed to overwrite first record with second record"); // make sure update went through let loaded = CredentialRecord::load(&updated.id, &crypt, &pool).await.unwrap(); assert_eq!(updated, loaded); } #[sqlx::test(fixtures("aws_credentials"))] async fn test_duplicate_name(pool: SqlitePool) { let crypt = Crypto::random(); let mut record = aws_record(); record.id = random_uuid(); let resp = record.save(&crypt, &pool).await; if !matches!(resp, Err(SaveCredentialsError::Duplicate)) { panic!("Attempt to create duplicate entry returned {resp:?}") } } #[sqlx::test(fixtures("aws_credentials"))] async fn test_change_default(pool: SqlitePool) { let crypt = Crypto::fixed(); let id = uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff"); // confirm that record as it currently exists in the database is not default let mut record = CredentialRecord::load(&id, &crypt, &pool).await .expect("Failed to load record"); assert!(!record.is_default); record.is_default = true; record.save(&crypt, &pool).await .expect("Failed to save record"); let loaded = CredentialRecord::load(&id, &crypt, &pool).await .expect("Failed to re-load record"); assert!(loaded.is_default); let other_id = uuid!("00000000-0000-0000-0000-000000000000"); let other_loaded = CredentialRecord::load(&other_id, &crypt, &pool).await .expect("Failed to load other credential"); assert!(!other_loaded.is_default); } #[sqlx::test(fixtures("aws_credentials"))] async fn test_list(pool: SqlitePool) { let crypt = Crypto::fixed(); let records = CredentialRecord::list(&crypt, &pool).await .expect("Failed to list credentials"); assert_eq!(aws_record(), records[0]); assert_eq!(aws_record_2(), records[1]); } #[sqlx::test(fixtures("aws_credentials"))] async fn test_rekey(pool: SqlitePool) { let old = Crypto::fixed(); let new = Crypto::random(); CredentialRecord::rekey(&old, &new, &pool).await .expect("Failed to rekey credentials"); let records = CredentialRecord::list(&new, &pool).await .expect("Failed to re-list credentials"); assert_eq!(aws_record(), records[0]); assert_eq!(aws_record_2(), records[1]); } } #[cfg(test)] mod uuid_tests { use super::*; use sqlx::types::uuid::uuid; #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] struct UuidWrapper { #[serde(serialize_with = "serialize_uuid")] #[serde(deserialize_with = "deserialize_uuid")] id: Uuid, } #[test] fn test_serialize_uuid() { let u = UuidWrapper { id: uuid!("693f84d2-4c1b-41e5-8483-cbe178324e04") }; let computed = serde_json::to_string(&u).unwrap(); assert_eq!( "{\"id\":\"693f84d2-4c1b-41e5-8483-cbe178324e04\"}", &computed, ); } #[test] fn test_deserialize_uuid() { let s = "{\"id\":\"045bd359-8630-4b76-9b7d-e4a86ed2222c\"}"; let computed = serde_json::from_str(s).unwrap(); let expected = UuidWrapper { id: uuid!("045bd359-8630-4b76-9b7d-e4a86ed2222c"), }; assert_eq!(expected, computed); } #[test] fn test_serialize_deserialize_uuid() { let buf = Crypto::salt(); let expected = UuidWrapper{ id: Uuid::from_slice(&buf[..16]).unwrap() }; let serialized = serde_json::to_string(&expected).unwrap(); let computed = serde_json::from_str(&serialized).unwrap(); assert_eq!(expected, computed) } }