switch to tokio RwLock instead of std
This commit is contained in:
parent
96bbc2dbc2
commit
ddf865d0b4
@ -1,9 +1,13 @@
|
|||||||
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
|
use netstat2::{AddressFamilyFlags, ProtocolFlags, ProtocolSocketInfo};
|
||||||
|
use tauri::Manager;
|
||||||
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
|
use sysinfo::{System, SystemExt, Pid, PidExt, ProcessExt};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::errors::*;
|
use crate::{
|
||||||
use crate::get_state;
|
errors::*,
|
||||||
|
config::AppConfig,
|
||||||
|
state::AppState,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||||
@ -13,13 +17,18 @@ pub struct Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
|
async fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Error> {
|
||||||
|
let state = crate::APP.get().unwrap().state::<AppState>();
|
||||||
|
let AppConfig {
|
||||||
|
listen_addr: app_listen_addr,
|
||||||
|
listen_port: app_listen_port,
|
||||||
|
..
|
||||||
|
} = *state.config.read().await;
|
||||||
|
|
||||||
let sockets_iter = netstat2::iterate_sockets_info(
|
let sockets_iter = netstat2::iterate_sockets_info(
|
||||||
AddressFamilyFlags::IPV4,
|
AddressFamilyFlags::IPV4,
|
||||||
ProtocolFlags::TCP
|
ProtocolFlags::TCP
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
get_state!(config as app_config);
|
|
||||||
for item in sockets_iter {
|
for item in sockets_iter {
|
||||||
let sock_info = item?;
|
let sock_info = item?;
|
||||||
let proto_info = match sock_info.protocol_socket_info {
|
let proto_info = match sock_info.protocol_socket_info {
|
||||||
@ -28,9 +37,9 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
|
|||||||
};
|
};
|
||||||
|
|
||||||
if proto_info.local_port == local_port
|
if proto_info.local_port == local_port
|
||||||
&& proto_info.remote_port == app_config.listen_port
|
&& proto_info.remote_port == app_listen_port
|
||||||
&& proto_info.local_addr == app_config.listen_addr
|
&& proto_info.local_addr == app_listen_addr
|
||||||
&& proto_info.remote_addr == app_config.listen_addr
|
&& proto_info.remote_addr == app_listen_addr
|
||||||
{
|
{
|
||||||
return Ok(sock_info.associated_pids)
|
return Ok(sock_info.associated_pids)
|
||||||
}
|
}
|
||||||
@ -40,10 +49,10 @@ fn get_associated_pids(local_port: u16) -> Result<Vec<u32>, netstat2::error::Err
|
|||||||
|
|
||||||
|
|
||||||
// Theoretically, on some systems, multiple processes can share a socket
|
// Theoretically, on some systems, multiple processes can share a socket
|
||||||
pub fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
|
pub async fn get_clients(local_port: u16) -> Result<Vec<Option<Client>>, ClientInfoError> {
|
||||||
let mut clients = Vec::new();
|
let mut clients = Vec::new();
|
||||||
let mut sys = System::new();
|
let mut sys = System::new();
|
||||||
for p in get_associated_pids(local_port)? {
|
for p in get_associated_pids(local_port).await? {
|
||||||
let pid = Pid::from_u32(p);
|
let pid = Pid::from_u32(p);
|
||||||
sys.refresh_process(pid);
|
sys.refresh_process(pid);
|
||||||
let proc = sys.process(pid)
|
let proc = sys.process(pid)
|
||||||
|
@ -41,13 +41,14 @@ pub async fn unlock(passphrase: String, app_state: State<'_, AppState>) -> Resul
|
|||||||
|
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn get_session_status(app_state: State<'_, AppState>) -> String {
|
pub async fn get_session_status(app_state: State<'_, AppState>) -> Result<String, ()> {
|
||||||
let session = app_state.session.read().unwrap();
|
let session = app_state.session.read().await;
|
||||||
match *session {
|
let status = match *session {
|
||||||
Session::Locked(_) => "locked".into(),
|
Session::Locked(_) => "locked".into(),
|
||||||
Session::Unlocked{..} => "unlocked".into(),
|
Session::Unlocked{..} => "unlocked".into(),
|
||||||
Session::Empty => "empty".into()
|
Session::Empty => "empty".into()
|
||||||
}
|
};
|
||||||
|
Ok(status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -62,9 +63,9 @@ pub async fn save_credentials(
|
|||||||
|
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn get_config(app_state: State<'_, AppState>) -> AppConfig {
|
pub async fn get_config(app_state: State<'_, AppState>) -> Result<AppConfig, ()> {
|
||||||
let config = app_state.config.read().unwrap();
|
let config = app_state.config.read().await;
|
||||||
config.clone()
|
Ok(config.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,37 +97,3 @@ fn run() -> tauri::Result<()> {
|
|||||||
fn main() {
|
fn main() {
|
||||||
run().error_popup("Creddy failed to start");
|
run().error_popup("Creddy failed to start");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
macro_rules! get_state {
|
|
||||||
($prop:ident as $name:ident) => {
|
|
||||||
use tauri::Manager;
|
|
||||||
let app = crate::APP.get().unwrap(); // as long as the app is running, this is fine
|
|
||||||
let state = app.state::<crate::state::AppState>();
|
|
||||||
let $name = state.$prop.read().unwrap(); // only panics if another thread has already panicked
|
|
||||||
};
|
|
||||||
(config.$prop:ident as $name:ident) => {
|
|
||||||
use tauri::Manager;
|
|
||||||
let app = crate::APP.get().unwrap();
|
|
||||||
let state = app.state::<crate::state::AppState>();
|
|
||||||
let config = state.config.read().unwrap();
|
|
||||||
let $name = config.$prop;
|
|
||||||
};
|
|
||||||
|
|
||||||
(mut $prop:ident as $name:ident) => {
|
|
||||||
use tauri::Manager;
|
|
||||||
let app = crate::APP.get().unwrap();
|
|
||||||
let state = app.state::<crate::state::AppState>();
|
|
||||||
let $name = state.$prop.write().unwrap();
|
|
||||||
};
|
|
||||||
(mut config.$prop:ident as $name:ident) => {
|
|
||||||
use tauri::Manager;
|
|
||||||
let app = crate::APP.get().unwrap();
|
|
||||||
let state = app.state::<crate::state::AppState>();
|
|
||||||
let config = state.config.write().unwrap();
|
|
||||||
let $name = config.$prop;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
pub(crate) use get_state;
|
|
||||||
|
@ -4,7 +4,6 @@ use std::net::{
|
|||||||
Ipv4Addr,
|
Ipv4Addr,
|
||||||
SocketAddr,
|
SocketAddr,
|
||||||
SocketAddrV4,
|
SocketAddrV4,
|
||||||
TcpListener as StdTcpListener,
|
|
||||||
};
|
};
|
||||||
use tokio::net::{
|
use tokio::net::{
|
||||||
TcpListener,
|
TcpListener,
|
||||||
@ -32,10 +31,10 @@ struct Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Handler {
|
impl Handler {
|
||||||
fn new(stream: TcpStream, app: AppHandle) -> Self {
|
async fn new(stream: TcpStream, app: AppHandle) -> Self {
|
||||||
let state = app.state::<AppState>();
|
let state = app.state::<AppState>();
|
||||||
let (chan_send, chan_recv) = oneshot::channel();
|
let (chan_send, chan_recv) = oneshot::channel();
|
||||||
let request_id = state.register_request(chan_send);
|
let request_id = state.register_request(chan_send).await;
|
||||||
Handler {
|
Handler {
|
||||||
request_id,
|
request_id,
|
||||||
stream,
|
stream,
|
||||||
@ -49,13 +48,13 @@ impl Handler {
|
|||||||
eprintln!("{e}");
|
eprintln!("{e}");
|
||||||
}
|
}
|
||||||
let state = self.app.state::<AppState>();
|
let state = self.app.state::<AppState>();
|
||||||
state.unregister_request(self.request_id);
|
state.unregister_request(self.request_id).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn try_handle(&mut self) -> Result<(), RequestError> {
|
async fn try_handle(&mut self) -> Result<(), RequestError> {
|
||||||
let _ = self.recv_request().await?;
|
let _ = self.recv_request().await?;
|
||||||
let clients = self.get_clients()?;
|
let clients = self.get_clients().await?;
|
||||||
if self.includes_banned(&clients) {
|
if self.includes_banned(&clients).await {
|
||||||
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
|
self.stream.write(b"HTTP/1.0 403 Access Denied\r\n\r\n").await?;
|
||||||
return Ok(())
|
return Ok(())
|
||||||
}
|
}
|
||||||
@ -69,7 +68,7 @@ impl Handler {
|
|||||||
Approval::Denied => {
|
Approval::Denied => {
|
||||||
let state = self.app.state::<AppState>();
|
let state = self.app.state::<AppState>();
|
||||||
for client in req.clients {
|
for client in req.clients {
|
||||||
state.add_ban(client, self.app.clone());
|
state.add_ban(client).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,12 +77,12 @@ impl Handler {
|
|||||||
// and b) there are no other pending requests
|
// and b) there are no other pending requests
|
||||||
let state = self.app.state::<AppState>();
|
let state = self.app.state::<AppState>();
|
||||||
let delay = {
|
let delay = {
|
||||||
let config = state.config.read().unwrap();
|
let config = state.config.read().await;
|
||||||
Duration::from_millis(config.rehide_ms)
|
Duration::from_millis(config.rehide_ms)
|
||||||
};
|
};
|
||||||
sleep(delay).await;
|
sleep(delay).await;
|
||||||
|
|
||||||
if !starting_visibility && state.req_count() == 0 {
|
if !starting_visibility && state.req_count().await == 0 {
|
||||||
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
|
let window = self.app.get_window("main").ok_or(RequestError::NoMainWindow)?;
|
||||||
window.hide()?;
|
window.hide()?;
|
||||||
}
|
}
|
||||||
@ -107,18 +106,23 @@ impl Handler {
|
|||||||
Ok(buf)
|
Ok(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
|
async fn get_clients(&self) -> Result<Vec<Option<Client>>, RequestError> {
|
||||||
let peer_addr = match self.stream.peer_addr()? {
|
let peer_addr = match self.stream.peer_addr()? {
|
||||||
SocketAddr::V4(addr) => addr,
|
SocketAddr::V4(addr) => addr,
|
||||||
_ => unreachable!(), // we only listen on IPv4
|
_ => unreachable!(), // we only listen on IPv4
|
||||||
};
|
};
|
||||||
let clients = clientinfo::get_clients(peer_addr.port())?;
|
let clients = clientinfo::get_clients(peer_addr.port()).await?;
|
||||||
Ok(clients)
|
Ok(clients)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
|
async fn includes_banned(&self, clients: &Vec<Option<Client>>) -> bool {
|
||||||
let state = self.app.state::<AppState>();
|
let state = self.app.state::<AppState>();
|
||||||
clients.iter().any(|c| state.is_banned(c))
|
for client in clients {
|
||||||
|
if state.is_banned(client).await {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
fn show_window(&self) -> Result<bool, RequestError> {
|
fn show_window(&self) -> Result<bool, RequestError> {
|
||||||
@ -157,7 +161,7 @@ impl Handler {
|
|||||||
|
|
||||||
async fn send_credentials(&mut self) -> Result<(), RequestError> {
|
async fn send_credentials(&mut self) -> Result<(), RequestError> {
|
||||||
let state = self.app.state::<AppState>();
|
let state = self.app.state::<AppState>();
|
||||||
let creds = state.serialize_session_creds()?;
|
let creds = state.serialize_session_creds().await?;
|
||||||
|
|
||||||
self.stream.write(b"\r\nContent-Length: ").await?;
|
self.stream.write(b"\r\nContent-Length: ").await?;
|
||||||
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
|
self.stream.write(creds.as_bytes().len().to_string().as_bytes()).await?;
|
||||||
@ -180,39 +184,39 @@ pub struct Server {
|
|||||||
|
|
||||||
impl Server {
|
impl Server {
|
||||||
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
|
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
|
||||||
// construct the listener before passing it to the task so that we know if it fails
|
let task = Self::start_server(addr, port, app_handle.app_handle()).await?;
|
||||||
let sock_addr = SocketAddrV4::new(addr, port);
|
|
||||||
let listener = TcpListener::bind(&sock_addr).await?;
|
|
||||||
let task = rt::spawn(Self::serve(listener, app_handle.app_handle()));
|
|
||||||
|
|
||||||
Ok(Server { addr, port, app_handle, task})
|
Ok(Server { addr, port, app_handle, task})
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is blocking because it's too much of a paint to juggle mutexes otherwise
|
pub async fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
|
||||||
pub fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
|
|
||||||
if addr == self.addr && port == self.port {
|
if addr == self.addr && port == self.port {
|
||||||
return Ok(())
|
return Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
let sock_addr = SocketAddrV4::new(addr, port);
|
let new_task = Self::start_server(addr, port, self.app_handle.app_handle()).await?;
|
||||||
let std_listener = StdTcpListener::bind(&sock_addr)?;
|
|
||||||
std_listener.set_nonblocking(true)?;
|
|
||||||
let async_listener = TcpListener::from_std(std_listener)?;
|
|
||||||
let new_task = rt::spawn(Self::serve(async_listener, self.app_handle.app_handle()));
|
|
||||||
self.task.abort();
|
self.task.abort();
|
||||||
|
|
||||||
self.addr = addr;
|
self.addr = addr;
|
||||||
self.port = port;
|
self.port = port;
|
||||||
self.task = new_task;
|
self.task = new_task;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// construct the listener before spawning the task so that we can return early if it fails
|
||||||
|
async fn start_server(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<JoinHandle<()>> {
|
||||||
|
let sock_addr = SocketAddrV4::new(addr, port);
|
||||||
|
let listener = TcpListener::bind(&sock_addr).await?;
|
||||||
|
let task = rt::spawn(
|
||||||
|
Self::serve(listener, app_handle.app_handle())
|
||||||
|
);
|
||||||
|
Ok(task)
|
||||||
|
}
|
||||||
|
|
||||||
async fn serve(listener: TcpListener, app_handle: AppHandle) {
|
async fn serve(listener: TcpListener, app_handle: AppHandle) {
|
||||||
loop {
|
loop {
|
||||||
match listener.accept().await {
|
match listener.accept().await {
|
||||||
Ok((stream, _)) => {
|
Ok((stream, _)) => {
|
||||||
let handler = Handler::new(stream, app_handle.app_handle());
|
let handler = Handler::new(stream, app_handle.app_handle()).await;
|
||||||
rt::spawn(handler.handle());
|
rt::spawn(handler.handle());
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::RwLock;
|
|
||||||
use std::time::{
|
use std::time::{
|
||||||
Duration,
|
Duration,
|
||||||
SystemTime,
|
SystemTime,
|
||||||
@ -12,8 +11,11 @@ use aws_smithy_types::date_time::{
|
|||||||
Format as AwsDateTimeFormat,
|
Format as AwsDateTimeFormat,
|
||||||
};
|
};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::{
|
||||||
use tokio::time::sleep;
|
sync::oneshot::Sender,
|
||||||
|
sync::RwLock,
|
||||||
|
time::sleep,
|
||||||
|
};
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use sodiumoxide::crypto::{
|
use sodiumoxide::crypto::{
|
||||||
pwhash,
|
pwhash,
|
||||||
@ -163,51 +165,49 @@ impl AppState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
|
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
|
||||||
|
let mut live_config = self.config.write().await;
|
||||||
|
|
||||||
|
if new_config.start_on_login != live_config.start_on_login {
|
||||||
|
config::set_auto_launch(new_config.start_on_login)?;
|
||||||
|
}
|
||||||
|
if new_config.listen_addr != live_config.listen_addr
|
||||||
|
|| new_config.listen_port != live_config.listen_port
|
||||||
{
|
{
|
||||||
let orig_config = self.config.read().unwrap();
|
let mut sv = self.server.write().await;
|
||||||
if new_config.start_on_login != orig_config.start_on_login {
|
sv.rebind(new_config.listen_addr, new_config.listen_port).await?;
|
||||||
config::set_auto_launch(new_config.start_on_login)?;
|
|
||||||
}
|
|
||||||
if new_config.listen_addr != orig_config.listen_addr
|
|
||||||
|| new_config.listen_port != orig_config.listen_port
|
|
||||||
{
|
|
||||||
let mut sv = self.server.write().unwrap();
|
|
||||||
sv.rebind(new_config.listen_addr, new_config.listen_port)?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
new_config.save(&self.pool).await?;
|
new_config.save(&self.pool).await?;
|
||||||
let mut live_config = self.config.write().unwrap();
|
|
||||||
*live_config = new_config;
|
*live_config = new_config;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
|
pub async fn register_request(&self, chan: Sender<ipc::Approval>) -> u64 {
|
||||||
let count = {
|
let count = {
|
||||||
let mut c = self.request_count.write().unwrap();
|
let mut c = self.request_count.write().await;
|
||||||
*c += 1;
|
*c += 1;
|
||||||
c
|
c
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut open_requests = self.open_requests.write().unwrap();
|
let mut open_requests = self.open_requests.write().await;
|
||||||
open_requests.insert(*count, chan); // `count` is the request id
|
open_requests.insert(*count, chan); // `count` is the request id
|
||||||
*count
|
*count
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unregister_request(&self, id: u64) {
|
pub async fn unregister_request(&self, id: u64) {
|
||||||
let mut open_requests = self.open_requests.write().unwrap();
|
let mut open_requests = self.open_requests.write().await;
|
||||||
open_requests.remove(&id);
|
open_requests.remove(&id);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn req_count(&self) -> usize {
|
pub async fn req_count(&self) -> usize {
|
||||||
let open_requests = self.open_requests.read().unwrap();
|
let open_requests = self.open_requests.read().await;
|
||||||
open_requests.len()
|
open_requests.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
|
pub async fn send_response(&self, response: ipc::RequestResponse) -> Result<(), SendResponseError> {
|
||||||
self.renew_session_if_expired().await?;
|
self.renew_session_if_expired().await?;
|
||||||
|
|
||||||
let mut open_requests = self.open_requests.write().unwrap();
|
let mut open_requests = self.open_requests.write().await;
|
||||||
let chan = open_requests
|
let chan = open_requests
|
||||||
.remove(&response.id)
|
.remove(&response.id)
|
||||||
.ok_or(SendResponseError::NotFound)
|
.ok_or(SendResponseError::NotFound)
|
||||||
@ -217,54 +217,58 @@ impl AppState {
|
|||||||
.map_err(|_e| SendResponseError::Abandoned)
|
.map_err(|_e| SendResponseError::Abandoned)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_ban(&self, client: Option<Client>, app: tauri::AppHandle) {
|
pub async fn add_ban(&self, client: Option<Client>) {
|
||||||
let mut bans = self.bans.write().unwrap();
|
let mut bans = self.bans.write().await;
|
||||||
bans.insert(client.clone());
|
bans.insert(client.clone());
|
||||||
|
|
||||||
runtime::spawn(async move {
|
runtime::spawn(async move {
|
||||||
sleep(Duration::from_secs(5)).await;
|
sleep(Duration::from_secs(5)).await;
|
||||||
|
let app = crate::APP.get().unwrap();
|
||||||
let state = app.state::<AppState>();
|
let state = app.state::<AppState>();
|
||||||
let mut bans = state.bans.write().unwrap();
|
let mut bans = state.bans.write().await;
|
||||||
bans.remove(&client);
|
bans.remove(&client);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_banned(&self, client: &Option<Client>) -> bool {
|
pub async fn is_banned(&self, client: &Option<Client>) -> bool {
|
||||||
self.bans.read().unwrap().contains(&client)
|
self.bans.read().await.contains(&client)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
|
pub async fn unlock(&self, passphrase: &str) -> Result<(), UnlockError> {
|
||||||
let (access_key_id, secret_access_key) = {
|
let mut session = self.session.write().await;
|
||||||
// do this all in a block so that we aren't holding a lock across an await
|
let LockedCredentials {
|
||||||
let session = self.session.read().unwrap();
|
access_key_id,
|
||||||
let locked = match *session {
|
secret_key_enc,
|
||||||
Session::Empty => {return Err(UnlockError::NoCredentials);},
|
salt,
|
||||||
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
|
nonce
|
||||||
Session::Locked(ref c) => c,
|
} = match *session {
|
||||||
};
|
Session::Empty => {return Err(UnlockError::NoCredentials);},
|
||||||
|
Session::Unlocked{..} => {return Err(UnlockError::NotLocked);},
|
||||||
let mut key_buf = [0; secretbox::KEYBYTES];
|
Session::Locked(ref c) => c,
|
||||||
// pretty sure this only fails if we're out of memory
|
|
||||||
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), &locked.salt).unwrap();
|
|
||||||
let decrypted = secretbox::open(&locked.secret_key_enc, &locked.nonce, &Key(key_buf))
|
|
||||||
.map_err(|_e| UnlockError::BadPassphrase)?;
|
|
||||||
|
|
||||||
let secret_str = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
|
|
||||||
(locked.access_key_id.clone(), secret_str)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let session_creds = self.new_session(&access_key_id, &secret_access_key).await?;
|
let mut key_buf = [0; secretbox::KEYBYTES];
|
||||||
let mut app_session = self.session.write().unwrap();
|
// pretty sure this only fails if we're out of memory
|
||||||
*app_session = Session::Unlocked {
|
pwhash::derive_key_interactive(&mut key_buf, passphrase.as_bytes(), salt).unwrap();
|
||||||
base: BaseCredentials {access_key_id, secret_access_key},
|
let decrypted = secretbox::open(secret_key_enc, nonce, &Key(key_buf))
|
||||||
|
.map_err(|_e| UnlockError::BadPassphrase)?;
|
||||||
|
|
||||||
|
let secret_access_key = String::from_utf8(decrypted).map_err(|_e| UnlockError::InvalidUtf8)?;
|
||||||
|
|
||||||
|
let session_creds = self.new_session(access_key_id, &secret_access_key).await?;
|
||||||
|
*session = Session::Unlocked {
|
||||||
|
base: BaseCredentials {
|
||||||
|
access_key_id: access_key_id.clone(),
|
||||||
|
secret_access_key,
|
||||||
|
},
|
||||||
session: session_creds
|
session: session_creds
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
|
// pub async fn serialize_base_creds(&self) -> Result<String, GetCredentialsError> {
|
||||||
// let session = self.session.read().unwrap();
|
// let session = self.session.read().await;
|
||||||
// match *session {
|
// match *session {
|
||||||
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
|
// Session::Unlocked{ref base, ..} => Ok(serde_json::to_string(base).unwrap()),
|
||||||
// Session::Locked(_) => Err(GetCredentialsError::Locked),
|
// Session::Locked(_) => Err(GetCredentialsError::Locked),
|
||||||
@ -272,8 +276,8 @@ impl AppState {
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
pub fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
|
pub async fn serialize_session_creds(&self) -> Result<String, GetCredentialsError> {
|
||||||
let session = self.session.read().unwrap();
|
let session = self.session.read().await;
|
||||||
match *session {
|
match *session {
|
||||||
Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()),
|
Session::Unlocked{ref session, ..} => Ok(serde_json::to_string(session).unwrap()),
|
||||||
Session::Locked(_) => Err(GetCredentialsError::Locked),
|
Session::Locked(_) => Err(GetCredentialsError::Locked),
|
||||||
@ -329,41 +333,21 @@ impl AppState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
|
pub async fn renew_session_if_expired(&self) -> Result<bool, GetSessionError> {
|
||||||
let base = {
|
match *self.session.write().await {
|
||||||
let session = self.session.read().unwrap();
|
Session::Unlocked{ref base, ref mut session} => {
|
||||||
match *session {
|
|
||||||
Session::Unlocked{ref base, ..} => base.clone(),
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let new_session = self.new_session(&base.access_key_id, &base.secret_access_key).await?;
|
|
||||||
match *self.session.write().unwrap() {
|
|
||||||
Session::Unlocked{ref mut session, ..} => {
|
|
||||||
if !session.is_expired() {
|
if !session.is_expired() {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
let new_session = self.new_session(
|
||||||
|
&base.access_key_id,
|
||||||
|
&base.secret_access_key
|
||||||
|
).await?;
|
||||||
*session = new_session;
|
*session = new_session;
|
||||||
Ok(true)
|
Ok(true)
|
||||||
},
|
},
|
||||||
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
||||||
Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
||||||
}
|
}
|
||||||
|
|
||||||
// match *self.session.write().unwrap() {
|
|
||||||
// Session::Unlocked{ref base, ref mut session} => {
|
|
||||||
// if !session.is_expired() {
|
|
||||||
// return Ok(false);
|
|
||||||
// }
|
|
||||||
// let new_session = self.new_session(
|
|
||||||
// &base.access_key_id,
|
|
||||||
// &base.secret_access_key
|
|
||||||
// ).await?;
|
|
||||||
// *session = new_session;
|
|
||||||
// Ok(true)
|
|
||||||
// },
|
|
||||||
// Session::Locked(_) => Err(GetSessionError::CredentialsLocked),
|
|
||||||
// Session::Empty => Err(GetSessionError::CredentialsEmpty),
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user