259 lines
7.4 KiB
Rust
259 lines
7.4 KiB
Rust
use std::fmt::{self, Formatter};
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
use aws_config::BehaviorVersion;
|
|
use aws_smithy_types::date_time::{DateTime, Format};
|
|
use chacha20poly1305::XNonce;
|
|
use serde::{
|
|
Serialize,
|
|
Deserialize,
|
|
Serializer,
|
|
Deserializer,
|
|
};
|
|
use serde::de::{self, Visitor};
|
|
use sqlx::{
|
|
FromRow,
|
|
Sqlite,
|
|
Transaction,
|
|
types::Uuid,
|
|
};
|
|
|
|
use super::{Credential, Crypto, PersistentCredential};
|
|
|
|
use crate::errors::*;
|
|
|
|
|
|
#[derive(Debug, Clone, FromRow)]
|
|
pub struct AwsRow {
|
|
id: Uuid,
|
|
access_key_id: String,
|
|
secret_key_enc: Vec<u8>,
|
|
nonce: Vec<u8>,
|
|
}
|
|
|
|
|
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "PascalCase")]
|
|
pub struct AwsBaseCredential {
|
|
#[serde(default = "default_credentials_version")]
|
|
pub version: usize,
|
|
pub access_key_id: String,
|
|
pub secret_access_key: String,
|
|
}
|
|
|
|
|
|
impl AwsBaseCredential {
|
|
pub fn new(access_key_id: String, secret_access_key: String) -> Self {
|
|
Self {version: 1, access_key_id, secret_access_key}
|
|
}
|
|
}
|
|
|
|
impl PersistentCredential for AwsBaseCredential {
|
|
type Row = AwsRow;
|
|
|
|
fn type_name() -> &'static str { "aws" }
|
|
|
|
fn into_credential(self) -> Credential { Credential::AwsBase(self) }
|
|
|
|
fn row_id(row: &AwsRow) -> Uuid { row.id }
|
|
|
|
fn from_row(row: AwsRow, crypto: &Crypto) -> Result<Self, LoadCredentialsError> {
|
|
let nonce = XNonce::clone_from_slice(&row.nonce);
|
|
let secret_key_bytes = crypto.decrypt(&nonce, &row.secret_key_enc)?;
|
|
let secret_key = String::from_utf8(secret_key_bytes)
|
|
.map_err(|_| LoadCredentialsError::InvalidData)?;
|
|
|
|
Ok(Self::new(row.access_key_id, secret_key))
|
|
}
|
|
|
|
async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError> {
|
|
let (nonce, ciphertext) = crypto.encrypt(self.secret_access_key.as_bytes())?;
|
|
let nonce_bytes = &nonce.as_slice();
|
|
|
|
sqlx::query!(
|
|
"INSERT OR REPLACE INTO aws_credentials (
|
|
id,
|
|
access_key_id,
|
|
secret_key_enc,
|
|
nonce
|
|
)
|
|
VALUES (?, ?, ?, ?);",
|
|
id, self.access_key_id, ciphertext, nonce_bytes,
|
|
).execute(&mut **txn).await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "PascalCase")]
|
|
pub struct AwsSessionCredential {
|
|
#[serde(default = "default_credentials_version")]
|
|
pub version: usize,
|
|
pub access_key_id: String,
|
|
pub secret_access_key: String,
|
|
pub session_token: String,
|
|
#[serde(serialize_with = "serialize_expiration")]
|
|
#[serde(deserialize_with = "deserialize_expiration")]
|
|
pub expiration: DateTime,
|
|
}
|
|
|
|
impl AwsSessionCredential {
|
|
pub async fn from_base(base: &AwsBaseCredential) -> Result<Self, GetSessionError> {
|
|
let req_creds = aws_sdk_sts::config::Credentials::new(
|
|
&base.access_key_id,
|
|
&base.secret_access_key,
|
|
None, // token
|
|
None, //expiration
|
|
"Creddy", // "provider name" apparently
|
|
);
|
|
let config = aws_config::defaults(BehaviorVersion::latest())
|
|
.credentials_provider(req_creds)
|
|
.load()
|
|
.await;
|
|
|
|
let client = aws_sdk_sts::Client::new(&config);
|
|
let resp = client.get_session_token()
|
|
.duration_seconds(43_200)
|
|
.send()
|
|
.await?;
|
|
|
|
let aws_session = resp.credentials.ok_or(GetSessionError::EmptyResponse)?;
|
|
|
|
let session_creds = AwsSessionCredential {
|
|
version: 1,
|
|
access_key_id: aws_session.access_key_id,
|
|
secret_access_key: aws_session.secret_access_key,
|
|
session_token: aws_session.session_token,
|
|
expiration: aws_session.expiration,
|
|
};
|
|
|
|
#[cfg(debug_assertions)]
|
|
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
|
|
|
|
Ok(session_creds)
|
|
}
|
|
|
|
pub fn is_expired(&self) -> bool {
|
|
let current_ts = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
|
|
.as_secs();
|
|
|
|
let expire_ts = self.expiration.secs();
|
|
let remaining = expire_ts - (current_ts as i64);
|
|
remaining < 60
|
|
}
|
|
}
|
|
|
|
|
|
fn default_credentials_version() -> usize { 1 }
|
|
|
|
|
|
struct DateTimeVisitor;
|
|
|
|
impl<'de> Visitor<'de> for DateTimeVisitor {
|
|
type Value = DateTime;
|
|
|
|
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
|
|
write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"")
|
|
}
|
|
|
|
fn visit_str<E: de::Error>(self, v: &str) -> Result<DateTime, E> {
|
|
DateTime::from_str(v, Format::DateTime)
|
|
.map_err(|_| E::custom(format!("Invalid date/time: {v}")))
|
|
}
|
|
}
|
|
|
|
|
|
fn deserialize_expiration<'de, D>(deserializer: D) -> Result<DateTime, D::Error>
|
|
where D: Deserializer<'de>
|
|
{
|
|
deserializer.deserialize_str(DateTimeVisitor)
|
|
}
|
|
|
|
fn serialize_expiration<S>(exp: &DateTime, serializer: S) -> Result<S::Ok, S::Error>
|
|
where S: Serializer
|
|
{
|
|
// this only fails if the d/t is out of range, which it can't be for this format
|
|
let time_str = exp.fmt(Format::DateTime).unwrap();
|
|
serializer.serialize_str(&time_str)
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use sqlx::SqlitePool;
|
|
use sqlx::types::uuid::uuid;
|
|
|
|
|
|
fn creds() -> AwsBaseCredential {
|
|
AwsBaseCredential::new(
|
|
"AKIAIOSFODNN7EXAMPLE".into(),
|
|
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".into(),
|
|
)
|
|
}
|
|
|
|
fn creds_2() -> AwsBaseCredential {
|
|
AwsBaseCredential::new(
|
|
"AKIAIOSFODNN7EXAMPL2".into(),
|
|
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKE2".into(),
|
|
)
|
|
}
|
|
|
|
fn test_uuid() -> Uuid {
|
|
Uuid::try_parse("00000000-0000-0000-0000-000000000000").unwrap()
|
|
}
|
|
|
|
fn test_uuid_2() -> Uuid {
|
|
Uuid::try_parse("ffffffff-ffff-ffff-ffff-ffffffffffff").unwrap()
|
|
}
|
|
|
|
fn test_uuid_random() -> Uuid {
|
|
let bytes = Crypto::salt();
|
|
Uuid::from_slice(&bytes[..16]).unwrap()
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_load(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
let id = uuid!("00000000-0000-0000-0000-000000000000");
|
|
let loaded = AwsBaseCredential::load(&id, &crypt, &pool).await.unwrap();
|
|
assert_eq!(creds(), loaded);
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_load_by_name(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
let loaded = AwsBaseCredential::load_by_name("test2", &crypt, &pool).await.unwrap();
|
|
assert_eq!(creds_2(), loaded);
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_load_default(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
let loaded = AwsBaseCredential::load_default(&crypt, &pool).await.unwrap();
|
|
assert_eq!(creds(), loaded)
|
|
}
|
|
|
|
|
|
#[sqlx::test(fixtures("aws_credentials"))]
|
|
async fn test_list(pool: SqlitePool) {
|
|
let crypt = Crypto::fixed();
|
|
let list: Vec<_> = AwsBaseCredential::list(&crypt, &pool)
|
|
.await
|
|
.expect("Failed to load credentials")
|
|
.into_iter()
|
|
.map(|(_, cred)| cred)
|
|
.collect();
|
|
|
|
assert_eq!(&creds().into_credential(), &list[0]);
|
|
assert_eq!(&creds_2().into_credential(), &list[1]);
|
|
}
|
|
}
|