creddy/src-tauri/src/credentials.rs

244 lines
7.2 KiB
Rust

use std::fmt::{self, Formatter};
use std::time::{SystemTime, UNIX_EPOCH};
use aws_smithy_types::date_time::{DateTime, Format};
use serde::{
Serialize,
Deserialize,
Serializer,
Deserializer,
};
use serde::de::{self, Visitor};
use sqlx::SqlitePool;
use sodiumoxide::crypto::{
pwhash,
pwhash::Salt,
secretbox,
secretbox::{Nonce, Key}
};
use crate::errors::*;
#[derive(Clone, Debug)]
pub enum Session {
Unlocked{
base: BaseCredentials,
session: SessionCredentials,
},
Locked(LockedCredentials),
Empty,
}
impl Session {
pub async fn load(pool: &SqlitePool) -> Result<Self, SetupError> {
let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc")
.fetch_optional(pool)
.await?;
let row = match res {
Some(r) => r,
None => {return Ok(Session::Empty);}
};
let salt_buf: [u8; 32] = row.salt
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let nonce_buf: [u8; 24] = row.nonce
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let creds = LockedCredentials {
access_key_id: row.access_key_id,
secret_key_enc: row.secret_key_enc,
salt: Salt(salt_buf),
nonce: Nonce(nonce_buf),
};
Ok(Session::Locked(creds))
}
pub async fn renew_if_expired(&mut self) -> Result<bool, GetSessionError> {
match self {
Session::Unlocked{ref base, ref mut session} => {
if !session.is_expired() {
return Ok(false);
}
*session = SessionCredentials::from_base(base).await?;
Ok(true)
},
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty),
}
}
}
#[derive(Clone, Debug)]
pub struct LockedCredentials {
pub access_key_id: String,
pub secret_key_enc: Vec<u8>,
pub salt: Salt,
pub nonce: Nonce,
}
impl LockedCredentials {
pub async fn save(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> {
sqlx::query(
"INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at)
VALUES (?, ?, ?, ?, strftime('%s'))"
)
.bind(&self.access_key_id)
.bind(&self.secret_key_enc)
.bind(&self.salt.0[0..])
.bind(&self.nonce.0[0..])
.execute(pool)
.await?;
Ok(())
}
pub fn decrypt(&self, passphrase: &str) -> Result<BaseCredentials, UnlockError> {
let mut key_buf = [0; secretbox::KEYBYTES];
// pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &self.salt).unwrap();
let decrypted = secretbox::open(&self.secret_key_enc, &self.nonce, &Key(key_buf))
.map_err(|_| UnlockError::BadPassphrase)?;
let secret_access_key = String::from_utf8(decrypted)
.map_err(|_| UnlockError::InvalidUtf8)?;
let creds = BaseCredentials {
access_key_id: self.access_key_id.clone(),
secret_access_key,
};
Ok(creds)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct BaseCredentials {
pub access_key_id: String,
pub secret_access_key: String,
}
impl BaseCredentials {
pub fn encrypt(&self, passphrase: &str) -> LockedCredentials {
let salt = pwhash::gen_salt();
let mut key_buf = [0; secretbox::KEYBYTES];
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap();
let key = Key(key_buf);
let nonce = secretbox::gen_nonce();
let secret_key_enc = secretbox::seal(self.secret_access_key.as_bytes(), &nonce, &key);
LockedCredentials {
access_key_id: self.access_key_id.clone(),
secret_key_enc,
salt,
nonce,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct SessionCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub token: String,
#[serde(serialize_with = "serialize_expiration")]
#[serde(deserialize_with = "deserialize_expiration")]
pub expiration: DateTime,
}
impl SessionCredentials {
pub async fn from_base(base: &BaseCredentials) -> Result<Self, GetSessionError> {
let req_creds = aws_sdk_sts::Credentials::new(
&base.access_key_id,
&base.secret_access_key,
None, // token
None, //expiration
"Creddy", // "provider name" apparently
);
let config = aws_config::from_env()
.credentials_provider(req_creds)
.load()
.await;
let client = aws_sdk_sts::Client::new(&config);
let resp = client.get_session_token()
.duration_seconds(43_200)
.send()
.await?;
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
let access_key_id = aws_session.access_key_id()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let token = aws_session.session_token()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let expiration = aws_session.expiration()
.ok_or(GetSessionError::EmptyResponse)?
.clone();
let session_creds = SessionCredentials {
access_key_id,
secret_access_key,
token,
expiration,
};
#[cfg(debug_assertions)]
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
Ok(session_creds)
}
pub fn is_expired(&self) -> bool {
let current_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
.as_secs();
let expire_ts = self.expiration.secs();
let remaining = expire_ts - (current_ts as i64);
remaining < 60
}
}
fn serialize_expiration<S>(exp: &DateTime, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
// this only fails if the d/t is out of range, which it can't be for this format
let time_str = exp.fmt(Format::DateTime).unwrap();
serializer.serialize_str(&time_str)
}
struct DateTimeVisitor;
impl<'de> Visitor<'de> for DateTimeVisitor {
type Value = DateTime;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<DateTime, E> {
DateTime::from_str(v, Format::DateTime)
.map_err(|_| E::custom(format!("Invalid date/time: {v}")))
}
}
fn deserialize_expiration<'de, D>(deserializer: D) -> Result<DateTime, D::Error>
where D: Deserializer<'de>
{
deserializer.deserialize_str(DateTimeVisitor)
}