session renewal
This commit is contained in:
parent
161148d1f6
commit
96bbc2dbc2
@ -106,9 +106,11 @@ pub enum DataDirError {
|
|||||||
#[derive(Debug, ThisError, AsRefStr)]
|
#[derive(Debug, ThisError, AsRefStr)]
|
||||||
pub enum SendResponseError {
|
pub enum SendResponseError {
|
||||||
#[error("The specified credentials request was not found")]
|
#[error("The specified credentials request was not found")]
|
||||||
NotFound, // no request with the given id
|
NotFound,
|
||||||
#[error("The specified request was already closed by the client")]
|
#[error("The specified request was already closed by the client")]
|
||||||
Abandoned, // request has already been closed by client
|
Abandoned,
|
||||||
|
#[error("Could not renew AWS sesssion: {0}")]
|
||||||
|
SessionRenew(#[from] GetSessionError),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -145,9 +147,13 @@ pub enum GetCredentialsError {
|
|||||||
#[derive(Debug, ThisError, AsRefStr)]
|
#[derive(Debug, ThisError, AsRefStr)]
|
||||||
pub enum GetSessionError {
|
pub enum GetSessionError {
|
||||||
#[error("Request completed successfully but no credentials were returned")]
|
#[error("Request completed successfully but no credentials were returned")]
|
||||||
NoCredentials, // SDK returned successfully but credentials are None
|
EmptyResponse, // SDK returned successfully but credentials are None
|
||||||
#[error("Error response from AWS SDK: {0}")]
|
#[error("Error response from AWS SDK: {0}")]
|
||||||
SdkError(#[from] AwsSdkError<GetSessionTokenError>),
|
SdkError(#[from] AwsSdkError<GetSessionTokenError>),
|
||||||
|
#[error("Could not construt session: credentials are locked")]
|
||||||
|
CredentialsLocked,
|
||||||
|
#[error("Could not construct session: no credentials are known")]
|
||||||
|
CredentialsEmpty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -199,7 +205,6 @@ impl Serialize for SerializeWrapper<&GetSessionTokenError> {
|
|||||||
|
|
||||||
|
|
||||||
impl_serialize_basic!(SetupError);
|
impl_serialize_basic!(SetupError);
|
||||||
impl_serialize_basic!(SendResponseError);
|
|
||||||
impl_serialize_basic!(GetCredentialsError);
|
impl_serialize_basic!(GetCredentialsError);
|
||||||
impl_serialize_basic!(ClientInfoError);
|
impl_serialize_basic!(ClientInfoError);
|
||||||
|
|
||||||
@ -221,6 +226,22 @@ impl Serialize for RequestError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl Serialize for SendResponseError {
|
||||||
|
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||||
|
let mut map = serializer.serialize_map(None)?;
|
||||||
|
map.serialize_entry("code", self.as_ref())?;
|
||||||
|
map.serialize_entry("msg", &format!("{self}"))?;
|
||||||
|
|
||||||
|
match self {
|
||||||
|
SendResponseError::SessionRenew(src) => map.serialize_entry("source", &src)?,
|
||||||
|
_ => serialize_upstream_err(self, &mut map)?,
|
||||||
|
}
|
||||||
|
|
||||||
|
map.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Serialize for GetSessionError {
|
impl Serialize for GetSessionError {
|
||||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||||
let mut map = serializer.serialize_map(None)?;
|
let mut map = serializer.serialize_map(None)?;
|
||||||
|
@ -29,9 +29,8 @@ pub enum Approval {
|
|||||||
|
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn respond(response: RequestResponse, app_state: State<'_, AppState>) -> Result<(), String> {
|
pub async fn respond(response: RequestResponse, app_state: State<'_, AppState>) -> Result<(), SendResponseError> {
|
||||||
app_state.send_response(response)
|
app_state.send_response(response).await
|
||||||
.map_err(|e| format!("Error responding to request: {e}"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,16 @@
|
|||||||
use core::time::Duration;
|
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
use std::time::{
|
||||||
|
Duration,
|
||||||
|
SystemTime,
|
||||||
|
UNIX_EPOCH
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
use aws_smithy_types::date_time::{
|
||||||
|
DateTime as AwsDateTime,
|
||||||
|
Format as AwsDateTimeFormat,
|
||||||
|
};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
@ -14,6 +23,7 @@ use sodiumoxide::crypto::{
|
|||||||
};
|
};
|
||||||
use tauri::async_runtime as runtime;
|
use tauri::async_runtime as runtime;
|
||||||
use tauri::Manager;
|
use tauri::Manager;
|
||||||
|
use serde::Serializer;
|
||||||
|
|
||||||
use crate::{config, config::AppConfig};
|
use crate::{config, config::AppConfig};
|
||||||
use crate::ipc;
|
use crate::ipc;
|
||||||
@ -22,24 +32,38 @@ use crate::errors::*;
|
|||||||
use crate::server::Server;
|
use crate::server::Server;
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "PascalCase")]
|
#[serde(rename_all = "PascalCase")]
|
||||||
pub struct BaseCredentials {
|
pub struct BaseCredentials {
|
||||||
access_key_id: String,
|
access_key_id: String,
|
||||||
secret_access_key: String,
|
secret_access_key: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize)]
|
||||||
#[serde(rename_all = "PascalCase")]
|
#[serde(rename_all = "PascalCase")]
|
||||||
pub struct SessionCredentials {
|
pub struct SessionCredentials {
|
||||||
access_key_id: String,
|
access_key_id: String,
|
||||||
secret_access_key: String,
|
secret_access_key: String,
|
||||||
token: String,
|
token: String,
|
||||||
expiration: String,
|
#[serde(serialize_with = "serialize_expiration")]
|
||||||
|
expiration: AwsDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionCredentials {
|
||||||
|
fn is_expired(&self) -> bool {
|
||||||
|
let current_ts = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
let expire_ts = self.expiration.secs();
|
||||||
|
let remaining = expire_ts - (current_ts as i64);
|
||||||
|
remaining < 60
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct LockedCredentials {
|
pub struct LockedCredentials {
|
||||||
access_key_id: String,
|
access_key_id: String,
|
||||||
secret_key_enc: Vec<u8>,
|
secret_key_enc: Vec<u8>,
|
||||||
@ -48,7 +72,7 @@ pub struct LockedCredentials {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum Session {
|
pub enum Session {
|
||||||
Unlocked{
|
Unlocked{
|
||||||
base: BaseCredentials,
|
base: BaseCredentials,
|
||||||
@ -180,7 +204,9 @@ impl AppState {
|
|||||||
open_requests.len()
|
open_requests.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
|
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
|
||||||
|
self.renew_session_if_expired().await?;
|
||||||
|
|
||||||
let mut open_requests = self.open_requests.write().unwrap();
|
let mut open_requests = self.open_requests.write().unwrap();
|
||||||
let chan = open_requests
|
let chan = open_requests
|
||||||
.remove(&response.id)
|
.remove(&response.id)
|
||||||
@ -274,21 +300,20 @@ impl AppState {
|
|||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let aws_session = resp.credentials().ok_or(GetSessionError::NoCredentials)?;
|
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
|
||||||
|
|
||||||
let access_key_id = aws_session.access_key_id()
|
let access_key_id = aws_session.access_key_id()
|
||||||
.ok_or(GetSessionError::NoCredentials)?
|
.ok_or(GetSessionError::EmptyResponse)?
|
||||||
.to_string();
|
.to_string();
|
||||||
let secret_access_key = aws_session.secret_access_key()
|
let secret_access_key = aws_session.secret_access_key()
|
||||||
.ok_or(GetSessionError::NoCredentials)?
|
.ok_or(GetSessionError::EmptyResponse)?
|
||||||
.to_string();
|
.to_string();
|
||||||
let token = aws_session.session_token()
|
let token = aws_session.session_token()
|
||||||
.ok_or(GetSessionError::NoCredentials)?
|
.ok_or(GetSessionError::EmptyResponse)?
|
||||||
.to_string();
|
.to_string();
|
||||||
let expiration = aws_session.expiration()
|
let expiration = aws_session.expiration()
|
||||||
.ok_or(GetSessionError::NoCredentials)?
|
.ok_or(GetSessionError::EmptyResponse)?
|
||||||
.fmt(aws_smithy_types::date_time::Format::DateTime)
|
.clone();
|
||||||
.unwrap(); // only fails if the d/t is out of range, which it can't be for this format
|
|
||||||
|
|
||||||
let session_creds = SessionCredentials {
|
let session_creds = SessionCredentials {
|
||||||
access_key_id,
|
access_key_id,
|
||||||
@ -302,4 +327,51 @@ impl AppState {
|
|||||||
|
|
||||||
Ok(session_creds)
|
Ok(session_creds)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
|
||||||
|
let base = {
|
||||||
|
let session = self.session.read().unwrap();
|
||||||
|
match *session {
|
||||||
|
Session::Unlocked{ref base, ..} => base.clone(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
|
||||||
|
match *self.session.write().unwrap() {
|
||||||
|
Session::Unlocked{ref mut session, ..} => {
|
||||||
|
if !session.is_expired() {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
*session = new_session;
|
||||||
|
Ok(true)
|
||||||
|
},
|
||||||
|
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
||||||
|
Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
||||||
|
}
|
||||||
|
|
||||||
|
// match *self.session.write().unwrap() {
|
||||||
|
// Session::Unlocked{ref base, ref mut session} => {
|
||||||
|
// if !session.is_expired() {
|
||||||
|
// return Ok(false);
|
||||||
|
// }
|
||||||
|
// let new_session = self.new_session(
|
||||||
|
// &base.access_key_id,
|
||||||
|
// &base.secret_access_key
|
||||||
|
// ).await?;
|
||||||
|
// *session = new_session;
|
||||||
|
// Ok(true)
|
||||||
|
// },
|
||||||
|
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
||||||
|
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn serialize_expiration<S>(exp: &AwsDateTime, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where S: Serializer
|
||||||
|
{
|
||||||
|
// this only fails if the d/t is out of range, which it can't be for this format
|
||||||
|
let time_str = exp.fmt(AwsDateTimeFormat::DateTime).unwrap();
|
||||||
|
serializer.serialize_str(&time_str)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user