Better concurrency management

This commit is contained in:
Matéo Duparc 2021-08-29 21:17:36 +02:00
parent 5ae61222f4
commit 6c5bbc3f64
Signed by: hardcoresushi
GPG Key ID: 007F84120107191E
4 changed files with 82 additions and 241 deletions

157
Cargo.lock generated
View File

@ -615,22 +615,6 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2df960f5d869b2dd8532793fde43eb5427cceb126c929747a26823ab0eeb536" checksum = "a2df960f5d869b2dd8532793fde43eb5427cceb126c929747a26823ab0eeb536"
[[package]]
name = "core-foundation"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a89e2ae426ea83155dccf10c0fa6b1463ef6d5fcb44cee0b224a408fa640a62"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b"
[[package]] [[package]]
name = "cow-utils" name = "cow-utils"
version = "0.1.2" version = "0.1.2"
@ -895,21 +879,6 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]] [[package]]
name = "form_urlencoded" name = "form_urlencoded"
version = "1.0.1" version = "1.0.1"
@ -1289,15 +1258,6 @@ dependencies = [
"hashbrown 0.9.1", "hashbrown 0.9.1",
] ]
[[package]]
name = "input_buffer"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413"
dependencies = [
"bytes 1.0.1",
]
[[package]] [[package]]
name = "instant" name = "instant"
version = "0.1.9" version = "0.1.9"
@ -1376,9 +1336,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.94" version = "0.2.101"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18794a8ad5b29321f790b55d93dfba91e125cb1a9edbd4f8e3150acc771c1a5e" checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21"
[[package]] [[package]]
name = "libmdns" name = "libmdns"
@ -1592,24 +1552,6 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "native-tls"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8d96b2e1c8da3957d58100b09f102c6d9cfdfced01b7ec5a8974044bb09dbd4"
dependencies = [
"lazy_static",
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]] [[package]]
name = "net2" name = "net2"
version = "0.2.37" version = "0.2.37"
@ -1704,39 +1646,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "openssl"
version = "0.10.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d7830286ad6a3973c0f1d9b73738f69c76b739301d0229c4b96501695cbe4c8"
dependencies = [
"bitflags",
"cfg-if 1.0.0",
"foreign-types",
"libc",
"once_cell",
"openssl-sys",
]
[[package]]
name = "openssl-probe"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a"
[[package]]
name = "openssl-sys"
version = "0.9.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6b0d6fb7d80f877617dfcb014e605e2b5ab2fb0afdf27935219bb6bd984cb98"
dependencies = [
"autocfg",
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.11.1" version = "0.11.1"
@ -2068,15 +1977,6 @@ version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi 0.3.9",
]
[[package]] [[package]]
name = "resolv-conf" name = "resolv-conf"
version = "0.7.0" version = "0.7.0"
@ -2136,16 +2036,6 @@ dependencies = [
"regex", "regex",
] ]
[[package]]
name = "schannel"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75"
dependencies = [
"lazy_static",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "scoped_threadpool" name = "scoped_threadpool"
version = "0.1.9" version = "0.1.9"
@ -2172,29 +2062,6 @@ dependencies = [
"sha2", "sha2",
] ]
[[package]]
name = "security-framework"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3670b1d2fdf6084d192bc71ead7aabe6c06aa2ea3fbd9cc3ac111fa5c2b1bd84"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3676258fd3cfe2c9a0ec99ce3038798d847ce3e4bb17746373eb9f0f1ac16339"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "semver" name = "semver"
version = "0.9.0" version = "0.9.0"
@ -2420,20 +2287,6 @@ dependencies = [
"unicode-xid", "unicode-xid",
] ]
[[package]]
name = "tempfile"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22"
dependencies = [
"cfg-if 1.0.0",
"libc",
"rand 0.8.3",
"redox_syscall",
"remove_dir_all",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.24" version = "1.0.24"
@ -2660,18 +2513,16 @@ dependencies = [
[[package]] [[package]]
name = "tungstenite" name = "tungstenite"
version = "0.13.0" version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093" checksum = "983d40747bce878d2fb67d910dcb8bd3eca2b2358540c3cc1b98c027407a3ae3"
dependencies = [ dependencies = [
"base64", "base64",
"byteorder", "byteorder",
"bytes 1.0.1", "bytes 1.0.1",
"http", "http",
"httparse", "httparse",
"input_buffer",
"log", "log",
"native-tls",
"rand 0.8.3", "rand 0.8.3",
"sha-1", "sha-1",
"thiserror", "thiserror",

View File

@ -23,7 +23,7 @@ actix-web = "3"
actix-multipart = "0.3" actix-multipart = "0.3"
time = "0.2" #needed for actix cookies time = "0.2" #needed for actix cookies
futures = "0.3" futures = "0.3"
tungstenite = "0.13" #websocket tungstenite = "0.15" #websocket
serde = "1.0" #serialization serde = "1.0" #serialization
html-escape = "0.2" html-escape = "0.2"
sanitize-filename = "0.3" sanitize-filename = "0.3"

View File

@ -10,7 +10,8 @@ mod discovery;
use std::{env, fs, io, net::SocketAddr, str::{FromStr, from_utf8}, sync::{Arc, RwLock}, cmp::Ordering}; use std::{env, fs, io, net::SocketAddr, str::{FromStr, from_utf8}, sync::{Arc, RwLock}, cmp::Ordering};
use image::GenericImageView; use image::GenericImageView;
use tokio::{net::TcpListener, runtime::Handle, sync::mpsc}; use tokio::{net::TcpListener, runtime::Handle, sync::mpsc, task::JoinError};
use tungstenite::Message;
use actix_web::{App, HttpMessage, HttpRequest, HttpResponse, HttpServer, http::{header, CookieBuilder}, web, web::Data}; use actix_web::{App, HttpMessage, HttpRequest, HttpResponse, HttpServer, http::{header, CookieBuilder}, web, web::Data};
use actix_multipart::Multipart; use actix_multipart::Multipart;
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
@ -24,41 +25,36 @@ use identity::Identity;
use session_manager::{SessionManager, SessionCommand}; use session_manager::{SessionManager, SessionCommand};
use ui_interface::UiConnection; use ui_interface::UiConnection;
async fn start_websocket_server(global_vars: Arc<RwLock<GlobalVars>>) -> u16 { async fn start_websocket_server(ui_auth_token: Arc<RwLock<Option<String>>>, session_manager: Arc<SessionManager>) -> u16 {
let websocket_bind_addr = env::var("AIRA_WEBSOCKET_ADDR").unwrap_or_else(|_| "127.0.0.1".to_owned()); let websocket_bind_addr = env::var("AIRA_WEBSOCKET_ADDR").unwrap_or_else(|_| "127.0.0.1".to_owned());
let websocket_port = env::var("AIRA_WEBSOCKET_PORT").unwrap_or_else(|_| "0".to_owned()); let websocket_port = env::var("AIRA_WEBSOCKET_PORT").unwrap_or_else(|_| "0".to_owned());
let server = TcpListener::bind(websocket_bind_addr+":"+&websocket_port).await.unwrap(); let server = TcpListener::bind(websocket_bind_addr+":"+&websocket_port).await.unwrap();
let websocket_port = server.local_addr().unwrap().port(); let websocket_port = server.local_addr().unwrap().port();
tokio::spawn(async move { tokio::spawn(async move {
let worker_done = Arc::new(RwLock::new(true));
loop { loop {
let (stream, _addr) = server.accept().await.unwrap(); let (stream, _addr) = server.accept().await.unwrap();
if *worker_done.read().unwrap() { let ui_auth_token = {
let ui_auth_token = { ui_auth_token.read().unwrap().clone()
global_vars.clone().read().unwrap().ui_auth_token.clone() };
}; if let Some(ui_auth_token) = ui_auth_token {
if let Some(ui_auth_token) = ui_auth_token { let stream = stream.into_std().unwrap();
let stream = stream.into_std().unwrap(); stream.set_nonblocking(false).unwrap();
stream.set_nonblocking(false).unwrap(); match tungstenite::accept(stream) {
match tungstenite::accept(stream.try_clone().unwrap()) { Ok(mut websocket) => {
Ok(mut websocket) => { if let Ok(message) = websocket.read_message() { //waiting for auth token
if let Ok(message) = websocket.read_message() { //waiting for auth token match message.into_text() {
match message.into_text() { Ok(token) => {
Ok(token) => { if token == ui_auth_token {
if token == ui_auth_token { let ui_connection = UiConnection::new(websocket);
let ui_connection = UiConnection::new(websocket); session_manager.set_ui_connection(ui_connection.clone());
let global_vars = global_vars.clone(); websocket_worker(ui_connection, session_manager.clone()).await.unwrap();
global_vars.read().unwrap().session_manager.set_ui_connection(ui_connection.clone());
*worker_done.write().unwrap() = false;
websocket_worker(ui_connection, global_vars, worker_done.clone()).await;
}
} }
Err(e) => print_error!(e)
} }
Err(e) => print_error!(e)
} }
} }
Err(e) => print_error!(e)
} }
Err(e) => print_error!(e)
} }
} }
} }
@ -83,19 +79,18 @@ fn discover_peers(session_manager: Arc<SessionManager>) {
}); });
} }
fn load_msgs(session_manager: Arc<SessionManager>, ui_connection: &mut UiConnection, session_id: &usize) { fn load_msgs(session_manager: &SessionManager, ui_connection: &mut UiConnection, session_id: &usize) {
if let Some(msgs) = session_manager.load_msgs(session_id, constants::MSG_LOADING_COUNT) { if let Some(msgs) = session_manager.load_msgs(session_id, constants::MSG_LOADING_COUNT) {
ui_connection.load_msgs(session_id, &msgs); ui_connection.load_msgs(session_id, &msgs);
} }
} }
async fn websocket_worker(mut ui_connection: UiConnection, global_vars: Arc<RwLock<GlobalVars>>, worker_done: Arc<RwLock<bool>>) { async fn websocket_worker(mut ui_connection: UiConnection, session_manager: Arc<SessionManager>) -> Result<(), JoinError> {
let session_manager = global_vars.read().unwrap().session_manager.clone();
ui_connection.set_name(&session_manager.identity.read().unwrap().as_ref().unwrap().name); ui_connection.set_name(&session_manager.identity.read().unwrap().as_ref().unwrap().name);
session_manager.list_contacts().into_iter().for_each(|contact|{ session_manager.list_contacts().into_iter().for_each(|contact|{
ui_connection.set_as_contact(contact.0, &contact.1, contact.2, &crypto::generate_fingerprint(&contact.3)); ui_connection.set_as_contact(contact.0, &contact.1, contact.2, &crypto::generate_fingerprint(&contact.3));
session_manager.last_loaded_msg_offsets.write().unwrap().insert(contact.0, 0); session_manager.last_loaded_msg_offsets.write().unwrap().insert(contact.0, 0);
load_msgs(session_manager.clone(), &mut ui_connection, &contact.0); load_msgs(&session_manager, &mut ui_connection, &contact.0);
}); });
session_manager.sessions.read().unwrap().iter().for_each(|session| { session_manager.sessions.read().unwrap().iter().for_each(|session| {
ui_connection.on_new_session( ui_connection.on_new_session(
@ -110,7 +105,7 @@ async fn websocket_worker(mut ui_connection: UiConnection, global_vars: Arc<RwLo
{ {
let not_seen = session_manager.not_seen.read().unwrap(); let not_seen = session_manager.not_seen.read().unwrap();
if not_seen.len() > 0 { if not_seen.len() > 0 {
ui_connection.set_not_seen(not_seen.clone()); ui_connection.set_not_seen(&not_seen);
} }
} }
session_manager.get_saved_msgs().into_iter().for_each(|msgs| { session_manager.get_saved_msgs().into_iter().for_each(|msgs| {
@ -141,15 +136,15 @@ async fn websocket_worker(mut ui_connection: UiConnection, global_vars: Arc<RwLo
} }
Err(e) => print_error!(e) Err(e) => print_error!(e)
} }
ui_connection.set_local_ips(ips); ui_connection.set_local_ips(&ips);
discover_peers(session_manager.clone()); discover_peers(session_manager.clone());
let handle = Handle::current(); let handle = Handle::current();
std::thread::spawn(move || { //new thread needed to block on read_message() without blocking tokio tasks tokio::task::spawn_blocking(move || {
loop { loop {
match ui_connection.websocket.read_message() { match ui_connection.websocket.read_message() {
Ok(msg) => { Ok(msg) => {
if msg.is_ping() { if msg.is_ping() {
ui_connection.write_message(tungstenite::Message::Pong(Vec::new())); //not sure if I'm doing this right ui_connection.write_message(Message::Pong(Vec::new())); //not sure if I'm doing this right
} else if msg.is_text() { } else if msg.is_text() {
let msg = msg.into_text().unwrap(); let msg = msg.into_text().unwrap();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -212,7 +207,7 @@ async fn websocket_worker(mut ui_connection: UiConnection, global_vars: Arc<RwLo
} }
"load_msgs" => { "load_msgs" => {
let session_id: usize = args[1].parse().unwrap(); let session_id: usize = args[1].parse().unwrap();
load_msgs(session_manager.clone(), &mut ui_connection, &session_id); load_msgs(&session_manager, &mut ui_connection, &session_id);
} }
"contact" => { "contact" => {
let session_id: usize = args[1].parse().unwrap(); let session_id: usize = args[1].parse().unwrap();
@ -304,7 +299,6 @@ async fn websocket_worker(mut ui_connection: UiConnection, global_vars: Arc<RwLo
Err(e) => { Err(e) => {
match e { match e {
tungstenite::Error::ConnectionClosed => { tungstenite::Error::ConnectionClosed => {
*worker_done.write().unwrap() = true;
break; break;
} }
_ => print_error!(e) _ => print_error!(e)
@ -312,13 +306,13 @@ async fn websocket_worker(mut ui_connection: UiConnection, global_vars: Arc<RwLo
} }
} }
} }
}); }).await
} }
fn is_authenticated(req: &HttpRequest) -> bool { fn is_authenticated(req: &HttpRequest) -> bool {
if let Some(cookie) = req.cookie(constants::HTTP_COOKIE_NAME) { if let Some(cookie) = req.cookie(constants::HTTP_COOKIE_NAME) {
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
if let Some(token) = &global_vars.read().unwrap().ui_auth_token { if let Some(token) = global_vars.ui_auth_token.read().unwrap().as_ref() {
return token == cookie.value(); return token == cookie.value();
} }
} }
@ -400,8 +394,8 @@ fn handle_avatar(req: HttpRequest) -> HttpResponse {
} }
} else if splits.len() == 3 && is_authenticated(&req) { } else if splits.len() == 3 && is_authenticated(&req) {
if let Ok(session_id) = splits[1].parse() { if let Ok(session_id) = splits[1].parse() {
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
return reply_with_avatar(global_vars.read().unwrap().session_manager.get_avatar(&session_id), Some(splits[2])); return reply_with_avatar(global_vars.session_manager.get_avatar(&session_id), Some(splits[2]));
} }
} }
HttpResponse::BadRequest().finish() HttpResponse::BadRequest().finish()
@ -416,8 +410,8 @@ fn handle_load_file(req: HttpRequest, file_info: web::Query<FileInfo>) -> HttpRe
if is_authenticated(&req) { if is_authenticated(&req) {
match Uuid::from_str(&file_info.uuid) { match Uuid::from_str(&file_info.uuid) {
Ok(uuid) => { Ok(uuid) => {
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
if let Some(buffer) = global_vars.read().unwrap().session_manager.identity.read().unwrap().as_ref().unwrap().load_file(uuid) { if let Some(buffer) = global_vars.session_manager.identity.read().unwrap().as_ref().unwrap().load_file(uuid) {
return HttpResponse::Ok().header("Content-Disposition", format!("attachment; filename=\"{}\"", escape_double_quote(html_escape::decode_html_entities(&file_info.file_name).to_string()))).content_type("application/octet-stream").body(buffer); return HttpResponse::Ok().header("Content-Disposition", format!("attachment; filename=\"{}\"", escape_double_quote(html_escape::decode_html_entities(&file_info.file_name).to_string()))).content_type("application/octet-stream").body(buffer);
} }
} }
@ -440,14 +434,13 @@ async fn handle_send_file(req: HttpRequest, mut payload: Multipart) -> HttpRespo
} else if session_id.is_some() { } else if session_id.is_some() {
let filename = content_disposition.get_filename().unwrap(); let filename = content_disposition.get_filename().unwrap();
let session_id = session_id.unwrap(); let session_id = session_id.unwrap();
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
let global_vars_read = global_vars.read().unwrap();
if req.path() == "/send_file" { if req.path() == "/send_file" {
let mut buffer = Vec::new(); let mut buffer = Vec::new();
while let Some(Ok(chunk)) = field.next().await { while let Some(Ok(chunk)) = field.next().await {
buffer.extend(chunk); buffer.extend(chunk);
} }
if let Ok(sent) = global_vars_read.session_manager.send_or_add_to_pending(&session_id, protocol::file(filename, &buffer)).await { if let Ok(sent) = global_vars.session_manager.send_or_add_to_pending(&session_id, protocol::file(filename, &buffer)).await {
return if sent { return if sent {
HttpResponse::Ok().finish() HttpResponse::Ok().finish()
} else { } else {
@ -476,7 +469,7 @@ async fn handle_send_file(req: HttpRequest, mut payload: Multipart) -> HttpRespo
break; break;
} }
} }
if !global_vars_read.session_manager.send_command(&session_id, SessionCommand::EncryptFileChunk{ if !global_vars.session_manager.send_command(&session_id, SessionCommand::EncryptFileChunk{
plain_text: chunk_buffer.clone() plain_text: chunk_buffer.clone()
}).await { }).await {
return HttpResponse::InternalServerError().finish(); return HttpResponse::InternalServerError().finish();
@ -484,7 +477,7 @@ async fn handle_send_file(req: HttpRequest, mut payload: Multipart) -> HttpRespo
if !match ack_receiver.recv().await { if !match ack_receiver.recv().await {
Some(should_continue) => { Some(should_continue) => {
//send previous encrypted chunk even if transfert is aborted to keep PSEC nonces syncrhonized //send previous encrypted chunk even if transfert is aborted to keep PSEC nonces syncrhonized
if global_vars_read.session_manager.send_command(&session_id, SessionCommand::SendEncryptedFileChunk { if global_vars.session_manager.send_command(&session_id, SessionCommand::SendEncryptedFileChunk {
ack_sender: ack_sender.clone() ack_sender: ack_sender.clone()
}).await { }).await {
should_continue should_continue
@ -513,11 +506,10 @@ async fn handle_send_file(req: HttpRequest, mut payload: Multipart) -> HttpRespo
async fn handle_logout(req: HttpRequest) -> HttpResponse { async fn handle_logout(req: HttpRequest) -> HttpResponse {
if is_authenticated(&req) { if is_authenticated(&req) {
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
let mut global_vars_write = global_vars.write().unwrap(); if global_vars.session_manager.is_identity_loaded() {
if global_vars_write.session_manager.is_identity_loaded() { *global_vars.ui_auth_token.write().unwrap() = None;
global_vars_write.ui_auth_token = None; global_vars.session_manager.stop().await;
global_vars_write.session_manager.stop().await;
} }
if Identity::is_protected().unwrap_or(true) { if Identity::is_protected().unwrap_or(true) {
HttpResponse::Found().header(header::LOCATION, "/").finish() HttpResponse::Found().header(header::LOCATION, "/").finish()
@ -529,13 +521,12 @@ async fn handle_logout(req: HttpRequest) -> HttpResponse {
} }
} }
fn login(identity: Identity, global_vars: &Arc<RwLock<GlobalVars>>) -> HttpResponse { fn login(identity: Identity, global_vars: &GlobalVars) -> HttpResponse {
let mut global_vars_write = global_vars.write().unwrap(); let session_manager = global_vars.session_manager.clone();
let session_manager = global_vars_write.session_manager.clone();
if !session_manager.is_identity_loaded() { if !session_manager.is_identity_loaded() {
global_vars_write.session_manager.set_identity(Some(identity)); session_manager.set_identity(Some(identity));
global_vars_write.tokio_handle.clone().spawn(async move { global_vars.tokio_handle.spawn(async move {
if SessionManager::start_listener(session_manager.clone()).await.is_err() { if SessionManager::start_listener(session_manager).await.is_err() {
print_error!("You won't be able to receive incomming connections from other peers."); print_error!("You won't be able to receive incomming connections from other peers.");
} }
}); });
@ -543,7 +534,7 @@ fn login(identity: Identity, global_vars: &Arc<RwLock<GlobalVars>>) -> HttpRespo
let mut raw_cookie = [0; 32]; let mut raw_cookie = [0; 32];
OsRng.fill_bytes(&mut raw_cookie); OsRng.fill_bytes(&mut raw_cookie);
let cookie_value = base64::encode(raw_cookie); let cookie_value = base64::encode(raw_cookie);
global_vars_write.ui_auth_token = Some(cookie_value.clone()); *global_vars.ui_auth_token.write().unwrap() = Some(cookie_value.clone());
let cookie = CookieBuilder::new(constants::HTTP_COOKIE_NAME, cookie_value).max_age(time::Duration::hours(4)).finish(); let cookie = CookieBuilder::new(constants::HTTP_COOKIE_NAME, cookie_value).max_age(time::Duration::hours(4)).finish();
HttpResponse::Found() HttpResponse::Found()
.header(header::LOCATION, "/") .header(header::LOCATION, "/")
@ -551,7 +542,7 @@ fn login(identity: Identity, global_vars: &Arc<RwLock<GlobalVars>>) -> HttpRespo
.finish() .finish()
} }
fn on_identity_loaded(identity: Identity, global_vars: &Arc<RwLock<GlobalVars>>) -> HttpResponse { fn on_identity_loaded(identity: Identity, global_vars: &Arc<GlobalVars>) -> HttpResponse {
match Identity::clear_cache() { match Identity::clear_cache() {
Ok(_) => {}, Ok(_) => {},
Err(e) => print_error!(e) Err(e) => print_error!(e)
@ -566,7 +557,7 @@ struct LoginParams {
fn handle_login(req: HttpRequest, mut params: web::Form<LoginParams>) -> HttpResponse { fn handle_login(req: HttpRequest, mut params: web::Form<LoginParams>) -> HttpResponse {
let response = match Identity::load_identity(Some(params.password.as_bytes())) { let response = match Identity::load_identity(Some(params.password.as_bytes())) {
Ok(identity) => { Ok(identity) => {
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
on_identity_loaded(identity, global_vars) on_identity_loaded(identity, global_vars)
} }
Err(e) => generate_login_response(Some(&e)) Err(e) => generate_login_response(Some(&e))
@ -625,8 +616,8 @@ async fn handle_create(req: HttpRequest, mut params: web::Form<CreateParams>) ->
} }
) { ) {
Ok(identity) => { Ok(identity) => {
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
login(identity, global_vars.get_ref()) login(identity, global_vars)
} }
Err(e) => { Err(e) => {
print_error!(e); print_error!(e);
@ -641,7 +632,7 @@ async fn handle_create(req: HttpRequest, mut params: web::Form<CreateParams>) ->
response response
} }
fn index_not_logged_in(global_vars: &Arc<RwLock<GlobalVars>>) -> HttpResponse { fn index_not_logged_in(global_vars: &Arc<GlobalVars>) -> HttpResponse {
if Identity::is_protected().unwrap_or(true) { if Identity::is_protected().unwrap_or(true) {
generate_login_response(None) generate_login_response(None)
} else { } else {
@ -653,22 +644,21 @@ fn index_not_logged_in(global_vars: &Arc<RwLock<GlobalVars>>) -> HttpResponse {
} }
async fn handle_index(req: HttpRequest) -> HttpResponse { async fn handle_index(req: HttpRequest) -> HttpResponse {
let global_vars = req.app_data::<Data<Arc<RwLock<GlobalVars>>>>().unwrap(); let global_vars = req.app_data::<Data<GlobalVars>>().unwrap();
if is_authenticated(&req) { if is_authenticated(&req) {
let global_vars_read = global_vars.read().unwrap();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
let html = fs::read_to_string("src/frontend/index.html").unwrap() let html = fs::read_to_string("src/frontend/index.html").unwrap()
.replace("AIRA_VERSION", env!("CARGO_PKG_VERSION")); .replace("AIRA_VERSION", env!("CARGO_PKG_VERSION"));
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
let html = include_str!(concat!(env!("OUT_DIR"), "/index.html")); let html = include_str!(concat!(env!("OUT_DIR"), "/index.html"));
let public_key = global_vars_read.session_manager.identity.read().unwrap().as_ref().unwrap().get_public_key(); let identity = global_vars.session_manager.identity.read().unwrap();
let use_padding = global_vars_read.session_manager.identity.read().unwrap().as_ref().unwrap().use_padding.to_string(); let identity = identity.as_ref().unwrap();
HttpResponse::Ok().body( HttpResponse::Ok().body(
html html
.replace("IDENTITY_FINGERPRINT", &crypto::generate_fingerprint(&public_key)) .replace("IDENTITY_FINGERPRINT", &crypto::generate_fingerprint(&identity.get_public_key()))
.replace("WEBSOCKET_PORT", &global_vars_read.websocket_port.to_string()) .replace("WEBSOCKET_PORT", &global_vars.websocket_port.to_string())
.replace("IS_IDENTITY_PROTECTED", &Identity::is_protected().unwrap().to_string()) .replace("IS_IDENTITY_PROTECTED", &Identity::is_protected().unwrap().to_string())
.replace("PSEC_PADDING", &use_padding) .replace("PSEC_PADDING", &identity.use_padding.to_string())
) )
} else { } else {
index_not_logged_in(global_vars) index_not_logged_in(global_vars)
@ -786,16 +776,15 @@ fn handle_static(req: HttpRequest) -> HttpResponse {
} }
#[actix_web::main] #[actix_web::main]
async fn start_http_server(global_vars: Arc<RwLock<GlobalVars>>) -> io::Result<()> { async fn start_http_server(global_vars: GlobalVars) -> io::Result<()> {
let http_addr = env::var("AIRA_HTTP_ADDR").unwrap_or_else(|_| "127.0.0.1".to_owned()).parse().expect("AIRA_HTTP_ADDR invalid"); let http_addr = env::var("AIRA_HTTP_ADDR").unwrap_or_else(|_| "127.0.0.1".to_owned()).parse().expect("AIRA_HTTP_ADDR invalid");
let http_port = match env::var("AIRA_HTTP_PORT") { let http_port = match env::var("AIRA_HTTP_PORT") {
Ok(port) => port.parse().expect("AIRA_HTTP_PORT invalid"), Ok(port) => port.parse().expect("AIRA_HTTP_PORT invalid"),
Err(_) => constants::UI_PORT Err(_) => constants::UI_PORT
}; };
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
let global_vars_clone = global_vars.clone();
App::new() App::new()
.data(global_vars_clone) .data(global_vars.clone())
.service(web::resource("/") .service(web::resource("/")
.route(web::get().to(handle_index)) .route(web::get().to(handle_index))
.route(web::post().to(handle_create)) .route(web::post().to(handle_create))
@ -818,10 +807,11 @@ async fn start_http_server(global_vars: Arc<RwLock<GlobalVars>>) -> io::Result<(
server.run().await server.run().await
} }
#[derive(Clone)]
struct GlobalVars { struct GlobalVars {
session_manager: Arc<SessionManager>, session_manager: Arc<SessionManager>,
websocket_port: u16, websocket_port: u16,
ui_auth_token: Option<String>, ui_auth_token: Arc<RwLock<Option<String>>>,
tokio_handle: Handle, tokio_handle: Handle,
} }
@ -832,13 +822,13 @@ async fn main() {
print_error!(e); print_error!(e);
} }
} }
let global_vars = Arc::new(RwLock::new(GlobalVars { let ui_auth_token = Arc::new(RwLock::new(None));
session_manager: Arc::new(SessionManager::new()), let session_manager = Arc::new(SessionManager::new());
websocket_port: 0, let websocket_port = start_websocket_server(ui_auth_token.clone(), session_manager.clone()).await;
ui_auth_token: None, start_http_server(GlobalVars {
session_manager,
websocket_port,
ui_auth_token,
tokio_handle: Handle::current(), tokio_handle: Handle::current(),
})); }).unwrap();
let websocket_port = start_websocket_server(global_vars.clone()).await; }
global_vars.write().unwrap().websocket_port = websocket_port;
start_http_server(global_vars).unwrap();
}

View File

@ -25,8 +25,8 @@ impl UiConnection {
fn simple_event(&mut self, command: &str, session_id: &usize) { fn simple_event(&mut self, command: &str, session_id: &usize) {
self.write_message(format!("{} {}", command, session_id)); self.write_message(format!("{} {}", command, session_id));
} }
fn data_list<T: Display>(command: &str, data: Vec<T>) -> String { fn data_list<T: Display>(command: &str, data: &[T]) -> String {
command.to_string()+&data.into_iter().map(|i| { command.to_string()+&data.iter().map(|i| {
format!(" {}", i) format!(" {}", i)
}).collect::<String>() }).collect::<String>()
} }
@ -120,7 +120,7 @@ impl UiConnection {
}); });
self.write_message(s); self.write_message(s);
} }
pub fn set_not_seen(&mut self, session_ids: Vec<usize>) { pub fn set_not_seen(&mut self, session_ids: &[usize]) {
self.write_message(Self::data_list("not_seen", session_ids)); self.write_message(Self::data_list("not_seen", session_ids));
} }
pub fn new_pending_msg(&mut self, session_id: &usize, is_file: bool, data: &str) { pub fn new_pending_msg(&mut self, session_id: &usize, is_file: bool, data: &str) {
@ -132,7 +132,7 @@ impl UiConnection {
pub fn on_pending_msgs_sent(&mut self, session_id: &usize) { pub fn on_pending_msgs_sent(&mut self, session_id: &usize) {
self.simple_event("pending_msgs_sent", session_id); self.simple_event("pending_msgs_sent", session_id);
} }
pub fn set_local_ips(&mut self, ips: Vec<IpAddr>) { pub fn set_local_ips(&mut self, ips: &[IpAddr]) {
self.write_message(Self::data_list("local_ips", ips)); self.write_message(Self::data_list("local_ips", ips));
} }
pub fn set_name(&mut self, new_name: &str) { pub fn set_name(&mut self, new_name: &str) {