diff --git a/src/lib.rs b/src/lib.rs index da5345d..be9828e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,9 +60,6 @@ const RANDOM_LEN: usize = 64; const MESSAGE_LEN_LEN: usize = 4; type MessageLenType = u32; -const DEFAULT_PADDED_MAX_SIZE: usize = 32768000; -const DEFAULT_MAX_RECV_SIZE: usize = MESSAGE_LEN_LEN + DEFAULT_PADDED_MAX_SIZE + AES_TAG_LEN; - /// The length of a PSEC public key, in bytes. pub const PUBLIC_KEY_LENGTH: usize = ed25519_dalek::PUBLIC_KEY_LENGTH; @@ -187,27 +184,28 @@ async fn encrypt_and_send(writer: &mut T, local_cipher send(writer, &cipher_text).await } -async fn receive_and_decrypt(reader: &mut T, peer_cipher: &Aes128Gcm, peer_iv: &[u8], peer_counter: &mut usize, max_recv_size: usize) -> Result, PsecError> { +async fn receive_and_decrypt(reader: &mut T, peer_cipher: &Aes128Gcm, peer_iv: &[u8], peer_counter: &mut usize, max_recv_size: Option) -> Result, PsecError> { let mut message_len = [0; MESSAGE_LEN_LEN]; receive(reader, &mut message_len).await?; let recv_len = MessageLenType::from_be_bytes(message_len) as usize + AES_TAG_LEN; - if recv_len <= max_recv_size { - let mut cipher_text = vec![0; recv_len]; - let mut read = 0; - while read < recv_len { - read += receive(reader, &mut cipher_text[read..]).await?; + if let Some(max_recv_size) = max_recv_size { + if recv_len > max_recv_size { + return Err(PsecError::BufferTooLarge); } - let peer_nonce = crypto::iv_to_nonce(peer_iv, peer_counter); - let payload = Payload { - msg: &cipher_text, - aad: &message_len - }; - match peer_cipher.decrypt(Nonce::from_slice(&peer_nonce), payload) { - Ok(plain_text) => unpad(plain_text), - Err(_) => Err(PsecError::TransmissionCorrupted) - } - } else { - Err(PsecError::BufferTooLarge) + } + let mut cipher_text = vec![0; recv_len]; + let mut read = 0; + while read < recv_len { + read += receive(reader, &mut cipher_text[read..]).await?; + } + let peer_nonce = crypto::iv_to_nonce(peer_iv, peer_counter); + let payload = Payload { + msg: &cipher_text, + aad: &message_len + }; + match peer_cipher.decrypt(Nonce::from_slice(&peer_nonce), payload) { + Ok(plain_text) => unpad(plain_text), + Err(_) => Err(PsecError::TransmissionCorrupted) } } @@ -231,9 +229,7 @@ pub trait PsecReader { Any received buffer larger than this value will be discarded and a [`BufferTooLarge`](PsecError::BufferTooLarge) error will be returned. Then, the PSEC session should be closed to prevent any DOS attacks. - If `is_raw_size` is set to `true`, the specified `size` will correspond to the maximum encrypted buffer size, including potential padding. Otherwise, the maximum buffer size will correspond to the length of a message of this size after applying padding and encryption. - - The default value is 32 768 020, which allows to receive any messages under 32 768 000 bytes.*/ + If `is_raw_size` is set to `true`, the specified `size` will correspond to the maximum encrypted buffer size, including potential padding. Otherwise, the maximum buffer size will correspond to the length of a message of this size after applying padding and encryption.*/ fn set_max_recv_size(&mut self, size: usize, is_raw_size: bool); /** Read then decrypt from a PSEC session. @@ -282,7 +278,7 @@ pub trait PsecReader { pub trait PsecWriter { /** Encrypt then send through a PSEC session. - `use_padding` specifies whether or not the plain text length should be obfuscated with padding. Enabling padding will use more network bandwidth: all messages below 1KB will be padded to 1KB and then the padded length doubles at each step (2KB, 4KB, 8KB...). When sending a buffer of 17MB, it will padded to 32MB. + `use_padding` specifies whether or not the plain text length should be obfuscated with padding. Enabling padding will use more network bandwidth: all messages below 1KB will be padded to 1KB and then the padded length doubles at each step (2KB, 4KB, 8KB...). For example, a buffer of 1.5KB will be padded to 2KB, and a buffer of 3KB will be padded to 4KB. # Panic Panics if the PSEC handshake is not finished and successful. @@ -325,7 +321,7 @@ pub struct SessionReadHalf { peer_cipher: Aes128Gcm, peer_iv: [u8; crypto::IV_LEN], peer_counter: usize, - max_recv_size: usize, + max_recv_size: Option, } #[cfg(feature = "split")] @@ -344,7 +340,7 @@ impl Debug for SessionReadHalf { #[async_trait] impl PsecReader for SessionReadHalf { fn set_max_recv_size(&mut self, size: usize, is_raw_size: bool) { - self.max_recv_size = compute_max_recv_size(size, is_raw_size) + self.max_recv_size = Some(compute_max_recv_size(size, is_raw_size)); } async fn receive_and_decrypt(&mut self) -> Result, PsecError> { receive_and_decrypt(&mut self.read_half, &self.peer_cipher, &self.peer_iv, &mut self.peer_counter, self.max_recv_size).await @@ -397,7 +393,7 @@ pub struct Session { peer_cipher: Option, peer_iv: Option<[u8; crypto::IV_LEN]>, peer_counter: usize, - max_recv_size: usize, + max_recv_size: Option, /** The public key of the remote peer. It is `None` before the PSEC handshake was performed. After a successful call to [`do_handshake`](Session::do_handshake), the field is `Some`. If the handshake was not successful, the field can be either `Some` or `None` depending on where the handshake failed. @@ -611,7 +607,7 @@ impl PsecWriter for Session { #[async_trait] impl PsecReader for Session { fn set_max_recv_size(&mut self, size: usize, is_raw_size: bool) { - self.max_recv_size = compute_max_recv_size(size, is_raw_size); + self.max_recv_size = Some(compute_max_recv_size(size, is_raw_size)); } async fn receive_and_decrypt(&mut self) -> Result, PsecError> { receive_and_decrypt(&mut self.stream, &self.peer_cipher.as_ref().unwrap(), &self.peer_iv.unwrap(), &mut self.peer_counter, self.max_recv_size).await @@ -632,7 +628,7 @@ impl From for Session { peer_iv: None, peer_counter: 0, peer_public_key: None, - max_recv_size: DEFAULT_MAX_RECV_SIZE, + max_recv_size: None, } } } diff --git a/tests/psec_test.rs b/tests/psec_test.rs index f6b15fa..77a1922 100644 --- a/tests/psec_test.rs +++ b/tests/psec_test.rs @@ -11,7 +11,7 @@ async fn tokio_main() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let bind_addr = listener.local_addr().unwrap(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { let (stream, addr) = listener.accept().await.unwrap(); let mut session = Session::from(stream); @@ -23,7 +23,7 @@ async fn tokio_main() { session.encrypt_and_send(b"Hello I'm Bob", true).await.unwrap(); assert_eq!(session.receive_and_decrypt().await.unwrap(), b"Hello I'm Alice"); - session.encrypt_and_send("!".repeat(997).as_bytes(), false).await.unwrap(); + session.encrypt_and_send("!".repeat(997).as_bytes(), true).await.unwrap(); assert_eq!(session.receive_and_decrypt().await, Err(PsecError::TransmissionCorrupted)); }); @@ -36,13 +36,16 @@ async fn tokio_main() { session.do_handshake(&client_keypair).await.unwrap(); assert_eq!(session.peer_public_key.unwrap(), server_public_key); + session.set_max_recv_size(996, false); + session.encrypt_and_send(b"Hello I'm Alice", true).await.unwrap(); assert_eq!(session.receive_and_decrypt().await.unwrap(), b"Hello I'm Bob"); - session.set_max_recv_size(1, false); assert_eq!(session.receive_and_decrypt().await, Err(PsecError::BufferTooLarge)); session.send(b"\x00\x00\x00\x00not encrypted data").await.unwrap(); + + handle.await.unwrap(); } #[test]