Compare commits

..

No commits in common. "4881b90b0b524baa9b060c245e20d6cb3115277c" and "367a140e2ab3bc30c2073b40c395446f8d49a991" have entirely different histories.

27 changed files with 667 additions and 810 deletions

3
.gitignore vendored
View File

@ -2,9 +2,6 @@ dist
**/node_modules **/node_modules
src-tauri/target/ src-tauri/target/
**/creddy.db **/creddy.db
# .env is system-specific
.env
.vscode
# just in case # just in case
credentials* credentials*

View File

@ -16,4 +16,3 @@
* Generalize Request across both credentials and terminal launch? * Generalize Request across both credentials and terminal launch?
* Make hotkey configuration a little more tolerant of slight mistiming * Make hotkey configuration a little more tolerant of slight mistiming
* Distinguish between request that was denied and request that was canceled (e.g. due to error) * Distinguish between request that was denied and request that was canceled (e.g. due to error)
* Use atomic types for primitive state values instead of RwLock'd types

4
package-lock.json generated
View File

@ -1,12 +1,12 @@
{ {
"name": "creddy", "name": "creddy",
"version": "0.3.3", "version": "0.3.1",
"lockfileVersion": 2, "lockfileVersion": 2,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "creddy", "name": "creddy",
"version": "0.3.3", "version": "0.3.1",
"dependencies": { "dependencies": {
"@tauri-apps/api": "^1.0.2", "@tauri-apps/api": "^1.0.2",
"daisyui": "^2.51.5" "daisyui": "^2.51.5"

1
src-tauri/.env Normal file
View File

@ -0,0 +1 @@
DATABASE_URL=sqlite://C:/Users/Joe/AppData/Roaming/creddy/creddy.dev.db

46
src-tauri/Cargo.lock generated
View File

@ -1047,6 +1047,7 @@ dependencies = [
"clap", "clap",
"dirs 5.0.1", "dirs 5.0.1",
"is-terminal", "is-terminal",
"netstat2",
"once_cell", "once_cell",
"serde", "serde",
"serde_json", "serde_json",
@ -1061,7 +1062,6 @@ dependencies = [
"thiserror", "thiserror",
"tokio", "tokio",
"which", "which",
"windows 0.51.1",
] ]
[[package]] [[package]]
@ -2641,6 +2641,20 @@ dependencies = [
"jni-sys", "jni-sys",
] ]
[[package]]
name = "netstat2"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0faa3f4ad230fd2bf2a5dad71476ecbaeaed904b3c7e7e5b1f266c415c03761f"
dependencies = [
"bitflags 1.3.2",
"byteorder",
"libc",
"num-derive",
"num-traits",
"thiserror",
]
[[package]] [[package]]
name = "new_debug_unreachable" name = "new_debug_unreachable"
version = "1.0.4" version = "1.0.4"
@ -2694,6 +2708,17 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "num-derive"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "num-integer" name = "num-integer"
version = "0.1.45" version = "0.1.45"
@ -5236,16 +5261,6 @@ dependencies = [
"windows-targets 0.48.5", "windows-targets 0.48.5",
] ]
[[package]]
name = "windows"
version = "0.51.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca229916c5ee38c2f2bc1e9d8f04df975b4bd93f9955dc69fabb5d91270045c9"
dependencies = [
"windows-core",
"windows-targets 0.48.5",
]
[[package]] [[package]]
name = "windows-bindgen" name = "windows-bindgen"
version = "0.39.0" version = "0.39.0"
@ -5256,15 +5271,6 @@ dependencies = [
"windows-tokens", "windows-tokens",
] ]
[[package]]
name = "windows-core"
version = "0.51.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64"
dependencies = [
"windows-targets 0.48.5",
]
[[package]] [[package]]
name = "windows-implement" name = "windows-implement"
version = "0.39.0" version = "0.39.0"

View File

@ -30,6 +30,7 @@ tauri-plugin-single-instance = { git = "https://github.com/tauri-apps/plugins-wo
sodiumoxide = "0.2.7" sodiumoxide = "0.2.7"
tokio = { version = ">=1.19", features = ["full"] } tokio = { version = ">=1.19", features = ["full"] }
sqlx = { version = "0.6.2", features = ["sqlite", "runtime-tokio-rustls"] } sqlx = { version = "0.6.2", features = ["sqlite", "runtime-tokio-rustls"] }
netstat2 = "0.9.1"
sysinfo = "0.26.8" sysinfo = "0.26.8"
aws-types = "0.52.0" aws-types = "0.52.0"
aws-sdk-sts = "0.22.0" aws-sdk-sts = "0.22.0"
@ -46,7 +47,6 @@ is-terminal = "0.4.7"
argon2 = { version = "0.5.0", features = ["std"] } argon2 = { version = "0.5.0", features = ["std"] }
chacha20poly1305 = { version = "0.10.1", features = ["std"] } chacha20poly1305 = { version = "0.10.1", features = ["std"] }
which = "4.4.0" which = "4.4.0"
windows = { version = "0.51.1", features = ["Win32_Foundation", "Win32_System_Pipes"] }
[features] [features]
# by default Tauri runs in production mode # by default Tauri runs in production mode

View File

@ -19,7 +19,6 @@ use crate::{
ipc, ipc,
server::Server, server::Server,
errors::*, errors::*,
shortcuts,
state::AppState, state::AppState,
tray, tray,
}; };
@ -94,18 +93,16 @@ async fn setup(app: &mut App) -> Result<(), Box<dyn Error>> {
}; };
let session = Session::load(&pool).await?; let session = Session::load(&pool).await?;
Server::start(app.handle())?; let srv = Server::new(conf.listen_addr, conf.listen_port, app.handle()).await?;
config::set_auto_launch(conf.start_on_login)?; config::set_auto_launch(conf.start_on_login)?;
if let Err(_e) = config::set_auto_launch(conf.start_on_login) { if let Err(_e) = config::set_auto_launch(conf.start_on_login) {
setup_errors.push("Error: Failed to manage autolaunch.".into()); setup_errors.push("Error: Failed to manage autolaunch.".into());
} }
if let Err(e) = config::register_hotkeys(&conf.hotkeys) {
// if hotkeys fail to register, disable them so that this error doesn't have to keep showing up conf.hotkeys.show_window.enabled = false;
if let Err(_e) = shortcuts::register_hotkeys(&conf.hotkeys) { conf.hotkeys.launch_terminal.enabled = false;
conf.hotkeys.disable_all(); setup_errors.push(format!("{e}"));
conf.save(&pool).await?;
setup_errors.push("Failed to register hotkeys. Hotkey settings have been disabled.".into());
} }
// if session is empty, this is probably the first launch, so don't autohide // if session is empty, this is probably the first launch, so don't autohide
@ -115,7 +112,7 @@ async fn setup(app: &mut App) -> Result<(), Box<dyn Error>> {
.show()?; .show()?;
} }
let state = AppState::new(conf, session, pool, setup_errors); let state = AppState::new(conf, session, srv, pool, setup_errors);
app.manage(state); app.manage(state);
Ok(()) Ok(())
} }

View File

@ -19,15 +19,13 @@ fn main() {
let res = match args.subcommand() { let res = match args.subcommand() {
None | Some(("run", _)) => launch_gui(), None | Some(("run", _)) => launch_gui(),
Some(("get", m)) => cli::get(m), Some(("show", m)) => cli::show(m),
Some(("exec", m)) => cli::exec(m), Some(("exec", m)) => cli::exec(m),
Some(("shortcut", m)) => cli::invoke_shortcut(m), _ => unreachable!(),
_ => unreachable!("Unknown subcommand"),
}; };
if let Err(e) = res { if let Err(e) = res {
eprintln!("Error: {e}"); eprintln!("Error: {e}");
process::exit(1);
} }
} }

View File

@ -1,33 +1,24 @@
use std::ffi::OsString; use std::ffi::OsString;
use std::process::Command as ChildCommand; use std::process::Command as ChildCommand;
#[cfg(windows)] #[cfg(unix)]
use std::time::Duration; use std::os::unix::process::CommandExt;
use clap::{ use clap::{
Command, Command,
Arg, Arg,
ArgMatches, ArgMatches,
ArgAction, ArgAction
builder::PossibleValuesParser,
}; };
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::{
net::TcpStream,
io::{AsyncReadExt, AsyncWriteExt},
};
use crate::credentials::Credentials;
use crate::app;
use crate::config::AppConfig;
use crate::credentials::{BaseCredentials, SessionCredentials};
use crate::errors::*; use crate::errors::*;
use crate::server::{Request, Response};
use crate::shortcuts::ShortcutAction;
#[cfg(unix)]
use {
std::os::unix::process::CommandExt,
tokio::net::UnixStream,
};
#[cfg(windows)]
use {
tokio::net::windows::named_pipe::{NamedPipeClient, ClientOptions},
windows::Win32::Foundation::ERROR_PIPE_BUSY,
};
pub fn parser() -> Command<'static> { pub fn parser() -> Command<'static> {
@ -39,8 +30,8 @@ pub fn parser() -> Command<'static> {
.about("Launch Creddy") .about("Launch Creddy")
) )
.subcommand( .subcommand(
Command::new("get") Command::new("show")
.about("Request AWS credentials from Creddy and output to stdout") .about("Fetch and display AWS credentials")
.arg( .arg(
Arg::new("base") Arg::new("base")
.short('b') .short('b')
@ -65,26 +56,13 @@ pub fn parser() -> Command<'static> {
.multiple_values(true) .multiple_values(true)
) )
) )
.subcommand(
Command::new("shortcut")
.about("Invoke an action normally trigged by hotkey (e.g. launch terminal)")
.arg(
Arg::new("action")
.value_parser(
PossibleValuesParser::new(["show_window", "launch_terminal"])
)
)
)
} }
pub fn get(args: &ArgMatches) -> Result<(), CliError> { pub fn show(args: &ArgMatches) -> Result<(), CliError> {
let base = args.get_one("base").unwrap_or(&false); let base = args.get_one("base").unwrap_or(&false);
let output = match get_credentials(*base)? { let creds = get_credentials(*base)?;
Credentials::Base(creds) => serde_json::to_string(&creds).unwrap(), println!("{creds}");
Credentials::Session(creds) => serde_json::to_string(&creds).unwrap(),
};
println!("{output}");
Ok(()) Ok(())
} }
@ -98,16 +76,18 @@ pub fn exec(args: &ArgMatches) -> Result<(), CliError> {
let mut cmd = ChildCommand::new(cmd_name); let mut cmd = ChildCommand::new(cmd_name);
cmd.args(cmd_line); cmd.args(cmd_line);
match get_credentials(base)? { if base {
Credentials::Base(creds) => { let creds: BaseCredentials = serde_json::from_str(&get_credentials(base)?)
cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id); .map_err(|_| RequestError::InvalidJson)?;
cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key); cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id);
}, cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key);
Credentials::Session(creds) => { }
cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id); else {
cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key); let creds: SessionCredentials = serde_json::from_str(&get_credentials(base)?)
cmd.env("AWS_SESSION_TOKEN", creds.session_token); .map_err(|_| RequestError::InvalidJson)?;
} cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id);
cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key);
cmd.env("AWS_SESSION_TOKEN", creds.token);
} }
#[cfg(unix)] #[cfg(unix)]
@ -141,63 +121,41 @@ pub fn exec(args: &ArgMatches) -> Result<(), CliError> {
} }
pub fn invoke_shortcut(args: &ArgMatches) -> Result<(), CliError> {
let action = match args.get_one::<String>("action").map(|s| s.as_str()) {
Some("show_window") => ShortcutAction::ShowWindow,
Some("launch_terminal") => ShortcutAction::LaunchTerminal,
Some(&_) | None => unreachable!("Unknown shortcut action"), // guaranteed by clap
};
let req = Request::InvokeShortcut(action);
match make_request(&req) {
Ok(Response::Empty) => Ok(()),
Ok(r) => Err(RequestError::Unexpected(r).into()),
Err(e) => Err(e.into()),
}
}
fn get_credentials(base: bool) -> Result<Credentials, RequestError> {
let req = Request::GetAwsCredentials { base };
match make_request(&req) {
Ok(Response::Aws(creds)) => Ok(creds),
Ok(r) => Err(RequestError::Unexpected(r)),
Err(e) => Err(e),
}
}
#[tokio::main] #[tokio::main]
async fn make_request(req: &Request) -> Result<Response, RequestError> { async fn get_credentials(base: bool) -> Result<String, RequestError> {
let mut data = serde_json::to_string(req).unwrap(); let pool = app::connect_db().await?;
// server expects newline marking end of request let config = AppConfig::load(&pool).await?;
data.push('\n'); let path = if base {"/creddy/base-credentials"} else {"/"};
let mut stream = connect().await?; let mut stream = TcpStream::connect((config.listen_addr, config.listen_port)).await?;
stream.write_all(&data.as_bytes()).await?; let req = format!("GET {path} HTTP/1.0\r\n\r\n");
stream.write_all(req.as_bytes()).await?;
let mut buf = Vec::with_capacity(1024); // some day we'll have a proper HTTP parser
let mut buf = vec![0; 8192];
stream.read_to_end(&mut buf).await?; stream.read_to_end(&mut buf).await?;
let res: Result<Response, ServerError> = serde_json::from_slice(&buf)?;
Ok(res?)
}
let status = buf.split(|&c| &[c] == b" ")
.skip(1)
.next()
.ok_or(RequestError::MalformedHttpResponse)?;
#[cfg(windows)] if status != b"200" {
async fn connect() -> Result<NamedPipeClient, std::io::Error> { let s = String::from_utf8_lossy(status).to_string();
// apparently attempting to connect can fail if there's already a client connected return Err(RequestError::Failed(s));
loop {
match ClientOptions::new().open(r"\\.\pipe\creddy-requests") {
Ok(stream) => return Ok(stream),
Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY.0 as i32) => (),
Err(e) => return Err(e),
}
tokio::time::sleep(Duration::from_millis(10)).await;
} }
}
let break_idx = buf.windows(4)
.position(|w| w == b"\r\n\r\n")
.ok_or(RequestError::MalformedHttpResponse)?;
let body = &buf[(break_idx + 4)..];
#[cfg(unix)] let creds_str = std::str::from_utf8(body)
async fn connect() -> Result<UnixStream, std::io::Error> { .map_err(|_| RequestError::MalformedHttpResponse)?
UnixStream::connect("/tmp/creddy.sock").await .to_string();
if creds_str == "Denied!" {
return Err(RequestError::Rejected);
}
Ok(creds_str)
} }

View File

@ -1,92 +1,76 @@
use std::path::{Path, PathBuf}; use std::path::PathBuf;
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
use tauri::Manager;
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt}; use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::errors::*; use crate::{
app::APP,
errors::*,
config::AppConfig,
state::AppState,
};
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Client { pub struct Client {
pub pid: u32, pub pid: u32,
pub exe: Option<PathBuf>, pub exe: PathBuf,
} }
pub fn get_process_parent_info(pid: u32) -> Result<Client, ClientInfoError> { async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
dbg!(pid); let state = APP.get().unwrap().state::<AppState>();
let sys_pid = Pid::from_u32(pid); let AppConfig {
let mut sys = System::new(); listen_addr: app_listen_addr,
sys.refresh_process(sys_pid); listen_port: app_listen_port,
let proc = sys.process(sys_pid) ..
.ok_or(ClientInfoError::ProcessNotFound)?; } = *state.config.read().await;
let parent_pid_sys = proc.parent() let sockets_iter = netstat2::iterate_sockets_info(
.ok_or(ClientInfoError::ParentPidNotFound)?; AddressFamilyFlags::IPV4,
sys.refresh_process(parent_pid_sys); ProtocolFlags::TCP
let parent = sys.process(parent_pid_sys) )?;
.ok_or(ClientInfoError::ParentProcessNotFound)?; for item in sockets_iter {
let sock_info = item?;
let proto_info = match sock_info.protocol_socket_info {
ProtocolSocketInfo::Tcp(tcp_info) => tcp_info,
ProtocolSocketInfo::Udp(_) => {continue;}
};
let exe = match parent.exe() { if proto_info.local_port == local_port
p if p == Path::new("") => None, && proto_info.remote_port == app_listen_port
p => Some(PathBuf::from(p)), && proto_info.local_addr == app_listen_addr
}; && proto_info.remote_addr == app_listen_addr
{
Ok(Client { pid: parent_pid_sys.as_u32(), exe }) return Ok(sock_info.associated_pids)
}
}
Ok(vec![])
} }
// async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
// let state = APP.get().unwrap().state::<AppState>();
// let AppConfig {
// listen_addr: app_listen_addr,
// listen_port: app_listen_port,
// ..
// } = *state.config.read().await;
// let sockets_iter = netstat2::iterate_sockets_info(
// AddressFamilyFlags::IPV4,
// ProtocolFlags::TCP
// )?;
// for item in sockets_iter {
// let sock_info = item?;
// let proto_info = match sock_info.protocol_socket_info {
// ProtocolSocketInfo::Tcp(tcp_info) => tcp_info,
// ProtocolSocketInfo::Udp(_) => {continue;}
// };
// if proto_info.local_port == local_port
// && proto_info.remote_port == app_listen_port
// && proto_info.local_addr == app_listen_addr
// && proto_info.remote_addr == app_listen_addr
// {
// return Ok(sock_info.associated_pids)
// }
// }
// Ok(vec![])
// }
// Theoretically, on some systems, multiple processes can share a socket // Theoretically, on some systems, multiple processes can share a socket
// pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> { pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
// let mut clients = Vec::new(); let mut clients = Vec::new();
// let mut sys = System::new(); let mut sys = System::new();
// for p in get_associated_pids(local_port).await? { for p in get_associated_pids(local_port).await? {
// let pid = Pid::from_u32(p); let pid = Pid::from_u32(p);
// sys.refresh_process(pid); sys.refresh_process(pid);
// let proc = sys.process(pid) let proc = sys.process(pid)
// .ok_or(ClientInfoError::ProcessNotFound)?; .ok_or(ClientInfoError::ProcessNotFound)?;
// let client = Client { let client = Client {
// pid: p, pid: p,
// exe: proc.exe().to_path_buf(), exe: proc.exe().to_path_buf(),
// }; };
// clients.push(Some(client)); clients.push(Some(client));
// } }
// if clients.is_empty() { if clients.is_empty() {
// clients.push(None); clients.push(None);
// } }
// Ok(clients) Ok(clients)
// } }

View File

@ -1,9 +1,15 @@
use std::net::Ipv4Addr;
use std::path::PathBuf; use std::path::PathBuf;
use auto_launch::AutoLaunchBuilder; use auto_launch::AutoLaunchBuilder;
use is_terminal::IsTerminal; use is_terminal::IsTerminal;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use tauri::{
Manager,
GlobalShortcutManager,
async_runtime as rt,
};
use crate::errors::*; use crate::errors::*;
@ -33,16 +39,13 @@ pub struct HotkeysConfig {
pub launch_terminal: Hotkey, pub launch_terminal: Hotkey,
} }
impl HotkeysConfig {
pub fn disable_all(&mut self) {
self.show_window.enabled = false;
self.launch_terminal.enabled = false;
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AppConfig { pub struct AppConfig {
#[serde(default = "default_listen_addr")]
pub listen_addr: Ipv4Addr,
#[serde(default = "default_listen_port")]
pub listen_port: u16,
#[serde(default = "default_rehide_ms")] #[serde(default = "default_rehide_ms")]
pub rehide_ms: u64, pub rehide_ms: u64,
#[serde(default = "default_start_minimized")] #[serde(default = "default_start_minimized")]
@ -59,6 +62,8 @@ pub struct AppConfig {
impl Default for AppConfig { impl Default for AppConfig {
fn default() -> Self { fn default() -> Self {
AppConfig { AppConfig {
listen_addr: default_listen_addr(),
listen_port: default_listen_port(),
rehide_ms: default_rehide_ms(), rehide_ms: default_rehide_ms(),
start_minimized: default_start_minimized(), start_minimized: default_start_minimized(),
start_on_login: default_start_on_login(), start_on_login: default_start_on_login(),
@ -139,6 +144,16 @@ pub fn get_or_create_db_path() -> Result<PathBuf, DataDirError> {
} }
fn default_listen_port() -> u16 {
if cfg!(debug_assertions) {
12_345
}
else {
19_923
}
}
fn default_term_config() -> TermConfig { fn default_term_config() -> TermConfig {
#[cfg(windows)] #[cfg(windows)]
{ {
@ -185,7 +200,52 @@ fn default_hotkey_config() -> HotkeysConfig {
} }
} }
// note: will panic if called before APP is set
pub fn register_hotkeys(hotkeys: &HotkeysConfig) -> tauri::Result<()> {
let app = crate::app::APP.get().unwrap();
let mut manager = app.global_shortcut_manager();
if let Err(_e) = manager.unregister_all() {
if !hotkeys.show_window.enabled && !hotkeys.launch_terminal.enabled {
// if both are disabled and we failed to unregister, then probably
// we also failed to register in the first place
return Ok(())
}
}
if hotkeys.show_window.enabled {
let handle = app.app_handle();
manager.register(
&hotkeys.show_window.keys,
move || {
handle.get_window("main")
.map(|w| w.show().error_popup("Failed to show"))
.ok_or(HandlerError::NoMainWindow)
.error_popup("No main window");
},
)?;
}
if hotkeys.launch_terminal.enabled {
// register() doesn't take an async fn, so we have to use spawn
manager.register(
&hotkeys.launch_terminal.keys,
|| {
rt::spawn(async {
crate::terminal::launch(false)
.await
.error_popup("Failed to launch");
});
}
)?;
}
Ok(())
}
fn default_listen_addr() -> Ipv4Addr { Ipv4Addr::LOCALHOST }
fn default_rehide_ms() -> u64 { 1000 } fn default_rehide_ms() -> u64 { 1000 }
// start minimized and on login only in production mode // start minimized and on login only in production mode
fn default_start_minimized() -> bool { !cfg!(debug_assertions) } fn default_start_minimized() -> bool { !cfg!(debug_assertions) }

View File

@ -162,10 +162,9 @@ impl BaseCredentials {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")] #[serde(rename_all = "PascalCase")]
pub struct SessionCredentials { pub struct SessionCredentials {
pub version: usize,
pub access_key_id: String, pub access_key_id: String,
pub secret_access_key: String, pub secret_access_key: String,
pub session_token: String, pub token: String,
#[serde(serialize_with = "serialize_expiration")] #[serde(serialize_with = "serialize_expiration")]
#[serde(deserialize_with = "deserialize_expiration")] #[serde(deserialize_with = "deserialize_expiration")]
pub expiration: DateTime, pub expiration: DateTime,
@ -199,7 +198,7 @@ impl SessionCredentials {
let secret_access_key = aws_session.secret_access_key() let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::EmptyResponse)? .ok_or(GetSessionError::EmptyResponse)?
.to_string(); .to_string();
let session_token = aws_session.session_token() let token = aws_session.session_token()
.ok_or(GetSessionError::EmptyResponse)? .ok_or(GetSessionError::EmptyResponse)?
.to_string(); .to_string();
let expiration = aws_session.expiration() let expiration = aws_session.expiration()
@ -207,10 +206,9 @@ impl SessionCredentials {
.clone(); .clone();
let session_creds = SessionCredentials { let session_creds = SessionCredentials {
version: 1,
access_key_id, access_key_id,
secret_access_key, secret_access_key,
session_token, token,
expiration, expiration,
}; };
@ -232,14 +230,6 @@ impl SessionCredentials {
} }
} }
#[derive(Debug, Serialize, Deserialize)]
pub enum Credentials {
Base(BaseCredentials),
Session(SessionCredentials),
}
fn serialize_expiration<S>(exp: &DateTime, serializer: S) -> Result<S::Ok, S::Error> fn serialize_expiration<S>(exp: &DateTime, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer where S: Serializer
{ {

View File

@ -2,7 +2,6 @@ use std::error::Error;
use std::convert::AsRef; use std::convert::AsRef;
use std::ffi::OsString; use std::ffi::OsString;
use std::sync::mpsc; use std::sync::mpsc;
use std::string::FromUtf8Error;
use strum_macros::AsRefStr; use strum_macros::AsRefStr;
use thiserror::Error as ThisError; use thiserror::Error as ThisError;
@ -18,22 +17,15 @@ use tauri::api::dialog::{
MessageDialogBuilder, MessageDialogBuilder,
MessageDialogKind, MessageDialogKind,
}; };
use serde::{ use serde::{Serialize, Serializer, ser::SerializeMap};
Serialize,
Serializer,
ser::SerializeMap,
Deserialize,
};
pub trait ShowError { pub trait ErrorPopup {
fn error_popup(self, title: &str); fn error_popup(self, title: &str);
fn error_popup_nowait(self, title: &str); fn error_popup_nowait(self, title: &str);
fn error_print(self);
fn error_print_prefix(self, prefix: &str);
} }
impl<E: std::fmt::Display> ShowError for Result<(), E> { impl<E: std::fmt::Display> ErrorPopup for Result<(), E> {
fn error_popup(self, title: &str) { fn error_popup(self, title: &str) {
if let Err(e) = self { if let Err(e) = self {
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
@ -52,18 +44,6 @@ impl<E: std::fmt::Display> ShowError for Result<(), E> {
.show(|_| {}) .show(|_| {})
} }
} }
fn error_print(self) {
if let Err(e) = self {
eprintln!("{e}");
}
}
fn error_print_prefix(self, prefix: &str) {
if let Err(e) = self {
eprintln!("{prefix}: {e}");
}
}
} }
@ -157,14 +137,12 @@ pub enum SendResponseError {
pub enum HandlerError { pub enum HandlerError {
#[error("Error writing to stream: {0}")] #[error("Error writing to stream: {0}")]
StreamIOError(#[from] std::io::Error), StreamIOError(#[from] std::io::Error),
#[error("Received invalid UTF-8 in request")] // #[error("Received invalid UTF-8 in request")]
InvalidUtf8(#[from] FromUtf8Error), // InvalidUtf8,
#[error("HTTP request malformed")] #[error("HTTP request malformed")]
BadRequest(#[from] serde_json::Error), BadRequest(Vec<u8>),
#[error("HTTP request too large")] #[error("HTTP request too large")]
RequestTooLarge, RequestTooLarge,
#[error("Internal server error")]
Internal,
#[error("Error accessing credentials: {0}")] #[error("Error accessing credentials: {0}")]
NoCredentials(#[from] GetCredentialsError), NoCredentials(#[from] GetCredentialsError),
#[error("Error getting client details: {0}")] #[error("Error getting client details: {0}")]
@ -173,17 +151,6 @@ pub enum HandlerError {
Tauri(#[from] tauri::Error), Tauri(#[from] tauri::Error),
#[error("No main application window found")] #[error("No main application window found")]
NoMainWindow, NoMainWindow,
#[error("Request was denied")]
Denied,
}
#[derive(Debug, ThisError, AsRefStr)]
pub enum WindowError {
#[error("Failed to find main application window")]
NoMainWindow,
#[error(transparent)]
ManageFailure(#[from] tauri::Error),
} }
@ -240,50 +207,26 @@ pub enum CryptoError {
pub enum ClientInfoError { pub enum ClientInfoError {
#[error("Found PID for client socket, but no corresponding process")] #[error("Found PID for client socket, but no corresponding process")]
ProcessNotFound, ProcessNotFound,
#[error("Could not determine parent PID of connected client")] #[error("Couldn't get client socket details: {0}")]
ParentPidNotFound, NetstatError(#[from] netstat2::error::Error),
#[error("Found PID for parent process of client, but no corresponding process")]
ParentProcessNotFound,
#[cfg(windows)]
#[error("Could not determine PID of connected client")]
WindowsError(#[from] windows::core::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
}
// Technically also an error, but formatted as a struct for easy deserialization
#[derive(Debug, Serialize, Deserialize)]
pub struct ServerError {
code: String,
msg: String,
}
impl std::fmt::Display for ServerError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(f, "{} ({})", self.msg, self.code)?;
Ok(())
}
} }
// Errors encountered while requesting credentials via CLI (creddy show, creddy exec) // Errors encountered while requesting credentials via CLI (creddy show, creddy exec)
#[derive(Debug, ThisError, AsRefStr)] #[derive(Debug, ThisError, AsRefStr)]
pub enum RequestError { pub enum RequestError {
#[error("Error response from server: {0}")] #[error("Credentials request failed: HTTP {0}")]
Server(ServerError), Failed(String),
#[error("Unexpected response from server")] #[error("Credentials request was rejected")]
Unexpected(crate::server::Response), Rejected,
#[error("Couldn't interpret the server's response")]
MalformedHttpResponse,
#[error("The server did not respond with valid JSON")] #[error("The server did not respond with valid JSON")]
InvalidJson(#[from] serde_json::Error), InvalidJson,
#[error("Error reading/writing stream: {0}")] #[error("Error reading/writing stream: {0}")]
StreamIOError(#[from] std::io::Error), StreamIOError(#[from] std::io::Error),
} #[error("Error loading configuration data: {0}")]
Setup(#[from] SetupError),
impl From<ServerError> for RequestError {
fn from(s: ServerError) -> Self {
Self::Server(s)
}
} }
@ -348,7 +291,6 @@ impl Serialize for SerializeWrapper<&GetSessionTokenError> {
impl_serialize_basic!(SetupError); impl_serialize_basic!(SetupError);
impl_serialize_basic!(GetCredentialsError); impl_serialize_basic!(GetCredentialsError);
impl_serialize_basic!(ClientInfoError); impl_serialize_basic!(ClientInfoError);
impl_serialize_basic!(WindowError);
impl Serialize for HandlerError { impl Serialize for HandlerError {
@ -356,6 +298,13 @@ impl Serialize for HandlerError {
let mut map = serializer.serialize_map(None)?; let mut map = serializer.serialize_map(None)?;
map.serialize_entry("code", self.as_ref())?; map.serialize_entry("code", self.as_ref())?;
map.serialize_entry("msg", &format!("{self}"))?; map.serialize_entry("msg", &format!("{self}"))?;
match self {
HandlerError::NoCredentials(src) => map.serialize_entry("source", &src)?,
HandlerError::ClientInfo(src) => map.serialize_entry("source", &src)?,
_ => serialize_upstream_err(self, &mut map)?,
}
map.end() map.end()
} }
} }
@ -404,8 +353,6 @@ impl Serialize for UnlockError {
match self { match self {
UnlockError::GetSession(src) => map.serialize_entry("source", &src)?, UnlockError::GetSession(src) => map.serialize_entry("source", &src)?,
// The string representation of the AEAD error is not very helpful, so skip it
UnlockError::Crypto(_src) => map.serialize_entry("source", &None::<&str>)?,
_ => serialize_upstream_err(self, &mut map)?, _ => serialize_upstream_err(self, &mut map)?,
} }
map.end() map.end()

View File

@ -10,9 +10,9 @@ use crate::terminal;
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AwsRequestNotification { pub struct Request {
pub id: u64, pub id: u64,
pub client: Client, pub clients: Vec<Option<Client>>,
pub base: bool, pub base: bool,
} }

View File

@ -7,6 +7,5 @@ mod clientinfo;
mod ipc; mod ipc;
mod state; mod state;
mod server; mod server;
mod shortcuts;
mod terminal; mod terminal;
mod tray; mod tray;

View File

@ -6,7 +6,7 @@
use creddy::{ use creddy::{
app, app,
cli, cli,
errors::ShowError, errors::ErrorPopup,
}; };
@ -16,14 +16,12 @@ fn main() {
app::run().error_popup("Creddy failed to start"); app::run().error_popup("Creddy failed to start");
Ok(()) Ok(())
}, },
Some(("get", m)) => cli::get(m), Some(("show", m)) => cli::show(m),
Some(("exec", m)) => cli::exec(m), Some(("exec", m)) => cli::exec(m),
Some(("shortcut", m)) => cli::invoke_shortcut(m),
_ => unreachable!(), _ => unreachable!(),
}; };
if let Err(e) = res { if let Err(e) = res {
eprintln!("Error: {e}"); eprintln!("Error: {e}");
std::process::exit(1);
} }
} }

275
src-tauri/src/server.rs Normal file
View File

@ -0,0 +1,275 @@
use core::time::Duration;
use std::io;
use std::net::{
Ipv4Addr,
SocketAddr,
SocketAddrV4,
};
use tokio::net::{
TcpListener,
TcpStream,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::oneshot::{self, Sender, Receiver};
use tokio::time::sleep;
use tauri::{AppHandle, Manager};
use tauri::async_runtime as rt;
use tauri::async_runtime::JoinHandle;
use crate::{clientinfo, clientinfo::Client};
use crate::errors::*;
use crate::ipc::{Request, Approval};
use crate::state::AppState;
#[derive(Debug)]
pub struct RequestWaiter {
pub rehide_after: bool,
pub sender: Option<Sender<Approval>>,
}
impl RequestWaiter {
pub fn notify(&mut self, approval: Approval) -> Result<(), SendResponseError> {
let chan = self.sender
.take()
.ok_or(SendResponseError::Fulfilled)?;
chan.send(approval)
.map_err(|_| SendResponseError::Abandoned)
}
}
struct Handler {
request_id: u64,
stream: TcpStream,
rehide_after: bool,
receiver: Option<Receiver<Approval>>,
app: AppHandle,
}
impl Handler {
async fn new(stream: TcpStream, app: AppHandle) -> Result<Self, HandlerError> {
let state = app.state::<AppState>();
// determine whether we should re-hide the window after handling this request
let is_currently_visible = app.get_window("main")
.ok_or(HandlerError::NoMainWindow)?
.is_visible()?;
let rehide_after = state.current_rehide_status()
.await
.unwrap_or(!is_currently_visible);
let (chan_send, chan_recv) = oneshot::channel();
let waiter = RequestWaiter {rehide_after, sender: Some(chan_send)};
let request_id = state.register_request(waiter).await;
let handler = Handler {
request_id,
stream,
rehide_after,
receiver: Some(chan_recv),
app
};
Ok(handler)
}
async fn handle(mut self) {
if let Err(e) = self.try_handle().await {
eprintln!("{e}");
}
let state = self.app.state::<AppState>();
state.unregister_request(self.request_id).await;
}
async fn try_handle(&mut self) -> Result<(), HandlerError> {
let req_path = self.recv_request().await?;
let clients = self.get_clients().await?;
if self.includes_banned(&clients).await {
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(())
}
let base = req_path == b"/creddy/base-credentials";
let req = Request {id: self.request_id, clients, base};
self.app.emit_all("credentials-request", &req)?;
self.show_window()?;
match self.wait_for_response().await? {
Approval::Approved => {
let state = self.app.state::<AppState>();
let creds = if base {
state.serialize_base_creds().await?
}
else {
state.serialize_session_creds().await?
};
self.send_body(creds.as_bytes()).await?;
},
Approval::Denied => {
let state = self.app.state::<AppState>();
for client in req.clients {
state.add_ban(client).await;
}
self.send_body(b"Denied!").await?;
self.stream.shutdown().await?;
}
}
// only hide the window if a) it was hidden to start with
// and b) there are no other pending requests
let state = self.app.state::<AppState>();
let delay = {
let config = state.config.read().await;
Duration::from_millis(config.rehide_ms)
};
sleep(delay).await;
if self.rehide_after && state.req_count().await == 1 {
self.app
.get_window("main")
.ok_or(HandlerError::NoMainWindow)?
.hide()?;
}
Ok(())
}
async fn recv_request(&mut self) -> Result<Vec<u8>, HandlerError> {
let mut buf = vec![0; 8192]; // it's what tokio's BufReader uses
let mut n = 0;
loop {
n += self.stream.read(&mut buf[n..]).await?;
if n >= 4 && &buf[(n - 4)..n] == b"\r\n\r\n" {break;}
if n == buf.len() {return Err(HandlerError::RequestTooLarge);}
}
let path = buf.split(|&c| &[c] == b" ")
.skip(1)
.next()
.ok_or(HandlerError::BadRequest(buf.clone()))?;
#[cfg(debug_assertions)] {
println!("Path: {}", std::str::from_utf8(&path).unwrap());
println!("{}", std::str::from_utf8(&buf).unwrap());
}
Ok(path.into())
}
async fn get_clients(&self) -> Result<Vec<Option<Client>>, HandlerError> {
let peer_addr = match self.stream.peer_addr()? {
SocketAddr::V4(addr) => addr,
_ => unreachable!(), // we only listen on IPv4
};
let clients = clientinfo::get_clients(peer_addr.port()).await?;
Ok(clients)
}
async fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
let state = self.app.state::<AppState>();
for client in clients {
if state.is_banned(client).await {
return true;
}
}
false
}
fn show_window(&self) -> Result<(), HandlerError> {
let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?;
if !window.is_visible()? {
window.unminimize()?;
window.show()?;
}
window.set_focus()?;
Ok(())
}
async fn wait_for_response(&mut self) -> Result<Approval, HandlerError> {
self.stream.write(b"HTTP/1.0 200 OK\r\n").await?;
self.stream.write(b"Content-Type: application/json\r\n").await?;
self.stream.write(b"X-Creddy-delaying-tactic: ").await?;
#[allow(unreachable_code)] // seems necessary for type inference
let stall = async {
let delay = std::time::Duration::from_secs(1);
loop {
tokio::time::sleep(delay).await;
self.stream.write(b"x").await?;
}
Ok(Approval::Denied)
};
// this is the only place we even read this field, so it's safe to unwrap
let receiver = self.receiver.take().unwrap();
tokio::select!{
r = receiver => Ok(r.unwrap()), // only panics if the sender is dropped without sending, which shouldn't be possible
e = stall => e,
}
}
async fn send_body(&mut self, body: &[u8]) -> Result<(), HandlerError> {
self.stream.write(b"\r\nContent-Length: ").await?;
self.stream.write(body.len().to_string().as_bytes()).await?;
self.stream.write(b"\r\n\r\n").await?;
self.stream.write(body).await?;
self.stream.shutdown().await?;
Ok(())
}
}
#[derive(Debug)]
pub struct Server {
addr: Ipv4Addr,
port: u16,
app_handle: AppHandle,
task: JoinHandle<()>,
}
impl Server {
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
let task = Self::start_server(addr, port, app_handle.app_handle()).await?;
Ok(Server { addr, port, app_handle, task})
}
pub async fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
if addr == self.addr && port == self.port {
return Ok(())
}
let new_task = Self::start_server(addr, port, self.app_handle.app_handle()).await?;
self.task.abort();
self.addr = addr;
self.port = port;
self.task = new_task;
Ok(())
}
// construct the listener before spawning the task so that we can return early if it fails
async fn start_server(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<JoinHandle<()>> {
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(
Self::serve(listener, app_handle.app_handle())
);
Ok(task)
}
async fn serve(listener: TcpListener, app_handle: AppHandle) {
loop {
match listener.accept().await {
Ok((stream, _)) => {
match Handler::new(stream, app_handle.app_handle()).await {
Ok(handler) => { rt::spawn(handler.handle()); }
Err(e) => { eprintln!("Error handling request: {e}"); }
}
},
Err(e) => { eprintln!("Error accepting connection: {e}"); }
}
}
}
}

View File

@ -1,126 +0,0 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::oneshot;
use serde::{Serialize, Deserialize};
use tauri::{AppHandle, Manager};
use crate::errors::*;
use crate::clientinfo::{self, Client};
use crate::credentials::Credentials;
use crate::ipc::{Approval, AwsRequestNotification};
use crate::state::AppState;
use crate::shortcuts::{self, ShortcutAction};
#[cfg(windows)]
mod server_win;
#[cfg(windows)]
pub use server_win::Server;
#[cfg(windows)]
use server_win::Stream;
#[cfg(unix)]
mod server_unix;
#[cfg(unix)]
pub use server_unix::Server;
#[cfg(unix)]
use server_unix::Stream;
#[derive(Serialize, Deserialize)]
pub enum Request {
GetAwsCredentials{
base: bool,
},
InvokeShortcut(ShortcutAction),
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Response {
Aws(Credentials),
Empty,
}
async fn handle(mut stream: Stream, app_handle: AppHandle, client_pid: u32) -> Result<(), HandlerError>
{
// read from stream until delimiter is reached
let mut buf: Vec<u8> = Vec::with_capacity(1024); // requests are small, 1KiB is more than enough
let mut n = 0;
loop {
n += stream.read_buf(&mut buf).await?;
if let Some(&b'\n') = buf.last() {
break;
}
else if n >= 1024 {
return Err(HandlerError::RequestTooLarge);
}
}
let client = clientinfo::get_process_parent_info(client_pid)?;
let req: Request = serde_json::from_slice(&buf)?;
let res = match req {
Request::GetAwsCredentials{ base } => get_aws_credentials(base, client, app_handle).await,
Request::InvokeShortcut(action) => invoke_shortcut(action).await,
};
let res = serde_json::to_vec(&res).unwrap();
stream.write_all(&res).await?;
Ok(())
}
async fn invoke_shortcut(action: ShortcutAction) -> Result<Response, HandlerError> {
shortcuts::exec_shortcut(action);
Ok(Response::Empty)
}
async fn get_aws_credentials(base: bool, client: Client, app_handle: AppHandle) -> Result<Response, HandlerError> {
let state = app_handle.state::<AppState>();
let rehide_ms = {
let config = state.config.read().await;
config.rehide_ms
};
let lease = state.acquire_visibility_lease(rehide_ms).await
.map_err(|_e| HandlerError::NoMainWindow)?; // automate this conversion eventually?
let (chan_send, chan_recv) = oneshot::channel();
let request_id = state.register_request(chan_send).await;
// if an error occurs in any of the following, we want to abort the operation
// but ? returns immediately, and we want to unregister the request before returning
// so we bundle it all up in an async block and return a Result so we can handle errors
let proceed = async {
let notification = AwsRequestNotification {id: request_id, client, base};
app_handle.emit_all("credentials-request", &notification)?;
match chan_recv.await {
Ok(Approval::Approved) => {
if base {
let creds = state.base_creds_cloned().await?;
Ok(Response::Aws(Credentials::Base(creds)))
}
else {
let creds = state.session_creds_cloned().await?;
Ok(Response::Aws(Credentials::Session(creds)))
}
},
Ok(Approval::Denied) => Err(HandlerError::Denied),
Err(_e) => Err(HandlerError::Internal),
}
};
let result = match proceed.await {
Ok(r) => Ok(r),
Err(e) => {
state.unregister_request(request_id).await;
Err(e)
}
};
lease.release();
result
}

View File

@ -1,59 +0,0 @@
use std::io::ErrorKind;
use tokio::net::{UnixListener, UnixStream};
use tauri::{
AppHandle,
Manager,
async_runtime as rt,
};
use crate::errors::*;
pub type Stream = UnixStream;
pub struct Server {
listener: UnixListener,
app_handle: AppHandle,
}
impl Server {
pub fn start(app_handle: AppHandle) -> std::io::Result<()> {
match std::fs::remove_file("/tmp/creddy.sock") {
Ok(_) => (),
Err(e) if e.kind() == ErrorKind::NotFound => (),
Err(e) => return Err(e),
}
let listener = UnixListener::bind("/tmp/creddy.sock")?;
let srv = Server { listener, app_handle };
rt::spawn(srv.serve());
Ok(())
}
async fn serve(self) {
loop {
self.try_serve()
.await
.error_print_prefix("Error accepting request: ");
}
}
async fn try_serve(&self) -> Result<(), HandlerError> {
let (stream, _addr) = self.listener.accept().await?;
let new_handle = self.app_handle.app_handle();
let client_pid = get_client_pid(&stream)?;
rt::spawn(async move {
super::handle(stream, new_handle, client_pid)
.await
.error_print_prefix("Error responding to request: ");
});
Ok(())
}
}
fn get_client_pid(stream: &UnixStream) -> std::io::Result<u32> {
let cred = stream.peer_cred()?;
Ok(cred.pid().unwrap() as u32)
}

View File

@ -1,75 +0,0 @@
use tokio::{
net::windows::named_pipe::{
NamedPipeServer,
ServerOptions,
},
sync::oneshot,
};
use windows::Win32:: {
Foundation::HANDLE,
System::Pipes::GetNamedPipeClientProcessId,
};
use std::os::windows::io::AsRawHandle;
use tauri::async_runtime as rt;
use crate::errors::*;
// used by parent module
pub type Stream = NamedPipeServer;
pub struct Server {
listener: NamedPipeServer,
app_handle: AppHandle,
}
impl Server {
pub fn start(app_handle: AppHandle) -> std::io::Result<()> {
let listener = ServerOptions::new()
.first_pipe_instance(true)
.create(r"\\.\pipe\creddy-requests")?;
let srv = Server {listener, app_handle};
rt::spawn(srv.serve());
Ok(())
}
async fn serve(mut self) {
loop {
if let Err(e) = self.try_serve().await {
eprintln!("Error accepting connection: {e}");
}
}
}
async fn try_serve(&mut self) -> Result<(), HandlerError> {
// connect() just waits for a client to connect, it doesn't return anything
self.listener.connect().await?;
// create a new pipe instance to listen for the next client, and swap it in
let new_listener = ServerOptions::new().create(r"\\.\pipe\creddy-requests")?;
let mut stream = std::mem::replace(&mut self.listener, new_listener);
let new_handle = self.app_handle.app_handle();
let client_pid = get_client_pid(&stream)?;
rt::spawn(async move {
super::handle(stream, app_handle)
.await
.error_print_prefix("Error responding to request: ");
});
Ok(())
}
}
fn get_client_pid(pipe: &NamedPipeServer) -> Result<u32, ClientInfoError> {
let raw_handle = pipe.as_raw_handle();
let mut pid = 0u32;
let handle = HANDLE(raw_handle as _);
unsafe { GetNamedPipeClientProcessId(handle, &mut pid as *mut u32)? };
pid
}

View File

@ -1,60 +0,0 @@
use serde::{Serialize, Deserialize};
use tauri::{
GlobalShortcutManager,
Manager,
async_runtime as rt,
};
use crate::app::APP;
use crate::config::HotkeysConfig;
use crate::errors::*;
use crate::terminal;
#[derive(Debug, Serialize, Deserialize)]
pub enum ShortcutAction {
ShowWindow,
LaunchTerminal,
}
pub fn exec_shortcut(action: ShortcutAction) {
match action {
ShortcutAction::ShowWindow => {
let app = APP.get().unwrap();
app.get_window("main")
.ok_or("Couldn't find application main window")
.map(|w| w.show().error_popup("Failed to show window"))
.error_popup("Failed to show window");
},
ShortcutAction::LaunchTerminal => {
rt::spawn(async {
terminal::launch(false).await.error_popup("Failed to launch terminal");
});
},
}
}
pub fn register_hotkeys(hotkeys: &HotkeysConfig) -> tauri::Result<()> {
let app = APP.get().unwrap();
let mut manager = app.global_shortcut_manager();
manager.unregister_all()?;
if hotkeys.show_window.enabled {
manager.register(
&hotkeys.show_window.keys,
|| exec_shortcut(ShortcutAction::ShowWindow)
)?;
}
if hotkeys.launch_terminal.enabled {
manager.register(
&hotkeys.launch_terminal.keys,
|| exec_shortcut(ShortcutAction::LaunchTerminal)
)?;
}
Ok(())
}

View File

@ -1,16 +1,15 @@
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::time::Duration; use std::time::Duration;
use tokio::{ use tokio::{
sync::RwLock, sync::RwLock,
sync::oneshot::{self, Sender}, time::sleep,
}; };
use sqlx::SqlitePool; use sqlx::SqlitePool;
use tauri::{ use tauri::async_runtime as runtime;
Manager, use tauri::Manager;
async_runtime as rt,
};
use crate::app::APP;
use crate::credentials::{ use crate::credentials::{
Session, Session,
BaseCredentials, BaseCredentials,
@ -18,74 +17,9 @@ use crate::credentials::{
}; };
use crate::{config, config::AppConfig}; use crate::{config, config::AppConfig};
use crate::ipc::{self, Approval}; use crate::ipc::{self, Approval};
use crate::clientinfo::Client;
use crate::errors::*; use crate::errors::*;
use crate::shortcuts; use crate::server::{Server, RequestWaiter};
#[derive(Debug)]
struct Visibility {
leases: usize,
original: Option<bool>,
}
impl Visibility {
fn new() -> Self {
Visibility { leases: 0, original: None }
}
fn acquire(&mut self, delay_ms: u64) -> Result<VisibilityLease, WindowError> {
let app = crate::app::APP.get().unwrap();
let window = app.get_window("main")
.ok_or(WindowError::NoMainWindow)?;
self.leases += 1;
if self.original.is_none() {
let is_visible = window.is_visible()?;
self.original = Some(is_visible);
if !is_visible {
window.show()?;
}
}
window.set_focus()?;
let (tx, rx) = oneshot::channel();
let lease = VisibilityLease { notify: tx };
let delay = Duration::from_millis(delay_ms);
let handle = app.app_handle();
rt::spawn(async move {
// We don't care if it's an error; lease being dropped should be handled identically
let _ = rx.await;
tokio::time::sleep(delay).await;
// we can't use `self` here because we would have to move it into the async block
let state = handle.state::<AppState>();
let mut visibility = state.visibility.write().await;
visibility.leases -= 1;
if visibility.leases == 0 {
if let Some(false) = visibility.original {
window.hide().error_print();
}
visibility.original = None;
}
});
Ok(lease)
}
}
pub struct VisibilityLease {
notify: Sender<()>,
}
impl VisibilityLease {
pub fn release(self) {
rt::spawn(async move {
if let Err(_) = self.notify.send(()) {
eprintln!("Error releasing visibility lease")
}
});
}
}
#[derive(Debug)] #[derive(Debug)]
@ -93,18 +27,20 @@ pub struct AppState {
pub config: RwLock<AppConfig>, pub config: RwLock<AppConfig>,
pub session: RwLock<Session>, pub session: RwLock<Session>,
pub request_count: RwLock<u64>, pub request_count: RwLock<u64>,
pub waiting_requests: RwLock<HashMap<u64, Sender<Approval>>>, pub waiting_requests: RwLock<HashMap<u64, RequestWaiter>>,
pub pending_terminal_request: RwLock<bool>, pub pending_terminal_request: RwLock<bool>,
pub bans: RwLock<std::collections::HashSet<Option<Client>>>,
// setup_errors is never modified and so doesn't need to be wrapped in RwLock // setup_errors is never modified and so doesn't need to be wrapped in RwLock
pub setup_errors: Vec<String>, pub setup_errors: Vec<String>,
server: RwLock<Server>,
pool: sqlx::SqlitePool, pool: sqlx::SqlitePool,
visibility: RwLock<Visibility>,
} }
impl AppState { impl AppState {
pub fn new( pub fn new(
config: AppConfig, config: AppConfig,
session: Session, session: Session,
server: Server,
pool: SqlitePool, pool: SqlitePool,
setup_errors: Vec<String>, setup_errors: Vec<String>,
) -> AppState { ) -> AppState {
@ -114,9 +50,10 @@ impl AppState {
request_count: RwLock::new(0), request_count: RwLock::new(0),
waiting_requests: RwLock::new(HashMap::new()), waiting_requests: RwLock::new(HashMap::new()),
pending_terminal_request: RwLock::new(false), pending_terminal_request: RwLock::new(false),
bans: RwLock::new(HashSet::new()),
setup_errors, setup_errors,
server: RwLock::new(server),
pool, pool,
visibility: RwLock::new(Visibility::new()),
} }
} }
@ -136,12 +73,18 @@ impl AppState {
if new_config.start_on_login != live_config.start_on_login { if new_config.start_on_login != live_config.start_on_login {
config::set_auto_launch(new_config.start_on_login)?; config::set_auto_launch(new_config.start_on_login)?;
} }
// rebind socket if necessary
if new_config.listen_addr != live_config.listen_addr
|| new_config.listen_port != live_config.listen_port
{
let mut sv = self.server.write().await;
sv.rebind(new_config.listen_addr, new_config.listen_port).await?;
}
// re-register hotkeys if necessary // re-register hotkeys if necessary
if new_config.hotkeys.show_window != live_config.hotkeys.show_window if new_config.hotkeys.show_window != live_config.hotkeys.show_window
|| new_config.hotkeys.launch_terminal != live_config.hotkeys.launch_terminal || new_config.hotkeys.launch_terminal != live_config.hotkeys.launch_terminal
{ {
shortcuts::register_hotkeys(&new_config.hotkeys)?; config::register_hotkeys(&new_config.hotkeys)?;
} }
new_config.save(&self.pool).await?; new_config.save(&self.pool).await?;
@ -149,7 +92,7 @@ impl AppState {
Ok(()) Ok(())
} }
pub async fn register_request(&self, sender: Sender<Approval>) -> u64 { pub async fn register_request(&self, waiter: RequestWaiter) -> u64 {
let count = { let count = {
let mut c = self.request_count.write().await; let mut c = self.request_count.write().await;
*c += 1; *c += 1;
@ -157,7 +100,7 @@ impl AppState {
}; };
let mut waiting_requests = self.waiting_requests.write().await; let mut waiting_requests = self.waiting_requests.write().await;
waiting_requests.insert(*count, sender); // `count` is the request id waiting_requests.insert(*count, waiter); // `count` is the request id
*count *count
} }
@ -166,9 +109,16 @@ impl AppState {
waiting_requests.remove(&id); waiting_requests.remove(&id);
} }
pub async fn acquire_visibility_lease(&self, delay: u64) -> Result<VisibilityLease, WindowError> { pub async fn req_count(&self) -> usize {
let mut visibility = self.visibility.write().await; let waiting_requests = self.waiting_requests.read().await;
visibility.acquire(delay) waiting_requests.len()
}
pub async fn current_rehide_status(&self) -> Option<bool> {
// since all requests that are pending at a given time should have the same
// value for rehide_after, it doesn't matter which one we use
let waiting_requests = self.waiting_requests.read().await;
waiting_requests.iter().next().map(|(_id, w)| w.rehide_after)
} }
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
@ -179,10 +129,26 @@ impl AppState {
let mut waiting_requests = self.waiting_requests.write().await; let mut waiting_requests = self.waiting_requests.write().await;
waiting_requests waiting_requests
.remove(&response.id) .get_mut(&response.id)
.ok_or(SendResponseError::NotFound)? .ok_or(SendResponseError::NotFound)?
.send(response.approval) .notify(response.approval)
.map_err(|_| SendResponseError::Abandoned) }
pub async fn add_ban(&self, client: Option<Client>) {
let mut bans = self.bans.write().await;
bans.insert(client.clone());
runtime::spawn(async move {
sleep(Duration::from_secs(5)).await;
let app = APP.get().unwrap();
let state = app.state::<AppState>();
let mut bans = state.bans.write().await;
bans.remove(&client);
});
}
pub async fn is_banned(&self, client: &Option<Client>) -> bool {
self.bans.read().await.contains(&client)
} }
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> { pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
@ -202,16 +168,16 @@ impl AppState {
matches!(*session, Session::Unlocked{..}) matches!(*session, Session::Unlocked{..})
} }
pub async fn base_creds_cloned(&self) -> Result<BaseCredentials, GetCredentialsError> { pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
let app_session = self.session.read().await; let app_session = self.session.read().await;
let (base, _session) = app_session.try_get()?; let (base, _session) = app_session.try_get()?;
Ok(base.clone()) Ok(serde_json::to_string(base).unwrap())
} }
pub async fn session_creds_cloned(&self) -> Result<SessionCredentials, GetCredentialsError> { pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
let app_session = self.session.read().await; let app_session = self.session.read().await;
let (_bsae, session) = app_session.try_get()?; let (_bsae, session) = app_session.try_get()?;
Ok(session.clone()) Ok(serde_json::to_string(session).unwrap())
} }
async fn new_session(&self, base: BaseCredentials) -> Result<(), GetSessionError> { async fn new_session(&self, base: BaseCredentials) -> Result<(), GetSessionError> {

View File

@ -26,8 +26,13 @@ pub async fn launch(use_base: bool) -> Result<(), LaunchTerminalError> {
// if session is unlocked or empty, wait for credentials from frontend // if session is unlocked or empty, wait for credentials from frontend
if !state.is_unlocked().await { if !state.is_unlocked().await {
app.emit_all("launch-terminal-request", ())?; app.emit_all("launch-terminal-request", ())?;
let lease = state.acquire_visibility_lease(0).await let window = app.get_window("main")
.map_err(|_e| LaunchTerminalError::NoMainWindow)?; // automate conversion eventually? .ok_or(LaunchTerminalError::NoMainWindow)?;
if !window.is_visible()? {
window.unminimize()?;
window.show()?;
}
window.set_focus()?;
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
app.once_global("credentials-event", move |e| { app.once_global("credentials-event", move |e| {
@ -42,7 +47,6 @@ pub async fn launch(use_base: bool) -> Result<(), LaunchTerminalError> {
state.unregister_terminal_request().await; state.unregister_terminal_request().await;
return Ok(()); // request was canceled by user return Ok(()); // request was canceled by user
} }
lease.release();
} }
// more lock-management // more lock-management
@ -59,7 +63,7 @@ pub async fn launch(use_base: bool) -> Result<(), LaunchTerminalError> {
else { else {
cmd.env("AWS_ACCESS_KEY_ID", &session_creds.access_key_id); cmd.env("AWS_ACCESS_KEY_ID", &session_creds.access_key_id);
cmd.env("AWS_SECRET_ACCESS_KEY", &session_creds.secret_access_key); cmd.env("AWS_SECRET_ACCESS_KEY", &session_creds.secret_access_key);
cmd.env("AWS_SESSION_TOKEN", &session_creds.session_token); cmd.env("AWS_SESSION_TOKEN", &session_creds.token);
} }
} }

View File

@ -10,21 +10,15 @@
export let min = null; export let min = null;
export let max = null; export let max = null;
export let decimal = false; export let decimal = false;
export let debounceInterval = 0;
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
$: localValue = value.toString(); $: localValue = value.toString();
let lastInputTime = null; let lastInputTime = null;
function debounce(event) { function debounce(event) {
lastInputTime = Date.now();
localValue = localValue.replace(/[^-0-9.]/g, ''); localValue = localValue.replace(/[^-0-9.]/g, '');
if (debounceInterval === 0) {
updateValue(localValue);
return;
}
lastInputTime = Date.now();
const eventTime = lastInputTime; const eventTime = lastInputTime;
const pendingValue = localValue; const pendingValue = localValue;
window.setTimeout( window.setTimeout(
@ -34,7 +28,7 @@
updateValue(pendingValue); updateValue(pendingValue);
} }
}, },
debounceInterval, 500
) )
} }

View File

@ -47,13 +47,16 @@
} }
// Extract executable name from full path // Extract executable name from full path
const client = $appState.currentRequest.client; let appName = null;
const m = client.exe?.match(/\/([^/]+?$)|\\([^\\]+?$)/); if ($appState.currentRequest.clients.length === 1) {
const appName = m[1] || m[2]; let path = $appState.currentRequest.clients[0].exe;
let m = path.match(/\/([^/]+?$)|\\([^\\]+?$)/);
appName = m[1] || m[2];
}
// Executable paths can be long, so ensure they only break on \ or / // Executable paths can be long, so ensure they only break on \ or /
function breakPath(path) { function breakPath(client) {
return path.replace(/(\\|\/)/g, '$1<wbr>'); return client.exe.replace(/(\\|\/)/g, '$1<wbr>');
} }
// if the request has already been approved/denied, send response immediately // if the request has already been approved/denied, send response immediately
@ -94,10 +97,12 @@
<h2 class="text-xl font-bold">{appName ? `"${appName}"` : 'An appplication'} would like to access your AWS credentials.</h2> <h2 class="text-xl font-bold">{appName ? `"${appName}"` : 'An appplication'} would like to access your AWS credentials.</h2>
<div class="grid grid-cols-[auto_1fr] gap-x-3"> <div class="grid grid-cols-[auto_1fr] gap-x-3">
<div class="text-right">Path:</div> {#each $appState.currentRequest.clients as client}
<code class="">{@html client.exe ? breakPath(client.exe) : 'Unknown'}</code> <div class="text-right">Path:</div>
<div class="text-right">PID:</div> <code class="">{@html client ? breakPath(client) : 'Unknown'}</code>
<code>{client.pid}</code> <div class="text-right">PID:</div>
<code>{client ? client.pid : 'Unknown'}</code>
{/each}
</div> </div>
</div> </div>

View File

@ -1,7 +1,6 @@
<script> <script>
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import { invoke } from '@tauri-apps/api/tauri'; import { invoke } from '@tauri-apps/api/tauri';
import { emit } from '@tauri-apps/api/event';
import { getRootCause } from '../lib/errors.js'; import { getRootCause } from '../lib/errors.js';
import { appState } from '../lib/state.js'; import { appState } from '../lib/state.js';

View File

@ -14,18 +14,15 @@
import { backInOut } from 'svelte/easing'; import { backInOut } from 'svelte/easing';
// make an independent copy so it can differ from the main config object
let config = JSON.parse(JSON.stringify($appState.config));
$: configModified = JSON.stringify(config) !== JSON.stringify($appState.config);
let error = null; let error = null;
async function save() { async function save() {
console.log('updating config');
try { try {
await invoke('save_config', {config}); await invoke('save_config', {config: $appState.config});
$appState.config = await invoke('get_config');
} }
catch (e) { catch (e) {
error = e; error = e;
$appState.config = await invoke('get_config');
} }
} }
@ -38,60 +35,74 @@
<h1 slot="title" class="text-2xl font-bold">Settings</h1> <h1 slot="title" class="text-2xl font-bold">Settings</h1>
</Nav> </Nav>
<div class="max-w-lg mx-auto mt-1.5 mb-24 p-4 space-y-16"> {#await invoke('get_config') then config}
<SettingsGroup name="General"> <div class="max-w-lg mx-auto mt-1.5 p-4 space-y-16">
<ToggleSetting title="Start on login" bind:value={config.start_on_login}> <SettingsGroup name="General">
<svelte:fragment slot="description"> <ToggleSetting title="Start on login" bind:value={$appState.config.start_on_login} on:update={save}>
Start Creddy when you log in to your computer. <svelte:fragment slot="description">
</svelte:fragment> Start Creddy when you log in to your computer.
</ToggleSetting> </svelte:fragment>
</ToggleSetting>
<ToggleSetting title="Start minimized" bind:value={config.start_minimized}> <ToggleSetting title="Start minimized" bind:value={$appState.config.start_minimized} on:update={save}>
<svelte:fragment slot="description"> <svelte:fragment slot="description">
Minimize to the system tray at startup. Minimize to the system tray at startup.
</svelte:fragment> </svelte:fragment>
</ToggleSetting> </ToggleSetting>
<NumericSetting title="Re-hide delay" bind:value={config.rehide_ms} min={0} unit="Milliseconds"> <NumericSetting title="Re-hide delay" bind:value={$appState.config.rehide_ms} min={0} unit="Milliseconds" on:update={save}>
<svelte:fragment slot="description"> <svelte:fragment slot="description">
How long to wait after a request is approved/denied before minimizing How long to wait after a request is approved/denied before minimizing
the window to tray. Only applicable if the window was minimized the window to tray. Only applicable if the window was minimized
to tray before the request was received. to tray before the request was received.
</svelte:fragment> </svelte:fragment>
</NumericSetting> </NumericSetting>
<Setting title="Update credentials"> <NumericSetting
<Link slot="input" target="EnterCredentials"> title="Listen port"
<button class="btn btn-sm btn-primary">Update</button> bind:value={$appState.config.listen_port}
</Link> min={osType === 'Windows_NT' ? 1 : 0}
<svelte:fragment slot="description"> on:update={save}
Update or re-enter your encrypted credentials. >
</svelte:fragment> <svelte:fragment slot="description">
</Setting> Listen for credentials requests on this port.
(Should be used with <code>$AWS_CONTAINER_CREDENTIALS_FULL_URI</code>)
</svelte:fragment>
</NumericSetting>
<FileSetting <Setting title="Update credentials">
title="Terminal emulator" <Link slot="input" target="EnterCredentials">
bind:value={config.terminal.exec} <button class="btn btn-sm btn-primary">Update</button>
</Link>
<svelte:fragment slot="description">
Update or re-enter your encrypted credentials.
</svelte:fragment>
</Setting>
> <FileSetting
<svelte:fragment slot="description"> title="Terminal emulator"
Choose your preferred terminal emulator (e.g. <code>gnome-terminal</code> or <code>wt.exe</code>.) May be an absolute path or an executable discoverable on <code>$PATH</code>. bind:value={$appState.config.terminal.exec}
</svelte:fragment> on:update={save}
</FileSetting> >
</SettingsGroup> <svelte:fragment slot="description">
Choose your preferred terminal emulator (e.g. <code>gnome-terminal</code> or <code>wt.exe</code>.) May be an absolute path or an executable discoverable on <code>$PATH</code>.
</svelte:fragment>
</FileSetting>
</SettingsGroup>
<SettingsGroup name="Hotkeys"> <SettingsGroup name="Hotkeys">
<div class="space-y-4"> <div class="space-y-4">
<p>Click on a keybinding to modify it. Use the checkbox to enable or disable a keybinding entirely.</p> <p>Click on a keybinding to modify it. Use the checkbox to enable or disable a keybinding entirely.</p>
<div class="grid grid-cols-[auto_1fr_auto] gap-y-3 items-center"> <div class="grid grid-cols-[auto_1fr_auto] gap-y-3 items-center">
<Keybind description="Show Creddy" bind:value={config.hotkeys.show_window} /> <Keybind description="Show Creddy" value={$appState.config.hotkeys.show_window} on:update={save} />
<Keybind description="Launch terminal" bind:value={config.hotkeys.launch_terminal} /> <Keybind description="Launch terminal" value={$appState.config.hotkeys.launch_terminal} on:update={save} />
</div>
</div> </div>
</div> </SettingsGroup>
</SettingsGroup>
</div> </div>
{/await}
{#if error} {#if error}
<div transition:fly={{y: 100, easing: backInOut, duration: 400}} class="toast"> <div transition:fly={{y: 100, easing: backInOut, duration: 400}} class="toast">
@ -105,15 +116,4 @@
</div> </div>
</div> </div>
</div> </div>
{:else if configModified}
<div transition:fly={{y: 100, easing: backInOut, duration: 400}} class="toast">
<div class="alert shadow-lg no-animation">
<span>You have unsaved changes.</span>
<div>
<!-- <button class="btn btn-sm btn-ghost">Cancel</button> -->
<buton class="btn btn-sm btn-primary" on:click={save}>Save</buton>
</div>
</div>
</div>
{/if} {/if}