use std::fmt::{self, Formatter}; use std::time::{SystemTime, UNIX_EPOCH}; use aws_config::BehaviorVersion; use aws_smithy_types::date_time::{DateTime, Format}; use chacha20poly1305::XNonce; use serde::{ Serialize, Deserialize, Serializer, Deserializer, }; use serde::de::{self, Visitor}; use sqlx::{ FromRow, Sqlite, Transaction, types::Uuid, }; use super::{Credential, Crypto, PersistentCredential}; use crate::errors::*; #[derive(Debug, Clone, FromRow)] pub struct AwsRow { id: Uuid, access_key_id: String, secret_key_enc: Vec, nonce: Vec, } #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct AwsBaseCredential { #[serde(default = "default_credentials_version")] pub version: usize, pub access_key_id: String, pub secret_access_key: String, } impl AwsBaseCredential { pub fn new(access_key_id: String, secret_access_key: String) -> Self { Self {version: 1, access_key_id, secret_access_key} } } impl PersistentCredential for AwsBaseCredential { type Row = AwsRow; fn type_name() -> &'static str { "aws" } fn into_credential(self) -> Credential { Credential::AwsBase(self) } fn row_id(row: &AwsRow) -> Uuid { row.id } fn from_row(row: AwsRow, crypto: &Crypto) -> Result { let nonce = XNonce::clone_from_slice(&row.nonce); let secret_key_bytes = crypto.decrypt(&nonce, &row.secret_key_enc)?; let secret_key = String::from_utf8(secret_key_bytes) .map_err(|_| LoadCredentialsError::InvalidData)?; Ok(Self::new(row.access_key_id, secret_key)) } async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError> { let (nonce, ciphertext) = crypto.encrypt(self.secret_access_key.as_bytes())?; let nonce_bytes = &nonce.as_slice(); sqlx::query!( "INSERT OR REPLACE INTO aws_credentials ( id, access_key_id, secret_key_enc, nonce ) VALUES (?, ?, ?, ?);", id, self.access_key_id, ciphertext, nonce_bytes, ).execute(&mut **txn).await?; Ok(()) } } #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct AwsSessionCredential { #[serde(default = "default_credentials_version")] pub version: usize, pub access_key_id: String, pub secret_access_key: String, pub session_token: String, #[serde(serialize_with = "serialize_expiration")] #[serde(deserialize_with = "deserialize_expiration")] pub expiration: DateTime, } impl AwsSessionCredential { pub async fn from_base(base: &AwsBaseCredential) -> Result { let req_creds = aws_sdk_sts::config::Credentials::new( &base.access_key_id, &base.secret_access_key, None, // token None, //expiration "Creddy", // "provider name" apparently ); let config = aws_config::defaults(BehaviorVersion::latest()) .credentials_provider(req_creds) .load() .await; let client = aws_sdk_sts::Client::new(&config); let resp = client.get_session_token() .duration_seconds(43_200) .send() .await?; let aws_session = resp.credentials.ok_or(GetSessionError::EmptyResponse)?; let session_creds = AwsSessionCredential { version: 1, access_key_id: aws_session.access_key_id, secret_access_key: aws_session.secret_access_key, session_token: aws_session.session_token, expiration: aws_session.expiration, }; #[cfg(debug_assertions)] println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap()); Ok(session_creds) } pub fn is_expired(&self) -> bool { let current_ts = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() // doesn't panic because UNIX_EPOCH won't be later than now() .as_secs(); let expire_ts = self.expiration.secs(); let remaining = expire_ts - (current_ts as i64); remaining < 60 } } fn default_credentials_version() -> usize { 1 } struct DateTimeVisitor; impl<'de> Visitor<'de> for DateTimeVisitor { type Value = DateTime; fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"") } fn visit_str(self, v: &str) -> Result { DateTime::from_str(v, Format::DateTime) .map_err(|_| E::custom(format!("Invalid date/time: {v}"))) } } fn deserialize_expiration<'de, D>(deserializer: D) -> Result where D: Deserializer<'de> { deserializer.deserialize_str(DateTimeVisitor) } fn serialize_expiration(exp: &DateTime, serializer: S) -> Result where S: Serializer { // this only fails if the d/t is out of range, which it can't be for this format let time_str = exp.fmt(Format::DateTime).unwrap(); serializer.serialize_str(&time_str) } #[cfg(test)] mod tests { use super::*; use sqlx::SqlitePool; use sqlx::types::uuid::uuid; fn creds() -> AwsBaseCredential { AwsBaseCredential::new( "AKIAIOSFODNN7EXAMPLE".into(), "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".into(), ) } fn creds_2() -> AwsBaseCredential { AwsBaseCredential::new( "AKIAIOSFODNN7EXAMPL2".into(), "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKE2".into(), ) } fn test_uuid() -> Uuid { Uuid::try_parse("00000000-0000-0000-0000-000000000000").unwrap() } fn test_uuid_2() -> Uuid { Uuid::try_parse("ffffffff-ffff-ffff-ffff-ffffffffffff").unwrap() } fn test_uuid_random() -> Uuid { let bytes = Crypto::salt(); Uuid::from_slice(&bytes[..16]).unwrap() } #[sqlx::test(fixtures("aws_credentials"))] async fn test_load(pool: SqlitePool) { let crypt = Crypto::fixed(); let id = uuid!("00000000-0000-0000-0000-000000000000"); let loaded = AwsBaseCredential::load(&id, &crypt, &pool).await.unwrap(); assert_eq!(creds(), loaded); } #[sqlx::test(fixtures("aws_credentials"))] async fn test_load_by_name(pool: SqlitePool) { let crypt = Crypto::fixed(); let loaded = AwsBaseCredential::load_by_name("test2", &crypt, &pool).await.unwrap(); assert_eq!(creds_2(), loaded); } #[sqlx::test(fixtures("aws_credentials"))] async fn test_load_default(pool: SqlitePool) { let crypt = Crypto::fixed(); let loaded = AwsBaseCredential::load_default(&crypt, &pool).await.unwrap(); assert_eq!(creds(), loaded) } #[sqlx::test(fixtures("aws_credentials"))] async fn test_list(pool: SqlitePool) { let crypt = Crypto::fixed(); let list: Vec<_> = AwsBaseCredential::list(&crypt, &pool) .await .expect("Failed to load credentials") .into_iter() .map(|(_, cred)| cred) .collect(); assert_eq!(&creds().into_credential(), &list[0]); assert_eq!(&creds_2().into_credential(), &list[1]); } }