From 9fd355b68e673a9fec349c5b05beced24f460391 Mon Sep 17 00:00:00 2001 From: Joseph Montanaro Date: Wed, 3 Jul 2024 14:54:10 -0400 Subject: [PATCH] finish SSH key support --- src-tauri/src/app.rs | 6 +- src-tauri/src/cli.rs | 31 +++- src-tauri/src/credentials/record.rs | 10 ++ src-tauri/src/errors.rs | 4 +- src-tauri/src/ipc.rs | 5 +- src-tauri/src/lib.rs | 2 +- src-tauri/src/server/_ssh_agent.rs | 77 -------- src-tauri/src/server/server_unix.rs | 58 ------ src-tauri/src/server/server_win.rs | 74 -------- .../src/{server/ssh_agent.rs => srv/agent.rs} | 106 ++++------- .../{server/mod.rs => srv/creddy_server.rs} | 99 ++++------ src-tauri/src/srv/mod.rs | 170 ++++++++++++++++++ src-tauri/src/state.rs | 19 +- src-tauri/src/terminal.rs | 4 +- src/ui/PassphraseInput.svelte | 6 + src/views/Unlock.svelte | 9 +- src/views/approve/CollectResponse.svelte | 6 +- 17 files changed, 314 insertions(+), 372 deletions(-) delete mode 100644 src-tauri/src/server/_ssh_agent.rs delete mode 100644 src-tauri/src/server/server_unix.rs delete mode 100644 src-tauri/src/server/server_win.rs rename src-tauri/src/{server/ssh_agent.rs => srv/agent.rs} (53%) rename src-tauri/src/{server/mod.rs => srv/creddy_server.rs} (67%) create mode 100644 src-tauri/src/srv/mod.rs diff --git a/src-tauri/src/app.rs b/src-tauri/src/app.rs index 5beb4e2..2954f1e 100644 --- a/src-tauri/src/app.rs +++ b/src-tauri/src/app.rs @@ -21,7 +21,7 @@ use crate::{ config::{self, AppConfig}, credentials::AppSession, ipc, - server::{Server, Agent}, + srv::{creddy_server, agent}, errors::*, shortcuts, state::AppState, @@ -105,8 +105,8 @@ async fn setup(app: &mut App) -> Result<(), Box> { }; let app_session = AppSession::load(&pool).await?; - Server::start(app.handle().clone())?; - Agent::start(app.handle().clone())?; + creddy_server::serve(app.handle().clone())?; + agent::serve(app.handle().clone())?; config::set_auto_launch(conf.start_on_login)?; if let Err(_e) = config::set_auto_launch(conf.start_on_login) { diff --git a/src-tauri/src/cli.rs b/src-tauri/src/cli.rs index a8c6aad..ce5cf60 100644 --- a/src-tauri/src/cli.rs +++ b/src-tauri/src/cli.rs @@ -13,7 +13,11 @@ use clap::{ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::errors::*; -use crate::server::{Request, Response}; +use crate::srv::{ + self, + Request, + Response +}; use crate::shortcuts::ShortcutAction; #[cfg(unix)] @@ -47,6 +51,10 @@ pub fn parser() -> Command<'static> { .action(ArgAction::SetTrue) .help("Use base credentials instead of session credentials") ) + .arg( + Arg::new("name") + .help("If unspecified, use default credentials") + ) ) .subcommand( Command::new("exec") @@ -59,6 +67,12 @@ pub fn parser() -> Command<'static> { .action(ArgAction::SetTrue) .help("Use base credentials instead of session credentials") ) + .arg( + Arg::new("name") + .short('n') + .long("name") + .help("If unspecified, use default credentials") + ) .arg( Arg::new("command") .multiple_values(true) @@ -78,8 +92,10 @@ pub fn parser() -> Command<'static> { pub fn get(args: &ArgMatches) -> Result<(), CliError> { - let base = args.get_one("base").unwrap_or(&false); - let output = match make_request(&Request::GetAwsCredentials { base: *base })? { + let name = args.get_one("name").cloned(); + let base = *args.get_one("base").unwrap_or(&false); + + let output = match make_request(&Request::GetAwsCredentials { name, base })? { Response::AwsBase(creds) => serde_json::to_string(&creds).unwrap(), Response::AwsSession(creds) => serde_json::to_string(&creds).unwrap(), r => return Err(RequestError::Unexpected(r).into()), @@ -90,6 +106,7 @@ pub fn get(args: &ArgMatches) -> Result<(), CliError> { pub fn exec(args: &ArgMatches) -> Result<(), CliError> { + let name = args.get_one("name").cloned(); let base = *args.get_one("base").unwrap_or(&false); let mut cmd_line = args.get_many("command") .ok_or(ExecError::NoCommand)?; @@ -98,7 +115,7 @@ pub fn exec(args: &ArgMatches) -> Result<(), CliError> { let mut cmd = ChildCommand::new(cmd_name); cmd.args(cmd_line); - match make_request(&Request::GetAwsCredentials { base })? { + match make_request(&Request::GetAwsCredentials { name, base })? { Response::AwsBase(creds) => { cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id); cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key); @@ -178,7 +195,8 @@ async fn make_request(req: &Request) -> Result { async fn connect() -> Result { // apparently attempting to connect can fail if there's already a client connected loop { - match ClientOptions::new().open(r"\\.\pipe\creddy-requests") { + let addr = srv::addr("creddy-server"); + match ClientOptions::new().open(&addr) { Ok(stream) => return Ok(stream), Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY.0 as i32) => (), Err(e) => return Err(e), @@ -190,5 +208,6 @@ async fn connect() -> Result { #[cfg(unix)] async fn connect() -> Result { - UnixStream::connect("/tmp/creddy.sock").await + let path = srv::addr("creddy-server"); + UnixStream::connect(&path).await } diff --git a/src-tauri/src/credentials/record.rs b/src-tauri/src/credentials/record.rs index 21fdd39..7a2a1d0 100644 --- a/src-tauri/src/credentials/record.rs +++ b/src-tauri/src/credentials/record.rs @@ -122,6 +122,16 @@ impl CredentialRecord { // Self::load_credential(row, crypto, pool).await // } + pub async fn load_by_name(name: &str, crypto: &Crypto, pool: &SqlitePool) -> Result { + let row: CredentialRow = sqlx::query_as("SELECT * FROM credentials WHERE name = ?") + .bind(name) + .fetch_optional(pool) + .await? + .ok_or(LoadCredentialsError::NoCredentials)?; + + Self::load_credential(row, crypto, pool).await + } + pub async fn load_default(credential_type: &str, crypto: &Crypto, pool: &SqlitePool) -> Result { let row: CredentialRow = sqlx::query_as( "SELECT * FROM credentials diff --git a/src-tauri/src/errors.rs b/src-tauri/src/errors.rs index 602def2..9f4407b 100644 --- a/src-tauri/src/errors.rs +++ b/src-tauri/src/errors.rs @@ -338,6 +338,8 @@ pub enum ClientInfoError { #[cfg(windows)] #[error("Could not determine PID of connected client")] WindowsError(#[from] windows::core::Error), + #[error("Could not determine PID of connected client")] + PidNotFound, #[error(transparent)] Io(#[from] std::io::Error), } @@ -364,7 +366,7 @@ pub enum RequestError { #[error("Error response from server: {0}")] Server(ServerError), #[error("Unexpected response from server")] - Unexpected(crate::server::Response), + Unexpected(crate::srv::Response), #[error("The server did not respond with valid JSON")] InvalidJson(#[from] serde_json::Error), #[error("Error reading/writing stream: {0}")] diff --git a/src-tauri/src/ipc.rs b/src-tauri/src/ipc.rs index 28aec08..a79fb05 100644 --- a/src-tauri/src/ipc.rs +++ b/src-tauri/src/ipc.rs @@ -18,6 +18,7 @@ use crate::terminal; pub struct AwsRequestNotification { pub id: u64, pub client: Client, + pub name: Option, pub base: bool, } @@ -38,8 +39,8 @@ pub enum RequestNotification { } impl RequestNotification { - pub fn new_aws(id: u64, client: Client, base: bool) -> Self { - Self::Aws(AwsRequestNotification {id, client, base}) + pub fn new_aws(id: u64, client: Client, name: Option, base: bool) -> Self { + Self::Aws(AwsRequestNotification {id, client, name, base}) } pub fn new_ssh(id: u64, client: Client, key_name: String) -> Self { diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index ef09f40..6674c23 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -7,7 +7,7 @@ mod clientinfo; mod ipc; mod kv; mod state; -pub mod server; +mod srv; mod shortcuts; mod terminal; mod tray; diff --git a/src-tauri/src/server/_ssh_agent.rs b/src-tauri/src/server/_ssh_agent.rs deleted file mode 100644 index f1d21ac..0000000 --- a/src-tauri/src/server/_ssh_agent.rs +++ /dev/null @@ -1,77 +0,0 @@ -use signature::Signer; -use ssh_agent_lib::agent::{Agent, Session}; -use ssh_agent_lib::proto::message::Message; -use ssh_key::public::PublicKey; -use ssh_key::private::PrivateKey; -use tokio::net::UnixListener; - - -struct SshAgent; - -impl std::default::Default for SshAgent { - fn default() -> Self { - SshAgent {} - } -} - -#[ssh_agent_lib::async_trait] -impl Session for SshAgent { - async fn handle(&mut self, message: Message) -> Result> { - println!("Received message"); - match message { - Message::RequestIdentities => { - let p = std::path::PathBuf::from("/home/joe/.ssh/id_ed25519.pub"); - let pubkey = PublicKey::read_openssh_file(&p).unwrap(); - let id = ssh_agent_lib::proto::message::Identity { - pubkey_blob: pubkey.to_bytes().unwrap(), - comment: pubkey.comment().to_owned(), - }; - Ok(Message::IdentitiesAnswer(vec![id])) - }, - Message::SignRequest(req) => { - println!("Received sign request"); - let mut req_bytes = vec![13]; - encode_string(&mut req_bytes, &req.pubkey_blob); - encode_string(&mut req_bytes, &req.data); - req_bytes.extend(req.flags.to_be_bytes()); - std::fs::File::create("/tmp/signreq").unwrap().write(&req_bytes).unwrap(); - - let p = std::path::PathBuf::from("/home/joe/.ssh/id_ed25519"); - let passphrase = std::env::var("PRIVKEY_PASSPHRASE").unwrap(); - let privkey = PrivateKey::read_openssh_file(&p) - .unwrap() - .decrypt(passphrase.as_bytes()) - .unwrap(); - - - - let sig = Signer::sign(&privkey, &req.data); - use std::io::Write; - std::fs::File::create("/tmp/sig").unwrap().write(sig.as_bytes()).unwrap(); - - let mut payload = Vec::with_capacity(128); - encode_string(&mut payload, "ssh-ed25519".as_bytes()); - encode_string(&mut payload, sig.as_bytes()); - println!("Payload length: {}", payload.len()); - std::fs::File::create("/tmp/payload").unwrap().write(&payload).unwrap(); - Ok(Message::SignResponse(payload)) - }, - _ => Ok(Message::Failure), - } - } -} - - -fn encode_string(buf: &mut Vec, s: &[u8]) { - let len = s.len() as u32; - buf.extend(len.to_be_bytes()); - buf.extend(s); -} - - -pub async fn run() { - let socket = "/tmp/creddy-agent.sock"; - let _ = std::fs::remove_file(socket); - let listener = UnixListener::bind(socket).unwrap(); - SshAgent.listen(listener).await.unwrap(); -} diff --git a/src-tauri/src/server/server_unix.rs b/src-tauri/src/server/server_unix.rs deleted file mode 100644 index 86d3014..0000000 --- a/src-tauri/src/server/server_unix.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::io::ErrorKind; -use tokio::net::{UnixListener, UnixStream}; -use tauri::{ - AppHandle, - async_runtime as rt, -}; - -use crate::errors::*; - - -pub type Stream = UnixStream; - - -pub struct Server { - listener: UnixListener, - app_handle: AppHandle, -} - -impl Server { - pub fn start(app_handle: AppHandle) -> std::io::Result<()> { - match std::fs::remove_file("/tmp/creddy.sock") { - Ok(_) => (), - Err(e) if e.kind() == ErrorKind::NotFound => (), - Err(e) => return Err(e), - } - - let listener = UnixListener::bind("/tmp/creddy.sock")?; - let srv = Server { listener, app_handle }; - rt::spawn(srv.serve()); - Ok(()) - } - - async fn serve(self) { - loop { - self.try_serve() - .await - .error_print_prefix("Error accepting request: "); - } - } - - async fn try_serve(&self) -> Result<(), HandlerError> { - let (stream, _addr) = self.listener.accept().await?; - let new_handle = self.app_handle.clone(); - let client_pid = get_client_pid(&stream)?; - rt::spawn(async move { - super::handle(stream, new_handle, client_pid) - .await - .error_print_prefix("Error responding to request: "); - }); - Ok(()) - } -} - - -fn get_client_pid(stream: &UnixStream) -> std::io::Result { - let cred = stream.peer_cred()?; - Ok(cred.pid().unwrap() as u32) -} diff --git a/src-tauri/src/server/server_win.rs b/src-tauri/src/server/server_win.rs deleted file mode 100644 index bcb31ff..0000000 --- a/src-tauri/src/server/server_win.rs +++ /dev/null @@ -1,74 +0,0 @@ -use tokio::net::windows::named_pipe::{ - NamedPipeServer, - ServerOptions, -}; - -use tauri::{AppHandle, Manager}; - -use windows::Win32:: { - Foundation::HANDLE, - System::Pipes::GetNamedPipeClientProcessId, -}; - -use std::os::windows::io::AsRawHandle; - -use tauri::async_runtime as rt; - -use crate::errors::*; - - -// used by parent module -pub type Stream = NamedPipeServer; - - -pub struct Server { - listener: NamedPipeServer, - app_handle: AppHandle, -} - -impl Server { - pub fn start(app_handle: AppHandle) -> std::io::Result<()> { - let listener = ServerOptions::new() - .first_pipe_instance(true) - .create(r"\\.\pipe\creddy-requests")?; - - let srv = Server {listener, app_handle}; - rt::spawn(srv.serve()); - Ok(()) - } - - async fn serve(mut self) { - loop { - if let Err(e) = self.try_serve().await { - eprintln!("Error accepting connection: {e}"); - } - } - } - - async fn try_serve(&mut self) -> Result<(), HandlerError> { - // connect() just waits for a client to connect, it doesn't return anything - self.listener.connect().await?; - - // create a new pipe instance to listen for the next client, and swap it in - let new_listener = ServerOptions::new().create(r"\\.\pipe\creddy-requests")?; - let stream = std::mem::replace(&mut self.listener, new_listener); - let new_handle = self.app_handle.clone(); - let client_pid = get_client_pid(&stream)?; - rt::spawn(async move { - super::handle(stream, new_handle, client_pid) - .await - .error_print_prefix("Error responding to request: "); - }); - - Ok(()) - } -} - - -fn get_client_pid(pipe: &NamedPipeServer) -> Result { - let raw_handle = pipe.as_raw_handle(); - let mut pid = 0u32; - let handle = HANDLE(raw_handle as _); - unsafe { GetNamedPipeClientProcessId(handle, &mut pid as *mut u32)? }; - Ok(pid) -} diff --git a/src-tauri/src/server/ssh_agent.rs b/src-tauri/src/srv/agent.rs similarity index 53% rename from src-tauri/src/server/ssh_agent.rs rename to src-tauri/src/srv/agent.rs index a41613c..70f67a7 100644 --- a/src-tauri/src/server/ssh_agent.rs +++ b/src-tauri/src/srv/agent.rs @@ -1,98 +1,68 @@ -use std::io::ErrorKind; - use futures::SinkExt; use signature::Signer; use ssh_agent_lib::agent::MessageCodec; use ssh_agent_lib::proto::message::{ Message, - Identity, SignRequest, }; -use tokio::net::{UnixListener, UnixStream}; -use tauri::{ - AppHandle, - Manager, - async_runtime as rt, -}; -use tokio_util::codec::Framed; +use tauri::{AppHandle, Manager}; use tokio_stream::StreamExt; use tokio::sync::oneshot; +use tokio_util::codec::Framed; use crate::clientinfo; use crate::errors::*; use crate::ipc::{Approval, RequestNotification}; use crate::state::AppState; +use super::{CloseWaiter, Stream}; -pub struct Agent { - listener: UnixListener, - app_handle: AppHandle, -} -impl Agent { - pub fn start(app_handle: AppHandle) -> std::io::Result<()> { - match std::fs::remove_file("/tmp/creddy-agent.sock") { - Ok(_) => (), - Err(e) if e.kind() == ErrorKind::NotFound => (), - Err(e) => return Err(e), - } - - let listener = UnixListener::bind("/tmp/creddy-agent.sock")?; - let srv = Agent { listener, app_handle }; - rt::spawn(srv.serve()); - Ok(()) - } - - async fn serve(self) { - loop { - self.try_serve() - .await - .error_print_prefix("Error accepting request: "); - } - } - - async fn try_serve(&self) -> Result<(), HandlerError> { - let (stream, _addr) = self.listener.accept().await?; - let new_handle = self.app_handle.clone(); - let client_pid = get_client_pid(&stream)?; - rt::spawn(async move { - let adapter = Framed::new(stream, MessageCodec); - handle_framed(adapter, new_handle, client_pid) - .await - .error_print_prefix("Error responding to request: "); - }); - Ok(()) - } +pub fn serve(app_handle: AppHandle) -> std::io::Result<()> { + super::serve("creddy-agent", app_handle, handle) } -async fn handle_framed( - mut adapter: Framed, +async fn handle( + stream: Stream, app_handle: AppHandle, - client_pid: u32, + client_pid: u32 ) -> Result<(), HandlerError> { + let mut adapter = Framed::new(stream, MessageCodec); while let Some(message) = adapter.try_next().await? { - let resp = match message { - Message::RequestIdentities => list_identities(app_handle.clone()).await?, - Message::SignRequest(req) => sign_request(req, app_handle.clone(), client_pid).await?, - _ => Message::Failure, + match message { + Message::RequestIdentities => { + let resp = list_identities(app_handle.clone()).await?; + adapter.send(resp).await?; + }, + Message::SignRequest(req) => { + // CloseWaiter could corrupt the framing, but this doesn't matter + // since we don't plan to pull any more frames out of the stream + let waiter = CloseWaiter { stream: adapter.get_mut() }; + let resp = sign_request(req, app_handle.clone(), client_pid, waiter).await?; + adapter.send(resp).await?; + break; + }, + _ => adapter.send(Message::Failure).await?, }; - - adapter.send(resp).await?; } - Ok(()) } async fn list_identities(app_handle: AppHandle) -> Result { let state = app_handle.state::(); - let identities: Vec = state.list_ssh_identities().await?; + let identities = state.list_ssh_identities().await?; Ok(Message::IdentitiesAnswer(identities)) } -async fn sign_request(req: SignRequest, app_handle: AppHandle, client_pid: u32) -> Result { +async fn sign_request( + req: SignRequest, + app_handle: AppHandle, + client_pid: u32, + mut waiter: CloseWaiter<'_>, +) -> Result { let state = app_handle.state::(); let rehide_ms = { let config = state.config.read().await; @@ -110,7 +80,14 @@ async fn sign_request(req: SignRequest, app_handle: AppHandle, client_pid: u32) let notification = RequestNotification::new_ssh(request_id, client, key_name.clone()); app_handle.emit("credential-request", ¬ification)?; - let response = chan_recv.await?; + let response = tokio::select! { + r = chan_recv => r?, + _ = waiter.wait_for_close() => { + app_handle.emit("request-cancelled", request_id)?; + return Err(HandlerError::Abandoned); + }, + }; + if let Approval::Denied = response.approval { return Ok(Message::Failure); } @@ -137,13 +114,6 @@ async fn sign_request(req: SignRequest, app_handle: AppHandle, client_pid: u32) } - -fn get_client_pid(stream: &UnixStream) -> std::io::Result { - let cred = stream.peer_cred()?; - Ok(cred.pid().unwrap() as u32) -} - - fn encode_string(buf: &mut Vec, s: &[u8]) { let len = s.len() as u32; buf.extend(len.to_be_bytes()); diff --git a/src-tauri/src/server/mod.rs b/src-tauri/src/srv/creddy_server.rs similarity index 67% rename from src-tauri/src/server/mod.rs rename to src-tauri/src/srv/creddy_server.rs index 9c196d5..08d9547 100644 --- a/src-tauri/src/server/mod.rs +++ b/src-tauri/src/srv/creddy_server.rs @@ -1,75 +1,30 @@ +use tauri::{AppHandle, Manager}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::oneshot; -use serde::{Serialize, Deserialize}; - -use tauri::{AppHandle, Manager}; - -use crate::errors::*; use crate::clientinfo::{self, Client}; -use crate::credentials::{ - AwsBaseCredential, - AwsSessionCredential, -}; +use crate::errors::*; use crate::ipc::{Approval, RequestNotification}; -use crate::state::AppState; use crate::shortcuts::{self, ShortcutAction}; - -#[cfg(windows)] -mod server_win; -#[cfg(windows)] -pub use server_win::Server; -#[cfg(windows)] -use server_win::Stream; - -#[cfg(unix)] -mod server_unix; -#[cfg(unix)] -pub use server_unix::Server; -#[cfg(unix)] -use server_unix::Stream; - -pub mod ssh_agent; -pub use ssh_agent::Agent; +use crate::state::AppState; +use super::{ + CloseWaiter, + Request, + Response, + Stream, +}; -#[derive(Serialize, Deserialize)] -pub enum Request { - GetAwsCredentials{ - base: bool, - }, - InvokeShortcut(ShortcutAction), +pub fn serve(app_handle: AppHandle) -> std::io::Result<()> { + super::serve("creddy-server", app_handle, handle) } -#[derive(Debug, Serialize, Deserialize)] -pub enum Response { - AwsBase(AwsBaseCredential), - AwsSession(AwsSessionCredential), - Empty, -} - - -struct CloseWaiter<'s> { - stream: &'s mut Stream, -} - -impl<'s> CloseWaiter<'s> { - async fn wait_for_close(&mut self) -> std::io::Result<()> { - let mut buf = [0u8; 8]; - loop { - match self.stream.read(&mut buf).await { - Ok(0) => break Ok(()), - Ok(_) => (), - Err(e) => break Err(e), - } - } - } -} - - -async fn handle(mut stream: Stream, app_handle: AppHandle, client_pid: u32) -> Result<(), HandlerError> -{ +async fn handle( + mut stream: Stream, + app_handle: AppHandle, + client_pid: u32 +) -> Result<(), HandlerError> { // read from stream until delimiter is reached let mut buf: Vec = Vec::with_capacity(1024); // requests are small, 1KiB is more than enough let mut n = 0; @@ -78,7 +33,8 @@ async fn handle(mut stream: Stream, app_handle: AppHandle, client_pid: u32) -> R if let Some(&b'\n') = buf.last() { break; } - else if n >= 1024 { + // sanity check, no request should ever be within a mile of 1MB + else if n >= (1024 * 1024) { return Err(HandlerError::RequestTooLarge); } } @@ -86,12 +42,14 @@ async fn handle(mut stream: Stream, app_handle: AppHandle, client_pid: u32) -> R let client = clientinfo::get_client(client_pid, true)?; let waiter = CloseWaiter { stream: &mut stream }; + let req: Request = serde_json::from_slice(&buf)?; let res = match req { - Request::GetAwsCredentials{ base } => get_aws_credentials( - base, client, app_handle, waiter + Request::GetAwsCredentials { name, base } => get_aws_credentials( + name, base, client, app_handle, waiter ).await, Request::InvokeShortcut(action) => invoke_shortcut(action).await, + Request::GetSshSignature(_) => return Err(HandlerError::Denied), }; // doesn't make sense to send the error to the client if the client has already left @@ -112,6 +70,7 @@ async fn invoke_shortcut(action: ShortcutAction) -> Result, base: bool, client: Client, app_handle: AppHandle, @@ -132,7 +91,9 @@ async fn get_aws_credentials( // but ? returns immediately, and we want to unregister the request before returning // so we bundle it all up in an async block and return a Result so we can handle errors let proceed = async { - let notification = RequestNotification::new_aws(request_id, client, base); + let notification = RequestNotification::new_aws( + request_id, client, name.clone(), base + ); app_handle.emit("credential-request", ¬ification)?; let response = tokio::select! { @@ -146,11 +107,11 @@ async fn get_aws_credentials( match response.approval { Approval::Approved => { if response.base { - let creds = state.get_aws_default().await?; + let creds = state.get_aws_base(name).await?; Ok(Response::AwsBase(creds)) } else { - let creds = state.get_aws_default_session().await?; + let creds = state.get_aws_session(name).await?; Ok(Response::AwsSession(creds.clone())) } }, @@ -163,9 +124,9 @@ async fn get_aws_credentials( Err(e) => { state.unregister_request(request_id).await; Err(e) - } + }, }; lease.release(); result -} +} \ No newline at end of file diff --git a/src-tauri/src/srv/mod.rs b/src-tauri/src/srv/mod.rs new file mode 100644 index 0000000..e5bdd40 --- /dev/null +++ b/src-tauri/src/srv/mod.rs @@ -0,0 +1,170 @@ +use std::future::Future; + +use tauri::{ + AppHandle, + async_runtime as rt, +}; +use tokio::io::AsyncReadExt; +use serde::{Serialize, Deserialize}; +use ssh_agent_lib::proto::message::SignRequest; + +use crate::credentials::{AwsBaseCredential, AwsSessionCredential}; +use crate::errors::*; +use crate::shortcuts::ShortcutAction; + +pub mod creddy_server; +pub mod agent; +use platform::Stream; +pub use platform::addr; + + +#[derive(Debug, Serialize, Deserialize)] +pub enum Request { + GetAwsCredentials { + name: Option, + base: bool, + }, + GetSshSignature(SignRequest), + InvokeShortcut(ShortcutAction), +} + + +#[derive(Debug, Serialize, Deserialize)] +pub enum Response { + AwsBase(AwsBaseCredential), + AwsSession(AwsSessionCredential), + Empty, +} + + +struct CloseWaiter<'s> { + stream: &'s mut Stream, +} + +impl<'s> CloseWaiter<'s> { + async fn wait_for_close(&mut self) -> std::io::Result<()> { + let mut buf = [0u8; 8]; + loop { + match self.stream.read(&mut buf).await { + Ok(0) => break Ok(()), + Ok(_) => (), + Err(e) => break Err(e), + } + } + } +} + + +fn serve(sock_name: &str, app_handle: AppHandle, handler: H) -> std::io::Result<()> + where H: Copy + Send + Fn(Stream, AppHandle, u32) -> F + 'static, + F: Send + Future>, +{ + let (mut listener, addr) = platform::bind(sock_name)?; + rt::spawn(async move { + loop { + let (stream, client_pid) = match platform::accept(&mut listener, &addr).await { + Ok((s, c)) => (s, c), + Err(e) => { + eprintln!("Error accepting request: {e}"); + continue; + }, + }; + let new_handle = app_handle.clone(); + rt::spawn(async move { + handler(stream, new_handle, client_pid) + .await + .error_print_prefix("Error responding to request: "); + }); + } + }); + Ok(()) +} + + +#[cfg(unix)] +mod platform { + use std::io::ErrorKind; + use std::path::PathBuf; + use tokio::net::{UnixListener, UnixStream}; + use super::*; + + + pub type Stream = UnixStream; + + pub fn bind(sock_name: &str) -> std::io::Result<(UnixListener, PathBuf)> { + let path = addr(sock_name); + match std::fs::remove_file(&path) { + Ok(_) => (), + Err(e) if e.kind() == ErrorKind::NotFound => (), + Err(e) => return Err(e), + } + + let listener = UnixListener::bind(&path)?; + Ok((listener, path)) + } + + pub async fn accept(listener: &mut UnixListener, _addr: &PathBuf) -> Result<(UnixStream, u32), HandlerError> { + let (stream, _addr) = listener.accept().await?; + let pid = stream.peer_cred()? + .pid() + .ok_or(ClientInfoError::PidNotFound)? + as u32; + + Ok((stream, pid)) + } + + + pub fn addr(sock_name: &str) -> PathBuf { + let mut path = dirs::runtime_dir() + .unwrap_or_else(|| PathBuf::from("/tmp")); + path.push(format!("{sock_name}.sock")); + path + } +} + + +#[cfg(windows)] +mod platform { + use std::os::windows::io::AsRawHandle; + use tokio::net::windows::named_pipe::{ + NamedPipeServer, + ServerOptions, + }; + use windows::Win32::{ + Foundation::HANDLE, + System::Pipes::GetNamedPipeClientProcessId, + }; + use super::*; + + + pub type Stream = NamedPipeServer; + + pub fn bind(sock_name: &str) -> std::io::Result<(String, NamedPipeServer)> { + let addr = addr(sock_name); + let listener = ServerOptions::new() + .first_pipe_instance(true) + .create(&addr)?; + Ok((listener, addr)) + } + + pub async fn accept(listener: &mut NamedPipeServer, addr: &String) -> Result<(NamedPipeServer, u32), HandlerError> { + // connect() just waits for a client to connect, it doesn't return anything + listener.connect().await?; + + // unlike Unix sockets, a Windows NamedPipeServer *becomes* the open stream + // once a client connects. If we want to keep listening, we have to construct + // a new server and swap it in. + let new_listener = ServerOptions::new().create(addr)?; + let stream = std::mem::replace(listener, new_listener); + + let raw_handle = stream.as_raw_handle(); + let mut pid = 0u32; + let handle = HANDLE(raw_handle as _); + unsafe { GetNamedPipeClientProcessId(handle, &mut pid as *mut u32)? }; + Ok((stream, pid)) + } + + pub fn addr(sock_name: &str) -> String { + format!(r"\\.\pipe\{sock_name}") + } +} diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index 1fb1513..5e5f535 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -270,22 +270,23 @@ impl AppState { Ok(()) } - pub async fn get_aws_default(&self) -> Result { + pub async fn get_aws_base(&self, name: Option) -> Result { let app_session = self.app_session.read().await; let crypto = app_session.try_get_crypto()?; - let creds = AwsBaseCredential::load_default(crypto, &self.pool).await?; - // let record = CredentialRecord::load_default("aws", crypto, &self.pool).await?; - // let creds = match record.credential { - // Credential::AwsBase(b) => Ok(b), - // _ => Err(LoadCredentialsError::NoCredentials) - // }?; + let creds = match name { + Some(n) => AwsBaseCredential::load_by_name(&n, crypto, &self.pool).await?, + None => AwsBaseCredential::load_default(crypto, &self.pool).await?, + }; Ok(creds) } - pub async fn get_aws_default_session(&self) -> Result, GetCredentialsError> { + pub async fn get_aws_session(&self, name: Option) -> Result, GetCredentialsError> { let app_session = self.app_session.read().await; let crypto = app_session.try_get_crypto()?; - let record = CredentialRecord::load_default("aws", crypto, &self.pool).await?; + let record = match name { + Some(n) => CredentialRecord::load_by_name(&n, crypto, &self.pool).await?, + None => CredentialRecord::load_default("aws", crypto, &self.pool).await?, + }; let base = match &record.credential { Credential::AwsBase(b) => Ok(b), _ => Err(LoadCredentialsError::NoCredentials) diff --git a/src-tauri/src/terminal.rs b/src-tauri/src/terminal.rs index 69e5336..436c7b9 100644 --- a/src-tauri/src/terminal.rs +++ b/src-tauri/src/terminal.rs @@ -63,12 +63,12 @@ async fn do_launch(app: &AppHandle, use_base: bool) -> Result<(), LaunchTerminal // (i.e. lies about unlocking) we could end up here with a locked session // this will result in an error popup to the user (see main hotkey handler) if use_base { - let base_creds = state.get_aws_default().await?; + let base_creds = state.get_aws_base(None).await?; cmd.env("AWS_ACCESS_KEY_ID", &base_creds.access_key_id); cmd.env("AWS_SECRET_ACCESS_KEY", &base_creds.secret_access_key); } else { - let session_creds = state.get_aws_default_session().await?; + let session_creds = state.get_aws_session(None).await?; cmd.env("AWS_ACCESS_KEY_ID", &session_creds.access_key_id); cmd.env("AWS_SECRET_ACCESS_KEY", &session_creds.secret_access_key); cmd.env("AWS_SESSION_TOKEN", &session_creds.session_token); diff --git a/src/ui/PassphraseInput.svelte b/src/ui/PassphraseInput.svelte index 47f3934..9c3c0d1 100644 --- a/src/ui/PassphraseInput.svelte +++ b/src/ui/PassphraseInput.svelte @@ -8,6 +8,11 @@ export {classes as class}; let show = false; + let input; + + export function focus() { + input.focus(); + } @@ -21,6 +26,7 @@
value = e.target.value} diff --git a/src/views/Unlock.svelte b/src/views/Unlock.svelte index 37372e7..c62d8da 100644 --- a/src/views/Unlock.svelte +++ b/src/views/Unlock.svelte @@ -34,6 +34,9 @@ } } + + let input; + onMount(() => input.focus()); @@ -52,7 +55,11 @@ - +