diff --git a/.gitignore b/.gitignore index d34ab6f..e1b2934 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ src-tauri/target/ # just in case credentials* +!credentials.rs diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 1868737..7f28c17 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -67,6 +67,7 @@ dependencies = [ "aws-sdk-sts", "aws-smithy-types", "aws-types", + "clap", "dirs 5.0.1", "netstat2", "once_cell", @@ -227,6 +228,17 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1181e1e0d1fce796a03db1ae795d67167da795f9cf4a39c37589e85ef57f26d3" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "auto-launch" version = "0.4.0" @@ -714,6 +726,45 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clap" +version = "3.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" +dependencies = [ + "atty", + "bitflags", + "clap_derive", + "clap_lex", + "indexmap", + "once_cell", + "strsim", + "termcolor", + "textwrap", +] + +[[package]] +name = "clap_derive" +version = "3.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008" +dependencies = [ + "heck 0.4.1", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "clap_lex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" +dependencies = [ + "os_str_bytes", +] + [[package]] name = "cocoa" version = "0.24.1" @@ -1735,6 +1786,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.2.6" @@ -2511,6 +2571,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "os_str_bytes" +version = "6.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" + [[package]] name = "outref" version = "0.1.0" @@ -4020,6 +4086,21 @@ dependencies = [ "utf-8", ] +[[package]] +name = "termcolor" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "textwrap" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" + [[package]] name = "thin-slice" version = "0.1.1" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 0193894..9369712 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -34,6 +34,7 @@ strum = "0.24" strum_macros = "0.24" auto-launch = "0.4.0" dirs = "5.0" +clap = { version = "3.2.23", features = ["derive"] } [features] # by default Tauri runs in production mode diff --git a/src-tauri/src/app.rs b/src-tauri/src/app.rs new file mode 100644 index 0000000..417c338 --- /dev/null +++ b/src-tauri/src/app.rs @@ -0,0 +1,92 @@ +#![cfg_attr( + all(not(debug_assertions), target_os = "windows"), + windows_subsystem = "windows" +)] + +use std::error::Error; + +use once_cell::sync::OnceCell; +use sqlx::{ + SqlitePool, + sqlite::SqlitePoolOptions, + sqlite::SqliteConnectOptions, +}; +use tauri::{ + App, + AppHandle, + Manager, + async_runtime as rt, +}; + +use crate::{ + config::{self, AppConfig}, + credentials::Session, + ipc, + server::Server, + errors::*, + state::AppState, + tray, +}; + + +pub static APP: OnceCell = OnceCell::new(); + + +async fn setup(app: &mut App) -> Result<(), Box> { + APP.set(app.handle()).unwrap(); + + let conn_opts = SqliteConnectOptions::new() + .filename(config::get_or_create_db_path()?) + .create_if_missing(true); + let pool_opts = SqlitePoolOptions::new(); + let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?; + sqlx::migrate!().run(&pool).await?; + + let conf = AppConfig::load(&pool).await?; + let session = Session::load(&pool).await?; + let srv = Server::new(conf.listen_addr, conf.listen_port, app.handle()).await?; + + config::set_auto_launch(conf.start_on_login)?; + if !conf.start_minimized { + app.get_window("main") + .ok_or(HandlerError::NoMainWindow)? + .show()?; + } + + let state = AppState::new(conf, session, srv, pool); + app.manage(state); + Ok(()) +} + + +pub fn run() -> tauri::Result<()> { + tauri::Builder::default() + .plugin(tauri_plugin_single_instance::init(|app, _argv, _cwd| { + app.get_window("main") + .map(|w| w.show().error_popup("Failed to show main window")); + })) + .system_tray(tray::create()) + .on_system_tray_event(tray::handle_event) + .invoke_handler(tauri::generate_handler![ + ipc::unlock, + ipc::respond, + ipc::get_session_status, + ipc::save_credentials, + ipc::get_config, + ipc::save_config, + ]) + .setup(|app| rt::block_on(setup(app))) + .build(tauri::generate_context!())? + .run(|app, run_event| match run_event { + tauri::RunEvent::WindowEvent { label, event, .. } => match event { + tauri::WindowEvent::CloseRequested { api, .. } => { + let _ = app.get_window(&label).map(|w| w.hide()); + api.prevent_close(); + } + _ => () + } + _ => () + }); + + Ok(()) +} \ No newline at end of file diff --git a/src-tauri/src/cli.rs b/src-tauri/src/cli.rs new file mode 100644 index 0000000..80b94a1 --- /dev/null +++ b/src-tauri/src/cli.rs @@ -0,0 +1,135 @@ +use std::io::{Read, Write}; +use std::process::Command as ChildCommand; +#[cfg(unix)] +use std::os::unix::process::CommandExt; + +use clap::{ + Command, + Arg, + ArgMatches, + ArgAction + }; +use std::net::TcpStream; + + +use crate::credentials::{BaseCredentials, SessionCredentials}; +use crate::errors::*; + + +pub fn parser() -> Command<'static> { + Command::new("creddy") + .about("A friendly AWS credentials manager") + .subcommand( + Command::new("run") + .about("Launch Creddy") + ) + .subcommand( + Command::new("show") + .about("Fetch and display AWS credentials") + .arg( + Arg::new("base") + .short('b') + .long("base") + .action(ArgAction::SetTrue) + .help("Use base credentials instead of session credentials") + ) + ) + .subcommand( + Command::new("exec") + .about("Inject AWS credentials into the environment of another command") + .trailing_var_arg(true) + .arg( + Arg::new("base") + .short('b') + .long("base") + .action(ArgAction::SetTrue) + .help("Use base credentials instead of session credentials") + ) + .arg( + Arg::new("command") + .multiple_values(true) + ) + ) +} + + +pub fn show(args: &ArgMatches) -> Result<(), CliError> { + let base = args.get_one("base").unwrap_or(&false); + let creds = get_credentials(*base)?; + println!("{creds}"); + Ok(()) +} + + +pub fn exec(args: &ArgMatches) -> Result<(), CliError> { + let base = *args.get_one("base").unwrap_or(&false); + let mut cmd_line = args.get_many("command") + .ok_or(ExecError::NoCommand)?; + + let cmd_name: &String = cmd_line.next().unwrap(); // Clap guarantees that there will be at least one + let mut cmd = ChildCommand::new(cmd_name); + cmd.args(cmd_line); + + if base { + let creds: BaseCredentials = serde_json::from_str(&get_credentials(base)?) + .map_err(|_| RequestError::InvalidJson)?; + cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id); + cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key); + } + else { + let creds: SessionCredentials = serde_json::from_str(&get_credentials(base)?) + .map_err(|_| RequestError::InvalidJson)?; + cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id); + cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key); + cmd.env("AWS_SESSION_TOKEN", creds.token); + } + + #[cfg(unix)] + cmd.exec().map_err(|e| ExecError::ExecutionFailed(e))?; + + #[cfg(windows)] + { + let mut child = cmd.spawn() + .map_err(|e| ExecError::ExecutionFailed(e))?; + let status = child.wait() + .map_err(|e| ExecError::ExecutionFailed(e))?; + std::process::exit(status.code().unwrap_or(1)); + }; +} + + +fn get_credentials(base: bool) -> Result { + let path = if base {"/creddy/base-credentials"} else {"/"}; + + let mut stream = TcpStream::connect("127.0.0.1:12345")?; + let req = format!("GET {path} HTTP/1.0\r\n\r\n"); + stream.write_all(req.as_bytes())?; + + // some day we'll have a proper HTTP parser + let mut buf = vec![0; 8192]; + stream.read_to_end(&mut buf)?; + + let status = buf.split(|&c| &[c] == b" ") + .skip(1) + .next() + .ok_or(RequestError::MalformedHttpResponse)?; + + if status != b"200" { + let s = String::from_utf8_lossy(status).to_string(); + return Err(RequestError::Failed(s)); + } + + let break_idx = buf.windows(4) + .position(|w| w == b"\r\n\r\n") + .ok_or(RequestError::MalformedHttpResponse)?; + let body = &buf[(break_idx + 4)..]; + + let creds_str = std::str::from_utf8(body) + .map_err(|_| RequestError::MalformedHttpResponse)? + .to_string(); + + if creds_str == "Denied!" { + return Err(RequestError::Rejected); + } + Ok(creds_str) +} diff --git a/src-tauri/src/clientinfo.rs b/src-tauri/src/clientinfo.rs index 6246fb0..f8d4b6f 100644 --- a/src-tauri/src/clientinfo.rs +++ b/src-tauri/src/clientinfo.rs @@ -1,9 +1,12 @@ +use std::path::PathBuf; + use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo}; use tauri::Manager; use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt}; use serde::{Serialize, Deserialize}; use crate::{ + app::APP, errors::*, config::AppConfig, state::AppState, @@ -13,12 +16,12 @@ use crate::{ #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] pub struct Client { pub pid: u32, - pub exe: String, + pub exe: PathBuf, } async fn get_associated_pids(local_port: u16) -> Result, netstat2::error::Error> { - let state = crate::APP.get().unwrap().state::(); + let state = APP.get().unwrap().state::(); let AppConfig { listen_addr: app_listen_addr, listen_port: app_listen_port, @@ -60,7 +63,7 @@ pub async fn get_clients(local_port: u16) -> Result>, ClientI let client = Client { pid: p, - exe: proc.exe().to_string_lossy().into_owned(), + exe: proc.exe().to_path_buf(), }; clients.push(Some(client)); } diff --git a/src-tauri/src/credentials.rs b/src-tauri/src/credentials.rs new file mode 100644 index 0000000..491fd0d --- /dev/null +++ b/src-tauri/src/credentials.rs @@ -0,0 +1,244 @@ +use std::fmt::{self, Formatter}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use aws_smithy_types::date_time::{DateTime, Format}; +use serde::{ + Serialize, + Deserialize, + Serializer, + Deserializer, +}; +use serde::de::{self, Visitor}; +use sqlx::SqlitePool; +use sodiumoxide::crypto::{ + pwhash, + pwhash::Salt, + secretbox, + secretbox::{Nonce, Key} +}; + +use crate::errors::*; + + +#[derive(Clone, Debug)] +pub enum Session { + Unlocked{ + base: BaseCredentials, + session: SessionCredentials, + }, + Locked(LockedCredentials), + Empty, +} + +impl Session { + pub async fn load(pool: &SqlitePool) -> Result { + let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc") + .fetch_optional(pool) + .await?; + let row = match res { + Some(r) => r, + None => {return Ok(Session::Empty);} + }; + + let salt_buf: [u8; 32] = row.salt + .try_into() + .map_err(|_e| SetupError::InvalidRecord)?; + let nonce_buf: [u8; 24] = row.nonce + .try_into() + .map_err(|_e| SetupError::InvalidRecord)?; + + let creds = LockedCredentials { + access_key_id: row.access_key_id, + secret_key_enc: row.secret_key_enc, + salt: Salt(salt_buf), + nonce: Nonce(nonce_buf), + }; + Ok(Session::Locked(creds)) + } + + pub async fn renew_if_expired(&mut self) -> Result { + match self { + Session::Unlocked{ref base, ref mut session} => { + if !session.is_expired() { + return Ok(false); + } + *session = SessionCredentials::from_base(base).await?; + Ok(true) + }, + Session::Locked(_) => Err(GetSessionError::CredentialsLocked), + Session::Empty => Err(GetSessionError::CredentialsEmpty), + } + } +} + + +#[derive(Clone, Debug)] +pub struct LockedCredentials { + pub access_key_id: String, + pub secret_key_enc: Vec, + pub salt: Salt, + pub nonce: Nonce, +} + +impl LockedCredentials { + pub async fn save(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> { + sqlx::query( + "INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at) + VALUES (?, ?, ?, ?, strftime('%s'))" + ) + .bind(&self.access_key_id) + .bind(&self.secret_key_enc) + .bind(&self.salt.0[0..]) + .bind(&self.nonce.0[0..]) + .execute(pool) + .await?; + + Ok(()) + } + + pub fn decrypt(&self, passphrase: &str) -> Result { + let mut key_buf = [0; secretbox::KEYBYTES]; + // pretty sure this only fails if we're out of memory + pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &self.salt).unwrap(); + let decrypted = secretbox::open(&self.secret_key_enc, &self.nonce, &Key(key_buf)) + .map_err(|_| UnlockError::BadPassphrase)?; + let secret_access_key = String::from_utf8(decrypted) + .map_err(|_| UnlockError::InvalidUtf8)?; + + let creds = BaseCredentials { + access_key_id: self.access_key_id.clone(), + secret_access_key, + }; + Ok(creds) + } +} + + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct BaseCredentials { + pub access_key_id: String, + pub secret_access_key: String, +} + +impl BaseCredentials { + pub fn encrypt(&self, passphrase: &str) -> LockedCredentials { + let salt = pwhash::gen_salt(); + let mut key_buf = [0; secretbox::KEYBYTES]; + pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap(); + let key = Key(key_buf); + let nonce = secretbox::gen_nonce(); + + let secret_key_enc = secretbox::seal(self.secret_access_key.as_bytes(), &nonce, &key); + + LockedCredentials { + access_key_id: self.access_key_id.clone(), + secret_key_enc, + salt, + nonce, + } + } +} + + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct SessionCredentials { + pub access_key_id: String, + pub secret_access_key: String, + pub token: String, + #[serde(serialize_with = "serialize_expiration")] + #[serde(deserialize_with = "deserialize_expiration")] + pub expiration: DateTime, +} + +impl SessionCredentials { + pub async fn from_base(base: &BaseCredentials) -> Result { + let req_creds = aws_sdk_sts::Credentials::new( + &base.access_key_id, + &base.secret_access_key, + None, // token + None, //expiration + "Creddy", // "provider name" apparently + ); + let config = aws_config::from_env() + .credentials_provider(req_creds) + .load() + .await; + + let client = aws_sdk_sts::Client::new(&config); + let resp = client.get_session_token() + .duration_seconds(43_200) + .send() + .await?; + + let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?; + + let access_key_id = aws_session.access_key_id() + .ok_or(GetSessionError::EmptyResponse)? + .to_string(); + let secret_access_key = aws_session.secret_access_key() + .ok_or(GetSessionError::EmptyResponse)? + .to_string(); + let token = aws_session.session_token() + .ok_or(GetSessionError::EmptyResponse)? + .to_string(); + let expiration = aws_session.expiration() + .ok_or(GetSessionError::EmptyResponse)? + .clone(); + + let session_creds = SessionCredentials { + access_key_id, + secret_access_key, + token, + expiration, + }; + + #[cfg(debug_assertions)] + println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap()); + + Ok(session_creds) + } + + pub fn is_expired(&self) -> bool { + let current_ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() // doesn't panic because UNIX_EPOCH won't be later than now() + .as_secs(); + + let expire_ts = self.expiration.secs(); + let remaining = expire_ts - (current_ts as i64); + remaining < 60 + } +} + +fn serialize_expiration(exp: &DateTime, serializer: S) -> Result +where S: Serializer +{ + // this only fails if the d/t is out of range, which it can't be for this format + let time_str = exp.fmt(Format::DateTime).unwrap(); + serializer.serialize_str(&time_str) +} + + +struct DateTimeVisitor; + +impl<'de> Visitor<'de> for DateTimeVisitor { + type Value = DateTime; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"") + } + + fn visit_str(self, v: &str) -> Result { + DateTime::from_str(v, Format::DateTime) + .map_err(|_| E::custom(format!("Invalid date/time: {v}"))) + } +} + + +fn deserialize_expiration<'de, D>(deserializer: D) -> Result +where D: Deserializer<'de> +{ + deserializer.deserialize_str(DateTimeVisitor) +} \ No newline at end of file diff --git a/src-tauri/src/errors.rs b/src-tauri/src/errors.rs index 05d9776..1c370d5 100644 --- a/src-tauri/src/errors.rs +++ b/src-tauri/src/errors.rs @@ -116,13 +116,13 @@ pub enum SendResponseError { // errors encountered while handling an HTTP request #[derive(Debug, ThisError, AsRefStr)] -pub enum RequestError { +pub enum HandlerError { #[error("Error writing to stream: {0}")] StreamIOError(#[from] std::io::Error), // #[error("Received invalid UTF-8 in request")] // InvalidUtf8, #[error("HTTP request malformed")] - BadRequest, + BadRequest(Vec), #[error("HTTP request too large")] RequestTooLarge, #[error("Error accessing credentials: {0}")] @@ -185,6 +185,41 @@ pub enum ClientInfoError { } +// Errors encountered while requesting credentials via CLI (creddy show, creddy exec) +#[derive(Debug, ThisError, AsRefStr)] +pub enum RequestError { + #[error("Credentials request failed: HTTP {0}")] + Failed(String), + #[error("Credentials request was rejected")] + Rejected, + #[error("Couldn't interpret the server's response")] + MalformedHttpResponse, + #[error("The server did not respond with valid JSON")] + InvalidJson, + #[error("Error reading/writing stream: {0}")] + StreamIOError(#[from] std::io::Error), +} + + +// Errors encountered while running a subprocess (via creddy exec) +#[derive(Debug, ThisError, AsRefStr)] +pub enum ExecError { + #[error("Please specify a command")] + NoCommand, + #[error("Failed to execute command: {0}")] + ExecutionFailed(#[from] std::io::Error) +} + + +#[derive(Debug, ThisError, AsRefStr)] +pub enum CliError { + #[error(transparent)] + Request(#[from] RequestError), + #[error(transparent)] + Exec(#[from] ExecError), +} + + // ========================= // Serialize implementations // ========================= @@ -210,15 +245,15 @@ impl_serialize_basic!(GetCredentialsError); impl_serialize_basic!(ClientInfoError); -impl Serialize for RequestError { +impl Serialize for HandlerError { fn serialize(&self, serializer: S) -> Result { let mut map = serializer.serialize_map(None)?; map.serialize_entry("code", self.as_ref())?; map.serialize_entry("msg", &format!("{self}"))?; match self { - RequestError::NoCredentials(src) => map.serialize_entry("source", &src)?, - RequestError::ClientInfo(src) => map.serialize_entry("source", &src)?, + HandlerError::NoCredentials(src) => map.serialize_entry("source", &src)?, + HandlerError::ClientInfo(src) => map.serialize_entry("source", &src)?, _ => serialize_upstream_err(self, &mut map)?, } diff --git a/src-tauri/src/ipc.rs b/src-tauri/src/ipc.rs index 3336875..4ec8317 100644 --- a/src-tauri/src/ipc.rs +++ b/src-tauri/src/ipc.rs @@ -1,10 +1,11 @@ use serde::{Serialize, Deserialize}; use tauri::State; -use crate::errors::*; use crate::config::AppConfig; +use crate::credentials::{Session,BaseCredentials}; +use crate::errors::*; use crate::clientinfo::Client; -use crate::state::{AppState, Session, BaseCredentials}; +use crate::state::AppState; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -58,7 +59,7 @@ pub async fn save_credentials( passphrase: String, app_state: State<'_, AppState> ) -> Result<(), UnlockError> { - app_state.save_creds(credentials, &passphrase).await + app_state.new_creds(credentials, &passphrase).await } diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 7242143..fa14562 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -1,23 +1,7 @@ -#![cfg_attr( - all(not(debug_assertions), target_os = "windows"), - windows_subsystem = "windows" -)] -use std::error::Error; - -use once_cell::sync::OnceCell; -use sqlx::{ - SqlitePool, - sqlite::SqlitePoolOptions, - sqlite::SqliteConnectOptions, -}; -use tauri::{ - App, - AppHandle, - Manager, - async_runtime as rt, -}; - +mod app; +mod cli; mod config; +mod credentials; mod errors; mod clientinfo; mod ipc; @@ -25,75 +9,22 @@ mod state; mod server; mod tray; -use config::AppConfig; -use server::Server; -use errors::*; -use state::AppState; - -pub static APP: OnceCell = OnceCell::new(); - - -async fn setup(app: &mut App) -> Result<(), Box> { - APP.set(app.handle()).unwrap(); - - let conn_opts = SqliteConnectOptions::new() - .filename(config::get_or_create_db_path()?) - .create_if_missing(true); - let pool_opts = SqlitePoolOptions::new(); - let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?; - sqlx::migrate!().run(&pool).await?; - - let conf = AppConfig::load(&pool).await?; - let session = AppState::load_creds(&pool).await?; - let srv = Server::new(conf.listen_addr, conf.listen_port, app.handle()).await?; - - config::set_auto_launch(conf.start_on_login)?; - if !conf.start_minimized { - app.get_window("main") - .ok_or(RequestError::NoMainWindow)? - .show()?; - } - - let state = AppState::new(conf, session, srv, pool); - app.manage(state); - Ok(()) -} - - -fn run() -> tauri::Result<()> { - tauri::Builder::default() - .plugin(tauri_plugin_single_instance::init(|app, _argv, _cwd| { - app.get_window("main") - .map(|w| w.show().error_popup("Failed to show main window")); - })) - .system_tray(tray::create()) - .on_system_tray_event(tray::handle_event) - .invoke_handler(tauri::generate_handler![ - ipc::unlock, - ipc::respond, - ipc::get_session_status, - ipc::save_credentials, - ipc::get_config, - ipc::save_config, - ]) - .setup(|app| rt::block_on(setup(app))) - .build(tauri::generate_context!())? - .run(|app, run_event| match run_event { - tauri::RunEvent::WindowEvent { label, event, .. } => match event { - tauri::WindowEvent::CloseRequested { api, .. } => { - let _ = app.get_window(&label).map(|w| w.hide()); - api.prevent_close(); - } - _ => () - } - _ => () - }); - - Ok(()) -} +use crate::errors::ErrorPopup; fn main() { - run().error_popup("Creddy failed to start"); + let res = match cli::parser().get_matches().subcommand() { + None | Some(("run", _)) => { + app::run().error_popup("Creddy failed to start"); + Ok(()) + }, + Some(("show", m)) => cli::show(m), + Some(("exec", m)) => cli::exec(m), + _ => unreachable!(), + }; + + if let Err(e) = res { + eprintln!("Error: {e}"); + } } diff --git a/src-tauri/src/server.rs b/src-tauri/src/server.rs index 15ccfd6..e9fd750 100644 --- a/src-tauri/src/server.rs +++ b/src-tauri/src/server.rs @@ -51,25 +51,46 @@ impl Handler { state.unregister_request(self.request_id).await; } - async fn try_handle(&mut self) -> Result<(), RequestError> { - let _ = self.recv_request().await?; + async fn try_handle(&mut self) -> Result<(), HandlerError> { + let req_path = self.recv_request().await?; let clients = self.get_clients().await?; if self.includes_banned(&clients).await { self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; return Ok(()) } + // at present only the running exe should be permitted to access this route + if req_path == b"/creddy/base-credentials" { + if clients.len() != 1 + || clients[0].is_none() + || clients[0].as_ref().unwrap().exe != std::env::current_exe()? + { + self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; + return Ok(()) + } + } let req = Request {id: self.request_id, clients}; self.app.emit_all("credentials-request", &req)?; let starting_visibility = self.show_window()?; match self.wait_for_response().await? { - Approval::Approved => self.send_credentials().await?, + Approval::Approved => { + let state = self.app.state::(); + let creds = if req_path == b"/creddy/base-credentials" { + state.serialize_base_creds().await? + } + else { + state.serialize_session_creds().await? + }; + self.send_body(creds.as_bytes()).await?; + }, Approval::Denied => { let state = self.app.state::(); for client in req.clients { state.add_ban(client).await; } + self.send_body(b"Denied!").await?; + self.stream.shutdown().await?; } } @@ -83,35 +104,36 @@ impl Handler { sleep(delay).await; if !starting_visibility && state.req_count().await == 0 { - let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; + let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?; window.hide()?; } Ok(()) } - async fn recv_request(&mut self) -> Result, RequestError> { + async fn recv_request(&mut self) -> Result, HandlerError> { let mut buf = vec![0; 8192]; // it's what tokio's BufReader uses let mut n = 0; loop { n += self.stream.read(&mut buf[n..]).await?; if n >= 4 && &buf[(n - 4)..n] == b"\r\n\r\n" {break;} - if n == buf.len() {return Err(RequestError::RequestTooLarge);} - } - - if cfg!(debug_assertions) { - println!("{}", std::str::from_utf8(&buf).unwrap()); + if n == buf.len() {return Err(HandlerError::RequestTooLarge);} } let path = buf.split(|&c| &[c] == b" ") .skip(1) .next() - .ok_or(RequestError::BadRequest(buf))?; + .ok_or(HandlerError::BadRequest(buf.clone()))?; - Ok(buf) + #[cfg(debug_assertions)] { + println!("Path: {}", std::str::from_utf8(&path).unwrap()); + println!("{}", std::str::from_utf8(&buf).unwrap()); + } + + Ok(path.into()) } - async fn get_clients(&self) -> Result>, RequestError> { + async fn get_clients(&self) -> Result>, HandlerError> { let peer_addr = match self.stream.peer_addr()? { SocketAddr::V4(addr) => addr, _ => unreachable!(), // we only listen on IPv4 @@ -130,8 +152,8 @@ impl Handler { false } - fn show_window(&self) -> Result { - let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; + fn show_window(&self) -> Result { + let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?; let starting_visibility = window.is_visible()?; if !starting_visibility { window.unminimize()?; @@ -141,7 +163,7 @@ impl Handler { Ok(starting_visibility) } - async fn wait_for_response(&mut self) -> Result { + async fn wait_for_response(&mut self) -> Result { self.stream.write(b"HTTP/1.0 200 OK\r\n").await?; self.stream.write(b"Content-Type: application/json\r\n").await?; self.stream.write(b"X-Creddy-delaying-tactic: ").await?; @@ -164,15 +186,12 @@ impl Handler { } } - async fn send_credentials(&mut self) -> Result<(), RequestError> { - let state = self.app.state::(); - let creds = state.serialize_session_creds().await?; - + async fn send_body(&mut self, body: &[u8]) -> Result<(), HandlerError> { self.stream.write(b"\r\nContent-Length: ").await?; - self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?; - self.stream.write(b"\r\n\r\n").await?; - self.stream.write(creds.as_bytes()).await?; + self.stream.write(body.len().to_string().as_bytes()).await?; self.stream.write(b"\r\n\r\n").await?; + self.stream.write(body).await?; + self.stream.shutdown().await?; Ok(()) } } diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index b6b55be..6ceae38 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -1,90 +1,28 @@ use std::collections::{HashMap, HashSet}; -use std::time::{ - Duration, - SystemTime, - UNIX_EPOCH -}; +use std::time::Duration; - -use aws_smithy_types::date_time::{ - DateTime as AwsDateTime, - Format as AwsDateTimeFormat, -}; -use serde::{Serialize, Deserialize}; use tokio::{ sync::oneshot::Sender, sync::RwLock, time::sleep, }; use sqlx::SqlitePool; -use sodiumoxide::crypto::{ - pwhash, - pwhash::Salt, - secretbox, - secretbox::{Nonce, Key} -}; use tauri::async_runtime as runtime; use tauri::Manager; -use serde::Serializer; +use crate::app::APP; +use crate::credentials::{ + Session, + BaseCredentials, + SessionCredentials, +}; use crate::{config, config::AppConfig}; -use crate::ipc; +use crate::ipc::{self, Approval}; use crate::clientinfo::Client; use crate::errors::*; use crate::server::Server; -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "PascalCase")] -pub struct BaseCredentials { - access_key_id: String, - secret_access_key: String, -} - -#[derive(Clone, Debug, Serialize)] -#[serde(rename_all = "PascalCase")] -pub struct SessionCredentials { - access_key_id: String, - secret_access_key: String, - token: String, - #[serde(serialize_with = "serialize_expiration")] - expiration: AwsDateTime, -} - -impl SessionCredentials { - fn is_expired(&self) -> bool { - let current_ts = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() // doesn't panic because UNIX_EPOCH won't be later than now() - .as_secs(); - - let expire_ts = self.expiration.secs(); - let remaining = expire_ts - (current_ts as i64); - remaining < 60 - } -} - - -#[derive(Clone, Debug)] -pub struct LockedCredentials { - access_key_id: String, - secret_key_enc: Vec, - salt: Salt, - nonce: Nonce, -} - - -#[derive(Clone, Debug)] -pub enum Session { - Unlocked{ - base: BaseCredentials, - session: SessionCredentials, - }, - Locked(LockedCredentials), - Empty, -} - - #[derive(Debug)] pub struct AppState { pub config: RwLock, @@ -109,57 +47,11 @@ impl AppState { } } - pub async fn load_creds(pool: &SqlitePool) -> Result { - let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc") - .fetch_optional(pool) - .await?; - let row = match res { - Some(r) => r, - None => {return Ok(Session::Empty);} - }; - - let salt_buf: [u8; 32] = row.salt - .try_into() - .map_err(|_e| SetupError::InvalidRecord)?; - let nonce_buf: [u8; 24] = row.nonce - .try_into() - .map_err(|_e| SetupError::InvalidRecord)?; - - let creds = LockedCredentials { - access_key_id: row.access_key_id, - secret_key_enc: row.secret_key_enc, - salt: Salt(salt_buf), - nonce: Nonce(nonce_buf), - }; - Ok(Session::Locked(creds)) - } - - pub async fn save_creds(&self, creds: BaseCredentials, passphrase: &str) -> Result<(), UnlockError> { - let BaseCredentials {access_key_id, secret_access_key} = creds; - + pub async fn new_creds(&self, base_creds: BaseCredentials, passphrase: &str) -> Result<(), UnlockError> { + let locked = base_creds.encrypt(passphrase); // do this first so that if it fails we don't save bad credentials - self.new_session(&access_key_id, &secret_access_key).await?; - - let salt = pwhash::gen_salt(); - let mut key_buf = [0; secretbox::KEYBYTES]; - pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap(); - let key = Key(key_buf); - // not sure we need both salt AND nonce given that we generate a - // fresh salt every time we encrypt, but better safe than sorry - let nonce = secretbox::gen_nonce(); - let secret_key_enc = secretbox::seal(secret_access_key.as_bytes(), &nonce, &key); - - - sqlx::query( - "INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at) - VALUES (?, ?, ?, ?, strftime('%s'))" - ) - .bind(&access_key_id) - .bind(&secret_key_enc) - .bind(&salt.0[0..]) - .bind(&nonce.0[0..]) - .execute(&self.pool) - .await?; + self.new_session(base_creds).await?; + locked.save(&self.pool).await?; Ok(()) } @@ -205,7 +97,10 @@ impl AppState { } pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { - self.renew_session_if_expired().await?; + if let Approval::Approved = response.approval { + let mut session = self.session.write().await; + session.renew_if_expired().await?; + } let mut open_requests = self.open_requests.write().await; let chan = open_requests @@ -223,7 +118,7 @@ impl AppState { runtime::spawn(async move { sleep(Duration::from_secs(5)).await; - let app = crate::APP.get().unwrap(); + let app = APP.get().unwrap(); let state = app.state::(); let mut bans = state.bans.write().await; bans.remove(&client); @@ -235,46 +130,25 @@ impl AppState { } pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> { - let mut session = self.session.write().await; - let LockedCredentials { - access_key_id, - secret_key_enc, - salt, - nonce - } = match *session { + let base_creds = match *self.session.read().await { Session::Empty => {return Err(UnlockError::NoCredentials);}, Session::Unlocked{..} => {return Err(UnlockError::NotLocked);}, - Session::Locked(ref c) => c, - }; - - let mut key_buf = [0; secretbox::KEYBYTES]; - // pretty sure this only fails if we're out of memory - pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap(); - let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf)) - .map_err(|_e| UnlockError::BadPassphrase)?; - - let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?; - - let session_creds = self.new_session(access_key_id, &secret_access_key).await?; - *session = Session::Unlocked { - base: BaseCredentials { - access_key_id: access_key_id.clone(), - secret_access_key, - }, - session: session_creds + Session::Locked(ref locked) => locked.decrypt(passphrase)?, }; + // Read lock is dropped here, so this doesn't deadlock + self.new_session(base_creds).await?; Ok(()) } - // pub async fn serialize_base_creds(&self) -> Result { - // let session = self.session.read().await; - // match *session { - // Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()), - // Session::Locked(_) => Err(GetCredentialsError::Locked), - // Session::Empty => Err(GetCredentialsError::Empty), - // } - // } + pub async fn serialize_base_creds(&self) -> Result { + let session = self.session.read().await; + match *session { + Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()), + Session::Locked(_) => Err(GetCredentialsError::Locked), + Session::Empty => Err(GetCredentialsError::Empty), + } + } pub async fn serialize_session_creds(&self) -> Result { let session = self.session.read().await; @@ -285,77 +159,10 @@ impl AppState { } } - async fn new_session(&self, key_id: &str, secret_key: &str) -> Result { - let creds = aws_sdk_sts::Credentials::new( - key_id, - secret_key, - None, // token - None, // expiration - "creddy", // "provider name" apparently - ); - let config = aws_config::from_env() - .credentials_provider(creds) - .load() - .await; - - let client = aws_sdk_sts::Client::new(&config); - let resp = client.get_session_token() - .duration_seconds(43_200) - .send() - .await?; - - let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?; - - let access_key_id = aws_session.access_key_id() - .ok_or(GetSessionError::EmptyResponse)? - .to_string(); - let secret_access_key = aws_session.secret_access_key() - .ok_or(GetSessionError::EmptyResponse)? - .to_string(); - let token = aws_session.session_token() - .ok_or(GetSessionError::EmptyResponse)? - .to_string(); - let expiration = aws_session.expiration() - .ok_or(GetSessionError::EmptyResponse)? - .clone(); - - let session_creds = SessionCredentials { - access_key_id, - secret_access_key, - token, - expiration, - }; - - #[cfg(debug_assertions)] - println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap()); - - Ok(session_creds) - } - - pub async fn renew_session_if_expired(&self) -> Result { - match *self.session.write().await { - Session::Unlocked{ref base, ref mut session} => { - if !session.is_expired() { - return Ok(false); - } - let new_session = self.new_session( - &base.access_key_id, - &base.secret_access_key - ).await?; - *session = new_session; - Ok(true) - }, - Session::Locked(_) => Err(GetSessionError::CredentialsLocked), - Session::Empty => Err(GetSessionError::CredentialsEmpty), - } + async fn new_session(&self, base: BaseCredentials) -> Result<(), GetSessionError> { + let session = SessionCredentials::from_base(&base).await?; + let mut app_session = self.session.write().await; + *app_session = Session::Unlocked {base, session}; + Ok(()) } } - - -fn serialize_expiration(exp: &AwsDateTime, serializer: S) -> Result -where S: Serializer -{ - // this only fails if the d/t is out of range, which it can't be for this format - let time_str = exp.fmt(AwsDateTimeFormat::DateTime).unwrap(); - serializer.serialize_str(&time_str) -} diff --git a/src/views/Approve.svelte b/src/views/Approve.svelte index ea555a4..6ad1040 100644 --- a/src/views/Approve.svelte +++ b/src/views/Approve.svelte @@ -68,7 +68,7 @@ -{#if !$appState.currentRequest.approval} +{#if error || !$appState.currentRequest.approval}
{#if error}