From bb980c5eef40e36aa407dde00472472c90f7e664 Mon Sep 17 00:00:00 2001 From: Joseph Montanaro Date: Wed, 26 Jun 2024 22:24:44 -0400 Subject: [PATCH] continue working on default credentials --- src-tauri/src/credentials/aws.rs | 107 +++------------------ src-tauri/src/credentials/mod.rs | 33 +++++-- src-tauri/src/credentials/record.rs | 25 +++-- src/views/ManageCredentials.svelte | 8 +- src/views/credentials/AwsCredential.svelte | 19 ++-- 5 files changed, 74 insertions(+), 118 deletions(-) diff --git a/src-tauri/src/credentials/aws.rs b/src-tauri/src/credentials/aws.rs index c104c80..f937b4c 100644 --- a/src-tauri/src/credentials/aws.rs +++ b/src-tauri/src/credentials/aws.rs @@ -18,14 +18,14 @@ use sqlx::{ types::Uuid, }; -use super::{Crypto, PersistentCredential}; +use super::{Credential, Crypto, PersistentCredential}; use crate::errors::*; #[derive(Debug, Clone, FromRow)] pub struct AwsRow { - pub id: Uuid, + id: Uuid, access_key_id: String, secret_key_enc: Vec, nonce: Vec, @@ -53,6 +53,10 @@ impl PersistentCredential for AwsBaseCredential { fn type_name() -> &'static str { "aws" } + fn into_credential(self) -> Credential { Credential::AwsBase(self) } + + fn row_id(row: &AwsRow) -> Uuid { row.id } + fn from_row(row: AwsRow, crypto: &Crypto) -> Result { let nonce = XNonce::clone_from_slice(&row.nonce); let secret_key_bytes = crypto.decrypt(&nonce, &row.secret_key_enc)?; @@ -79,93 +83,6 @@ impl PersistentCredential for AwsBaseCredential { Ok(()) } - - // async fn save(&self, record: CredentialRecord, &Crypto, pool: &SqlitePool) -> Result<(), CredentialRecordsError> { - // let (nonce, ciphertext) = crypto.encrypt(self.secret_access_key.as_bytes())?; - // let nonce_bytes = &nonce.as_slice(); - - // let res = sqlx::query!( - // "INSERT INTO credentials (id, name, type, created_at) - // VALUES (?, ?, 'aws', strftime('%s')) - // ON CONFLICT(id) DO UPDATE SET - // name = excluded.name, - // type = excluded.type, - // created_at = excluded.created_at; - - // INSERT OR REPLACE INTO aws_credentials ( - // id, - // access_key_id, - // secret_key_enc, - // nonce - // ) - // VALUES (?, ?, ?, ?);", - // id, - // name, - // id, // for the second query - // self.access_key_id, - // ciphertext, - // nonce_bytes, - // ).execute(pool).await; - - // match res { - // Err(SqlxError::Database(e)) if e.code().as_deref() == Some("2067") => Err(CredentialRecordsError::Duplicate), - // Err(e) => Err(SaveCredentialsError::DbError(e)), - // Ok(_) => Ok(()) - // } - // } - - // async fn load(name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result { - // let record: AwsRecord = sqlx::query_as( - // "SELECT c.id, c.name, c.is_default, a.access_key_id, a.secret_key_enc, a.nonce - // FROM credentials c JOIN aws_credentials a ON a.id = c.id - // WHERE c.name = ?" - // ).bind(name) - // .fetch_optional(pool) - // .await? - // .ok_or(LoadCredentialsError::NoCredentials)?; - - // let key = record.decrypt_key(crypto)?; - // let credential = AwsBaseCredential::new(record.access_key_id, key); - // Ok(credential) - // } - - // async fn load_default(crypto: &Crypto, pool: &SqlitePool) -> Result { - // let record: AwsRecord = sqlx::query_as( - // "SELECT c.id, c.name, c.is_default, a.access_key_id, a.secret_key_enc, a.nonce - // FROM credentials c JOIN aws_credentials a ON a.id = c.id - // WHERE c.type = 'aws' AND c.is_default = 1" - // ).fetch_optional(pool) - // .await? - // .ok_or(LoadCredentialsError::NoCredentials)?; - - // let key = record.decrypt_key(crypto)?; - // let credential = AwsBaseCredential::new(record.access_key_id, key); - // Ok(credential) - // } - - // async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result, LoadCredentialsError> { - // let mut rows = sqlx::query_as::<_, AwsRecord>( - // "SELECT c.id, c.name, c.is_default, a.access_key_id, a.secret_key_enc, a.nonce - // FROM credentials c JOIN aws_credentials a ON a.id = c.id" - // ).fetch(pool); - - // let mut creds = Vec::new(); - - // while let Some(record) = rows.try_next().await? { - // let key = record.decrypt_key(crypto)?; - // let aws = AwsBaseCredential::new(record.access_key_id, key); - - // let cred = SaveCredential { - // id: record.id, - // name: record.name, - // is_default: record.is_default, - // credential: Credential::AwsBase(aws), - // }; - // creds.push(cred); - // } - - // Ok(creds) - // } } @@ -269,6 +186,7 @@ where S: Serializer mod tests { use super::*; use sqlx::SqlitePool; + use sqlx::types::uuid::uuid; fn creds() -> AwsBaseCredential { @@ -302,7 +220,8 @@ mod tests { #[sqlx::test(fixtures("aws_credentials"))] async fn test_load(pool: SqlitePool) { let crypt = Crypto::fixed(); - let loaded = AwsBaseCredential::load(&test_uuid(), &crypt, &pool).await.unwrap(); + let id = uuid!("00000000-0000-0000-0000-000000000000"); + let loaded = AwsBaseCredential::load(&id, &crypt, &pool).await.unwrap(); assert_eq!(creds(), loaded); } @@ -326,14 +245,14 @@ mod tests { #[sqlx::test(fixtures("aws_credentials"))] async fn test_list(pool: SqlitePool) { let crypt = Crypto::fixed(); - let list: Vec<_> = AwsBaseCredential::list(&pool) + let list: Vec<_> = AwsBaseCredential::list(&crypt, &pool) .await .expect("Failed to load credentials") .into_iter() - .map(|r| AwsBaseCredential::from_row(r, &crypt).unwrap()) + .map(|(_, cred)| cred) .collect(); - assert_eq!(&creds(), &list[0]); - assert_eq!(&creds_2(), &list[1]); + assert_eq!(&creds().into_credential(), &list[0]); + assert_eq!(&creds_2().into_credential(), &list[1]); } } diff --git a/src-tauri/src/credentials/mod.rs b/src-tauri/src/credentials/mod.rs index b9f0096..4fcef3a 100644 --- a/src-tauri/src/credentials/mod.rs +++ b/src-tauri/src/credentials/mod.rs @@ -1,5 +1,3 @@ -use std::fmt::Formatter; - use serde::{Serialize, Deserialize}; use sqlx::{ FromRow, @@ -9,6 +7,7 @@ use sqlx::{ Transaction, types::Uuid, }; +use tokio_stream::StreamExt; use crate::errors::*; @@ -37,7 +36,13 @@ pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized { type Row: Send + Unpin + for<'r> FromRow<'r, SqliteRow>; fn type_name() -> &'static str; + + fn into_credential(self) -> Credential; + + fn row_id(row: &Self::Row) -> Uuid; + fn from_row(row: Self::Row, crypto: &Crypto) -> Result; + // save_details needs to be implemented per-type because we don't know the number of parameters in advance async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError>; @@ -87,9 +92,25 @@ pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized { Self::from_row(row, crypto) } - async fn list(pool: &SqlitePool) -> Result, LoadCredentialsError> { - let q = format!("SELECT * FROM {}", Self::table_name()); - let rows: Vec = sqlx::query_as(&q).fetch_all(pool).await?; - Ok(rows) + async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result, LoadCredentialsError> { + let q = format!( + "SELECT details.* + FROM + {} details + JOIN credentials c + ON c.id = details.id + ORDER BY c.created_at", + Self::table_name(), + ); + let mut rows = sqlx::query_as::<_, Self::Row>(&q).fetch(pool); + + let mut creds = Vec::new(); + while let Some(row) = rows.try_next().await? { + let id = Self::row_id(&row); + let cred = Self::from_row(row, crypto)?.into_credential(); + creds.push((id, cred)); + } + + Ok(creds) } } diff --git a/src-tauri/src/credentials/record.rs b/src-tauri/src/credentials/record.rs index 23c7e2c..b3f42dd 100644 --- a/src-tauri/src/credentials/record.rs +++ b/src-tauri/src/credentials/record.rs @@ -18,7 +18,6 @@ use tokio_stream::StreamExt; use crate::errors::*; use super::{ AwsBaseCredential, - aws::AwsRow, Credential, Crypto, PersistentCredential, @@ -100,7 +99,7 @@ impl CredentialRecord { } } - async fn load_details(row: CredentialRow, crypto: &Crypto, pool: &SqlitePool) -> Result { + async fn load_credential(row: CredentialRow, crypto: &Crypto, pool: &SqlitePool) -> Result { let credential = match row.credential_type.as_str() { "aws" => Credential::AwsBase(AwsBaseCredential::load(&row.id, crypto, pool).await?), _ => return Err(LoadCredentialsError::InvalidData), @@ -116,7 +115,7 @@ impl CredentialRecord { .await? .ok_or(LoadCredentialsError::NoCredentials)?; - Self::load_details(row, crypto, pool).await + Self::load_credential(row, crypto, pool).await } pub async fn load_default(credential_type: &str, crypto: &Crypto, pool: &SqlitePool) -> Result { @@ -128,7 +127,7 @@ impl CredentialRecord { .await? .ok_or(LoadCredentialsError::NoCredentials)?; - Self::load_details(row, crypto, pool).await + Self::load_credential(row, crypto, pool).await } pub async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result, LoadCredentialsError> { @@ -143,10 +142,9 @@ impl CredentialRecord { let mut records = Vec::with_capacity(parent_map.len()); - for row in AwsBaseCredential::list(&pool).await? { - let parent = parent_map.remove(&row.id) + for (id, credential) in AwsBaseCredential::list(crypto, pool).await? { + let parent = parent_map.remove(&id) .ok_or(LoadCredentialsError::InvalidData)?; - let credential = Credential::AwsBase(AwsBaseCredential::from_row(row, crypto)?); records.push(Self::from_parts(parent, credential)); } @@ -274,6 +272,7 @@ mod tests { } + #[sqlx::test] async fn test_overwrite_aws(pool: SqlitePool) { let crypt = Crypto::fixed(); @@ -329,6 +328,18 @@ mod tests { .expect("Failed to load other credential"); assert!(!other_loaded.is_default); } + + + #[sqlx::test(fixtures("aws_credentials"))] + async fn test_list(pool: SqlitePool) { + let crypt = Crypto::fixed(); + + let records = CredentialRecord::list(&crypt, &pool).await + .expect("Failed to list credentials"); + + assert_eq!(aws_record(), records[0]); + assert_eq!(aws_record_2(), records[1]); + } } diff --git a/src/views/ManageCredentials.svelte b/src/views/ManageCredentials.svelte index 8c1b247..fb2c232 100644 --- a/src/views/ManageCredentials.svelte +++ b/src/views/ManageCredentials.svelte @@ -1,6 +1,7 @@