continue working on default credentials

This commit is contained in:
2024-06-26 22:24:44 -04:00
parent ce7d75f15a
commit bb980c5eef
5 changed files with 74 additions and 118 deletions

View File

@ -18,14 +18,14 @@ use sqlx::{
types::Uuid,
};
use super::{Crypto, PersistentCredential};
use super::{Credential, Crypto, PersistentCredential};
use crate::errors::*;
#[derive(Debug, Clone, FromRow)]
pub struct AwsRow {
pub id: Uuid,
id: Uuid,
access_key_id: String,
secret_key_enc: Vec<u8>,
nonce: Vec<u8>,
@ -53,6 +53,10 @@ impl PersistentCredential for AwsBaseCredential {
fn type_name() -> &'static str { "aws" }
fn into_credential(self) -> Credential { Credential::AwsBase(self) }
fn row_id(row: &AwsRow) -> Uuid { row.id }
fn from_row(row: AwsRow, crypto: &Crypto) -> Result<Self, LoadCredentialsError> {
let nonce = XNonce::clone_from_slice(&row.nonce);
let secret_key_bytes = crypto.decrypt(&nonce, &row.secret_key_enc)?;
@ -79,93 +83,6 @@ impl PersistentCredential for AwsBaseCredential {
Ok(())
}
// async fn save(&self, record: CredentialRecord, &Crypto, pool: &SqlitePool) -> Result<(), CredentialRecordsError> {
// let (nonce, ciphertext) = crypto.encrypt(self.secret_access_key.as_bytes())?;
// let nonce_bytes = &nonce.as_slice();
// let res = sqlx::query!(
// "INSERT INTO credentials (id, name, type, created_at)
// VALUES (?, ?, 'aws', strftime('%s'))
// ON CONFLICT(id) DO UPDATE SET
// name = excluded.name,
// type = excluded.type,
// created_at = excluded.created_at;
// INSERT OR REPLACE INTO aws_credentials (
// id,
// access_key_id,
// secret_key_enc,
// nonce
// )
// VALUES (?, ?, ?, ?);",
// id,
// name,
// id, // for the second query
// self.access_key_id,
// ciphertext,
// nonce_bytes,
// ).execute(pool).await;
// match res {
// Err(SqlxError::Database(e)) if e.code().as_deref() == Some("2067") => Err(CredentialRecordsError::Duplicate),
// Err(e) => Err(SaveCredentialsError::DbError(e)),
// Ok(_) => Ok(())
// }
// }
// async fn load(name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
// let record: AwsRecord = sqlx::query_as(
// "SELECT c.id, c.name, c.is_default, a.access_key_id, a.secret_key_enc, a.nonce
// FROM credentials c JOIN aws_credentials a ON a.id = c.id
// WHERE c.name = ?"
// ).bind(name)
// .fetch_optional(pool)
// .await?
// .ok_or(LoadCredentialsError::NoCredentials)?;
// let key = record.decrypt_key(crypto)?;
// let credential = AwsBaseCredential::new(record.access_key_id, key);
// Ok(credential)
// }
// async fn load_default(crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
// let record: AwsRecord = sqlx::query_as(
// "SELECT c.id, c.name, c.is_default, a.access_key_id, a.secret_key_enc, a.nonce
// FROM credentials c JOIN aws_credentials a ON a.id = c.id
// WHERE c.type = 'aws' AND c.is_default = 1"
// ).fetch_optional(pool)
// .await?
// .ok_or(LoadCredentialsError::NoCredentials)?;
// let key = record.decrypt_key(crypto)?;
// let credential = AwsBaseCredential::new(record.access_key_id, key);
// Ok(credential)
// }
// async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<SaveCredential>, LoadCredentialsError> {
// let mut rows = sqlx::query_as::<_, AwsRecord>(
// "SELECT c.id, c.name, c.is_default, a.access_key_id, a.secret_key_enc, a.nonce
// FROM credentials c JOIN aws_credentials a ON a.id = c.id"
// ).fetch(pool);
// let mut creds = Vec::new();
// while let Some(record) = rows.try_next().await? {
// let key = record.decrypt_key(crypto)?;
// let aws = AwsBaseCredential::new(record.access_key_id, key);
// let cred = SaveCredential {
// id: record.id,
// name: record.name,
// is_default: record.is_default,
// credential: Credential::AwsBase(aws),
// };
// creds.push(cred);
// }
// Ok(creds)
// }
}
@ -269,6 +186,7 @@ where S: Serializer
mod tests {
use super::*;
use sqlx::SqlitePool;
use sqlx::types::uuid::uuid;
fn creds() -> AwsBaseCredential {
@ -302,7 +220,8 @@ mod tests {
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_load(pool: SqlitePool) {
let crypt = Crypto::fixed();
let loaded = AwsBaseCredential::load(&test_uuid(), &crypt, &pool).await.unwrap();
let id = uuid!("00000000-0000-0000-0000-000000000000");
let loaded = AwsBaseCredential::load(&id, &crypt, &pool).await.unwrap();
assert_eq!(creds(), loaded);
}
@ -326,14 +245,14 @@ mod tests {
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_list(pool: SqlitePool) {
let crypt = Crypto::fixed();
let list: Vec<_> = AwsBaseCredential::list(&pool)
let list: Vec<_> = AwsBaseCredential::list(&crypt, &pool)
.await
.expect("Failed to load credentials")
.into_iter()
.map(|r| AwsBaseCredential::from_row(r, &crypt).unwrap())
.map(|(_, cred)| cred)
.collect();
assert_eq!(&creds(), &list[0]);
assert_eq!(&creds_2(), &list[1]);
assert_eq!(&creds().into_credential(), &list[0]);
assert_eq!(&creds_2().into_credential(), &list[1]);
}
}

View File

@ -1,5 +1,3 @@
use std::fmt::Formatter;
use serde::{Serialize, Deserialize};
use sqlx::{
FromRow,
@ -9,6 +7,7 @@ use sqlx::{
Transaction,
types::Uuid,
};
use tokio_stream::StreamExt;
use crate::errors::*;
@ -37,7 +36,13 @@ pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
type Row: Send + Unpin + for<'r> FromRow<'r, SqliteRow>;
fn type_name() -> &'static str;
fn into_credential(self) -> Credential;
fn row_id(row: &Self::Row) -> Uuid;
fn from_row(row: Self::Row, crypto: &Crypto) -> Result<Self, LoadCredentialsError>;
// save_details needs to be implemented per-type because we don't know the number of parameters in advance
async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError>;
@ -87,9 +92,25 @@ pub trait PersistentCredential: for<'a> Deserialize<'a> + Sized {
Self::from_row(row, crypto)
}
async fn list(pool: &SqlitePool) -> Result<Vec<Self::Row>, LoadCredentialsError> {
let q = format!("SELECT * FROM {}", Self::table_name());
let rows: Vec<Self::Row> = sqlx::query_as(&q).fetch_all(pool).await?;
Ok(rows)
async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<(Uuid, Credential)>, LoadCredentialsError> {
let q = format!(
"SELECT details.*
FROM
{} details
JOIN credentials c
ON c.id = details.id
ORDER BY c.created_at",
Self::table_name(),
);
let mut rows = sqlx::query_as::<_, Self::Row>(&q).fetch(pool);
let mut creds = Vec::new();
while let Some(row) = rows.try_next().await? {
let id = Self::row_id(&row);
let cred = Self::from_row(row, crypto)?.into_credential();
creds.push((id, cred));
}
Ok(creds)
}
}

View File

@ -18,7 +18,6 @@ use tokio_stream::StreamExt;
use crate::errors::*;
use super::{
AwsBaseCredential,
aws::AwsRow,
Credential,
Crypto,
PersistentCredential,
@ -100,7 +99,7 @@ impl CredentialRecord {
}
}
async fn load_details(row: CredentialRow, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
async fn load_credential(row: CredentialRow, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
let credential = match row.credential_type.as_str() {
"aws" => Credential::AwsBase(AwsBaseCredential::load(&row.id, crypto, pool).await?),
_ => return Err(LoadCredentialsError::InvalidData),
@ -116,7 +115,7 @@ impl CredentialRecord {
.await?
.ok_or(LoadCredentialsError::NoCredentials)?;
Self::load_details(row, crypto, pool).await
Self::load_credential(row, crypto, pool).await
}
pub async fn load_default(credential_type: &str, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
@ -128,7 +127,7 @@ impl CredentialRecord {
.await?
.ok_or(LoadCredentialsError::NoCredentials)?;
Self::load_details(row, crypto, pool).await
Self::load_credential(row, crypto, pool).await
}
pub async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<Self>, LoadCredentialsError> {
@ -143,10 +142,9 @@ impl CredentialRecord {
let mut records = Vec::with_capacity(parent_map.len());
for row in AwsBaseCredential::list(&pool).await? {
let parent = parent_map.remove(&row.id)
for (id, credential) in AwsBaseCredential::list(crypto, pool).await? {
let parent = parent_map.remove(&id)
.ok_or(LoadCredentialsError::InvalidData)?;
let credential = Credential::AwsBase(AwsBaseCredential::from_row(row, crypto)?);
records.push(Self::from_parts(parent, credential));
}
@ -274,6 +272,7 @@ mod tests {
}
#[sqlx::test]
async fn test_overwrite_aws(pool: SqlitePool) {
let crypt = Crypto::fixed();
@ -329,6 +328,18 @@ mod tests {
.expect("Failed to load other credential");
assert!(!other_loaded.is_default);
}
#[sqlx::test(fixtures("aws_credentials"))]
async fn test_list(pool: SqlitePool) {
let crypt = Crypto::fixed();
let records = CredentialRecord::list(&crypt, &pool).await
.expect("Failed to list credentials");
assert_eq!(aws_record(), records[0]);
assert_eq!(aws_record_2(), records[1]);
}
}