cancel approval flow on frontend when request is abandoned by client

This commit is contained in:
Joseph Montanaro 2024-01-21 13:46:39 -08:00
parent 7fdb336c79
commit 1df849442e
14 changed files with 104 additions and 29 deletions

View File

@ -1,7 +1,7 @@
## Definitely ## Definitely
* ~~Switch to "process" provider for AWS credentials (much less hacky)~~ * ~~Switch to "process" provider for AWS credentials (much less hacky)~~
* Frontend needs to react when request is cancelled from backend * ~~Frontend needs to react when request is cancelled from backend~~
* Session timeout (plain duration, or activity-based?) * Session timeout (plain duration, or activity-based?)
* ~~Fix rehide behavior when new request comes in while old one is still being resolved~~ * ~~Fix rehide behavior when new request comes in while old one is still being resolved~~
* Additional hotkey configuration (approve/deny at the very least) * Additional hotkey configuration (approve/deny at the very least)

View File

@ -1,6 +1,6 @@
{ {
"name": "creddy", "name": "creddy",
"version": "0.4.3", "version": "0.4.4",
"scripts": { "scripts": {
"dev": "vite", "dev": "vite",
"build": "vite build", "build": "vite build",

2
src-tauri/Cargo.lock generated
View File

@ -1035,7 +1035,7 @@ dependencies = [
[[package]] [[package]]
name = "creddy" name = "creddy"
version = "0.4.3" version = "0.4.4"
dependencies = [ dependencies = [
"argon2", "argon2",
"auto-launch", "auto-launch",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "creddy" name = "creddy"
version = "0.4.3" version = "0.4.4"
description = "A friendly AWS credentials manager" description = "A friendly AWS credentials manager"
authors = ["Joseph Montanaro"] authors = ["Joseph Montanaro"]
license = "" license = ""

View File

@ -14,7 +14,6 @@ pub struct Client {
pub fn get_process_parent_info(pid: u32) -> Result<Client, ClientInfoError> { pub fn get_process_parent_info(pid: u32) -> Result<Client, ClientInfoError> {
dbg!(pid);
let sys_pid = Pid::from_u32(pid); let sys_pid = Pid::from_u32(pid);
let mut sys = System::new(); let mut sys = System::new();
sys.refresh_process(sys_pid); sys.refresh_process(sys_pid);

View File

@ -83,15 +83,33 @@ where
} }
struct SerializeUpstream<E>(pub E);
impl<E: Error> Serialize for SerializeUpstream<E> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let msg = format!("{}", self.0);
let mut map = serializer.serialize_map(None)?;
map.serialize_entry("msg", &msg)?;
map.serialize_entry("code", &None::<&str>)?;
map.serialize_entry("source", &None::<&str>)?;
map.end()
}
}
fn serialize_upstream_err<E, M>(err: &E, map: &mut M) -> Result<(), M::Error> fn serialize_upstream_err<E, M>(err: &E, map: &mut M) -> Result<(), M::Error>
where where
E: Error, E: Error,
M: serde::ser::SerializeMap, M: serde::ser::SerializeMap,
{ {
let msg = err.source().map(|s| format!("{s}")); // let msg = err.source().map(|s| format!("{s}"));
map.serialize_entry("msg", &msg)?; // map.serialize_entry("msg", &msg)?;
map.serialize_entry("code", &None::<&str>)?; // map.serialize_entry("code", &None::<&str>)?;
map.serialize_entry("source", &None::<&str>)?; // map.serialize_entry("source", &None::<&str>)?;
match err.source() {
Some(src) => map.serialize_entry("source", &SerializeUpstream(src))?,
None => map.serialize_entry("source", &None::<&str>)?,
}
Ok(()) Ok(())
} }
@ -153,7 +171,7 @@ pub enum SendResponseError {
} }
// errors encountered while handling an HTTP request // errors encountered while handling a client request
#[derive(Debug, ThisError, AsRefStr)] #[derive(Debug, ThisError, AsRefStr)]
pub enum HandlerError { pub enum HandlerError {
#[error("Error writing to stream: {0}")] #[error("Error writing to stream: {0}")]
@ -164,6 +182,8 @@ pub enum HandlerError {
BadRequest(#[from] serde_json::Error), BadRequest(#[from] serde_json::Error),
#[error("HTTP request too large")] #[error("HTTP request too large")]
RequestTooLarge, RequestTooLarge,
#[error("Connection closed early by client")]
Abandoned,
#[error("Internal server error")] #[error("Internal server error")]
Internal(#[from] RecvError), Internal(#[from] RecvError),
#[error("Error accessing credentials: {0}")] #[error("Error accessing credentials: {0}")]
@ -345,7 +365,6 @@ impl Serialize for SerializeWrapper<&GetSessionTokenError> {
} }
impl_serialize_basic!(SetupError); impl_serialize_basic!(SetupError);
impl_serialize_basic!(GetCredentialsError); impl_serialize_basic!(GetCredentialsError);
impl_serialize_basic!(ClientInfoError); impl_serialize_basic!(ClientInfoError);

View File

@ -43,6 +43,24 @@ pub enum Response {
} }
struct CloseWaiter<'s> {
stream: &'s mut Stream,
}
impl<'s> CloseWaiter<'s> {
async fn wait_for_close(&mut self) -> std::io::Result<()> {
let mut buf = [0u8; 8];
loop {
match self.stream.read(&mut buf).await {
Ok(0) => break Ok(()),
Ok(_) => (),
Err(e) => break Err(e),
}
}
}
}
async fn handle(mut stream: Stream, app_handle: AppHandle, client_pid: u32) -> Result<(), HandlerError> async fn handle(mut stream: Stream, app_handle: AppHandle, client_pid: u32) -> Result<(), HandlerError>
{ {
// read from stream until delimiter is reached // read from stream until delimiter is reached
@ -59,13 +77,21 @@ async fn handle(mut stream: Stream, app_handle: AppHandle, client_pid: u32) -> R
} }
let client = clientinfo::get_process_parent_info(client_pid)?; let client = clientinfo::get_process_parent_info(client_pid)?;
let waiter = CloseWaiter { stream: &mut stream };
let req: Request = serde_json::from_slice(&buf)?; let req: Request = serde_json::from_slice(&buf)?;
let res = match req { let res = match req {
Request::GetAwsCredentials{ base } => get_aws_credentials(base, client, app_handle).await, Request::GetAwsCredentials{ base } => get_aws_credentials(
base, client, app_handle, waiter
).await,
Request::InvokeShortcut(action) => invoke_shortcut(action).await, Request::InvokeShortcut(action) => invoke_shortcut(action).await,
}; };
// doesn't make sense to send the error to the client if the client has already left
if let Err(HandlerError::Abandoned) = res {
return Err(HandlerError::Abandoned);
}
let res = serde_json::to_vec(&res).unwrap(); let res = serde_json::to_vec(&res).unwrap();
stream.write_all(&res).await?; stream.write_all(&res).await?;
Ok(()) Ok(())
@ -78,7 +104,12 @@ async fn invoke_shortcut(action: ShortcutAction) -> Result<Response, HandlerErro
} }
async fn get_aws_credentials(base: bool, client: Client, app_handle: AppHandle) -> Result<Response, HandlerError> { async fn get_aws_credentials(
base: bool,
client: Client,
app_handle: AppHandle,
mut waiter: CloseWaiter<'_>,
) -> Result<Response, HandlerError> {
let state = app_handle.state::<AppState>(); let state = app_handle.state::<AppState>();
let rehide_ms = { let rehide_ms = {
let config = state.config.read().await; let config = state.config.read().await;
@ -97,7 +128,14 @@ async fn get_aws_credentials(base: bool, client: Client, app_handle: AppHandle)
let notification = AwsRequestNotification {id: request_id, client, base}; let notification = AwsRequestNotification {id: request_id, client, base};
app_handle.emit_all("credentials-request", &notification)?; app_handle.emit_all("credentials-request", &notification)?;
let response = chan_recv.await?; let response = tokio::select! {
r = chan_recv => r?,
_ = waiter.wait_for_close() => {
app_handle.emit_all("request-cancelled", request_id)?;
return Err(HandlerError::Abandoned);
},
};
match response.approval { match response.approval {
Approval::Approved => { Approval::Approved => {
if response.base { if response.base {

View File

@ -8,7 +8,7 @@
}, },
"package": { "package": {
"productName": "creddy", "productName": "creddy",
"version": "0.4.3" "version": "0.4.4"
}, },
"tauri": { "tauri": {
"allowlist": { "allowlist": {

View File

@ -3,7 +3,7 @@ import { onMount } from 'svelte';
import { listen } from '@tauri-apps/api/event'; import { listen } from '@tauri-apps/api/event';
import { invoke } from '@tauri-apps/api/tauri'; import { invoke } from '@tauri-apps/api/tauri';
import { appState, acceptRequest } from './lib/state.js'; import { appState, acceptRequest, cleanupRequest } from './lib/state.js';
import { views, currentView, navigate } from './lib/routing.js'; import { views, currentView, navigate } from './lib/routing.js';
@ -16,6 +16,16 @@ listen('credentials-request', (tauriEvent) => {
$appState.pendingRequests.put(tauriEvent.payload); $appState.pendingRequests.put(tauriEvent.payload);
}); });
listen('request-cancelled', (tauriEvent) => {
const id = tauriEvent.payload;
if (id === $appState.currentRequest?.id) {
cleanupRequest()
}
else {
const found = $appState.pendingRequests.find_remove(r => r.id === id);
}
});
listen('launch-terminal-request', async (tauriEvent) => { listen('launch-terminal-request', async (tauriEvent) => {
if ($appState.currentRequest === null) { if ($appState.currentRequest === null) {
let status = await invoke('get_session_status'); let status = await invoke('get_session_status');

View File

@ -30,5 +30,15 @@ export default function() {
return this.items.shift(); return this.items.shift();
}, },
find_remove(pred) {
for (let i=0; i<this.items.length; i++) {
if (pred(this.items[i])) {
this.items.splice(i, 1);
return true;
}
}
return false;
},
} }
} }

View File

@ -23,7 +23,7 @@ export async function acceptRequest() {
} }
export function completeRequest() { export function cleanupRequest() {
appState.update($appState => { appState.update($appState => {
$appState.currentRequest = null; $appState.currentRequest = null;
return $appState; return $appState;

View File

@ -30,7 +30,6 @@
&& alt === event.altKey && alt === event.altKey
&& shift === event.shiftKey && shift === event.shiftKey
) { ) {
console.log({hotkey, ctrl, alt, shift});
click(); click();
} }
} }

View File

@ -3,7 +3,7 @@
import { invoke } from '@tauri-apps/api/tauri'; import { invoke } from '@tauri-apps/api/tauri';
import { navigate } from '../lib/routing.js'; import { navigate } from '../lib/routing.js';
import { appState, completeRequest } from '../lib/state.js'; import { appState, cleanupRequest } from '../lib/state.js';
import ErrorAlert from '../ui/ErrorAlert.svelte'; import ErrorAlert from '../ui/ErrorAlert.svelte';
import Link from '../ui/Link.svelte'; import Link from '../ui/Link.svelte';
import KeyCombo from '../ui/KeyCombo.svelte'; import KeyCombo from '../ui/KeyCombo.svelte';
@ -71,29 +71,29 @@
if ($appState.currentRequest.response) { if ($appState.currentRequest.response) {
await respond(); await respond();
} }
}) });
</script> </script>
<!-- Don't render at all if we're just going to immediately proceed to the next screen --> <!-- Don't render at all if we're just going to immediately proceed to the next screen -->
{#if error || !$appState.currentRequest.response} {#if error || !$appState.currentRequest?.response}
<div class="flex flex-col space-y-4 p-4 m-auto max-w-xl h-screen items-center justify-center"> <div class="flex flex-col space-y-4 p-4 m-auto max-w-xl h-screen items-center justify-center">
{#if error} {#if error}
<ErrorAlert bind:this={alert}> <ErrorAlert bind:this={alert}>
{error} {error.msg}
<svelte:fragment slot="buttons"> <svelte:fragment slot="buttons">
<button class="btn btn-sm btn-alert-error" on:click={completeRequest}>Cancel</button> <button class="btn btn-sm btn-alert-error" on:click={cleanupRequest}>Cancel</button>
<button class="btn btn-sm btn-alert-error" on:click={respond}>Retry</button> <button class="btn btn-sm btn-alert-error" on:click={respond}>Retry</button>
</svelte:fragment> </svelte:fragment>
</ErrorAlert> </ErrorAlert>
{/if} {/if}
{#if $appState.currentRequest.base} {#if $appState.currentRequest?.base}
<div class="alert alert-warning shadow-lg"> <div class="alert alert-warning shadow-lg">
<div> <div>
<svg xmlns="http://www.w3.org/2000/svg" class="stroke-current flex-shrink-0 h-6 w-6" fill="none" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" /></svg> <svg xmlns="http://www.w3.org/2000/svg" class="stroke-current flex-shrink-0 h-6 w-6" fill="none" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" /></svg>
<span> <span>
WARNING: This application is requesting your long-lived AWS credentials. WARNING: This application is requesting your base AWS credentials.
These credentials are less secure than session credentials, since they don't expire automatically. These credentials are less secure than session credentials, since they don't expire automatically.
</span> </span>
</div> </div>
@ -113,7 +113,7 @@
<div class="w-full grid grid-cols-[1fr_auto] items-center gap-y-6"> <div class="w-full grid grid-cols-[1fr_auto] items-center gap-y-6">
<!-- Don't display the option to approve with session credentials if base was specifically requested --> <!-- Don't display the option to approve with session credentials if base was specifically requested -->
{#if !$appState.currentRequest.base} {#if !$appState.currentRequest?.base}
<h3 class="font-semibold"> <h3 class="font-semibold">
Approve with session credentials Approve with session credentials
</h3> </h3>
@ -126,7 +126,7 @@
<h3 class="font-semibold"> <h3 class="font-semibold">
<span class="mr-2"> <span class="mr-2">
{#if $appState.currentRequest.base} {#if $appState.currentRequest?.base}
Approve Approve
{:else} {:else}
Approve with base credentials Approve with base credentials

View File

@ -2,7 +2,7 @@
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import { draw, fade } from 'svelte/transition'; import { draw, fade } from 'svelte/transition';
import { appState, completeRequest } from '../lib/state.js'; import { appState, cleanupRequest } from '../lib/state.js';
let success = false; let success = false;
let error = null; let error = null;
@ -13,7 +13,7 @@
onMount(() => { onMount(() => {
window.setTimeout( window.setTimeout(
completeRequest, cleanupRequest,
// Extra 50ms so the window can finish disappearing before the redraw // Extra 50ms so the window can finish disappearing before the redraw
Math.min(5000, $appState.config.rehide_ms + 50), Math.min(5000, $appState.config.rehide_ms + 50),
) )