440 lines
13 KiB
Rust

use std::fmt::{self, Formatter};
use chacha20poly1305::XNonce;
use serde::{
Deserialize,
Deserializer,
Serialize,
Serializer,
};
use serde::ser::{
Error as SerError,
SerializeStruct,
};
use serde::de::{self, Visitor};
use sqlx::{
FromRow,
Sqlite,
SqlitePool,
Transaction,
types::Uuid,
};
use ssh_agent_lib::proto::message::Identity;
use ssh_key::{
Algorithm,
LineEnding,
private::PrivateKey,
public::PublicKey,
};
use tokio_stream::StreamExt;
use crate::errors::*;
use super::{
Credential,
Crypto,
PersistentCredential,
};
#[derive(Debug, Clone, FromRow)]
pub struct SshRow {
id: Uuid,
algorithm: String,
comment: String,
public_key: Vec<u8>,
private_key_enc: Vec<u8>,
nonce: Vec<u8>,
}
#[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
pub struct SshKey {
#[serde(deserialize_with = "deserialize_algorithm")]
pub algorithm: Algorithm,
pub comment: String,
#[serde(deserialize_with = "deserialize_pubkey")]
pub public_key: PublicKey,
#[serde(deserialize_with = "deserialize_privkey")]
pub private_key: PrivateKey,
}
impl SshKey {
pub fn from_file(path: &str, passphrase: &str) -> Result<SshKey, LoadSshKeyError> {
let mut privkey = PrivateKey::read_openssh_file(path.as_ref())?;
if privkey.is_encrypted() {
privkey = privkey.decrypt(passphrase)
.map_err(|_| LoadSshKeyError::InvalidPassphrase)?;
}
Ok(SshKey {
algorithm: privkey.algorithm(),
comment: privkey.comment().into(),
public_key: privkey.public_key().clone(),
private_key: privkey,
})
}
pub fn from_private_key(private_key: &str, passphrase: &str) -> Result<SshKey, LoadSshKeyError> {
let mut privkey = PrivateKey::from_openssh(private_key)?;
if privkey.is_encrypted() {
privkey = privkey.decrypt(passphrase)
.map_err(|_| LoadSshKeyError::InvalidPassphrase)?;
}
Ok(SshKey {
algorithm: privkey.algorithm(),
comment: privkey.comment().into(),
public_key: privkey.public_key().clone(),
private_key: privkey,
})
}
pub async fn name_from_pubkey(pubkey: &[u8], pool: &SqlitePool) -> Result<String, LoadCredentialsError> {
let row = sqlx::query!(
"SELECT c.name
FROM credentials c
JOIN ssh_credentials s
ON s.id = c.id
WHERE s.public_key = ?",
pubkey
).fetch_optional(pool)
.await?
.ok_or(LoadCredentialsError::NoCredentials)?;
Ok(row.name)
}
pub async fn list_identities(pool: &SqlitePool) -> Result<Vec<Identity>, LoadCredentialsError> {
let mut rows = sqlx::query!(
"SELECT public_key, comment FROM ssh_credentials"
).fetch(pool);
let mut identities = Vec::new();
while let Some(row) = rows.try_next().await? {
identities.push(Identity {
pubkey_blob: row.public_key,
comment: row.comment,
});
}
Ok(identities)
}
}
impl PersistentCredential for SshKey {
type Row = SshRow;
fn type_name() -> &'static str { "ssh" }
fn into_credential(self) -> Credential { Credential::Ssh(self) }
fn row_id(row: &SshRow) -> Uuid { row.id }
fn from_row(row: SshRow, crypto: &Crypto) -> Result<Self, LoadCredentialsError> {
let nonce = XNonce::clone_from_slice(&row.nonce);
let privkey_bytes = crypto.decrypt(&nonce, &row.private_key_enc)?;
let algorithm = Algorithm::new(&row.algorithm)
.map_err(|_| LoadCredentialsError::InvalidData)?;
let public_key = PublicKey::from_bytes(&row.public_key)
.map_err(|_| LoadCredentialsError::InvalidData)?;
let private_key = PrivateKey::from_bytes(&privkey_bytes)
.map_err(|_| LoadCredentialsError::InvalidData)?;
Ok(SshKey {
algorithm,
comment: row.comment,
public_key,
private_key,
})
}
async fn save_details(&self, id: &Uuid, crypto: &Crypto, txn: &mut Transaction<'_, Sqlite>) -> Result<(), SaveCredentialsError> {
let alg = self.algorithm.as_str();
let pubkey_bytes = self.public_key.to_bytes()?;
let privkey_bytes = self.private_key.to_bytes()?;
let (nonce, ciphertext) = crypto.encrypt(privkey_bytes.as_ref())?;
let nonce_bytes = nonce.as_slice();
sqlx::query!(
"INSERT OR REPLACE INTO ssh_credentials (
id,
algorithm,
comment,
public_key,
private_key_enc,
nonce
)
VALUES (?, ?, ?, ?, ?, ?)",
id, alg, self.comment, pubkey_bytes, ciphertext, nonce_bytes,
).execute(&mut **txn).await?;
Ok(())
}
}
impl Serialize for SshKey {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
let mut key = s.serialize_struct("SshKey", 5)?;
key.serialize_field("algorithm", self.algorithm.as_str())?;
key.serialize_field("comment", &self.comment)?;
let pubkey_str = self.public_key.to_openssh()
.map_err(|e| S::Error::custom(format!("Failed to encode SSH public key: {e}")))?;
key.serialize_field("public_key", &pubkey_str)?;
let privkey_str = self.private_key.to_openssh(LineEnding::LF)
.map_err(|e| S::Error::custom(format!("Failed to encode SSH private key: {e}")))?;
key.serialize_field::<str>("private_key", privkey_str.as_ref())?;
key.end()
}
}
struct PubkeyVisitor;
impl<'de> Visitor<'de> for PubkeyVisitor {
type Value = PublicKey;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "an OpenSSH-encoded public key, e.g. `ssh-rsa ...`")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
PublicKey::from_openssh(v)
.map_err(|e| E::custom(format!("{e}")))
}
}
fn deserialize_pubkey<'de, D>(deserializer: D) -> Result<PublicKey, D::Error>
where D: Deserializer<'de>
{
deserializer.deserialize_str(PubkeyVisitor)
}
struct PrivkeyVisitor;
impl<'de> Visitor<'de> for PrivkeyVisitor {
type Value = PrivateKey;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "an OpenSSH-encoded private key")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
PrivateKey::from_openssh(v)
.map_err(|e| E::custom(format!("{e}")))
}
}
fn deserialize_privkey<'de, D>(deserializer: D) -> Result<PrivateKey, D::Error>
where D: Deserializer<'de>
{
deserializer.deserialize_str(PrivkeyVisitor)
}
struct AlgorithmVisitor;
impl<'de> Visitor<'de> for AlgorithmVisitor {
type Value = Algorithm;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "an SSH key algorithm identifier, e.g. `ssh-rsa`")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
Algorithm::new(v)
.map_err(|e| E::custom(format!("{e}")))
}
}
fn deserialize_algorithm<'de, D>(deserializer: D) -> Result<Algorithm, D::Error>
where D: Deserializer<'de>
{
deserializer.deserialize_str(AlgorithmVisitor)
}
#[cfg(test)]
mod tests {
use std::fs::{self, File};
use ssh_key::Fingerprint;
use sqlx::types::uuid::uuid;
use super::*;
fn path(name: &str) -> String {
format!("./src/credentials/fixtures/{name}")
}
fn random_uuid() -> Uuid {
let bytes = Crypto::salt();
Uuid::from_slice(&bytes[..16]).unwrap()
}
fn rsa_plain() -> SshKey {
SshKey::from_file(&path("ssh_rsa_plain"), "")
.expect("Failed to load SSH key")
}
fn rsa_enc() -> SshKey {
SshKey::from_file(
&path("ssh_rsa_enc"),
"correct horse battery staple"
).expect("Failed to load SSH key")
}
fn ed25519_plain() -> SshKey {
SshKey::from_file(&path("ssh_ed25519_plain"), "")
.expect("Failed to load SSH key")
}
fn ed25519_enc() -> SshKey {
SshKey::from_file(
&path("ssh_ed25519_enc"),
"correct horse battery staple"
).expect("Failed to load SSH key")
}
#[test]
fn test_from_file_rsa_plain() {
let k = rsa_plain();
assert_eq!(k.algorithm.as_str(), "ssh-rsa");
assert_eq!(&k.comment, "hello world");
assert_eq!(
k.public_key.fingerprint(Default::default()),
k.private_key.fingerprint(Default::default()),
);
assert_eq!(
k.private_key.fingerprint(Default::default()).as_bytes(),
[90,162,92,235,160,164,88,179,144,234,84,135,1,249,9,206,
201,172,233,129,82,11,145,191,186,144,209,43,81,119,197,18],
);
}
#[test]
fn test_from_file_rsa_enc() {
let k = rsa_enc();
assert_eq!(k.algorithm.as_str(), "ssh-rsa");
assert_eq!(&k.comment, "hello world");
assert_eq!(
k.public_key.fingerprint(Default::default()),
k.private_key.fingerprint(Default::default()),
);
assert_eq!(
k.private_key.fingerprint(Default::default()).as_bytes(),
[254,147,219,185,96,234,125,190,195,128,37,243,214,193,8,162,
34,237,126,199,241,91,195,251,232,84,144,120,25,63,224,157],
);
}
#[test]
fn test_from_file_ed25519_plain() {
let k = ed25519_plain();
assert_eq!(k.algorithm.as_str(),"ssh-ed25519");
assert_eq!(&k.comment, "hello world");
assert_eq!(
k.public_key.fingerprint(Default::default()),
k.private_key.fingerprint(Default::default()),
);
assert_eq!(
k.private_key.fingerprint(Default::default()).as_bytes(),
[29,30,193,72,239,167,35,89,1,206,126,186,123,112,78,187,
240,59,1,15,107,189,72,30,44,64,114,216,32,195,22,201],
);
}
#[test]
fn test_from_file_ed25519_enc() {
let k = ed25519_enc();
assert_eq!(k.algorithm.as_str(), "ssh-ed25519");
assert_eq!(&k.comment, "hello world");
assert_eq!(
k.public_key.fingerprint(Default::default()),
k.private_key.fingerprint(Default::default()),
);
assert_eq!(
k.private_key.fingerprint(Default::default()).as_bytes(),
[87,233,161,170,18,47,245,116,30,177,120,211,248,54,65,255,
41,45,113,107,182,221,189,167,110,9,245,254,44,6,118,141],
);
}
#[test]
fn test_serialize() {
let expected = fs::read_to_string(path("ssh_ed25519_plain.json")).unwrap();
let k = ed25519_plain();
let computed = serde_json::to_string(&k)
.expect("Failed to serialize SshKey");
assert_eq!(expected, computed);
}
#[test]
fn test_deserialize() {
let expected = ed25519_plain();
let json_file = File::open(path("ssh_ed25519_plain.json")).unwrap();
let computed = serde_json::from_reader(json_file)
.expect("Failed to deserialize json file");
assert_eq!(expected, computed);
}
#[sqlx::test]
async fn test_save_db(pool: SqlitePool) {
let crypto = Crypto::random();
let k = rsa_plain();
let mut txn = pool.begin().await.unwrap();
k.save_details(&random_uuid(), &crypto, &mut txn).await
.expect("Failed to save SSH key to database");
txn.commit().await.expect("Failed to finalize transaction");
}
#[sqlx::test(fixtures("ssh_credentials"))]
async fn test_load_db(pool: SqlitePool) {
let crypto = Crypto::fixed();
let id = uuid!("11111111-1111-1111-1111-111111111111");
let k = SshKey::load(&id, &crypto, &pool).await
.expect("Failed to load SSH key from database");
}
#[sqlx::test]
async fn test_save_load_db(pool: SqlitePool) {
let crypto = Crypto::random();
let id = uuid!("7bc994dd-113a-4841-bcf7-b47c2fffdd25");
let known = ed25519_plain();
let mut txn = pool.begin().await.unwrap();
known.save_details(&id, &crypto, &mut txn).await.unwrap();
txn.commit().await.unwrap();
let loaded = SshKey::load(&id, &crypto, &pool).await.unwrap();
assert_eq!(known.algorithm, loaded.algorithm);
assert_eq!(known.comment, loaded.comment);
// comment gets stripped by saving as bytes, so we just compare raw key data
assert_eq!(known.public_key.key_data(), loaded.public_key.key_data());
assert_eq!(known.private_key, loaded.private_key);
}
}