use std::fmt::{self, Formatter}; use chacha20poly1305::XNonce; use serde::{ Deserialize, Deserializer, Serialize, Serializer, }; use serde::ser::{ Error as SerError, SerializeStruct, }; use serde::de::{self, Visitor}; use sqlx::{ FromRow, Sqlite, SqlitePool, Transaction, types::Uuid, }; use ssh_agent_lib::proto::message::Identity; use ssh_key::{ Algorithm, LineEnding, private::PrivateKey, public::PublicKey, }; use tokio_stream::StreamExt; use crate::errors::*; use super::{ Credential, Crypto, PersistentCredential, }; #[derive(Debug, Clone, FromRow)] pub struct SshRow { id: Uuid, algorithm: String, comment: String, public_key: Vec, private_key_enc: Vec, nonce: Vec, } #[derive(Debug, Clone, Eq, PartialEq, Deserialize)] pub struct SshKey { #[serde(deserialize_with = "deserialize_algorithm")] pub algorithm: Algorithm, pub comment: String, #[serde(deserialize_with = "deserialize_pubkey")] pub public_key: PublicKey, #[serde(deserialize_with = "deserialize_privkey")] pub private_key: PrivateKey, } impl SshKey { pub fn from_file(path: &str, passphrase: &str) -> Result { let mut privkey = PrivateKey::read_openssh_file(path.as_ref())?; if privkey.is_encrypted() { privkey = privkey.decrypt(passphrase) .map_err(|_| LoadSshKeyError::InvalidPassphrase)?; } Ok(SshKey { algorithm: privkey.algorithm(), comment: privkey.comment().into(), public_key: privkey.public_key().clone(), private_key: privkey, }) } pub fn from_private_key(private_key: &str, passphrase: &str) -> Result { let mut privkey = PrivateKey::from_openssh(private_key)?; if privkey.is_encrypted() { privkey = privkey.decrypt(passphrase) .map_err(|_| LoadSshKeyError::InvalidPassphrase)?; } Ok(SshKey { algorithm: privkey.algorithm(), comment: privkey.comment().into(), public_key: privkey.public_key().clone(), private_key: privkey, }) } pub async fn name_from_pubkey(pubkey: &[u8], pool: &SqlitePool) -> Result { let row = sqlx::query!( "SELECT c.name FROM credentials c JOIN ssh_credentials s ON s.id = c.id WHERE s.public_key = ?", pubkey ).fetch_optional(pool) .await? .ok_or(LoadCredentialsError::NoCredentials)?; Ok(row.name) } pub async fn list_identities(pool: &SqlitePool) -> Result, LoadCredentialsError> { let mut rows = sqlx::query!( "SELECT public_key, comment FROM ssh_credentials" ).fetch(pool); let mut identities = Vec::new(); while let Some(row) = rows.try_next().await? { identities.push(Identity { pubkey_blob: row.public_key, comment: row.comment, }); } Ok(identities) } } impl PersistentCredential for SshKey { type Row = SshRow; fn type_name() -> &'static str { "ssh" } fn into_credential(self) -> Credential { Credential::Ssh(self) } fn row_id(row: &SshRow) -> Uuid { row.id } fn from_row(row: SshRow, crypto: &Crypto) -> Result { let nonce = XNonce::clone_from_slice(&row.nonce); let privkey_bytes = crypto.decrypt(&nonce, &row.private_key_enc)?; let algorithm = Algorithm::new(&row.algorithm) .map_err(|_| LoadCredentialsError::InvalidData)?; let public_key = PublicKey::from_bytes(&row.public_key) .map_err(|_| LoadCredentialsError::InvalidData)?; let private_key = PrivateKey::from_bytes(&privkey_bytes) .map_err(|_| LoadCredentialsError::InvalidData)?; Ok(SshKey { algorithm, comment: row.comment, public_key, private_key, }) } async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError> { let alg = self.algorithm.as_str(); let pubkey_bytes = self.public_key.to_bytes()?; let privkey_bytes = self.private_key.to_bytes()?; let (nonce, ciphertext) = crypto.encrypt(privkey_bytes.as_ref())?; let nonce_bytes = nonce.as_slice(); sqlx::query!( "INSERT OR REPLACE INTO ssh_credentials ( id, algorithm, comment, public_key, private_key_enc, nonce ) VALUES (?, ?, ?, ?, ?, ?)", id, alg, self.comment, pubkey_bytes, ciphertext, nonce_bytes, ).execute(&mut **txn).await?; Ok(()) } } impl Serialize for SshKey { fn serialize(&self, s: S) -> Result { let mut key = s.serialize_struct("SshKey", 5)?; key.serialize_field("algorithm", self.algorithm.as_str())?; key.serialize_field("comment", &self.comment)?; let pubkey_str = self.public_key.to_openssh() .map_err(|e| S::Error::custom(format!("Failed to encode SSH public key: {e}")))?; key.serialize_field("public_key", &pubkey_str)?; let privkey_str = self.private_key.to_openssh(LineEnding::LF) .map_err(|e| S::Error::custom(format!("Failed to encode SSH private key: {e}")))?; key.serialize_field::("private_key", privkey_str.as_ref())?; key.end() } } struct PubkeyVisitor; impl<'de> Visitor<'de> for PubkeyVisitor { type Value = PublicKey; fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { write!(formatter, "an OpenSSH-encoded public key, e.g. `ssh-rsa ...`") } fn visit_str(self, v: &str) -> Result { PublicKey::from_openssh(v) .map_err(|e| E::custom(format!("{e}"))) } } fn deserialize_pubkey<'de, D>(deserializer: D) -> Result where D: Deserializer<'de> { deserializer.deserialize_str(PubkeyVisitor) } struct PrivkeyVisitor; impl<'de> Visitor<'de> for PrivkeyVisitor { type Value = PrivateKey; fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { write!(formatter, "an OpenSSH-encoded private key") } fn visit_str(self, v: &str) -> Result { PrivateKey::from_openssh(v) .map_err(|e| E::custom(format!("{e}"))) } } fn deserialize_privkey<'de, D>(deserializer: D) -> Result where D: Deserializer<'de> { deserializer.deserialize_str(PrivkeyVisitor) } struct AlgorithmVisitor; impl<'de> Visitor<'de> for AlgorithmVisitor { type Value = Algorithm; fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { write!(formatter, "an SSH key algorithm identifier, e.g. `ssh-rsa`") } fn visit_str(self, v: &str) -> Result { Algorithm::new(v) .map_err(|e| E::custom(format!("{e}"))) } } fn deserialize_algorithm<'de, D>(deserializer: D) -> Result where D: Deserializer<'de> { deserializer.deserialize_str(AlgorithmVisitor) } #[cfg(test)] mod tests { use std::fs::{self, File}; use ssh_key::Fingerprint; use sqlx::types::uuid::uuid; use super::*; fn path(name: &str) -> String { format!("./src/credentials/fixtures/{name}") } fn random_uuid() -> Uuid { let bytes = Crypto::salt(); Uuid::from_slice(&bytes[..16]).unwrap() } fn rsa_plain() -> SshKey { SshKey::from_file(&path("ssh_rsa_plain"), "") .expect("Failed to load SSH key") } fn rsa_enc() -> SshKey { SshKey::from_file( &path("ssh_rsa_enc"), "correct horse battery staple" ).expect("Failed to load SSH key") } fn ed25519_plain() -> SshKey { SshKey::from_file(&path("ssh_ed25519_plain"), "") .expect("Failed to load SSH key") } fn ed25519_enc() -> SshKey { SshKey::from_file( &path("ssh_ed25519_enc"), "correct horse battery staple" ).expect("Failed to load SSH key") } #[test] fn test_from_file_rsa_plain() { let k = rsa_plain(); assert_eq!(k.algorithm.as_str(), "ssh-rsa"); assert_eq!(&k.comment, "hello world"); assert_eq!( k.public_key.fingerprint(Default::default()), k.private_key.fingerprint(Default::default()), ); assert_eq!( k.private_key.fingerprint(Default::default()).as_bytes(), [90,162,92,235,160,164,88,179,144,234,84,135,1,249,9,206, 201,172,233,129,82,11,145,191,186,144,209,43,81,119,197,18], ); } #[test] fn test_from_file_rsa_enc() { let k = rsa_enc(); assert_eq!(k.algorithm.as_str(), "ssh-rsa"); assert_eq!(&k.comment, "hello world"); assert_eq!( k.public_key.fingerprint(Default::default()), k.private_key.fingerprint(Default::default()), ); assert_eq!( k.private_key.fingerprint(Default::default()).as_bytes(), [254,147,219,185,96,234,125,190,195,128,37,243,214,193,8,162, 34,237,126,199,241,91,195,251,232,84,144,120,25,63,224,157], ); } #[test] fn test_from_file_ed25519_plain() { let k = ed25519_plain(); assert_eq!(k.algorithm.as_str(),"ssh-ed25519"); assert_eq!(&k.comment, "hello world"); assert_eq!( k.public_key.fingerprint(Default::default()), k.private_key.fingerprint(Default::default()), ); assert_eq!( k.private_key.fingerprint(Default::default()).as_bytes(), [29,30,193,72,239,167,35,89,1,206,126,186,123,112,78,187, 240,59,1,15,107,189,72,30,44,64,114,216,32,195,22,201], ); } #[test] fn test_from_file_ed25519_enc() { let k = ed25519_enc(); assert_eq!(k.algorithm.as_str(), "ssh-ed25519"); assert_eq!(&k.comment, "hello world"); assert_eq!( k.public_key.fingerprint(Default::default()), k.private_key.fingerprint(Default::default()), ); assert_eq!( k.private_key.fingerprint(Default::default()).as_bytes(), [87,233,161,170,18,47,245,116,30,177,120,211,248,54,65,255, 41,45,113,107,182,221,189,167,110,9,245,254,44,6,118,141], ); } #[test] fn test_serialize() { let expected = fs::read_to_string(path("ssh_ed25519_plain.json")).unwrap(); let k = ed25519_plain(); let computed = serde_json::to_string(&k) .expect("Failed to serialize SshKey"); assert_eq!(expected, computed); } #[test] fn test_deserialize() { let expected = ed25519_plain(); let json_file = File::open(path("ssh_ed25519_plain.json")).unwrap(); let computed = serde_json::from_reader(json_file) .expect("Failed to deserialize json file"); assert_eq!(expected, computed); } #[sqlx::test] async fn test_save_db(pool: SqlitePool) { let crypto = Crypto::random(); let k = rsa_plain(); let mut txn = pool.begin().await.unwrap(); k.save_details(&random_uuid(), &crypto, &mut txn).await .expect("Failed to save SSH key to database"); txn.commit().await.expect("Failed to finalize transaction"); } #[sqlx::test(fixtures("ssh_credentials"))] async fn test_load_db(pool: SqlitePool) { let crypto = Crypto::fixed(); let id = uuid!("11111111-1111-1111-1111-111111111111"); let k = SshKey::load(&id, &crypto, &pool).await .expect("Failed to load SSH key from database"); } #[sqlx::test] async fn test_save_load_db(pool: SqlitePool) { let crypto = Crypto::random(); let id = uuid!("7bc994dd-113a-4841-bcf7-b47c2fffdd25"); let known = ed25519_plain(); let mut txn = pool.begin().await.unwrap(); known.save_details(&id, &crypto, &mut txn).await.unwrap(); txn.commit().await.unwrap(); let loaded = SshKey::load(&id, &crypto, &pool).await.unwrap(); assert_eq!(known.algorithm, loaded.algorithm); assert_eq!(known.comment, loaded.comment); // comment gets stripped by saving as bytes, so we just compare raw key data assert_eq!(known.public_key.key_data(), loaded.public_key.key_data()); assert_eq!(known.private_key, loaded.private_key); } }