406 lines
17 KiB
Rust
406 lines
17 KiB
Rust
mod session;
|
|
pub mod protocol;
|
|
|
|
use std::{collections::HashMap, io::ErrorKind, net::{SocketAddr, TcpStream}, sync::{Arc, Mutex, MutexGuard, RwLock}, thread, thread::sleep, time::Duration};
|
|
use socket2::{Socket, Domain, Type};
|
|
use session::Session;
|
|
use strum_macros::Display;
|
|
use ed25519_dalek::PUBLIC_KEY_LENGTH;
|
|
use uuid::Uuid;
|
|
use zeroize::Zeroize;
|
|
use crate::{constants, identity::{Contact, Identity}, print_error};
|
|
use crate::ui_interface::UiConnection;
|
|
|
|
#[derive(Display, Debug, PartialEq, Eq)]
|
|
pub enum SessionError {
|
|
SocketTimeout,
|
|
ConnectionReset,
|
|
BrokenPipe,
|
|
TransmissionCorrupted,
|
|
BufferTooLarge,
|
|
ConnectFailed,
|
|
InvalidSessionId,
|
|
BindFailed,
|
|
AlreadyConnected,
|
|
IsUs,
|
|
Unknown
|
|
}
|
|
|
|
pub struct SessionManager {
|
|
session_counter: RwLock<usize>,
|
|
sessions: RwLock<HashMap<usize, Session>>,
|
|
identity: RwLock<Option<Identity>>,
|
|
ui_connection: Mutex<Option<Arc<Mutex<UiConnection>>>>,
|
|
loaded_contacts: RwLock<HashMap<usize, Contact>>,
|
|
msg_queue: Mutex<Vec<(usize, Vec<u8>)>>,
|
|
is_stopping: RwLock<bool>
|
|
}
|
|
|
|
impl SessionManager {
|
|
|
|
fn with_ui_connection<F>(&self, f: F) -> bool where F: Fn(MutexGuard<UiConnection>) {
|
|
let mut ui_connection_opt = self.ui_connection.lock().unwrap();
|
|
let ui_connection = ui_connection_opt.as_mut().unwrap().lock().unwrap();
|
|
if ui_connection.is_valid {
|
|
f(ui_connection);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
fn do_handshake_then_add(&self, mut session: Session) -> Result<usize, SessionError>{
|
|
let identity_opt = self.identity.read().unwrap();
|
|
let identity = identity_opt.as_ref().unwrap();
|
|
session.do_handshake(identity)?;
|
|
let peer_public_key = session.peer_public_key.unwrap();
|
|
if identity.get_public_key() == peer_public_key { //did handshake with the same Identity
|
|
return Err(SessionError::IsUs);
|
|
}
|
|
let mut sessions = self.sessions.write().unwrap();
|
|
for (_, registered_session) in sessions.iter() {
|
|
if registered_session.peer_public_key.unwrap() == peer_public_key { //already connected with a different addr
|
|
return Err(SessionError::AlreadyConnected)
|
|
}
|
|
}
|
|
for (index, contact) in self.loaded_contacts.read().unwrap().iter() {
|
|
if contact.public_key == peer_public_key { //session is a known contact. Assign the contact index to it
|
|
sessions.insert(*index, session);
|
|
return Ok(*index)
|
|
}
|
|
}
|
|
//if not a contact, increment the session_counter
|
|
let mut session_counter = self.session_counter.write().unwrap();
|
|
sessions.insert(*session_counter, session);
|
|
let r = *session_counter;
|
|
*session_counter += 1;
|
|
Ok(r)
|
|
}
|
|
|
|
pub fn connect_to(&self, ip: &str) -> Result<usize, SessionError> {
|
|
let sessions = self.sessions.read().unwrap();
|
|
for (_, s) in sessions.iter() {
|
|
if s.get_ip() == ip {
|
|
return Err(SessionError::AlreadyConnected)
|
|
}
|
|
}
|
|
drop(sessions); //release mutex
|
|
match TcpStream::connect((ip, constants::PORT.parse().unwrap())) {
|
|
Ok(stream) => {
|
|
let session = Session::new(stream);
|
|
self.do_handshake_then_add(session)
|
|
}
|
|
Err(_) => Err(SessionError::ConnectFailed)
|
|
}
|
|
}
|
|
|
|
pub fn send_to(&self, index: &usize, message: &[u8]) -> Result<(), SessionError> {
|
|
let mut sessions = self.sessions.write().unwrap();
|
|
match sessions.get_mut(index) {
|
|
Some(session) => session.encrypt_and_send(message),
|
|
None => Err(SessionError::InvalidSessionId)
|
|
}
|
|
}
|
|
|
|
pub fn start_receiver_loop(session_manager: &Arc<SessionManager>) {
|
|
let session_manager_clone = Arc::clone(session_manager);
|
|
thread::spawn(move || {
|
|
loop {
|
|
let mut dead_sessions = Vec::new();
|
|
let mut sessions = session_manager_clone.sessions.write().unwrap();
|
|
for (index, session) in sessions.iter_mut() {
|
|
let mut dead_session = false;
|
|
match session.receive_and_decrypt() {
|
|
Ok(buffer) => {
|
|
if buffer[0] == protocol::Headers::ASK_NAME {
|
|
session.encrypt_and_send(&protocol::tell_name(&session_manager_clone.identity.read().unwrap().as_ref().unwrap().name)).unwrap();
|
|
} else {
|
|
let buffer = if buffer[0] == protocol::Headers::FILE {
|
|
let file_name_len = u16::from_be_bytes([buffer[1], buffer[2]]) as usize;
|
|
let file_name = &buffer[3..3+file_name_len];
|
|
match session_manager_clone.store_file(index, &buffer[3+file_name_len..]) {
|
|
Ok(file_uuid) => {
|
|
Some([&[protocol::Headers::FILE][..], file_uuid.as_bytes(), file_name].concat())
|
|
}
|
|
Err(e) => {
|
|
print_error(e);
|
|
None
|
|
}
|
|
}
|
|
} else {
|
|
Some(buffer)
|
|
};
|
|
if buffer.is_some() {
|
|
let mut msg_saved = false;
|
|
if session_manager_clone.is_contact(index) && buffer.as_ref().unwrap()[0] != protocol::Headers::TELL_NAME {
|
|
match session_manager_clone.store_msg(&index, false, &buffer.as_ref().unwrap()) {
|
|
Ok(_) => msg_saved = true,
|
|
Err(e) => print_error(e)
|
|
}
|
|
}
|
|
let ui_connection_valid = session_manager_clone.with_ui_connection(|mut ui_connection| {
|
|
ui_connection.on_received(index, &buffer.as_ref().unwrap());
|
|
});
|
|
if !ui_connection_valid && !msg_saved {
|
|
session_manager_clone.msg_queue.lock().unwrap().push((*index, buffer.unwrap()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
if e == SessionError::BrokenPipe {
|
|
dead_session = true
|
|
} else if e != SessionError::SocketTimeout {
|
|
print_error(e);
|
|
}
|
|
}
|
|
}
|
|
if dead_session {
|
|
session_manager_clone.with_ui_connection(|mut ui_connection| {
|
|
ui_connection.on_disconnected(*index);
|
|
});
|
|
dead_sessions.push(*index);
|
|
}
|
|
}
|
|
dead_sessions.into_iter().for_each(|index| {
|
|
sessions.remove(&index);
|
|
});
|
|
drop(sessions); //release mutex
|
|
if *session_manager_clone.is_stopping.read().unwrap() {
|
|
break;
|
|
}
|
|
sleep(Duration::from_millis(constants::MUTEX_RELEASE_DELAY_MS));
|
|
}
|
|
println!("Stopping receiver thread");
|
|
});
|
|
}
|
|
|
|
pub fn start_listener(session_manager: &Arc<SessionManager>) -> Result<(), SessionError> {
|
|
let socket_v6 = Socket::new(Domain::ipv6(), Type::stream(), None).unwrap();
|
|
let socket_v4 = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
|
|
socket_v4.set_reuse_address(true).unwrap();
|
|
socket_v6.set_reuse_address(true).unwrap();
|
|
let addr_v6 = "[::1]:".to_owned()+constants::PORT;
|
|
let addr_v4 = "0.0.0.0:".to_owned()+constants::PORT;
|
|
let mut sockets = Vec::new();
|
|
match socket_v6.bind(&addr_v6.parse::<SocketAddr>().unwrap().into()) {
|
|
Ok(_) => sockets.push(socket_v6),
|
|
Err(e) => println!("Unable to bind on IPv6: {}", e)
|
|
};
|
|
match socket_v4.bind(&addr_v4.parse::<SocketAddr>().unwrap().into()) {
|
|
Ok(_) => sockets.push(socket_v4),
|
|
Err(e) => println!("Unable to bind on IPv4: {}", e)
|
|
}
|
|
if sockets.len() > 0 {
|
|
println!("Listening on port {}...", constants::PORT);
|
|
for socket in sockets {
|
|
socket.listen(256).unwrap();
|
|
socket.set_read_timeout(Some(Duration::from_millis(100))).unwrap();
|
|
let session_manager_clone = Arc::clone(session_manager);
|
|
thread::spawn(move ||{
|
|
for stream in socket.into_tcp_listener().incoming() {
|
|
match stream {
|
|
Ok(stream) => {
|
|
let session = Session::new(stream);
|
|
match session_manager_clone.do_handshake_then_add(session) {
|
|
Ok(index) => {
|
|
session_manager_clone.with_ui_connection(|mut ui_connection| {
|
|
ui_connection.on_new_session(index);
|
|
session_manager_clone.handle_new_session(&index, ui_connection);
|
|
});
|
|
}
|
|
Err(e) => {
|
|
if e != SessionError::AlreadyConnected && e != SessionError::IsUs {
|
|
print_error(e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
if e.kind() != ErrorKind::WouldBlock {
|
|
print_error(e);
|
|
}
|
|
}
|
|
}
|
|
if *session_manager_clone.is_stopping.read().unwrap() {
|
|
break;
|
|
}
|
|
}
|
|
println!("Stopping listener thread");
|
|
});
|
|
}
|
|
Ok(())
|
|
} else {
|
|
Err(SessionError::BindFailed)
|
|
}
|
|
}
|
|
|
|
pub fn handle_new_session(&self, index: &usize, mut ui_connection: MutexGuard<UiConnection>) {
|
|
if self.is_contact(index) {
|
|
match self.load_msgs(index) {
|
|
Some(msgs) => {
|
|
ui_connection.load_msgs(index, msgs);
|
|
}
|
|
None => {}
|
|
}
|
|
} else {
|
|
match self.ask_name_to(&index) {
|
|
Ok(_) => {}
|
|
Err(e) => print_error(e)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn list_sessions(&self) -> Vec<usize> {
|
|
let sessions = self.sessions.read().unwrap();
|
|
sessions.iter().map(|t| *t.0).collect()
|
|
}
|
|
|
|
pub fn list_contacts(&self) -> Vec<(usize, String, bool)> {
|
|
self.loaded_contacts.read().unwrap().iter().map(|c| (*c.0, c.1.name.clone(), c.1.verified)).collect()
|
|
}
|
|
|
|
pub fn get_identity_uuid(&self) -> Option<Uuid> {
|
|
Some(self.identity.read().unwrap().as_ref()?.uuid)
|
|
}
|
|
|
|
pub fn get_saved_msgs(&self) -> Vec<(usize, Vec<u8>)> {
|
|
let mut msgs = Vec::new();
|
|
let mut msg_queue = self.msg_queue.lock().unwrap();
|
|
let sessions = self.sessions.read().unwrap();
|
|
for i in 0..msg_queue.len() {
|
|
let mut entry = msg_queue.remove(i);
|
|
if sessions.contains_key(&entry.0) {
|
|
msgs.push(entry);
|
|
} else {
|
|
entry.1.zeroize();
|
|
}
|
|
};
|
|
msgs
|
|
}
|
|
|
|
fn ask_name_to(&self, index: &usize) -> Result<(), SessionError> {
|
|
let mut sessions = self.sessions.write().unwrap();
|
|
match sessions.get_mut(index) {
|
|
Some(session) => {
|
|
session.encrypt_and_send(&protocol::ask_name())
|
|
},
|
|
None => Err(SessionError::InvalidSessionId)
|
|
}
|
|
}
|
|
|
|
pub fn get_peer_public_key(&self, index: &usize) -> Option<[u8; PUBLIC_KEY_LENGTH]> {
|
|
let sessions = self.sessions.read().unwrap();
|
|
let session = sessions.get(index)?;
|
|
session.peer_public_key
|
|
}
|
|
|
|
pub fn add_contact(&self, index: usize, name: String) -> Result<(), rusqlite::Error> {
|
|
let contact = self.identity.read().unwrap().as_ref().unwrap().add_contact(name, self.get_peer_public_key(&index).unwrap())?;
|
|
self.loaded_contacts.write().unwrap().insert(index, contact);
|
|
Ok(())
|
|
}
|
|
|
|
pub fn remove_contact(&self, index: &usize) -> Result<(), rusqlite::Error> {
|
|
let mut loaded_contacts = self.loaded_contacts.write().unwrap();
|
|
let result = self.identity.read().unwrap().as_ref().unwrap().remove_contact(&loaded_contacts.get(index).unwrap().uuid);
|
|
if result.is_ok() {
|
|
loaded_contacts.remove(index);
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn set_verified(&self, index: &usize) -> Result<(), rusqlite::Error> {
|
|
let mut loaded_contacts = self.loaded_contacts.write().unwrap();
|
|
let contact = loaded_contacts.get_mut(index).unwrap();
|
|
let result = self.identity.read().unwrap().as_ref().unwrap().set_verified(&contact.uuid);
|
|
if result.is_ok() {
|
|
contact.verified = true;
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn is_contact(&self, index: &usize) -> bool {
|
|
self.loaded_contacts.read().unwrap().contains_key(index)
|
|
}
|
|
|
|
pub fn load_file(&self, uuid: Uuid) -> Option<Vec<u8>> {
|
|
self.identity.read().unwrap().as_ref().unwrap().load_file(uuid)
|
|
}
|
|
|
|
pub fn store_file(&self, index: &usize, data: &[u8]) -> Result<Uuid, rusqlite::Error> {
|
|
self.identity.read().unwrap().as_ref().unwrap().store_file(match self.loaded_contacts.read().unwrap().get(index) {
|
|
Some(contact) => Some(contact.uuid),
|
|
None => None
|
|
}, data)
|
|
}
|
|
|
|
pub fn store_msg(&self, index: &usize, outgoing: bool, data: &[u8]) -> Result<(), rusqlite::Error> {
|
|
self.identity.read().unwrap().as_ref().unwrap().store_msg(&self.loaded_contacts.read().unwrap().get(index).unwrap().uuid, outgoing, data)
|
|
}
|
|
|
|
pub fn load_msgs(&self, index: &usize) -> Option<Vec<(bool, Vec<u8>)>> {
|
|
self.identity.read().unwrap().as_ref().unwrap().load_msgs(&self.loaded_contacts.read().unwrap().get(index).unwrap().uuid)
|
|
}
|
|
|
|
pub fn get_public_keys(&self, index: &usize) -> ([u8; PUBLIC_KEY_LENGTH], [u8; PUBLIC_KEY_LENGTH]) {
|
|
(self.identity.read().unwrap().as_ref().unwrap().get_public_key(), self.loaded_contacts.read().unwrap().get(index).unwrap().public_key)
|
|
}
|
|
|
|
fn clear_identity_related_data(&self){
|
|
self.loaded_contacts.write().unwrap().clear();
|
|
let mut msg_queue = self.msg_queue.lock().unwrap();
|
|
msg_queue.iter_mut().for_each(|m| m.1.zeroize());
|
|
msg_queue.clear();
|
|
}
|
|
|
|
pub fn stop(&self) {
|
|
*self.is_stopping.write().unwrap() = true;
|
|
self.set_identity(None);
|
|
*self.ui_connection.lock().unwrap() = None;
|
|
}
|
|
|
|
pub fn set_identity(&self, identity: Option<Identity>) {
|
|
let mut identity_guard = self.identity.write().unwrap();
|
|
match identity_guard.as_mut() {
|
|
Some(previous_identity) => {
|
|
previous_identity.zeroize();
|
|
self.sessions.write().unwrap().clear();
|
|
*self.session_counter.write().unwrap() = 0;
|
|
self.clear_identity_related_data();
|
|
}
|
|
None => {}
|
|
}
|
|
*identity_guard = identity;
|
|
if identity_guard.is_some() {
|
|
match identity_guard.as_ref().unwrap().load_contacts() {
|
|
Some(contacts) => {
|
|
let mut loaded_contacts = self.loaded_contacts.write().unwrap();
|
|
let mut session_counter = self.session_counter.write().unwrap();
|
|
contacts.into_iter().for_each(|contact|{
|
|
loaded_contacts.insert(*session_counter, contact);
|
|
*session_counter += 1;
|
|
})
|
|
}
|
|
None => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn set_ui_connection(&self, ui_connection: &Arc<Mutex<UiConnection>>){
|
|
*self.ui_connection.lock().unwrap() = Some(ui_connection.clone());
|
|
}
|
|
|
|
pub fn new() -> SessionManager {
|
|
SessionManager {
|
|
session_counter: RwLock::new(0),
|
|
sessions: RwLock::new(HashMap::new()),
|
|
identity: RwLock::new(None),
|
|
ui_connection: Mutex::new(None),
|
|
loaded_contacts: RwLock::new(HashMap::new()),
|
|
msg_queue: Mutex::new(Vec::new()),
|
|
is_stopping: RwLock::new(false)
|
|
}
|
|
}
|
|
}
|