session renewal
This commit is contained in:
parent
161148d1f6
commit
96bbc2dbc2
@ -106,9 +106,11 @@ pub enum DataDirError {
|
||||
#[derive(Debug, ThisError, AsRefStr)]
|
||||
pub enum SendResponseError {
|
||||
#[error("The specified credentials request was not found")]
|
||||
NotFound, // no request with the given id
|
||||
NotFound,
|
||||
#[error("The specified request was already closed by the client")]
|
||||
Abandoned, // request has already been closed by client
|
||||
Abandoned,
|
||||
#[error("Could not renew AWS sesssion: {0}")]
|
||||
SessionRenew(#[from] GetSessionError),
|
||||
}
|
||||
|
||||
|
||||
@ -145,9 +147,13 @@ pub enum GetCredentialsError {
|
||||
#[derive(Debug, ThisError, AsRefStr)]
|
||||
pub enum GetSessionError {
|
||||
#[error("Request completed successfully but no credentials were returned")]
|
||||
NoCredentials, // SDK returned successfully but credentials are None
|
||||
EmptyResponse, // SDK returned successfully but credentials are None
|
||||
#[error("Error response from AWS SDK: {0}")]
|
||||
SdkError(#[from] AwsSdkError<GetSessionTokenError>),
|
||||
#[error("Could not construt session: credentials are locked")]
|
||||
CredentialsLocked,
|
||||
#[error("Could not construct session: no credentials are known")]
|
||||
CredentialsEmpty,
|
||||
}
|
||||
|
||||
|
||||
@ -199,7 +205,6 @@ impl Serialize for SerializeWrapper<&GetSessionTokenError> {
|
||||
|
||||
|
||||
impl_serialize_basic!(SetupError);
|
||||
impl_serialize_basic!(SendResponseError);
|
||||
impl_serialize_basic!(GetCredentialsError);
|
||||
impl_serialize_basic!(ClientInfoError);
|
||||
|
||||
@ -221,6 +226,22 @@ impl Serialize for RequestError {
|
||||
}
|
||||
|
||||
|
||||
impl Serialize for SendResponseError {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let mut map = serializer.serialize_map(None)?;
|
||||
map.serialize_entry("code", self.as_ref())?;
|
||||
map.serialize_entry("msg", &format!("{self}"))?;
|
||||
|
||||
match self {
|
||||
SendResponseError::SessionRenew(src) => map.serialize_entry("source", &src)?,
|
||||
_ => serialize_upstream_err(self, &mut map)?,
|
||||
}
|
||||
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Serialize for GetSessionError {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let mut map = serializer.serialize_map(None)?;
|
||||
|
@ -29,9 +29,8 @@ pub enum Approval {
|
||||
|
||||
|
||||
#[tauri::command]
|
||||
pub fn respond(response: RequestResponse, app_state: State<'_, AppState>) -> Result<(), String> {
|
||||
app_state.send_response(response)
|
||||
.map_err(|e| format!("Error responding to request: {e}"))
|
||||
pub async fn respond(response: RequestResponse, app_state: State<'_, AppState>) -> Result<(), SendResponseError> {
|
||||
app_state.send_response(response).await
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,7 +1,16 @@
|
||||
use core::time::Duration;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::RwLock;
|
||||
use std::time::{
|
||||
Duration,
|
||||
SystemTime,
|
||||
UNIX_EPOCH
|
||||
};
|
||||
|
||||
|
||||
use aws_smithy_types::date_time::{
|
||||
DateTime as AwsDateTime,
|
||||
Format as AwsDateTimeFormat,
|
||||
};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::sleep;
|
||||
@ -14,6 +23,7 @@ use sodiumoxide::crypto::{
|
||||
};
|
||||
use tauri::async_runtime as runtime;
|
||||
use tauri::Manager;
|
||||
use serde::Serializer;
|
||||
|
||||
use crate::{config, config::AppConfig};
|
||||
use crate::ipc;
|
||||
@ -22,24 +32,38 @@ use crate::errors::*;
|
||||
use crate::server::Server;
|
||||
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub struct BaseCredentials {
|
||||
access_key_id: String,
|
||||
secret_access_key: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub struct SessionCredentials {
|
||||
access_key_id: String,
|
||||
secret_access_key: String,
|
||||
token: String,
|
||||
expiration: String,
|
||||
#[serde(serialize_with = "serialize_expiration")]
|
||||
expiration: AwsDateTime,
|
||||
}
|
||||
|
||||
impl SessionCredentials {
|
||||
fn is_expired(&self) -> bool {
|
||||
let current_ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
|
||||
.as_secs();
|
||||
|
||||
let expire_ts = self.expiration.secs();
|
||||
let remaining = expire_ts - (current_ts as i64);
|
||||
remaining < 60
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LockedCredentials {
|
||||
access_key_id: String,
|
||||
secret_key_enc: Vec<u8>,
|
||||
@ -48,7 +72,7 @@ pub struct LockedCredentials {
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Session {
|
||||
Unlocked{
|
||||
base: BaseCredentials,
|
||||
@ -180,7 +204,9 @@ impl AppState {
|
||||
open_requests.len()
|
||||
}
|
||||
|
||||
pub fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
|
||||
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
|
||||
self.renew_session_if_expired().await?;
|
||||
|
||||
let mut open_requests = self.open_requests.write().unwrap();
|
||||
let chan = open_requests
|
||||
.remove(&response.id)
|
||||
@ -274,21 +300,20 @@ impl AppState {
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let aws_session = resp.credentials().ok_or(GetSessionError::NoCredentials)?;
|
||||
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
|
||||
|
||||
let access_key_id = aws_session.access_key_id()
|
||||
.ok_or(GetSessionError::NoCredentials)?
|
||||
.ok_or(GetSessionError::EmptyResponse)?
|
||||
.to_string();
|
||||
let secret_access_key = aws_session.secret_access_key()
|
||||
.ok_or(GetSessionError::NoCredentials)?
|
||||
.ok_or(GetSessionError::EmptyResponse)?
|
||||
.to_string();
|
||||
let token = aws_session.session_token()
|
||||
.ok_or(GetSessionError::NoCredentials)?
|
||||
.ok_or(GetSessionError::EmptyResponse)?
|
||||
.to_string();
|
||||
let expiration = aws_session.expiration()
|
||||
.ok_or(GetSessionError::NoCredentials)?
|
||||
.fmt(aws_smithy_types::date_time::Format::DateTime)
|
||||
.unwrap(); // only fails if the d/t is out of range, which it can't be for this format
|
||||
.ok_or(GetSessionError::EmptyResponse)?
|
||||
.clone();
|
||||
|
||||
let session_creds = SessionCredentials {
|
||||
access_key_id,
|
||||
@ -302,4 +327,51 @@ impl AppState {
|
||||
|
||||
Ok(session_creds)
|
||||
}
|
||||
|
||||
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
|
||||
let base = {
|
||||
let session = self.session.read().unwrap();
|
||||
match *session {
|
||||
Session::Unlocked{ref base, ..} => base.clone(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
|
||||
match *self.session.write().unwrap() {
|
||||
Session::Unlocked{ref mut session, ..} => {
|
||||
if !session.is_expired() {
|
||||
return Ok(false);
|
||||
}
|
||||
*session = new_session;
|
||||
Ok(true)
|
||||
},
|
||||
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
||||
Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
||||
}
|
||||
|
||||
// match *self.session.write().unwrap() {
|
||||
// Session::Unlocked{ref base, ref mut session} => {
|
||||
// if !session.is_expired() {
|
||||
// return Ok(false);
|
||||
// }
|
||||
// let new_session = self.new_session(
|
||||
// &base.access_key_id,
|
||||
// &base.secret_access_key
|
||||
// ).await?;
|
||||
// *session = new_session;
|
||||
// Ok(true)
|
||||
// },
|
||||
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
||||
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn serialize_expiration<S>(exp: &AwsDateTime, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where S: Serializer
|
||||
{
|
||||
// this only fails if the d/t is out of range, which it can't be for this format
|
||||
let time_str = exp.fmt(AwsDateTimeFormat::DateTime).unwrap();
|
||||
serializer.serialize_str(&time_str)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user