351 lines
10 KiB
Rust
351 lines
10 KiB
Rust
use std::fmt::{self, Formatter};
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
use aws_smithy_types::date_time::{DateTime, Format};
|
|
use argon2::{
|
|
Argon2,
|
|
Algorithm,
|
|
Version,
|
|
ParamsBuilder,
|
|
password_hash::rand_core::{RngCore, OsRng},
|
|
};
|
|
use chacha20poly1305::{
|
|
XChaCha20Poly1305,
|
|
XNonce,
|
|
aead::{
|
|
Aead,
|
|
AeadCore,
|
|
KeyInit,
|
|
Error as AeadError,
|
|
generic_array::GenericArray,
|
|
},
|
|
};
|
|
use serde::{
|
|
Serialize,
|
|
Deserialize,
|
|
Serializer,
|
|
Deserializer,
|
|
};
|
|
use serde::de::{self, Visitor};
|
|
use sqlx::SqlitePool;
|
|
|
|
|
|
use crate::errors::*;
|
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub enum Session {
|
|
Unlocked{
|
|
base: BaseCredentials,
|
|
session: SessionCredentials,
|
|
},
|
|
Locked(LockedCredentials),
|
|
Empty,
|
|
}
|
|
|
|
impl Session {
|
|
pub async fn load(pool: &SqlitePool) -> Result<Self, SetupError> {
|
|
let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc")
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
let row = match res {
|
|
Some(r) => r,
|
|
None => {return Ok(Session::Empty);}
|
|
};
|
|
|
|
let salt: [u8; 32] = row.salt
|
|
.try_into()
|
|
.map_err(|_e| SetupError::InvalidRecord)?;
|
|
let nonce = XNonce::from_exact_iter(row.nonce.into_iter())
|
|
.ok_or(SetupError::InvalidRecord)?;
|
|
|
|
let creds = LockedCredentials {
|
|
access_key_id: row.access_key_id,
|
|
secret_key_enc: row.secret_key_enc,
|
|
salt,
|
|
nonce,
|
|
};
|
|
Ok(Session::Locked(creds))
|
|
}
|
|
|
|
pub async fn renew_if_expired(&mut self) -> Result<bool, GetSessionError> {
|
|
match self {
|
|
Session::Unlocked{ref base, ref mut session} => {
|
|
if !session.is_expired() {
|
|
return Ok(false);
|
|
}
|
|
*session = SessionCredentials::from_base(base).await?;
|
|
Ok(true)
|
|
},
|
|
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
|
Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
|
}
|
|
}
|
|
|
|
pub fn try_get(
|
|
&self
|
|
) -> Result<(&BaseCredentials, &SessionCredentials), GetCredentialsError> {
|
|
match self {
|
|
Self::Empty => Err(GetCredentialsError::Empty),
|
|
Self::Locked(_) => Err(GetCredentialsError::Locked),
|
|
Self::Unlocked{ ref base, ref session } => Ok((base, session))
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct LockedCredentials {
|
|
pub access_key_id: String,
|
|
pub secret_key_enc: Vec<u8>,
|
|
pub salt: [u8; 32],
|
|
pub nonce: XNonce,
|
|
}
|
|
|
|
impl LockedCredentials {
|
|
pub async fn save(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> {
|
|
sqlx::query(
|
|
"INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at)
|
|
VALUES (?, ?, ?, ?, strftime('%s'))"
|
|
)
|
|
.bind(&self.access_key_id)
|
|
.bind(&self.secret_key_enc)
|
|
.bind(&self.salt[..])
|
|
.bind(&self.nonce[..])
|
|
.execute(pool)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn decrypt(&self, passphrase: &str) -> Result<BaseCredentials, UnlockError> {
|
|
let crypto = Crypto::new(passphrase, &self.salt)
|
|
.map_err(|e| CryptoError::Argon2(e))?;
|
|
let decrypted = crypto.decrypt(&self.nonce, &self.secret_key_enc)
|
|
.map_err(|e| CryptoError::Aead(e))?;
|
|
let secret_access_key = String::from_utf8(decrypted)
|
|
.map_err(|_| UnlockError::InvalidUtf8)?;
|
|
|
|
let creds = BaseCredentials::new(
|
|
self.access_key_id.clone(),
|
|
secret_access_key,
|
|
);
|
|
Ok(creds)
|
|
}
|
|
}
|
|
|
|
|
|
fn default_credentials_version() -> usize { 1 }
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(rename_all = "PascalCase")]
|
|
pub struct BaseCredentials {
|
|
#[serde(default = "default_credentials_version")]
|
|
pub version: usize,
|
|
pub access_key_id: String,
|
|
pub secret_access_key: String,
|
|
}
|
|
|
|
impl BaseCredentials {
|
|
pub fn new(access_key_id: String, secret_access_key: String) -> Self {
|
|
Self {version: 1, access_key_id, secret_access_key}
|
|
}
|
|
|
|
pub fn encrypt(&self, passphrase: &str) -> Result<LockedCredentials, CryptoError> {
|
|
let salt = Crypto::salt();
|
|
let crypto = Crypto::new(passphrase, &salt)?;
|
|
let (nonce, secret_key_enc) = crypto.encrypt(self.secret_access_key.as_bytes())?;
|
|
|
|
let locked = LockedCredentials {
|
|
access_key_id: self.access_key_id.clone(),
|
|
secret_key_enc,
|
|
salt,
|
|
nonce,
|
|
};
|
|
Ok(locked)
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(rename_all = "PascalCase")]
|
|
pub struct SessionCredentials {
|
|
#[serde(default = "default_credentials_version")]
|
|
pub version: usize,
|
|
pub access_key_id: String,
|
|
pub secret_access_key: String,
|
|
pub session_token: String,
|
|
#[serde(serialize_with = "serialize_expiration")]
|
|
#[serde(deserialize_with = "deserialize_expiration")]
|
|
pub expiration: DateTime,
|
|
}
|
|
|
|
impl SessionCredentials {
|
|
pub async fn from_base(base: &BaseCredentials) -> Result<Self, GetSessionError> {
|
|
let req_creds = aws_sdk_sts::Credentials::new(
|
|
&base.access_key_id,
|
|
&base.secret_access_key,
|
|
None, // token
|
|
None, //expiration
|
|
"Creddy", // "provider name" apparently
|
|
);
|
|
let config = aws_config::from_env()
|
|
.credentials_provider(req_creds)
|
|
.load()
|
|
.await;
|
|
|
|
let client = aws_sdk_sts::Client::new(&config);
|
|
let resp = client.get_session_token()
|
|
.duration_seconds(43_200)
|
|
.send()
|
|
.await?;
|
|
|
|
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
|
|
|
|
let access_key_id = aws_session.access_key_id()
|
|
.ok_or(GetSessionError::EmptyResponse)?
|
|
.to_string();
|
|
let secret_access_key = aws_session.secret_access_key()
|
|
.ok_or(GetSessionError::EmptyResponse)?
|
|
.to_string();
|
|
let session_token = aws_session.session_token()
|
|
.ok_or(GetSessionError::EmptyResponse)?
|
|
.to_string();
|
|
let expiration = aws_session.expiration()
|
|
.ok_or(GetSessionError::EmptyResponse)?
|
|
.clone();
|
|
|
|
let session_creds = SessionCredentials {
|
|
version: 1,
|
|
access_key_id,
|
|
secret_access_key,
|
|
session_token,
|
|
expiration,
|
|
};
|
|
|
|
#[cfg(debug_assertions)]
|
|
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
|
|
|
|
Ok(session_creds)
|
|
}
|
|
|
|
pub fn is_expired(&self) -> bool {
|
|
let current_ts = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
|
|
.as_secs();
|
|
|
|
let expire_ts = self.expiration.secs();
|
|
let remaining = expire_ts - (current_ts as i64);
|
|
remaining < 60
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub enum Credentials {
|
|
Base(BaseCredentials),
|
|
Session(SessionCredentials),
|
|
}
|
|
|
|
|
|
fn serialize_expiration<S>(exp: &DateTime, serializer: S) -> Result<S::Ok, S::Error>
|
|
where S: Serializer
|
|
{
|
|
// this only fails if the d/t is out of range, which it can't be for this format
|
|
let time_str = exp.fmt(Format::DateTime).unwrap();
|
|
serializer.serialize_str(&time_str)
|
|
}
|
|
|
|
|
|
struct DateTimeVisitor;
|
|
|
|
impl<'de> Visitor<'de> for DateTimeVisitor {
|
|
type Value = DateTime;
|
|
|
|
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
|
|
write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"")
|
|
}
|
|
|
|
fn visit_str<E: de::Error>(self, v: &str) -> Result<DateTime, E> {
|
|
DateTime::from_str(v, Format::DateTime)
|
|
.map_err(|_| E::custom(format!("Invalid date/time: {v}")))
|
|
}
|
|
}
|
|
|
|
|
|
fn deserialize_expiration<'de, D>(deserializer: D) -> Result<DateTime, D::Error>
|
|
where D: Deserializer<'de>
|
|
{
|
|
deserializer.deserialize_str(DateTimeVisitor)
|
|
}
|
|
|
|
|
|
struct Crypto {
|
|
cipher: XChaCha20Poly1305,
|
|
}
|
|
|
|
impl Crypto {
|
|
/// Argon2 params rationale:
|
|
///
|
|
/// m_cost is measured in KiB, so 128 * 1024 gives us 128MiB.
|
|
/// This should roughly double the memory usage of the application
|
|
/// while deriving the key.
|
|
///
|
|
/// p_cost is irrelevant since (at present) there isn't any parallelism
|
|
/// implemented, so we leave it at 1.
|
|
///
|
|
/// With the above m_cost, t_cost = 8 results in about 800ms to derive
|
|
/// a key on my (somewhat older) CPU. This is probably overkill, but
|
|
/// given that it should only have to happen ~once a day for most
|
|
/// usage, it should be acceptable.
|
|
#[cfg(not(debug_assertions))]
|
|
const MEM_COST: u32 = 128 * 1024;
|
|
#[cfg(not(debug_assertions))]
|
|
const TIME_COST: u32 = 8;
|
|
|
|
/// But since this takes a million years without optimizations,
|
|
/// we turn it way down in debug builds.
|
|
#[cfg(debug_assertions)]
|
|
const MEM_COST: u32 = 48 * 1024;
|
|
#[cfg(debug_assertions)]
|
|
const TIME_COST: u32 = 1;
|
|
|
|
|
|
fn new(passphrase: &str, salt: &[u8]) -> argon2::Result<Crypto> {
|
|
let params = ParamsBuilder::new()
|
|
.m_cost(Self::MEM_COST)
|
|
.p_cost(1)
|
|
.t_cost(Self::TIME_COST)
|
|
.build()
|
|
.unwrap(); // only errors if the given params are invalid
|
|
|
|
let hasher = Argon2::new(
|
|
Algorithm::Argon2id,
|
|
Version::V0x13,
|
|
params,
|
|
);
|
|
|
|
let mut key = [0; 32];
|
|
hasher.hash_password_into(passphrase.as_bytes(), &salt, &mut key)?;
|
|
let cipher = XChaCha20Poly1305::new(GenericArray::from_slice(&key));
|
|
Ok(Crypto { cipher })
|
|
}
|
|
|
|
fn salt() -> [u8; 32] {
|
|
let mut salt = [0; 32];
|
|
OsRng.fill_bytes(&mut salt);
|
|
salt
|
|
}
|
|
|
|
fn encrypt(&self, data: &[u8]) -> Result<(XNonce, Vec<u8>), AeadError> {
|
|
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
|
|
let ciphertext = self.cipher.encrypt(&nonce, data)?;
|
|
Ok((nonce, ciphertext))
|
|
}
|
|
|
|
fn decrypt(&self, nonce: &XNonce, data: &[u8]) -> Result<Vec<u8>, AeadError> {
|
|
self.cipher.decrypt(nonce, data)
|
|
}
|
|
}
|