411 lines
12 KiB
Rust
411 lines
12 KiB
Rust
use std::collections::HashMap;
|
|
use std::fmt::{self, Debug, Formatter};
|
|
use serde::{
|
|
Serialize,
|
|
Deserialize,
|
|
Serializer,
|
|
Deserializer,
|
|
};
|
|
use serde::de::{self, Visitor};
|
|
use sqlx::{
|
|
Error as SqlxError,
|
|
FromRow,
|
|
SqlitePool,
|
|
types::Uuid,
|
|
};
|
|
use tokio_stream::StreamExt;
|
|
|
|
use crate::errors::*;
|
|
use super::{
|
|
AwsBaseCredential,
|
|
Credential,
|
|
Crypto,
|
|
PersistentCredential,
|
|
};
|
|
|
|
|
|
#[derive(Debug, Clone, FromRow)]
|
|
struct CredentialRow {
|
|
id: Uuid,
|
|
name: String,
|
|
credential_type: String,
|
|
is_default: bool,
|
|
created_at: i64,
|
|
}
|
|
|
|
|
|
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
|
pub struct CredentialRecord {
|
|
#[serde(serialize_with = "serialize_uuid")]
|
|
#[serde(deserialize_with = "deserialize_uuid")]
|
|
pub id: Uuid, // UUID so it can be generated on the frontend
|
|
pub name: String, // user-facing identifier so it can be changed
|
|
pub is_default: bool,
|
|
pub credential: Credential,
|
|
}
|
|
|
|
impl CredentialRecord {
|
|
pub async fn save(&self, crypto: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
|
|
let type_name = match &self.credential {
|
|
Credential::AwsBase(_) => AwsBaseCredential::type_name(),
|
|
_ => return Err(SaveCredentialsError::NotPersistent),
|
|
};
|
|
|
|
// if the credential being saved is default, make sure it's the only default of its type
|
|
let mut txn = pool.begin().await?;
|
|
if self.is_default {
|
|
sqlx::query!(
|
|
"UPDATE credentials SET is_default = 0 WHERE credential_type = ?",
|
|
type_name
|
|
).execute(&mut *txn).await?;
|
|
}
|
|
|
|
// save to parent credentials table
|
|
let res = sqlx::query!(
|
|
"INSERT INTO credentials (id, name, credential_type, is_default, created_at)
|
|
VALUES (?, ?, ?, ?, strftime('%s'))
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
name = excluded.name,
|
|
credential_type = excluded.credential_type,
|
|
is_default = excluded.is_default",
|
|
self.id, self.name, type_name, self.is_default
|
|
).execute(&mut *txn).await;
|
|
|
|
// if id is unique, but name is not, we will get an error
|
|
// (if id is not unique, this becomes an upsert due to ON CONFLICT clause)
|
|
match res {
|
|
Err(SqlxError::Database(e)) if e.is_unique_violation() => Err(SaveCredentialsError::Duplicate),
|
|
Err(e) => Err(SaveCredentialsError::DbError(e)),
|
|
Ok(_) => Ok(())
|
|
}?;
|
|
|
|
// save credential details to child table
|
|
match &self.credential {
|
|
Credential::AwsBase(b) => b.save_details(&self.id, crypto, &mut txn).await,
|
|
_ => Err(SaveCredentialsError::NotPersistent),
|
|
}?;
|
|
|
|
// make it real
|
|
txn.commit().await?;
|
|
Ok(())
|
|
}
|
|
|
|
fn from_parts(row: CredentialRow, credential: Credential) -> Self {
|
|
CredentialRecord {
|
|
id: row.id,
|
|
name: row.name,
|
|
is_default: row.is_default,
|
|
credential,
|
|
}
|
|
}
|
|
|
|
async fn load_credential(row: CredentialRow, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
|
|
let credential = match row.credential_type.as_str() {
|
|
"aws" => Credential::AwsBase(AwsBaseCredential::load(&row.id, crypto, pool).await?),
|
|
_ => return Err(LoadCredentialsError::InvalidData),
|
|
};
|
|
|
|
Ok(Self::from_parts(row, credential))
|
|
}
|
|
|
|
pub async fn load(id: &Uuid, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
|
|
let row: CredentialRow = sqlx::query_as("SELECT * FROM credentials WHERE id = ?")
|
|
.bind(id)
|
|
.fetch_optional(pool)
|
|
.await?
|
|
.ok_or(LoadCredentialsError::NoCredentials)?;
|
|
|
|
Self::load_credential(row, crypto, pool).await
|
|
}
|
|
|
|
pub async fn load_default(credential_type: &str, crypto: &Crypto, pool: &SqlitePool) -> Result<Self, LoadCredentialsError> {
|
|
let row: CredentialRow = sqlx::query_as(
|
|
"SELECT * FROM credentials
|
|
WHERE credential_type = ? AND is_default = 1"
|
|
).bind(credential_type)
|
|
.fetch_optional(pool)
|
|
.await?
|
|
.ok_or(LoadCredentialsError::NoCredentials)?;
|
|
|
|
Self::load_credential(row, crypto, pool).await
|
|
}
|
|
|
|
pub async fn list(crypto: &Crypto, pool: &SqlitePool) -> Result<Vec<Self>, LoadCredentialsError> {
|
|
let mut parent_rows = sqlx::query_as::<_, CredentialRow>(
|
|
"SELECT * FROM credentials"
|
|
).fetch(pool);
|
|
|
|
let mut parent_map = HashMap::new();
|
|
while let Some(row) = parent_rows.try_next().await? {
|
|
parent_map.insert(row.id, row);
|
|
}
|
|
|
|
let mut records = Vec::with_capacity(parent_map.len());
|
|
|
|
for (id, credential) in AwsBaseCredential::list(crypto, pool).await? {
|
|
let parent = parent_map.remove(&id)
|
|
.ok_or(LoadCredentialsError::InvalidData)?;
|
|
records.push(Self::from_parts(parent, credential));
|
|
}
|
|
|
|
Ok(records)
|
|
}
|
|
|
|
pub async fn rekey(old: &Crypto, new: &Crypto, pool: &SqlitePool) -> Result<(), SaveCredentialsError> {
|
|
for record in Self::list(old, pool).await? {
|
|
record.save(new, pool).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
fn serialize_uuid<S: Serializer>(u: &Uuid, s: S) -> Result<S::Ok, S::Error> {
|
|
let mut buf = Uuid::encode_buffer();
|
|
s.serialize_str(u.as_hyphenated().encode_lower(&mut buf))
|
|
}
|
|
|
|
struct UuidVisitor;
|
|
|
|
impl<'de> Visitor<'de> for UuidVisitor {
|
|
type Value = Uuid;
|
|
|
|
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
|
|
write!(formatter, "a hyphenated UUID")
|
|
}
|
|
|
|
fn visit_str<E: de::Error>(self, v: &str) -> Result<Uuid, E> {
|
|
Uuid::try_parse(v)
|
|
.map_err(|_| E::custom(format!("Could not interpret string as UUID: {v}")))
|
|
}
|
|
}
|
|
|
|
fn deserialize_uuid<'de, D: Deserializer<'de>>(ds: D) -> Result<Uuid, D::Error> {
|
|
ds.deserialize_str(UuidVisitor)
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use sqlx::types::uuid::uuid;
|
|
|
|
|
|
fn aws_record() -> CredentialRecord {
|
|
let id = uuid!("00000000-0000-0000-0000-000000000000");
|
|
let aws = AwsBaseCredential::new(
|
|
"AKIAIOSFODNN7EXAMPLE".into(),
|
|
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".into(),
|
|
);
|
|
|
|
CredentialRecord {
|
|
id,
|
|
name: "test".into(),
|
|
is_default: true,
|
|
credential: Credential::AwsBase(aws),
|
|
}
|
|
}
|
|
|
|
fn aws_record_2() -> CredentialRecord {
|
|
let id = uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff");
|
|
let aws = AwsBaseCredential::new(
|
|
"AKIAIOSFODNN7EXAMPL2".into(),
|
|
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKE2".into(),
|
|
);
|
|
|
|
CredentialRecord {
|
|
id,
|
|
name: "test2".into(),
|
|
is_default: false,
|
|
credential: Credential::AwsBase(aws),
|
|
}
|
|
}
|
|
|
|
fn random_uuid() -> Uuid {
|
|
let bytes = Crypto::salt();
|
|
Uuid::from_slice(&bytes[..16]).unwrap()
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_load_aws(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
let id = uuid!("00000000-0000-0000-0000-000000000000");
|
|
let loaded = CredentialRecord::load(&id, &crypt, &pool).await
|
|
.expect("Failed to load record");
|
|
|
|
assert_eq!(aws_record(), loaded);
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_load_aws_default(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
let loaded = CredentialRecord::load_default("aws", &crypt, &pool).await
|
|
.expect("Failed to load record");
|
|
|
|
assert_eq!(aws_record(), loaded);
|
|
}
|
|
|
|
|
|
#[sqlx::test]
|
|
async fn test_save_aws(pool: SqlitePool) {
|
|
let crypt = Crypto::random();
|
|
let mut record = aws_record();
|
|
record.id = random_uuid();
|
|
|
|
aws_record().save(&crypt, &pool).await
|
|
.expect("Failed to save record");
|
|
}
|
|
|
|
|
|
#[sqlx::test]
|
|
async fn test_save_load(pool: SqlitePool) {
|
|
let crypt = Crypto::random();
|
|
let mut record = aws_record();
|
|
record.id = random_uuid();
|
|
|
|
record.save(&crypt, &pool).await
|
|
.expect("Failed to save record");
|
|
let loaded = CredentialRecord::load(&record.id, &crypt, &pool).await
|
|
.expect("Failed to load record");
|
|
|
|
assert_eq!(record, loaded);
|
|
}
|
|
|
|
|
|
#[sqlx::test]
|
|
async fn test_overwrite_aws(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
|
|
let original = aws_record();
|
|
original.save(&crypt, &pool).await
|
|
.expect("Failed to save first record");
|
|
|
|
let mut updated = aws_record_2();
|
|
updated.id = original.id;
|
|
updated.save(&crypt, &pool).await
|
|
.expect("Failed to overwrite first record with second record");
|
|
|
|
// make sure update went through
|
|
let loaded = CredentialRecord::load(&updated.id, &crypt, &pool).await.unwrap();
|
|
assert_eq!(updated, loaded);
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_duplicate_name(pool: SqlitePool) {
|
|
let crypt = Crypto::random();
|
|
|
|
let mut record = aws_record();
|
|
record.id = random_uuid();
|
|
let resp = record.save(&crypt, &pool).await;
|
|
|
|
if !matches!(resp, Err(SaveCredentialsError::Duplicate)) {
|
|
panic!("Attempt to create duplicate entry returned {resp:?}")
|
|
}
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_change_default(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
let id = uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff");
|
|
|
|
// confirm that record as it currently exists in the database is not default
|
|
let mut record = CredentialRecord::load(&id, &crypt, &pool).await
|
|
.expect("Failed to load record");
|
|
assert!(!record.is_default);
|
|
|
|
record.is_default = true;
|
|
record.save(&crypt, &pool).await
|
|
.expect("Failed to save record");
|
|
|
|
let loaded = CredentialRecord::load(&id, &crypt, &pool).await
|
|
.expect("Failed to re-load record");
|
|
assert!(loaded.is_default);
|
|
|
|
let other_id = uuid!("00000000-0000-0000-0000-000000000000");
|
|
let other_loaded = CredentialRecord::load(&other_id, &crypt, &pool).await
|
|
.expect("Failed to load other credential");
|
|
assert!(!other_loaded.is_default);
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_list(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
|
|
let records = CredentialRecord::list(&crypt, &pool).await
|
|
.expect("Failed to list credentials");
|
|
|
|
assert_eq!(aws_record(), records[0]);
|
|
assert_eq!(aws_record_2(), records[1]);
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_rekey(pool: SqlitePool) {
|
|
let old = Crypto::fixed();
|
|
let new = Crypto::random();
|
|
|
|
CredentialRecord::rekey(&old, &new, &pool).await
|
|
.expect("Failed to rekey credentials");
|
|
|
|
let records = CredentialRecord::list(&new, &pool).await
|
|
.expect("Failed to re-list credentials");
|
|
|
|
assert_eq!(aws_record(), records[0]);
|
|
assert_eq!(aws_record_2(), records[1]);
|
|
}
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
mod uuid_tests {
|
|
use super::*;
|
|
use sqlx::types::uuid::uuid;
|
|
|
|
|
|
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
|
|
struct UuidWrapper {
|
|
#[serde(serialize_with = "serialize_uuid")]
|
|
#[serde(deserialize_with = "deserialize_uuid")]
|
|
id: Uuid,
|
|
}
|
|
|
|
|
|
#[test]
|
|
fn test_serialize_uuid() {
|
|
let u = UuidWrapper {
|
|
id: uuid!("693f84d2-4c1b-41e5-8483-cbe178324e04")
|
|
};
|
|
let computed = serde_json::to_string(&u).unwrap();
|
|
assert_eq!(
|
|
"{\"id\":\"693f84d2-4c1b-41e5-8483-cbe178324e04\"}",
|
|
&computed,
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_deserialize_uuid() {
|
|
let s = "{\"id\":\"045bd359-8630-4b76-9b7d-e4a86ed2222c\"}";
|
|
let computed = serde_json::from_str(s).unwrap();
|
|
let expected = UuidWrapper {
|
|
id: uuid!("045bd359-8630-4b76-9b7d-e4a86ed2222c"),
|
|
};
|
|
assert_eq!(expected, computed);
|
|
}
|
|
|
|
#[test]
|
|
fn test_serialize_deserialize_uuid() {
|
|
let buf = Crypto::salt();
|
|
let expected = UuidWrapper{
|
|
id: Uuid::from_slice(&buf[..16]).unwrap()
|
|
};
|
|
let serialized = serde_json::to_string(&expected).unwrap();
|
|
let computed = serde_json::from_str(&serialized).unwrap();
|
|
assert_eq!(expected, computed)
|
|
}
|
|
}
|