add show/exec commands and refactor AppState

This commit is contained in:
Joseph Montanaro 2023-05-06 12:01:56 -07:00
parent e8b8dc2976
commit 616600687d
14 changed files with 715 additions and 345 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ src-tauri/target/
# just in case # just in case
credentials* credentials*
!credentials.rs

81
src-tauri/Cargo.lock generated
View File

@ -67,6 +67,7 @@ dependencies = [
"aws-sdk-sts", "aws-sdk-sts",
"aws-smithy-types", "aws-smithy-types",
"aws-types", "aws-types",
"clap",
"dirs 5.0.1", "dirs 5.0.1",
"netstat2", "netstat2",
"once_cell", "once_cell",
@ -227,6 +228,17 @@ version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1181e1e0d1fce796a03db1ae795d67167da795f9cf4a39c37589e85ef57f26d3" checksum = "1181e1e0d1fce796a03db1ae795d67167da795f9cf4a39c37589e85ef57f26d3"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi",
]
[[package]] [[package]]
name = "auto-launch" name = "auto-launch"
version = "0.4.0" version = "0.4.0"
@ -714,6 +726,45 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clap"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
dependencies = [
"atty",
"bitflags",
"clap_derive",
"clap_lex",
"indexmap",
"once_cell",
"strsim",
"termcolor",
"textwrap",
]
[[package]]
name = "clap_derive"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008"
dependencies = [
"heck 0.4.1",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "clap_lex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
dependencies = [
"os_str_bytes",
]
[[package]] [[package]]
name = "cocoa" name = "cocoa"
version = "0.24.1" version = "0.24.1"
@ -1735,6 +1786,15 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
] ]
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.2.6" version = "0.2.6"
@ -2511,6 +2571,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "os_str_bytes"
version = "6.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267"
[[package]] [[package]]
name = "outref" name = "outref"
version = "0.1.0" version = "0.1.0"
@ -4020,6 +4086,21 @@ dependencies = [
"utf-8", "utf-8",
] ]
[[package]]
name = "termcolor"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6"
dependencies = [
"winapi-util",
]
[[package]]
name = "textwrap"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d"
[[package]] [[package]]
name = "thin-slice" name = "thin-slice"
version = "0.1.1" version = "0.1.1"

View File

@ -34,6 +34,7 @@ strum = "0.24"
strum_macros = "0.24" strum_macros = "0.24"
auto-launch = "0.4.0" auto-launch = "0.4.0"
dirs = "5.0" dirs = "5.0"
clap = { version = "3.2.23", features = ["derive"] }
[features] [features]
# by default Tauri runs in production mode # by default Tauri runs in production mode

92
src-tauri/src/app.rs Normal file
View File

@ -0,0 +1,92 @@
#![cfg_attr(
all(not(debug_assertions), target_os = "windows"),
windows_subsystem = "windows"
)]
use std::error::Error;
use once_cell::sync::OnceCell;
use sqlx::{
SqlitePool,
sqlite::SqlitePoolOptions,
sqlite::SqliteConnectOptions,
};
use tauri::{
App,
AppHandle,
Manager,
async_runtime as rt,
};
use crate::{
config::{self, AppConfig},
credentials::Session,
ipc,
server::Server,
errors::*,
state::AppState,
tray,
};
pub static APP: OnceCell<AppHandle> = OnceCell::new();
async fn setup(app: &mut App) -> Result<(), Box<dyn Error>> {
APP.set(app.handle()).unwrap();
let conn_opts = SqliteConnectOptions::new()
.filename(config::get_or_create_db_path()?)
.create_if_missing(true);
let pool_opts = SqlitePoolOptions::new();
let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?;
sqlx::migrate!().run(&pool).await?;
let conf = AppConfig::load(&pool).await?;
let session = Session::load(&pool).await?;
let srv = Server::new(conf.listen_addr, conf.listen_port, app.handle()).await?;
config::set_auto_launch(conf.start_on_login)?;
if !conf.start_minimized {
app.get_window("main")
.ok_or(HandlerError::NoMainWindow)?
.show()?;
}
let state = AppState::new(conf, session, srv, pool);
app.manage(state);
Ok(())
}
pub fn run() -> tauri::Result<()> {
tauri::Builder::default()
.plugin(tauri_plugin_single_instance::init(|app, _argv, _cwd| {
app.get_window("main")
.map(|w| w.show().error_popup("Failed to show main window"));
}))
.system_tray(tray::create())
.on_system_tray_event(tray::handle_event)
.invoke_handler(tauri::generate_handler![
ipc::unlock,
ipc::respond,
ipc::get_session_status,
ipc::save_credentials,
ipc::get_config,
ipc::save_config,
])
.setup(|app| rt::block_on(setup(app)))
.build(tauri::generate_context!())?
.run(|app, run_event| match run_event {
tauri::RunEvent::WindowEvent { label, event, .. } => match event {
tauri::WindowEvent::CloseRequested { api, .. } => {
let _ = app.get_window(&label).map(|w| w.hide());
api.prevent_close();
}
_ => ()
}
_ => ()
});
Ok(())
}

135
src-tauri/src/cli.rs Normal file
View File

@ -0,0 +1,135 @@
use std::io::{Read, Write};
use std::process::Command as ChildCommand;
#[cfg(unix)]
use std::os::unix::process::CommandExt;
use clap::{
Command,
Arg,
ArgMatches,
ArgAction
};
use std::net::TcpStream;
use crate::credentials::{BaseCredentials, SessionCredentials};
use crate::errors::*;
pub fn parser() -> Command<'static> {
Command::new("creddy")
.about("A friendly AWS credentials manager")
.subcommand(
Command::new("run")
.about("Launch Creddy")
)
.subcommand(
Command::new("show")
.about("Fetch and display AWS credentials")
.arg(
Arg::new("base")
.short('b')
.long("base")
.action(ArgAction::SetTrue)
.help("Use base credentials instead of session credentials")
)
)
.subcommand(
Command::new("exec")
.about("Inject AWS credentials into the environment of another command")
.trailing_var_arg(true)
.arg(
Arg::new("base")
.short('b')
.long("base")
.action(ArgAction::SetTrue)
.help("Use base credentials instead of session credentials")
)
.arg(
Arg::new("command")
.multiple_values(true)
)
)
}
pub fn show(args: &ArgMatches) -> Result<(), CliError> {
let base = args.get_one("base").unwrap_or(&false);
let creds = get_credentials(*base)?;
println!("{creds}");
Ok(())
}
pub fn exec(args: &ArgMatches) -> Result<(), CliError> {
let base = *args.get_one("base").unwrap_or(&false);
let mut cmd_line = args.get_many("command")
.ok_or(ExecError::NoCommand)?;
let cmd_name: &String = cmd_line.next().unwrap(); // Clap guarantees that there will be at least one
let mut cmd = ChildCommand::new(cmd_name);
cmd.args(cmd_line);
if base {
let creds: BaseCredentials = serde_json::from_str(&get_credentials(base)?)
.map_err(|_| RequestError::InvalidJson)?;
cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id);
cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key);
}
else {
let creds: SessionCredentials = serde_json::from_str(&get_credentials(base)?)
.map_err(|_| RequestError::InvalidJson)?;
cmd.env("AWS_ACCESS_KEY_ID", creds.access_key_id);
cmd.env("AWS_SECRET_ACCESS_KEY", creds.secret_access_key);
cmd.env("AWS_SESSION_TOKEN", creds.token);
}
#[cfg(unix)]
cmd.exec().map_err(|e| ExecError::ExecutionFailed(e))?;
#[cfg(windows)]
{
let mut child = cmd.spawn()
.map_err(|e| ExecError::ExecutionFailed(e))?;
let status = child.wait()
.map_err(|e| ExecError::ExecutionFailed(e))?;
std::process::exit(status.code().unwrap_or(1));
};
}
fn get_credentials(base: bool) -> Result<String, RequestError> {
let path = if base {"/creddy/base-credentials"} else {"/"};
let mut stream = TcpStream::connect("127.0.0.1:12345")?;
let req = format!("GET {path} HTTP/1.0\r\n\r\n");
stream.write_all(req.as_bytes())?;
// some day we'll have a proper HTTP parser
let mut buf = vec![0; 8192];
stream.read_to_end(&mut buf)?;
let status = buf.split(|&c| &[c] == b" ")
.skip(1)
.next()
.ok_or(RequestError::MalformedHttpResponse)?;
if status != b"200" {
let s = String::from_utf8_lossy(status).to_string();
return Err(RequestError::Failed(s));
}
let break_idx = buf.windows(4)
.position(|w| w == b"\r\n\r\n")
.ok_or(RequestError::MalformedHttpResponse)?;
let body = &buf[(break_idx + 4)..];
let creds_str = std::str::from_utf8(body)
.map_err(|_| RequestError::MalformedHttpResponse)?
.to_string();
if creds_str == "Denied!" {
return Err(RequestError::Rejected);
}
Ok(creds_str)
}

View File

@ -1,9 +1,12 @@
use std::path::PathBuf;
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo}; use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
use tauri::Manager; use tauri::Manager;
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt}; use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::{ use crate::{
app::APP,
errors::*, errors::*,
config::AppConfig, config::AppConfig,
state::AppState, state::AppState,
@ -13,12 +16,12 @@ use crate::{
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Client { pub struct Client {
pub pid: u32, pub pid: u32,
pub exe: String, pub exe: PathBuf,
} }
async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> { async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
let state = crate::APP.get().unwrap().state::<AppState>(); let state = APP.get().unwrap().state::<AppState>();
let AppConfig { let AppConfig {
listen_addr: app_listen_addr, listen_addr: app_listen_addr,
listen_port: app_listen_port, listen_port: app_listen_port,
@ -60,7 +63,7 @@ pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientI
let client = Client { let client = Client {
pid: p, pid: p,
exe: proc.exe().to_string_lossy().into_owned(), exe: proc.exe().to_path_buf(),
}; };
clients.push(Some(client)); clients.push(Some(client));
} }

View File

@ -0,0 +1,244 @@
use std::fmt::{self, Formatter};
use std::time::{SystemTime, UNIX_EPOCH};
use aws_smithy_types::date_time::{DateTime, Format};
use serde::{
Serialize,
Deserialize,
Serializer,
Deserializer,
};
use serde::de::{self, Visitor};
use sqlx::SqlitePool;
use sodiumoxide::crypto::{
pwhash,
pwhash::Salt,
secretbox,
secretbox::{Nonce, Key}
};
use crate::errors::*;
#[derive(Clone, Debug)]
pub enum Session {
Unlocked{
base: BaseCredentials,
session: SessionCredentials,
},
Locked(LockedCredentials),
Empty,
}
impl Session {
pub async fn load(pool: &SqlitePool) -> Result<Self, SetupError> {
let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc")
.fetch_optional(pool)
.await?;
let row = match res {
Some(r) => r,
None => {return Ok(Session::Empty);}
};
let salt_buf: [u8; 32] = row.salt
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let nonce_buf: [u8; 24] = row.nonce
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let creds = LockedCredentials {
access_key_id: row.access_key_id,
secret_key_enc: row.secret_key_enc,
salt: Salt(salt_buf),
nonce: Nonce(nonce_buf),
};
Ok(Session::Locked(creds))
}
pub async fn renew_if_expired(&mut self) -> Result<bool, GetSessionError> {
match self {
Session::Unlocked{ref base, ref mut session} => {
if !session.is_expired() {
return Ok(false);
}
*session = SessionCredentials::from_base(base).await?;
Ok(true)
},
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty),
}
}
}
#[derive(Clone, Debug)]
pub struct LockedCredentials {
pub access_key_id: String,
pub secret_key_enc: Vec<u8>,
pub salt: Salt,
pub nonce: Nonce,
}
impl LockedCredentials {
pub async fn save(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> {
sqlx::query(
"INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at)
VALUES (?, ?, ?, ?, strftime('%s'))"
)
.bind(&self.access_key_id)
.bind(&self.secret_key_enc)
.bind(&self.salt.0[0..])
.bind(&self.nonce.0[0..])
.execute(pool)
.await?;
Ok(())
}
pub fn decrypt(&self, passphrase: &str) -> Result<BaseCredentials, UnlockError> {
let mut key_buf = [0; secretbox::KEYBYTES];
// pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &self.salt).unwrap();
let decrypted = secretbox::open(&self.secret_key_enc, &self.nonce, &Key(key_buf))
.map_err(|_| UnlockError::BadPassphrase)?;
let secret_access_key = String::from_utf8(decrypted)
.map_err(|_| UnlockError::InvalidUtf8)?;
let creds = BaseCredentials {
access_key_id: self.access_key_id.clone(),
secret_access_key,
};
Ok(creds)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct BaseCredentials {
pub access_key_id: String,
pub secret_access_key: String,
}
impl BaseCredentials {
pub fn encrypt(&self, passphrase: &str) -> LockedCredentials {
let salt = pwhash::gen_salt();
let mut key_buf = [0; secretbox::KEYBYTES];
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap();
let key = Key(key_buf);
let nonce = secretbox::gen_nonce();
let secret_key_enc = secretbox::seal(self.secret_access_key.as_bytes(), &nonce, &key);
LockedCredentials {
access_key_id: self.access_key_id.clone(),
secret_key_enc,
salt,
nonce,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct SessionCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub token: String,
#[serde(serialize_with = "serialize_expiration")]
#[serde(deserialize_with = "deserialize_expiration")]
pub expiration: DateTime,
}
impl SessionCredentials {
pub async fn from_base(base: &BaseCredentials) -> Result<Self, GetSessionError> {
let req_creds = aws_sdk_sts::Credentials::new(
&base.access_key_id,
&base.secret_access_key,
None, // token
None, //expiration
"Creddy", // "provider name" apparently
);
let config = aws_config::from_env()
.credentials_provider(req_creds)
.load()
.await;
let client = aws_sdk_sts::Client::new(&config);
let resp = client.get_session_token()
.duration_seconds(43_200)
.send()
.await?;
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
let access_key_id = aws_session.access_key_id()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let token = aws_session.session_token()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let expiration = aws_session.expiration()
.ok_or(GetSessionError::EmptyResponse)?
.clone();
let session_creds = SessionCredentials {
access_key_id,
secret_access_key,
token,
expiration,
};
#[cfg(debug_assertions)]
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
Ok(session_creds)
}
pub fn is_expired(&self) -> bool {
let current_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
.as_secs();
let expire_ts = self.expiration.secs();
let remaining = expire_ts - (current_ts as i64);
remaining < 60
}
}
fn serialize_expiration<S>(exp: &DateTime, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
// this only fails if the d/t is out of range, which it can't be for this format
let time_str = exp.fmt(Format::DateTime).unwrap();
serializer.serialize_str(&time_str)
}
struct DateTimeVisitor;
impl<'de> Visitor<'de> for DateTimeVisitor {
type Value = DateTime;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "an RFC 3339 UTC string, e.g. \"2014-01-05T10:17:34Z\"")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<DateTime, E> {
DateTime::from_str(v, Format::DateTime)
.map_err(|_| E::custom(format!("Invalid date/time: {v}")))
}
}
fn deserialize_expiration<'de, D>(deserializer: D) -> Result<DateTime, D::Error>
where D: Deserializer<'de>
{
deserializer.deserialize_str(DateTimeVisitor)
}

View File

@ -116,12 +116,13 @@ pub enum SendResponseError {
// errors encountered while handling an HTTP request // errors encountered while handling an HTTP request
#[derive(Debug, ThisError, AsRefStr)] #[derive(Debug, ThisError, AsRefStr)]
pub enum RequestError { pub enum HandlerError {
#[error("Error writing to stream: {0}")] #[error("Error writing to stream: {0}")]
StreamIOError(#[from] std::io::Error), StreamIOError(#[from] std::io::Error),
// #[error("Received invalid UTF-8 in request")] // #[error("Received invalid UTF-8 in request")]
// InvalidUtf8, // InvalidUtf8,
// MalformedHttpRequest, #[error("HTTP request malformed")]
BadRequest(Vec<u8>),
#[error("HTTP request too large")] #[error("HTTP request too large")]
RequestTooLarge, RequestTooLarge,
#[error("Error accessing credentials: {0}")] #[error("Error accessing credentials: {0}")]
@ -184,6 +185,41 @@ pub enum ClientInfoError {
} }
// Errors encountered while requesting credentials via CLI (creddy show, creddy exec)
#[derive(Debug, ThisError, AsRefStr)]
pub enum RequestError {
#[error("Credentials request failed: HTTP {0}")]
Failed(String),
#[error("Credentials request was rejected")]
Rejected,
#[error("Couldn't interpret the server's response")]
MalformedHttpResponse,
#[error("The server did not respond with valid JSON")]
InvalidJson,
#[error("Error reading/writing stream: {0}")]
StreamIOError(#[from] std::io::Error),
}
// Errors encountered while running a subprocess (via creddy exec)
#[derive(Debug, ThisError, AsRefStr)]
pub enum ExecError {
#[error("Please specify a command")]
NoCommand,
#[error("Failed to execute command: {0}")]
ExecutionFailed(#[from] std::io::Error)
}
#[derive(Debug, ThisError, AsRefStr)]
pub enum CliError {
#[error(transparent)]
Request(#[from] RequestError),
#[error(transparent)]
Exec(#[from] ExecError),
}
// ========================= // =========================
// Serialize implementations // Serialize implementations
// ========================= // =========================
@ -209,15 +245,15 @@ impl_serialize_basic!(GetCredentialsError);
impl_serialize_basic!(ClientInfoError); impl_serialize_basic!(ClientInfoError);
impl Serialize for RequestError { impl Serialize for HandlerError {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(None)?; let mut map = serializer.serialize_map(None)?;
map.serialize_entry("code", self.as_ref())?; map.serialize_entry("code", self.as_ref())?;
map.serialize_entry("msg", &format!("{self}"))?; map.serialize_entry("msg", &format!("{self}"))?;
match self { match self {
RequestError::NoCredentials(src) => map.serialize_entry("source", &src)?, HandlerError::NoCredentials(src) => map.serialize_entry("source", &src)?,
RequestError::ClientInfo(src) => map.serialize_entry("source", &src)?, HandlerError::ClientInfo(src) => map.serialize_entry("source", &src)?,
_ => serialize_upstream_err(self, &mut map)?, _ => serialize_upstream_err(self, &mut map)?,
} }

View File

@ -1,16 +1,18 @@
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use tauri::State; use tauri::State;
use crate::errors::*;
use crate::config::AppConfig; use crate::config::AppConfig;
use crate::credentials::{Session,BaseCredentials};
use crate::errors::*;
use crate::clientinfo::Client; use crate::clientinfo::Client;
use crate::state::{AppState, Session, BaseCredentials}; use crate::state::AppState;
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Request { pub struct Request {
pub id: u64, pub id: u64,
pub clients: Vec<Option<Client>>, pub clients: Vec<Option<Client>>,
pub base: bool,
} }
@ -58,7 +60,7 @@ pub async fn save_credentials(
passphrase: String, passphrase: String,
app_state: State<'_, AppState> app_state: State<'_, AppState>
) -> Result<(), UnlockError> { ) -> Result<(), UnlockError> {
app_state.save_creds(credentials, &passphrase).await app_state.new_creds(credentials, &passphrase).await
} }

View File

@ -1,23 +1,7 @@
#![cfg_attr( mod app;
all(not(debug_assertions), target_os = "windows"), mod cli;
windows_subsystem = "windows"
)]
use std::error::Error;
use once_cell::sync::OnceCell;
use sqlx::{
SqlitePool,
sqlite::SqlitePoolOptions,
sqlite::SqliteConnectOptions,
};
use tauri::{
App,
AppHandle,
Manager,
async_runtime as rt,
};
mod config; mod config;
mod credentials;
mod errors; mod errors;
mod clientinfo; mod clientinfo;
mod ipc; mod ipc;
@ -25,75 +9,22 @@ mod state;
mod server; mod server;
mod tray; mod tray;
use config::AppConfig;
use server::Server;
use errors::*;
use state::AppState;
use crate::errors::ErrorPopup;
pub static APP: OnceCell<AppHandle> = OnceCell::new();
async fn setup(app: &mut App) -> Result<(), Box<dyn Error>> {
APP.set(app.handle()).unwrap();
let conn_opts = SqliteConnectOptions::new()
.filename(config::get_or_create_db_path()?)
.create_if_missing(true);
let pool_opts = SqlitePoolOptions::new();
let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?;
sqlx::migrate!().run(&pool).await?;
let conf = AppConfig::load(&pool).await?;
let session = AppState::load_creds(&pool).await?;
let srv = Server::new(conf.listen_addr, conf.listen_port, app.handle()).await?;
config::set_auto_launch(conf.start_on_login)?;
if !conf.start_minimized {
app.get_window("main")
.ok_or(RequestError::NoMainWindow)?
.show()?;
}
let state = AppState::new(conf, session, srv, pool);
app.manage(state);
Ok(())
}
fn run() -> tauri::Result<()> {
tauri::Builder::default()
.plugin(tauri_plugin_single_instance::init(|app, _argv, _cwd| {
app.get_window("main")
.map(|w| w.show().error_popup("Failed to show main window"));
}))
.system_tray(tray::create())
.on_system_tray_event(tray::handle_event)
.invoke_handler(tauri::generate_handler![
ipc::unlock,
ipc::respond,
ipc::get_session_status,
ipc::save_credentials,
ipc::get_config,
ipc::save_config,
])
.setup(|app| rt::block_on(setup(app)))
.build(tauri::generate_context!())?
.run(|app, run_event| match run_event {
tauri::RunEvent::WindowEvent { label, event, .. } => match event {
tauri::WindowEvent::CloseRequested { api, .. } => {
let _ = app.get_window(&label).map(|w| w.hide());
api.prevent_close();
}
_ => ()
}
_ => ()
});
Ok(())
}
fn main() { fn main() {
run().error_popup("Creddy failed to start"); let res = match cli::parser().get_matches().subcommand() {
None | Some(("run", _)) => {
app::run().error_popup("Creddy failed to start");
Ok(())
},
Some(("show", m)) => cli::show(m),
Some(("exec", m)) => cli::exec(m),
_ => unreachable!(),
};
if let Err(e) = res {
eprintln!("Error: {e}");
}
} }

View File

@ -51,25 +51,46 @@ impl Handler {
state.unregister_request(self.request_id).await; state.unregister_request(self.request_id).await;
} }
async fn try_handle(&mut self) -> Result<(), RequestError> { async fn try_handle(&mut self) -> Result<(), HandlerError> {
let _ = self.recv_request().await?; let req_path = self.recv_request().await?;
let clients = self.get_clients().await?; let clients = self.get_clients().await?;
if self.includes_banned(&clients).await { if self.includes_banned(&clients).await {
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?; self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(()) return Ok(())
} }
let base = req_path == b"/creddy/base-credentials";
if base {
if clients.len() != 1
|| clients[0].is_none()
|| clients[0].as_ref().unwrap().exe != std::env::current_exe()?
{
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
return Ok(())
}
}
let req = Request {id: self.request_id, clients}; let req = Request {id: self.request_id, clients, base};
self.app.emit_all("credentials-request", &req)?; self.app.emit_all("credentials-request", &req)?;
let starting_visibility = self.show_window()?; let starting_visibility = self.show_window()?;
match self.wait_for_response().await? { match self.wait_for_response().await? {
Approval::Approved => self.send_credentials().await?, Approval::Approved => {
let state = self.app.state::<AppState>();
let creds = if base {
state.serialize_base_creds().await?
}
else {
state.serialize_session_creds().await?
};
self.send_body(creds.as_bytes()).await?;
},
Approval::Denied => { Approval::Denied => {
let state = self.app.state::<AppState>(); let state = self.app.state::<AppState>();
for client in req.clients { for client in req.clients {
state.add_ban(client).await; state.add_ban(client).await;
} }
self.send_body(b"Denied!").await?;
self.stream.shutdown().await?;
} }
} }
@ -83,30 +104,36 @@ impl Handler {
sleep(delay).await; sleep(delay).await;
if !starting_visibility && state.req_count().await == 0 { if !starting_visibility && state.req_count().await == 0 {
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?;
window.hide()?; window.hide()?;
} }
Ok(()) Ok(())
} }
async fn recv_request(&mut self) -> Result<Vec<u8>, RequestError> { async fn recv_request(&mut self) -> Result<Vec<u8>, HandlerError> {
let mut buf = vec![0; 8192]; // it's what tokio's BufReader uses let mut buf = vec![0; 8192]; // it's what tokio's BufReader uses
let mut n = 0; let mut n = 0;
loop { loop {
n += self.stream.read(&mut buf[n..]).await?; n += self.stream.read(&mut buf[n..]).await?;
if n >= 4 && &buf[(n - 4)..n] == b"\r\n\r\n" {break;} if n >= 4 && &buf[(n - 4)..n] == b"\r\n\r\n" {break;}
if n == buf.len() {return Err(RequestError::RequestTooLarge);} if n == buf.len() {return Err(HandlerError::RequestTooLarge);}
} }
if cfg!(debug_assertions) { let path = buf.split(|&c| &[c] == b" ")
.skip(1)
.next()
.ok_or(HandlerError::BadRequest(buf.clone()))?;
#[cfg(debug_assertions)] {
println!("Path: {}", std::str::from_utf8(&path).unwrap());
println!("{}", std::str::from_utf8(&buf).unwrap()); println!("{}", std::str::from_utf8(&buf).unwrap());
} }
Ok(buf) Ok(path.into())
} }
async fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> { async fn get_clients(&self) -> Result<Vec<Option<Client>>, HandlerError> {
let peer_addr = match self.stream.peer_addr()? { let peer_addr = match self.stream.peer_addr()? {
SocketAddr::V4(addr) => addr, SocketAddr::V4(addr) => addr,
_ => unreachable!(), // we only listen on IPv4 _ => unreachable!(), // we only listen on IPv4
@ -125,8 +152,8 @@ impl Handler {
false false
} }
fn show_window(&self) -> Result<bool, RequestError> { fn show_window(&self) -> Result<bool, HandlerError> {
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?; let window = self.app.get_window("main").ok_or(HandlerError::NoMainWindow)?;
let starting_visibility = window.is_visible()?; let starting_visibility = window.is_visible()?;
if !starting_visibility { if !starting_visibility {
window.unminimize()?; window.unminimize()?;
@ -136,7 +163,7 @@ impl Handler {
Ok(starting_visibility) Ok(starting_visibility)
} }
async fn wait_for_response(&mut self) -> Result<Approval, RequestError> { async fn wait_for_response(&mut self) -> Result<Approval, HandlerError> {
self.stream.write(b"HTTP/1.0 200 OK\r\n").await?; self.stream.write(b"HTTP/1.0 200 OK\r\n").await?;
self.stream.write(b"Content-Type: application/json\r\n").await?; self.stream.write(b"Content-Type: application/json\r\n").await?;
self.stream.write(b"X-Creddy-delaying-tactic: ").await?; self.stream.write(b"X-Creddy-delaying-tactic: ").await?;
@ -159,15 +186,12 @@ impl Handler {
} }
} }
async fn send_credentials(&mut self) -> Result<(), RequestError> { async fn send_body(&mut self, body: &[u8]) -> Result<(), HandlerError> {
let state = self.app.state::<AppState>();
let creds = state.serialize_session_creds().await?;
self.stream.write(b"\r\nContent-Length: ").await?; self.stream.write(b"\r\nContent-Length: ").await?;
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?; self.stream.write(body.len().to_string().as_bytes()).await?;
self.stream.write(b"\r\n\r\n").await?;
self.stream.write(creds.as_bytes()).await?;
self.stream.write(b"\r\n\r\n").await?; self.stream.write(b"\r\n\r\n").await?;
self.stream.write(body).await?;
self.stream.shutdown().await?;
Ok(()) Ok(())
} }
} }

View File

@ -1,90 +1,28 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::time::{ use std::time::Duration;
Duration,
SystemTime,
UNIX_EPOCH
};
use aws_smithy_types::date_time::{
DateTime as AwsDateTime,
Format as AwsDateTimeFormat,
};
use serde::{Serialize, Deserialize};
use tokio::{ use tokio::{
sync::oneshot::Sender, sync::oneshot::Sender,
sync::RwLock, sync::RwLock,
time::sleep, time::sleep,
}; };
use sqlx::SqlitePool; use sqlx::SqlitePool;
use sodiumoxide::crypto::{
pwhash,
pwhash::Salt,
secretbox,
secretbox::{Nonce, Key}
};
use tauri::async_runtime as runtime; use tauri::async_runtime as runtime;
use tauri::Manager; use tauri::Manager;
use serde::Serializer;
use crate::app::APP;
use crate::credentials::{
Session,
BaseCredentials,
SessionCredentials,
};
use crate::{config, config::AppConfig}; use crate::{config, config::AppConfig};
use crate::ipc; use crate::ipc::{self, Approval};
use crate::clientinfo::Client; use crate::clientinfo::Client;
use crate::errors::*; use crate::errors::*;
use crate::server::Server; use crate::server::Server;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct BaseCredentials {
access_key_id: String,
secret_access_key: String,
}
#[derive(Clone, Debug, Serialize)]
#[serde(rename_all = "PascalCase")]
pub struct SessionCredentials {
access_key_id: String,
secret_access_key: String,
token: String,
#[serde(serialize_with = "serialize_expiration")]
expiration: AwsDateTime,
}
impl SessionCredentials {
fn is_expired(&self) -> bool {
let current_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap() // doesn't panic because UNIX_EPOCH won't be later than now()
.as_secs();
let expire_ts = self.expiration.secs();
let remaining = expire_ts - (current_ts as i64);
remaining < 60
}
}
#[derive(Clone, Debug)]
pub struct LockedCredentials {
access_key_id: String,
secret_key_enc: Vec<u8>,
salt: Salt,
nonce: Nonce,
}
#[derive(Clone, Debug)]
pub enum Session {
Unlocked{
base: BaseCredentials,
session: SessionCredentials,
},
Locked(LockedCredentials),
Empty,
}
#[derive(Debug)] #[derive(Debug)]
pub struct AppState { pub struct AppState {
pub config: RwLock<AppConfig>, pub config: RwLock<AppConfig>,
@ -109,57 +47,11 @@ impl AppState {
} }
} }
pub async fn load_creds(pool: &SqlitePool) -> Result<Session, SetupError> { pub async fn new_creds(&self, base_creds: BaseCredentials, passphrase: &str) -> Result<(), UnlockError> {
let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc") let locked = base_creds.encrypt(passphrase);
.fetch_optional(pool)
.await?;
let row = match res {
Some(r) => r,
None => {return Ok(Session::Empty);}
};
let salt_buf: [u8; 32] = row.salt
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let nonce_buf: [u8; 24] = row.nonce
.try_into()
.map_err(|_e| SetupError::InvalidRecord)?;
let creds = LockedCredentials {
access_key_id: row.access_key_id,
secret_key_enc: row.secret_key_enc,
salt: Salt(salt_buf),
nonce: Nonce(nonce_buf),
};
Ok(Session::Locked(creds))
}
pub async fn save_creds(&self, creds: BaseCredentials, passphrase: &str) -> Result<(), UnlockError> {
let BaseCredentials {access_key_id, secret_access_key} = creds;
// do this first so that if it fails we don't save bad credentials // do this first so that if it fails we don't save bad credentials
self.new_session(&access_key_id, &secret_access_key).await?; self.new_session(base_creds).await?;
locked.save(&self.pool).await?;
let salt = pwhash::gen_salt();
let mut key_buf = [0; secretbox::KEYBYTES];
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &salt).unwrap();
let key = Key(key_buf);
// not sure we need both salt AND nonce given that we generate a
// fresh salt every time we encrypt, but better safe than sorry
let nonce = secretbox::gen_nonce();
let secret_key_enc = secretbox::seal(secret_access_key.as_bytes(), &nonce, &key);
sqlx::query(
"INSERT INTO credentials (access_key_id, secret_key_enc, salt, nonce, created_at)
VALUES (?, ?, ?, ?, strftime('%s'))"
)
.bind(&access_key_id)
.bind(&secret_key_enc)
.bind(&salt.0[0..])
.bind(&nonce.0[0..])
.execute(&self.pool)
.await?;
Ok(()) Ok(())
} }
@ -205,7 +97,10 @@ impl AppState {
} }
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> { pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
self.renew_session_if_expired().await?; if let Approval::Approved = response.approval {
let mut session = self.session.write().await;
session.renew_if_expired().await?;
}
let mut open_requests = self.open_requests.write().await; let mut open_requests = self.open_requests.write().await;
let chan = open_requests let chan = open_requests
@ -223,7 +118,7 @@ impl AppState {
runtime::spawn(async move { runtime::spawn(async move {
sleep(Duration::from_secs(5)).await; sleep(Duration::from_secs(5)).await;
let app = crate::APP.get().unwrap(); let app = APP.get().unwrap();
let state = app.state::<AppState>(); let state = app.state::<AppState>();
let mut bans = state.bans.write().await; let mut bans = state.bans.write().await;
bans.remove(&client); bans.remove(&client);
@ -235,46 +130,25 @@ impl AppState {
} }
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> { pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
let mut session = self.session.write().await; let base_creds = match *self.session.read().await {
let LockedCredentials {
access_key_id,
secret_key_enc,
salt,
nonce
} = match *session {
Session::Empty => {return Err(UnlockError::NoCredentials);}, Session::Empty => {return Err(UnlockError::NoCredentials);},
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);}, Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
Session::Locked(ref c) => c, Session::Locked(ref locked) => locked.decrypt(passphrase)?,
};
let mut key_buf = [0; secretbox::KEYBYTES];
// pretty sure this only fails if we're out of memory
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap();
let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf))
.map_err(|_e| UnlockError::BadPassphrase)?;
let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
let session_creds = self.new_session(access_key_id, &secret_access_key).await?;
*session = Session::Unlocked {
base: BaseCredentials {
access_key_id: access_key_id.clone(),
secret_access_key,
},
session: session_creds
}; };
// Read lock is dropped here, so this doesn't deadlock
self.new_session(base_creds).await?;
Ok(()) Ok(())
} }
// pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> { pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
// let session = self.session.read().await; let session = self.session.read().await;
// match *session { match *session {
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()), Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
// Session::Locked(_) => Err(GetCredentialsError::Locked), Session::Locked(_) => Err(GetCredentialsError::Locked),
// Session::Empty => Err(GetCredentialsError::Empty), Session::Empty => Err(GetCredentialsError::Empty),
// } }
// } }
pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> { pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
let session = self.session.read().await; let session = self.session.read().await;
@ -285,77 +159,10 @@ impl AppState {
} }
} }
async fn new_session(&self, key_id: &str, secret_key: &str) -> Result<SessionCredentials, GetSessionError> { async fn new_session(&self, base: BaseCredentials) -> Result<(), GetSessionError> {
let creds = aws_sdk_sts::Credentials::new( let session = SessionCredentials::from_base(&base).await?;
key_id, let mut app_session = self.session.write().await;
secret_key, *app_session = Session::Unlocked {base, session};
None, // token Ok(())
None, // expiration
"creddy", // "provider name" apparently
);
let config = aws_config::from_env()
.credentials_provider(creds)
.load()
.await;
let client = aws_sdk_sts::Client::new(&config);
let resp = client.get_session_token()
.duration_seconds(43_200)
.send()
.await?;
let aws_session = resp.credentials().ok_or(GetSessionError::EmptyResponse)?;
let access_key_id = aws_session.access_key_id()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let secret_access_key = aws_session.secret_access_key()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let token = aws_session.session_token()
.ok_or(GetSessionError::EmptyResponse)?
.to_string();
let expiration = aws_session.expiration()
.ok_or(GetSessionError::EmptyResponse)?
.clone();
let session_creds = SessionCredentials {
access_key_id,
secret_access_key,
token,
expiration,
};
#[cfg(debug_assertions)]
println!("Got new session:\n{}", serde_json::to_string(&session_creds).unwrap());
Ok(session_creds)
}
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
match *self.session.write().await {
Session::Unlocked{ref base, ref mut session} => {
if !session.is_expired() {
return Ok(false);
}
let new_session = self.new_session(
&base.access_key_id,
&base.secret_access_key
).await?;
*session = new_session;
Ok(true)
},
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
Session::Empty => Err(GetSessionError::CredentialsEmpty),
}
} }
} }
fn serialize_expiration<S>(exp: &AwsDateTime, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
// this only fails if the d/t is out of range, which it can't be for this format
let time_str = exp.fmt(AwsDateTimeFormat::DateTime).unwrap();
serializer.serialize_str(&time_str)
}

View File

@ -15,6 +15,7 @@ invoke('get_config').then(config => $appState.config = config);
listen('credentials-request', (tauriEvent) => { listen('credentials-request', (tauriEvent) => {
$appState.pendingRequests.put(tauriEvent.payload); $appState.pendingRequests.put(tauriEvent.payload);
}); });
window.state = $appState;
acceptRequest(); acceptRequest();
</script> </script>

View File

@ -68,7 +68,7 @@
<!-- Don't render at all if we're just going to immediately proceed to the next screen --> <!-- Don't render at all if we're just going to immediately proceed to the next screen -->
{#if !$appState.currentRequest.approval} {#if error || !$appState.currentRequest.approval}
<div class="flex flex-col space-y-4 p-4 m-auto max-w-xl h-screen items-center justify-center"> <div class="flex flex-col space-y-4 p-4 m-auto max-w-xl h-screen items-center justify-center">
{#if error} {#if error}
<ErrorAlert bind:this={alert}> <ErrorAlert bind:this={alert}>
@ -80,6 +80,18 @@
</ErrorAlert> </ErrorAlert>
{/if} {/if}
{#if $appState.currentRequest.base}
<div class="alert alert-warning shadow-lg">
<div>
<svg xmlns="http://www.w3.org/2000/svg" class="stroke-current flex-shrink-0 h-6 w-6" fill="none" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" /></svg>
<span>
WARNING: This application is requesting your base (long-lived) AWS credentials.
These crednetials are less secure than session credentials, since they don't expire automatically.
</span>
</div>
</div>
{/if}
<div class="space-y-1 mb-4"> <div class="space-y-1 mb-4">
<h2 class="text-xl font-bold">{appName ? `"${appName}"` : 'An appplication'} would like to access your AWS credentials.</h2> <h2 class="text-xl font-bold">{appName ? `"${appName}"` : 'An appplication'} would like to access your AWS credentials.</h2>