session renewal

This commit is contained in:
Joseph Montanaro 2023-05-02 11:33:18 -07:00
parent 161148d1f6
commit 96bbc2dbc2
3 changed files with 113 additions and 21 deletions

View File

@ -106,9 +106,11 @@ pub enum DataDirError {
#[derive(Debug, ThisError, AsRefStr)]
pub enum SendResponseError {
#[error("The specified credentials request was not found")]
NotFound, // no request with the given id
NotFound,
#[error("The specified request was already closed by the client")]
Abandoned, // request has already been closed by client
Abandoned,
#[error("Could not renew AWS sesssion: {0}")]
SessionRenew(#[from] GetSessionError),
}
@ -145,9 +147,13 @@ pub enum GetCredentialsError {
#[derive(Debug, ThisError, AsRefStr)]
pub enum GetSessionError {
#[error("Request completed successfully but no credentials were returned")]
NoCredentials, // SDK returned successfully but credentials are None
EmptyResponse, // SDK returned successfully but credentials are None
#[error("Error response from AWS SDK: {0}")]
SdkError(#[from] AwsSdkError<GetSessionTokenError>),
#[error("Could not construt session: credentials are locked")]
CredentialsLocked,
#[error("Could not construct session: no credentials are known")]
CredentialsEmpty,
}
@ -199,7 +205,6 @@ impl Serialize for SerializeWrapper<&GetSessionTokenError> {
impl_serialize_basic!(SetupError);
impl_serialize_basic!(SendResponseError);
impl_serialize_basic!(GetCredentialsError);
impl_serialize_basic!(ClientInfoError);
@ -221,6 +226,22 @@ impl Serialize for RequestError {
}
impl Serialize for SendResponseError {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(None)?;
map.serialize_entry("code", self.as_ref())?;
map.serialize_entry("msg", &format!("{self}"))?;
match self {
SendResponseError::SessionRenew(src) => map.serialize_entry("source", &src)?,
_ => serialize_upstream_err(self, &mut map)?,
}
map.end()
}
}
impl Serialize for GetSessionError {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(None)?;

View File

@ -29,9 +29,8 @@ pub enum Approval {
#[tauri::command]
pub fn respond(response: RequestResponse, app_state: State<'_, AppState>) -> Result<(), String> {
app_state.send_response(response)
.map_err(|e| format!("Error responding to request: {e}"))
pub async fn respond(response: RequestResponse, app_state: State<'_, AppState>) -> Result<(), SendResponseError> {
app_state.send_response(response).await
}

View File

@ -1,7 +1,16 @@
use core::time::Duration;
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use std::time::{
Duration,
SystemTime,
UNIX_EPOCH
};
use aws_smithy_types::date_time::{
DateTime as AwsDateTime,
Format as AwsDateTimeFormat,
};
use serde::{Serialize, Deserialize};
use tokio::sync::oneshot::Sender;
use tokio::time::sleep;
@ -14,6 +23,7 @@ use sodiumoxide::crypto::{
};
use tauri::async_runtime as runtime;
use tauri::Manager;
use serde::Serializer;
use crate::{config, config::AppConfig};
use crate::ipc;
@ -22,24 +32,38 @@ use crate::errors::*;
use crate::server::Server;
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct BaseCredentials {
access_key_id: String,
secret_access_key: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize)]
#[serde(rename_all = "PascalCase")]
pub struct SessionCredentials {
access_key_id: String,
secret_access_key: String,
token: String,
expiration: String,
#[serde(serialize_with = "serialize_expiration")]
expiration: AwsDateTime,
}
impl SessionCredentials {
fn is_expired(&self) -> bool {
let current_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
.as_secs();
let expire_ts = self.expiration.secs();
let remaining = expire_ts - (current_ts as i64);
remaining < 60
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct LockedCredentials {
access_key_id: String,
secret_key_enc: Vec<u8>,
@ -48,7 +72,7 @@ pub struct LockedCredentials {
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub enum Session {
Unlocked{
base: BaseCredentials,
@ -180,7 +204,9 @@ impl AppState {
open_requests.len()
}
pub fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
self.renew_session_if_expired().await?;
let mut open_requests = self.open_requests.write().unwrap();
let chan = open_requests
.remove(&response.id)
@ -274,21 +300,20 @@ impl AppState {
.send()
.await?;
let aws_session = resp.credentials().ok_or(GetSessionError::NoCredentials)?;
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
let access_key_id = aws_session.access_key_id()
.ok_or(GetSessionError::NoCredentials)?
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::NoCredentials)?
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let token = aws_session.session_token()
.ok_or(GetSessionError::NoCredentials)?
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let expiration = aws_session.expiration()
.ok_or(GetSessionError::NoCredentials)?
.fmt(aws_smithy_types::date_time::Format::DateTime)
.unwrap(); // only fails if the d/t is out of range, which it can't be for this format
.ok_or(GetSessionError::EmptyResponse)?
.clone();
let session_creds = SessionCredentials {
access_key_id,
@ -302,4 +327,51 @@ impl AppState {
Ok(session_creds)
}
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
let base = {
let session = self.session.read().unwrap();
match *session {
Session::Unlocked{ref base, ..} => base.clone(),
_ => unreachable!(),
}
};
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
match *self.session.write().unwrap() {
Session::Unlocked{ref mut session, ..} => {
if !session.is_expired() {
return Ok(false);
}
*session = new_session;
Ok(true)
},
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty),
}
// match *self.session.write().unwrap() {
// Session::Unlocked{ref base, ref mut session} => {
// if !session.is_expired() {
// return Ok(false);
// }
// let new_session = self.new_session(
// &base.access_key_id,
// &base.secret_access_key
// ).await?;
// *session = new_session;
// Ok(true)
// },
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
// }
}
}
fn serialize_expiration<S>(exp: &AwsDateTime, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
// this only fails if the d/t is out of range, which it can't be for this format
let time_str = exp.fmt(AwsDateTimeFormat::DateTime).unwrap();
serializer.serialize_str(&time_str)
}