AIRA/src/session_manager/mod.rs

406 lines
17 KiB
Rust

mod session;
pub mod protocol;
use std::{collections::HashMap, io::ErrorKind, net::{SocketAddr, TcpStream}, sync::{Arc, Mutex, MutexGuard, RwLock}, thread, thread::sleep, time::Duration};
use socket2::{Socket, Domain, Type};
use session::Session;
use strum_macros::Display;
use ed25519_dalek::PUBLIC_KEY_LENGTH;
use uuid::Uuid;
use zeroize::Zeroize;
use crate::{constants, identity::{Contact, Identity}, print_error};
use crate::ui_interface::UiConnection;
#[derive(Display, Debug, PartialEq, Eq)]
pub enum SessionError {
SocketTimeout,
ConnectionReset,
BrokenPipe,
TransmissionCorrupted,
BufferTooLarge,
ConnectFailed,
InvalidSessionId,
BindFailed,
AlreadyConnected,
IsUs,
Unknown
}
pub struct SessionManager {
session_counter: RwLock<usize>,
sessions: RwLock<HashMap<usize, Session>>,
identity: RwLock<Option<Identity>>,
ui_connection: Mutex<Option<Arc<Mutex<UiConnection>>>>,
loaded_contacts: RwLock<HashMap<usize, Contact>>,
msg_queue: Mutex<Vec<(usize, Vec<u8>)>>,
is_stopping: RwLock<bool>
}
impl SessionManager {
fn with_ui_connection<F>(&self, f: F) -> bool where F: Fn(MutexGuard<UiConnection>) {
let mut ui_connection_opt = self.ui_connection.lock().unwrap();
let ui_connection = ui_connection_opt.as_mut().unwrap().lock().unwrap();
if ui_connection.is_valid {
f(ui_connection);
true
} else {
false
}
}
fn do_handshake_then_add(&self, mut session: Session) -> Result<usize, SessionError>{
let identity_opt = self.identity.read().unwrap();
let identity = identity_opt.as_ref().unwrap();
session.do_handshake(identity)?;
let peer_public_key = session.peer_public_key.unwrap();
if identity.get_public_key() == peer_public_key { //did handshake with the same Identity
return Err(SessionError::IsUs);
}
let mut sessions = self.sessions.write().unwrap();
for (_, registered_session) in sessions.iter() {
if registered_session.peer_public_key.unwrap() == peer_public_key { //already connected with a different addr
return Err(SessionError::AlreadyConnected)
}
}
for (index, contact) in self.loaded_contacts.read().unwrap().iter() {
if contact.public_key == peer_public_key { //session is a known contact. Assign the contact index to it
sessions.insert(*index, session);
return Ok(*index)
}
}
//if not a contact, increment the session_counter
let mut session_counter = self.session_counter.write().unwrap();
sessions.insert(*session_counter, session);
let r = *session_counter;
*session_counter += 1;
Ok(r)
}
pub fn connect_to(&self, ip: &str) -> Result<usize, SessionError> {
let sessions = self.sessions.read().unwrap();
for (_, s) in sessions.iter() {
if s.get_ip() == ip {
return Err(SessionError::AlreadyConnected)
}
}
drop(sessions); //release mutex
match TcpStream::connect((ip, constants::PORT.parse().unwrap())) {
Ok(stream) => {
let session = Session::new(stream);
self.do_handshake_then_add(session)
}
Err(_) => Err(SessionError::ConnectFailed)
}
}
pub fn send_to(&self, index: &usize, message: &[u8]) -> Result<(), SessionError> {
let mut sessions = self.sessions.write().unwrap();
match sessions.get_mut(index) {
Some(session) => session.encrypt_and_send(message),
None => Err(SessionError::InvalidSessionId)
}
}
pub fn start_receiver_loop(session_manager: &Arc<SessionManager>) {
let session_manager_clone = Arc::clone(session_manager);
thread::spawn(move || {
loop {
let mut dead_sessions = Vec::new();
let mut sessions = session_manager_clone.sessions.write().unwrap();
for (index, session) in sessions.iter_mut() {
let mut dead_session = false;
match session.receive_and_decrypt() {
Ok(buffer) => {
if buffer[0] == protocol::Headers::ASK_NAME {
session.encrypt_and_send(&protocol::tell_name(&session_manager_clone.identity.read().unwrap().as_ref().unwrap().name)).unwrap();
} else {
let buffer = if buffer[0] == protocol::Headers::FILE {
let file_name_len = u16::from_be_bytes([buffer[1], buffer[2]]) as usize;
let file_name = &buffer[3..3+file_name_len];
match session_manager_clone.store_file(index, &buffer[3+file_name_len..]) {
Ok(file_uuid) => {
Some([&[protocol::Headers::FILE][..], file_uuid.as_bytes(), file_name].concat())
}
Err(e) => {
print_error(e);
None
}
}
} else {
Some(buffer)
};
if buffer.is_some() {
let mut msg_saved = false;
if session_manager_clone.is_contact(index) && buffer.as_ref().unwrap()[0] != protocol::Headers::TELL_NAME {
match session_manager_clone.store_msg(&index, false, &buffer.as_ref().unwrap()) {
Ok(_) => msg_saved = true,
Err(e) => print_error(e)
}
}
let ui_connection_valid = session_manager_clone.with_ui_connection(|mut ui_connection| {
ui_connection.on_received(index, &buffer.as_ref().unwrap());
});
if !ui_connection_valid && !msg_saved {
session_manager_clone.msg_queue.lock().unwrap().push((*index, buffer.unwrap()));
}
}
}
}
Err(e) => {
if e == SessionError::BrokenPipe {
dead_session = true
} else if e != SessionError::SocketTimeout {
print_error(e);
}
}
}
if dead_session {
session_manager_clone.with_ui_connection(|mut ui_connection| {
ui_connection.on_disconnected(*index);
});
dead_sessions.push(*index);
}
}
dead_sessions.into_iter().for_each(|index| {
sessions.remove(&index);
});
drop(sessions); //release mutex
if *session_manager_clone.is_stopping.read().unwrap() {
break;
}
sleep(Duration::from_millis(constants::MUTEX_RELEASE_DELAY_MS));
}
println!("Stopping receiver thread");
});
}
pub fn start_listener(session_manager: &Arc<SessionManager>) -> Result<(), SessionError> {
let socket_v6 = Socket::new(Domain::ipv6(), Type::stream(), None).unwrap();
let socket_v4 = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket_v4.set_reuse_address(true).unwrap();
socket_v6.set_reuse_address(true).unwrap();
let addr_v6 = "[::1]:".to_owned()+constants::PORT;
let addr_v4 = "0.0.0.0:".to_owned()+constants::PORT;
let mut sockets = Vec::new();
match socket_v6.bind(&addr_v6.parse::<SocketAddr>().unwrap().into()) {
Ok(_) => sockets.push(socket_v6),
Err(e) => println!("Unable to bind on IPv6: {}", e)
};
match socket_v4.bind(&addr_v4.parse::<SocketAddr>().unwrap().into()) {
Ok(_) => sockets.push(socket_v4),
Err(e) => println!("Unable to bind on IPv4: {}", e)
}
if sockets.len() > 0 {
println!("Listening on port {}...", constants::PORT);
for socket in sockets {
socket.listen(256).unwrap();
socket.set_read_timeout(Some(Duration::from_millis(100))).unwrap();
let session_manager_clone = Arc::clone(session_manager);
thread::spawn(move ||{
for stream in socket.into_tcp_listener().incoming() {
match stream {
Ok(stream) => {
let session = Session::new(stream);
match session_manager_clone.do_handshake_then_add(session) {
Ok(index) => {
session_manager_clone.with_ui_connection(|mut ui_connection| {
ui_connection.on_new_session(index);
session_manager_clone.handle_new_session(&index, ui_connection);
});
}
Err(e) => {
if e != SessionError::AlreadyConnected && e != SessionError::IsUs {
print_error(e);
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
print_error(e);
}
}
}
if *session_manager_clone.is_stopping.read().unwrap() {
break;
}
}
println!("Stopping listener thread");
});
}
Ok(())
} else {
Err(SessionError::BindFailed)
}
}
pub fn handle_new_session(&self, index: &usize, mut ui_connection: MutexGuard<UiConnection>) {
if self.is_contact(index) {
match self.load_msgs(index) {
Some(msgs) => {
ui_connection.load_msgs(index, msgs);
}
None => {}
}
} else {
match self.ask_name_to(&index) {
Ok(_) => {}
Err(e) => print_error(e)
}
}
}
pub fn list_sessions(&self) -> Vec<usize> {
let sessions = self.sessions.read().unwrap();
sessions.iter().map(|t| *t.0).collect()
}
pub fn list_contacts(&self) -> Vec<(usize, String, bool)> {
self.loaded_contacts.read().unwrap().iter().map(|c| (*c.0, c.1.name.clone(), c.1.verified)).collect()
}
pub fn get_identity_uuid(&self) -> Option<Uuid> {
Some(self.identity.read().unwrap().as_ref()?.uuid)
}
pub fn get_saved_msgs(&self) -> Vec<(usize, Vec<u8>)> {
let mut msgs = Vec::new();
let mut msg_queue = self.msg_queue.lock().unwrap();
let sessions = self.sessions.read().unwrap();
for i in 0..msg_queue.len() {
let mut entry = msg_queue.remove(i);
if sessions.contains_key(&entry.0) {
msgs.push(entry);
} else {
entry.1.zeroize();
}
};
msgs
}
fn ask_name_to(&self, index: &usize) -> Result<(), SessionError> {
let mut sessions = self.sessions.write().unwrap();
match sessions.get_mut(index) {
Some(session) => {
session.encrypt_and_send(&protocol::ask_name())
},
None => Err(SessionError::InvalidSessionId)
}
}
pub fn get_peer_public_key(&self, index: &usize) -> Option<[u8; PUBLIC_KEY_LENGTH]> {
let sessions = self.sessions.read().unwrap();
let session = sessions.get(index)?;
session.peer_public_key
}
pub fn add_contact(&self, index: usize, name: String) -> Result<(), rusqlite::Error> {
let contact = self.identity.read().unwrap().as_ref().unwrap().add_contact(name, self.get_peer_public_key(&index).unwrap())?;
self.loaded_contacts.write().unwrap().insert(index, contact);
Ok(())
}
pub fn remove_contact(&self, index: &usize) -> Result<(), rusqlite::Error> {
let mut loaded_contacts = self.loaded_contacts.write().unwrap();
let result = self.identity.read().unwrap().as_ref().unwrap().remove_contact(&loaded_contacts.get(index).unwrap().uuid);
if result.is_ok() {
loaded_contacts.remove(index);
}
result
}
pub fn set_verified(&self, index: &usize) -> Result<(), rusqlite::Error> {
let mut loaded_contacts = self.loaded_contacts.write().unwrap();
let contact = loaded_contacts.get_mut(index).unwrap();
let result = self.identity.read().unwrap().as_ref().unwrap().set_verified(&contact.uuid);
if result.is_ok() {
contact.verified = true;
}
result
}
pub fn is_contact(&self, index: &usize) -> bool {
self.loaded_contacts.read().unwrap().contains_key(index)
}
pub fn load_file(&self, uuid: Uuid) -> Option<Vec<u8>> {
self.identity.read().unwrap().as_ref().unwrap().load_file(uuid)
}
pub fn store_file(&self, index: &usize, data: &[u8]) -> Result<Uuid, rusqlite::Error> {
self.identity.read().unwrap().as_ref().unwrap().store_file(match self.loaded_contacts.read().unwrap().get(index) {
Some(contact) => Some(contact.uuid),
None => None
}, data)
}
pub fn store_msg(&self, index: &usize, outgoing: bool, data: &[u8]) -> Result<(), rusqlite::Error> {
self.identity.read().unwrap().as_ref().unwrap().store_msg(&self.loaded_contacts.read().unwrap().get(index).unwrap().uuid, outgoing, data)
}
pub fn load_msgs(&self, index: &usize) -> Option<Vec<(bool, Vec<u8>)>> {
self.identity.read().unwrap().as_ref().unwrap().load_msgs(&self.loaded_contacts.read().unwrap().get(index).unwrap().uuid)
}
pub fn get_public_keys(&self, index: &usize) -> ([u8; PUBLIC_KEY_LENGTH], [u8; PUBLIC_KEY_LENGTH]) {
(self.identity.read().unwrap().as_ref().unwrap().get_public_key(), self.loaded_contacts.read().unwrap().get(index).unwrap().public_key)
}
fn clear_identity_related_data(&self){
self.loaded_contacts.write().unwrap().clear();
let mut msg_queue = self.msg_queue.lock().unwrap();
msg_queue.iter_mut().for_each(|m| m.1.zeroize());
msg_queue.clear();
}
pub fn stop(&self) {
*self.is_stopping.write().unwrap() = true;
self.set_identity(None);
*self.ui_connection.lock().unwrap() = None;
}
pub fn set_identity(&self, identity: Option<Identity>) {
let mut identity_guard = self.identity.write().unwrap();
match identity_guard.as_mut() {
Some(previous_identity) => {
previous_identity.zeroize();
self.sessions.write().unwrap().clear();
*self.session_counter.write().unwrap() = 0;
self.clear_identity_related_data();
}
None => {}
}
*identity_guard = identity;
if identity_guard.is_some() {
match identity_guard.as_ref().unwrap().load_contacts() {
Some(contacts) => {
let mut loaded_contacts = self.loaded_contacts.write().unwrap();
let mut session_counter = self.session_counter.write().unwrap();
contacts.into_iter().for_each(|contact|{
loaded_contacts.insert(*session_counter, contact);
*session_counter += 1;
})
}
None => {}
}
}
}
pub fn set_ui_connection(&self, ui_connection: &Arc<Mutex<UiConnection>>){
*self.ui_connection.lock().unwrap() = Some(ui_connection.clone());
}
pub fn new() -> SessionManager {
SessionManager {
session_counter: RwLock::new(0),
sessions: RwLock::new(HashMap::new()),
identity: RwLock::new(None),
ui_connection: Mutex::new(None),
loaded_contacts: RwLock::new(HashMap::new()),
msg_queue: Mutex::new(Vec::new()),
is_stopping: RwLock::new(false)
}
}
}