From a988c537e2ee84c12468a4f2bc1635d92683f1ac Mon Sep 17 00:00:00 2001 From: Hardcore Sushi Date: Wed, 5 May 2021 12:30:37 +0200 Subject: [PATCH] Session split --- src/session_manager/mod.rs | 35 +++-- src/session_manager/session.rs | 245 ++++++++++++++++++++------------- 2 files changed, 174 insertions(+), 106 deletions(-) diff --git a/src/session_manager/mod.rs b/src/session_manager/mod.rs index 1fd07f5..fd892ca 100644 --- a/src/session_manager/mod.rs +++ b/src/session_manager/mod.rs @@ -12,6 +12,8 @@ use platform_dirs::UserDirs; use crate::{constants, crypto, discovery, identity::{Contact, Identity}, print_error, utils::{get_unix_timestamp, get_not_used_path}}; use crate::ui_interface::UiConnection; +use self::session::SessionWrite; + #[derive(Display, Debug, PartialEq, Eq)] pub enum SessionError { ConnectionReset, @@ -173,8 +175,8 @@ impl SessionManager { self.not_seen.write().unwrap().retain(|x| x != session_id); } - async fn send_msg(&self, session_id: usize, session: &mut Session, buff: &[u8], aborted: &mut bool, file_ack_sender: Option<&Sender>) -> Result<(), SessionError> { - session.encrypt_and_send(&buff).await?; + async fn send_msg(&self, session_id: usize, session_write: &mut SessionWrite, buff: &[u8], aborted: &mut bool, file_ack_sender: Option<&Sender>) -> Result<(), SessionError> { + session_write.encrypt_and_send(&buff).await?; if buff[0] == protocol::Headers::ACCEPT_LARGE_FILE { self.sessions.write().unwrap().get_mut(&session_id).unwrap().file_download.as_mut().unwrap().state = FileState::ACCEPTED; } else if buff[0] == protocol::Headers::ABORT_FILE_TRANSFER { @@ -192,7 +194,7 @@ impl SessionManager { Ok(()) } - async fn session_worker(&self, session_id: usize, mut receiver: Receiver, mut session: Session) { + async fn session_worker(&self, session_id: usize, mut receiver: Receiver, session: Session) { //used when we receive large file let mut local_file_path = None; let mut local_file_handle = None; @@ -201,18 +203,23 @@ impl SessionManager { let mut file_ack_sender: Option> = None; let mut msg_queue = Vec::new(); let mut aborted = false; + + let (session_read, mut session_write) = session.into_spit().unwrap(); + let receiving = session_read.receive_and_decrypt(); + tokio::pin!(receiving); loop { tokio::select! { - buffer = session.receive_and_decrypt() => { - match buffer { - Ok(buffer) => { + result = &mut receiving => { + match result { + Ok((session_read, buffer)) => { + receiving.set(session_read.receive_and_decrypt()); match buffer[0] { protocol::Headers::ASK_NAME => { let name = { self.identity.read().unwrap().as_ref().and_then(|identity| Some(identity.name.clone())) }; if name.is_some() { //can be None if we log out just before locking the identity mutex - if let Err(e) = session.encrypt_and_send(&protocol::tell_name(&name.unwrap())).await { + if let Err(e) = session_write.encrypt_and_send(&protocol::tell_name(&name.unwrap())).await { print_error!(e); break; } @@ -254,7 +261,7 @@ impl SessionManager { ui_connection.on_ask_large_file(&session_id, file_size, &file_name, download_dir.to_str().unwrap()); }) } - } else if let Err(e) = session.encrypt_and_send(&[protocol::Headers::ABORT_FILE_TRANSFER]).await { + } else if let Err(e) = session_write.encrypt_and_send(&[protocol::Headers::ABORT_FILE_TRANSFER]).await { print_error!(e); break; } @@ -303,7 +310,7 @@ impl SessionManager { local_file_handle = None; } } - if let Err(e) = session.encrypt_and_send(&[protocol::Headers::ACK_CHUNK]).await { + if let Err(e) = session_write.encrypt_and_send(&[protocol::Headers::ACK_CHUNK]).await { print_error!(e); break; } @@ -319,7 +326,7 @@ impl SessionManager { self.sessions.write().unwrap().get_mut(&session_id).unwrap().file_download = None; local_file_path = None; local_file_handle = None; - if let Err(e) = session.encrypt_and_send(&[protocol::Headers::ABORT_FILE_TRANSFER]).await { + if let Err(e) = session_write.encrypt_and_send(&[protocol::Headers::ABORT_FILE_TRANSFER]).await { print_error!(e); break; } @@ -405,7 +412,7 @@ impl SessionManager { SessionCommand::Send { buff } => { //don't send msg if we already encrypted a file chunk (keep PSEC nonces synchronized) if next_chunk.is_none() || aborted { - if let Err(e) = self.send_msg(session_id, &mut session, &buff, &mut aborted, file_ack_sender.as_ref()).await { + if let Err(e) = self.send_msg(session_id, &mut session_write, &buff, &mut aborted, file_ack_sender.as_ref()).await { print_error!(e); break; } @@ -413,16 +420,16 @@ impl SessionManager { msg_queue.push(buff); } } - SessionCommand::EncryptFileChunk { plain_text } => next_chunk = Some(session.encrypt(&plain_text)), + SessionCommand::EncryptFileChunk { plain_text } => next_chunk = Some(session_write.encrypt(&plain_text)), SessionCommand::SendEncryptedFileChunk { sender } => { if let Some(chunk) = next_chunk.as_ref() { - match session.socket_write(chunk).await { + match session_write.socket_write(chunk).await { Ok(_) => { file_ack_sender = Some(sender); //once the pre-encrypted chunk is sent, we can send the pending messages while msg_queue.len() > 0 { let msg = msg_queue.remove(0); - if let Err(e) = self.send_msg(session_id, &mut session, &msg, &mut aborted, file_ack_sender.as_ref()).await { + if let Err(e) = self.send_msg(session_id, &mut session_write, &msg, &mut aborted, file_ack_sender.as_ref()).await { print_error!(e); break; } diff --git a/src/session_manager/session.rs b/src/session_manager/session.rs index 2a4888b..26626bc 100644 --- a/src/session_manager/session.rs +++ b/src/session_manager/session.rs @@ -1,5 +1,5 @@ use std::{convert::TryInto, io::ErrorKind, net::IpAddr}; -use tokio::{net::TcpStream, io::{AsyncReadExt, AsyncWriteExt}}; +use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::{TcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}}}; use ed25519_dalek; use ed25519_dalek::{ed25519::signature::Signature, Verifier, PUBLIC_KEY_LENGTH, SIGNATURE_LENGTH}; use x25519_dalek; @@ -16,6 +16,135 @@ const RANDOM_LEN: usize = 64; const MESSAGE_LEN_LEN: usize = 4; type MessageLenType = u32; +async fn socket_read(reader: &mut T, buff: &mut [u8]) -> Result { + match reader.read(buff).await { + Ok(read) => { + if read > 0 { + Ok(read) + } else { + Err(SessionError::BrokenPipe) + } + } + Err(e) => { + match e.kind() { + ErrorKind::ConnectionReset => Err(SessionError::ConnectionReset), + _ => { + print_error!("Receive error ({:?}): {}", e.kind(), e); + Err(SessionError::Unknown) + } + } + } + } +} + +async fn socket_write(writer: &mut T, buff: &[u8]) -> Result<(), SessionError> { + match writer.write_all(buff).await { + Ok(_) => Ok(()), + Err(e) => Err(match e.kind() { + ErrorKind::BrokenPipe => SessionError::BrokenPipe, + ErrorKind::ConnectionReset => SessionError::ConnectionReset, + _ => { + print_error!("Send error ({:?}): {}", e.kind(), e); + SessionError::Unknown + } + }) + } +} + +fn pad(plain_text: &[u8]) -> Vec { + let encoded_msg_len = (plain_text.len() as MessageLenType).to_be_bytes(); + let msg_len = plain_text.len()+encoded_msg_len.len(); + let mut len = 1000; + while len < msg_len { + len *= 2; + } + let mut output = Vec::from(encoded_msg_len); + output.reserve(len); + output.extend(plain_text); + output.resize(len, 0); + OsRng.fill_bytes(&mut output[msg_len..]); + output +} + +fn unpad(input: Vec) -> Vec { + let msg_len = MessageLenType::from_be_bytes(input[0..MESSAGE_LEN_LEN].try_into().unwrap()) as usize; + Vec::from(&input[MESSAGE_LEN_LEN..MESSAGE_LEN_LEN+msg_len]) +} + +fn encrypt(local_cipher: &Aes128Gcm, local_iv: &[u8], local_counter: &mut usize, plain_text: &[u8]) -> Vec { + let padded_msg = pad(plain_text); + let cipher_len = (padded_msg.len() as MessageLenType).to_be_bytes(); + let payload = Payload { + msg: &padded_msg, + aad: &cipher_len + }; + let nonce = iv_to_nonce(local_iv, local_counter); + let cipher_text = local_cipher.encrypt(Nonce::from_slice(&nonce), payload).unwrap(); + [&cipher_len, cipher_text.as_slice()].concat() +} + +pub async fn encrypt_and_send(writer: &mut T, local_cipher: &Aes128Gcm, local_iv: &[u8], local_counter: &mut usize, plain_text: &[u8]) -> Result<(), SessionError> { + let cipher_text = encrypt(local_cipher, local_iv, local_counter, plain_text); + socket_write(writer, &cipher_text).await +} + +pub struct SessionRead { + read_half: OwnedReadHalf, + peer_cipher: Aes128Gcm, + peer_iv: [u8; IV_LEN], + peer_counter: usize, +} + +impl SessionRead { + async fn socket_read(&mut self, buff: &mut [u8]) -> Result { + socket_read(&mut self.read_half, buff).await + } + + pub async fn receive_and_decrypt(mut self) -> Result<(SessionRead, Vec), SessionError> { + let mut message_len = [0; MESSAGE_LEN_LEN]; + self.socket_read(&mut message_len).await?; + let recv_len = MessageLenType::from_be_bytes(message_len) as usize + AES_TAG_LEN; + if recv_len <= Session::MAX_RECV_SIZE { + let mut cipher_text = vec![0; recv_len]; + let mut read = 0; + while read < recv_len { + read += self.socket_read(&mut cipher_text[read..]).await?; + } + let peer_nonce = iv_to_nonce(&self.peer_iv, &mut self.peer_counter); + let payload = Payload { + msg: &cipher_text, + aad: &message_len + }; + match self.peer_cipher.decrypt(Nonce::from_slice(&peer_nonce), payload) { + Ok(plain_text) => Ok((self, unpad(plain_text))), + Err(_) => Err(SessionError::TransmissionCorrupted) + } + } else { + print_error!("Buffer too large: {} B", recv_len); + Err(SessionError::BufferTooLarge) + } + } +} + +pub struct SessionWrite { + write_half: OwnedWriteHalf, + local_cipher: Aes128Gcm, + local_iv: [u8; IV_LEN], + local_counter: usize, +} + +impl SessionWrite { + pub async fn encrypt_and_send(&mut self, plain_text: &[u8]) -> Result<(), SessionError> { + encrypt_and_send(&mut self.write_half, &self.local_cipher, &self.local_iv, &mut self.local_counter, plain_text).await + } + pub fn encrypt(&mut self, plain_text: &[u8]) -> Vec { + encrypt(&self.local_cipher, &self.local_iv, &mut self.local_counter, plain_text) + } + pub async fn socket_write(&mut self, cipher_text: &[u8]) -> Result<(), SessionError> { + socket_write(&mut self.write_half, cipher_text).await + } +} + pub struct Session { stream: TcpStream, handshake_sent_buff: Vec, @@ -48,43 +177,34 @@ impl Session { } } + pub fn into_spit(self) -> Option<(SessionRead, SessionWrite)> { + let (read_half, write_half) = self.stream.into_split(); + Some(( + SessionRead { + read_half, + peer_cipher: self.peer_cipher?, + peer_iv: self.peer_iv?, + peer_counter: self.peer_counter, + }, + SessionWrite { + write_half, + local_cipher: self.local_cipher?, + local_iv: self.local_iv?, + local_counter: self.local_counter, + } + )) + } + pub fn get_ip(&self) -> IpAddr { self.stream.peer_addr().unwrap().ip() } async fn socket_read(&mut self, buff: &mut [u8]) -> Result { - match self.stream.read(buff).await { - Ok(read) => { - if read > 0 { - Ok(read) - } else { - Err(SessionError::BrokenPipe) - } - } - Err(e) => { - match e.kind() { - ErrorKind::ConnectionReset => Err(SessionError::ConnectionReset), - _ => { - print_error!("Receive error ({:?}): {}", e.kind(), e); - Err(SessionError::Unknown) - } - } - } - } + socket_read(&mut self.stream, buff).await } pub async fn socket_write(&mut self, buff: &[u8]) -> Result<(), SessionError> { - match self.stream.write_all(buff).await { - Ok(_) => Ok(()), - Err(e) => Err(match e.kind() { - ErrorKind::BrokenPipe => SessionError::BrokenPipe, - ErrorKind::ConnectionReset => SessionError::ConnectionReset, - _ => { - print_error!("Send error ({:?}): {}", e.kind(), e); - SessionError::Unknown - } - }) - } + socket_write(&mut self.stream, buff).await } async fn handshake_read(&mut self, buff: &mut [u8]) -> Result<(), SessionError> { @@ -193,67 +313,8 @@ impl Session { } Err(SessionError::TransmissionCorrupted) } - - fn random_pad(message: &[u8]) -> Vec { - let encoded_msg_len = (message.len() as MessageLenType).to_be_bytes(); - let msg_len = message.len()+encoded_msg_len.len(); - let mut len = 1000; - while len < msg_len { - len *= 2; - } - let mut output = Vec::from(encoded_msg_len); - output.reserve(len); - output.extend(message); - output.resize(len, 0); - OsRng.fill_bytes(&mut output[msg_len..]); - output - } - - fn unpad(input: Vec) -> Vec { - let msg_len = MessageLenType::from_be_bytes(input[0..MESSAGE_LEN_LEN].try_into().unwrap()) as usize; - Vec::from(&input[MESSAGE_LEN_LEN..MESSAGE_LEN_LEN+msg_len]) - } - - pub fn encrypt(&mut self, message: &[u8]) -> Vec { - let padded_msg = Session::random_pad(message); - let cipher_len = (padded_msg.len() as MessageLenType).to_be_bytes(); - let payload = Payload { - msg: &padded_msg, - aad: &cipher_len - }; - let nonce = iv_to_nonce(&self.local_iv.unwrap(), &mut self.local_counter); - let cipher_text = self.local_cipher.as_ref().unwrap().encrypt(Nonce::from_slice(&nonce), payload).unwrap(); - [&cipher_len, cipher_text.as_slice()].concat() - } - pub async fn encrypt_and_send(&mut self, message: &[u8]) -> Result<(), SessionError> { - let cipher_text = self.encrypt(message); - self.socket_write(&cipher_text).await + pub async fn encrypt_and_send(&mut self, plain_text: &[u8]) -> Result<(), SessionError> { + encrypt_and_send(&mut self.stream, self.local_cipher.as_ref().unwrap(), self.local_iv.as_ref().unwrap(), &mut self.local_counter, plain_text).await } - - pub async fn receive_and_decrypt(&mut self) -> Result, SessionError> { - let mut message_len = [0; MESSAGE_LEN_LEN]; - self.socket_read(&mut message_len).await?; - let recv_len = MessageLenType::from_be_bytes(message_len) as usize + AES_TAG_LEN; - if recv_len <= Session::MAX_RECV_SIZE { - let mut cipher_text = vec![0; recv_len]; - let mut read = 0; - while read < recv_len { - read += self.socket_read(&mut cipher_text[read..]).await?; - } - let peer_nonce = iv_to_nonce(&self.peer_iv.unwrap(), &mut self.peer_counter); - let peer_cipher = self.peer_cipher.as_ref().unwrap(); - let payload = Payload { - msg: &cipher_text, - aad: &message_len - }; - match peer_cipher.decrypt(Nonce::from_slice(&peer_nonce), payload) { - Ok(plain_text) => Ok(Session::unpad(plain_text)), - Err(_) => Err(SessionError::TransmissionCorrupted) - } - } else { - print_error!("Buffer too large: {} B", recv_len); - Err(SessionError::BufferTooLarge) - } - } -} +} \ No newline at end of file