Compare commits

..

2 Commits

Author SHA1 Message Date
c5dcc2e50a handle errors on config update 2023-04-28 14:33:23 -07:00
70d71ce14e restart listener when config changes 2023-04-28 14:33:04 -07:00
11 changed files with 224 additions and 88 deletions

View File

@ -66,7 +66,7 @@ impl AppConfig {
} }
pub fn set_auto_launch(enable: bool) -> Result<(), SetupError> { pub fn set_auto_launch(is_configured: bool) -> Result<(), SetupError> {
let path_buf = std::env::current_exe() let path_buf = std::env::current_exe()
.map_err(|e| auto_launch::Error::Io(e))?; .map_err(|e| auto_launch::Error::Io(e))?;
let path = path_buf let path = path_buf
@ -75,13 +75,14 @@ pub fn set_auto_launch(enable: bool) -> Result<(), SetupError> {
let auto = AutoLaunchBuilder::new() let auto = AutoLaunchBuilder::new()
.set_app_name("Creddy") .set_app_name("Creddy")
.set_app_path(&path) .set_app_path(&path)
.build()?; .build().expect("Failed to build");
if enable { let is_enabled = auto.is_enabled()?;
auto.enable()?; if is_configured && !is_enabled {
auto.enable().expect("Failed to enable");
} }
else { else if !is_configured && is_enabled {
auto.disable()?; auto.disable().expect("Failed to disable");
} }
Ok(()) Ok(())

View File

@ -89,6 +89,8 @@ pub enum SetupError {
ConfigParseError(#[from] serde_json::Error), ConfigParseError(#[from] serde_json::Error),
#[error("Failed to set up start-on-login: {0}")] #[error("Failed to set up start-on-login: {0}")]
AutoLaunchError(#[from] auto_launch::Error), AutoLaunchError(#[from] auto_launch::Error),
#[error("Failed to start listener: {0}")]
ServerSetupError(#[from] std::io::Error),
} }

View File

@ -73,5 +73,6 @@ pub fn get_config(app_state: State<'_, AppState>) -> AppConfig {
pub async fn save_config(config: AppConfig, app_state: State<'_, AppState>) -> Result<(), String> { pub async fn save_config(config: AppConfig, app_state: State<'_, AppState>) -> Result<(), String> {
app_state.update_config(config) app_state.update_config(config)
.await .await
.map_err(|e| format!("Error saving config to database: {e}")) .map_err(|e| format!("Error saving config: {e}"))?;
Ok(())
} }

View File

@ -2,8 +2,19 @@
all(not(debug_assertions), target_os = "windows"), all(not(debug_assertions), target_os = "windows"),
windows_subsystem = "windows" windows_subsystem = "windows"
)] )]
use std::error::Error;
use tauri::{AppHandle, Manager, async_runtime as rt}; use sqlx::{
SqlitePool,
sqlite::SqlitePoolOptions,
sqlite::SqliteConnectOptions,
};
use tauri::{
App,
AppHandle,
Manager,
async_runtime as rt,
};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
mod config; mod config;
@ -14,20 +25,44 @@ mod state;
mod server; mod server;
mod tray; mod tray;
use crate::errors::*; use config::AppConfig;
use server::Server;
use errors::*;
use state::AppState; use state::AppState;
pub static APP: OnceCell<AppHandle> = OnceCell::new(); pub static APP: OnceCell<AppHandle> = OnceCell::new();
fn main() {
let initial_state = match rt::block_on(AppState::load()) {
Ok(state) => state,
Err(e) => {eprintln!("{}", e); return;}
};
async fn setup(app: &mut App) -> Result<(), Box<dyn Error>> {
APP.set(app.handle()).unwrap();
let conn_opts = SqliteConnectOptions::new()
.filename(config::get_or_create_db_path())
.create_if_missing(true);
let pool_opts = SqlitePoolOptions::new();
let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?;
sqlx::migrate!().run(&pool).await?;
let conf = AppConfig::load(&pool).await?;
let session = AppState::load_creds(&pool).await?;
let srv = Server::new(conf.listen_addr, conf.listen_port, app.handle()).await?;
config::set_auto_launch(conf.start_on_login)?;
if !conf.start_minimized {
app.get_window("main")
.ok_or(RequestError::NoMainWindow)?
.show()?;
}
let state = AppState::new(conf, session, srv, pool);
app.manage(state);
Ok(())
}
fn main() {
tauri::Builder::default() tauri::Builder::default()
.manage(initial_state)
.system_tray(tray::create()) .system_tray(tray::create())
.on_system_tray_event(tray::handle_event) .on_system_tray_event(tray::handle_event)
.invoke_handler(tauri::generate_handler![ .invoke_handler(tauri::generate_handler![
@ -38,22 +73,7 @@ fn main() {
ipc::get_config, ipc::get_config,
ipc::save_config, ipc::save_config,
]) ])
.setup(|app| { .setup(|app| rt::block_on(setup(app)))
APP.set(app.handle()).unwrap();
let state = app.state::<AppState>();
let config = state.config.read().unwrap();
config::set_auto_launch(config.start_on_login)?;
let addr = std::net::SocketAddrV4::new(config.listen_addr, config.listen_port);
tauri::async_runtime::spawn(server::serve(addr, app.handle()));
if !config.start_minimized {
app.get_window("main")
.ok_or(RequestError::NoMainWindow)?
.show()?;
}
Ok(())
})
.build(tauri::generate_context!()) .build(tauri::generate_context!())
.expect("error while running tauri application") .expect("error while running tauri application")
.run(|app, run_event| match run_event { .run(|app, run_event| match run_event {

View File

@ -1,12 +1,22 @@
use core::time::Duration; use core::time::Duration;
use std::io; use std::io;
use std::net::{SocketAddr, SocketAddrV4}; use std::net::{
use tokio::net::{TcpListener, TcpStream}; Ipv4Addr,
SocketAddr,
SocketAddrV4,
TcpListener as StdTcpListener,
};
use tokio::net::{
TcpListener,
TcpStream,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::time::sleep; use tokio::time::sleep;
use tauri::{AppHandle, Manager}; use tauri::{AppHandle, Manager};
use tauri::async_runtime as rt;
use tauri::async_runtime::JoinHandle;
use crate::{clientinfo, clientinfo::Client}; use crate::{clientinfo, clientinfo::Client};
use crate::errors::*; use crate::errors::*;
@ -159,18 +169,56 @@ impl Handler {
} }
pub async fn serve(addr: SocketAddrV4, app_handle: AppHandle) -> io::Result<()> { #[derive(Debug)]
let listener = TcpListener::bind(&addr).await?; pub struct Server {
println!("Listening on {addr}"); addr: Ipv4Addr,
port: u16,
app_handle: AppHandle,
task: JoinHandle<()>,
}
impl Server {
pub async fn new(addr: Ipv4Addr, port: u16, app_handle: AppHandle) -> io::Result<Server> {
// construct the listener before passing it to the task so that we know if it fails
let sock_addr = SocketAddrV4::new(addr, port);
let listener = TcpListener::bind(&sock_addr).await?;
let task = rt::spawn(Self::serve(listener, app_handle.app_handle()));
Ok(Server { addr, port, app_handle, task})
}
// this is blocking because it's too much of a paint to juggle mutexes otherwise
pub fn rebind(&mut self, addr: Ipv4Addr, port: u16) -> io::Result<()> {
if addr == self.addr && port == self.port {
return Ok(())
}
let sock_addr = SocketAddrV4::new(addr, port);
let std_listener = StdTcpListener::bind(&sock_addr)?;
std_listener.set_nonblocking(true)?;
let async_listener = TcpListener::from_std(std_listener)?;
let new_task = rt::spawn(Self::serve(async_listener, self.app_handle.app_handle()));
self.task.abort();
self.addr = addr;
self.port = port;
self.task = new_task;
Ok(())
}
async fn serve(listener: TcpListener, app_handle: AppHandle) {
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
let handler = Handler::new(stream, app_handle.app_handle()); let handler = Handler::new(stream, app_handle.app_handle());
tauri::async_runtime::spawn(handler.handle()); rt::spawn(handler.handle());
}, },
Err(e) => { Err(e) => {
eprintln!("Error accepting connection: {e}"); eprintln!("Error accepting connection: {e}");
} }
} }
} }
}
} }

View File

@ -5,7 +5,7 @@ use std::sync::RwLock;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::time::sleep; use tokio::time::sleep;
use sqlx::{SqlitePool, sqlite::SqlitePoolOptions, sqlite::SqliteConnectOptions}; use sqlx::SqlitePool;
use sodiumoxide::crypto::{ use sodiumoxide::crypto::{
pwhash, pwhash,
pwhash::Salt, pwhash::Salt,
@ -19,6 +19,7 @@ use crate::{config, config::AppConfig};
use crate::ipc; use crate::ipc;
use crate::clientinfo::Client; use crate::clientinfo::Client;
use crate::errors::*; use crate::errors::*;
use crate::server::Server;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -63,34 +64,50 @@ pub struct AppState {
pub request_count: RwLock<u64>, pub request_count: RwLock<u64>,
pub open_requests: RwLock<HashMap<u64, Sender<ipc::Approval>>>, pub open_requests: RwLock<HashMap<u64, Sender<ipc::Approval>>>,
pub bans: RwLock<std::collections::HashSet<Option<Client>>>, pub bans: RwLock<std::collections::HashSet<Option<Client>>>,
pool: SqlitePool, server: RwLock<Server>,
pool: sqlx::SqlitePool,
} }
impl AppState { impl AppState {
pub async fn load() -> Result<Self, SetupError> { pub fn new(config: AppConfig, session: Session, server: Server, pool: SqlitePool) -> AppState {
let conn_opts = SqliteConnectOptions::new() AppState {
.filename(config::get_or_create_db_path()) config: RwLock::new(config),
.create_if_missing(true); session: RwLock::new(session),
let pool_opts = SqlitePoolOptions::new();
let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?;
sqlx::migrate!().run(&pool).await?;
let creds = Self::load_creds(&pool).await?;
let conf = AppConfig::load(&pool).await?;
let state = AppState {
config: RwLock::new(conf),
session: RwLock::new(creds),
request_count: RwLock::new(0), request_count: RwLock::new(0),
open_requests: RwLock::new(HashMap::new()), open_requests: RwLock::new(HashMap::new()),
bans: RwLock::new(HashSet::new()), bans: RwLock::new(HashSet::new()),
server: RwLock::new(server),
pool, pool,
};
Ok(state)
} }
}
async fn load_creds(pool: &SqlitePool) -> Result<Session, SetupError> { // pub async fn load(app_handle: AppHandle) -> Result<Self, SetupError> {
// let conn_opts = SqliteConnectOptions::new()
// .filename(config::get_or_create_db_path())
// .create_if_missing(true);
// let pool_opts = SqlitePoolOptions::new();
// let pool: SqlitePool = pool_opts.connect_with(conn_opts).await?;
// sqlx::migrate!().run(&pool).await?;
// let creds = Self::load_creds(&pool).await?;
// let conf = AppConfig::load(&pool).await?;
// let server = Server::new(conf.listen_addr, conf.listen_port, app_handle)?;
// let state = AppState {
// config: RwLock::new(conf),
// session: RwLock::new(creds),
// request_count: RwLock::new(0),
// open_requests: RwLock::new(HashMap::new()),
// bans: RwLock::new(HashSet::new()),
// server: RwLock::new(server),
// pool,
// };
// Ok(state)
// }
pub async fn load_creds(pool: &SqlitePool) -> Result<Session, SetupError> {
let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc") let res = sqlx::query!("SELECT * FROM credentials ORDER BY created_at desc")
.fetch_optional(pool) .fetch_optional(pool)
.await?; .await?;
@ -151,14 +168,22 @@ impl AppState {
} }
pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> { pub async fn update_config(&self, new_config: AppConfig) -> Result<(), SetupError> {
new_config.save(&self.pool).await?; {
let orig_config = self.config.read().unwrap();
let mut live_config = self.config.write().unwrap(); if new_config.start_on_login != orig_config.start_on_login {
if new_config.start_on_login != live_config.start_on_login {
config::set_auto_launch(new_config.start_on_login)?; config::set_auto_launch(new_config.start_on_login)?;
} }
*live_config = new_config; if new_config.listen_addr != orig_config.listen_addr
|| new_config.listen_port != orig_config.listen_port
{
let mut sv = self.server.write().unwrap();
sv.rebind(new_config.listen_addr, new_config.listen_port)?;
}
}
new_config.save(&self.pool).await?;
let mut live_config = self.config.write().unwrap();
*live_config = new_config;
Ok(()) Ok(())
} }

View File

@ -1,3 +1,7 @@
@tailwind base; @tailwind base;
@tailwind components; @tailwind components;
@tailwind utilities; @tailwind utilities;
.btn-alert-error {
@apply bg-transparent hover:bg-[#cd5a5a] border border-error-content text-error-content
}

View File

@ -10,19 +10,37 @@
export let max = null; export let max = null;
export let decimal = false; export let decimal = false;
let error = null;
let localValue = value.toString();
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
function validate(event) {
$: localValue = value.toString();
let lastInputTime = null;
function debounce(event) {
lastInputTime = Date.now();
localValue = localValue.replace(/[^-0-9.]/g, ''); localValue = localValue.replace(/[^-0-9.]/g, '');
const eventTime = lastInputTime;
const pendingValue = localValue;
window.setTimeout(
() => {
// if no other inputs have occured since then
if (eventTime === lastInputTime) {
updateValue(pendingValue);
}
},
500
)
}
let error = null;
function updateValue(newValue) {
// Don't update the value, but also don't error, if it's empty // Don't update the value, but also don't error, if it's empty
// or if it could be the start of a negative or decimal number // or if it could be the start of a negative or decimal number
if (localValue.match(/^$|^-$|^\.$/) !== null) { if (newValue.match(/^$|^-$|^\.$/) !== null) {
error = null; error = null;
return; return;
} }
let num = parseFloat(localValue); const num = parseFloat(newValue);
if (num % 1 !== 0 && !decimal) { if (num % 1 !== 0 && !decimal) {
error = `${num} is not a whole number`; error = `${num} is not a whole number`;
} }
@ -53,7 +71,7 @@
size="{Math.max(5, localValue.length)}" size="{Math.max(5, localValue.length)}"
class:input-error={error} class:input-error={error}
bind:value={localValue} bind:value={localValue}
on:input="{validate}" on:input="{debounce}"
/> />
</div> </div>
</div> </div>

View File

@ -5,12 +5,19 @@
import Nav from '../ui/Nav.svelte'; import Nav from '../ui/Nav.svelte';
import Link from '../ui/Link.svelte'; import Link from '../ui/Link.svelte';
import ErrorAlert from '../ui/ErrorAlert.svelte'; import ErrorAlert from '../ui/ErrorAlert.svelte';
// import Setting from '../ui/settings/Setting.svelte';
import { Setting, ToggleSetting, NumericSetting } from '../ui/settings'; import { Setting, ToggleSetting, NumericSetting } from '../ui/settings';
let error = null;
async function save() { async function save() {
try {
await invoke('save_config', {config: $appState.config}); await invoke('save_config', {config: $appState.config});
} }
catch (e) {
error = e;
$appState.config = await invoke('get_config');
}
}
</script> </script>
@ -57,3 +64,17 @@
</Setting> </Setting>
</div> </div>
{/await} {/await}
{#if error}
<div class="toast">
<div class="alert alert-error">
<div>
<span>{error}</span>
</div>
<div>
<button class="btn btn-sm btn-alert-error" on:click={() => error = null}>Ok</button>
</div>
</div>
</div>
{/if}

View File

@ -49,9 +49,7 @@
{error} {error}
<svelte:fragment slot="buttons"> <svelte:fragment slot="buttons">
<Link target="Home"> <Link target="Home">
<button class="btn btn-sm bg-transparent hover:bg-[#cd5a5a] border border-error-content text-error-content"> <button class="btn btn-sm btn-alert-error" on:click={() => navigate('Home')}>Ok</button>
Ok
</button>
</Link> </Link>
</svelte:fragment> </svelte:fragment>
</ErrorAlert> </ErrorAlert>

View File

@ -38,9 +38,7 @@
{error} {error}
<svelte:fragment slot="buttons"> <svelte:fragment slot="buttons">
<Link target="Home"> <Link target="Home">
<button class="btn btn-sm bg-transparent hover:bg-[#cd5a5a] border border-error-content text-error-content" on:click="{() => navigate('Home')}"> <button class="btn btn-sm btn-alert-error" on:click={() => navigate('Home')}>Ok</button>
Ok
</button>
</Link> </Link>
</svelte:fragment> </svelte:fragment>
</ErrorAlert> </ErrorAlert>