diff --git a/app/src/main/java/sushi/hardcore/aira/background_service/Session.kt b/app/src/main/java/sushi/hardcore/aira/background_service/Session.kt index 073dc70..653d93d 100644 --- a/app/src/main/java/sushi/hardcore/aira/background_service/Session.kt +++ b/app/src/main/java/sushi/hardcore/aira/background_service/Session.kt @@ -8,6 +8,7 @@ import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec import org.whispersystems.curve25519.Curve25519 import sushi.hardcore.aira.AIRADatabase import java.io.ByteArrayOutputStream +import java.io.OutputStream import java.nio.ByteBuffer import java.nio.channels.* import java.nio.channels.spi.SelectorProvider @@ -42,8 +43,6 @@ class Session(private val socket: SocketChannel, val outgoing: Boolean): Selecta } private val prng = SecureRandom() - private val handshakeSentBuff = ByteArrayOutputStream(handshakeBufferLen) - private val handshakeRecvBuff = ByteArrayOutputStream(handshakeBufferLen) private val peerCipher = Cipher.getInstance(CIPHER_TYPE) private val localCipher = Cipher.getInstance(CIPHER_TYPE) private var peerCounter = 0L @@ -52,42 +51,37 @@ class Session(private val socket: SocketChannel, val outgoing: Boolean): Selecta lateinit var peerPublicKey: ByteArray val ip: String = socket.socket().inetAddress.hostAddress - private fun handshakeWrite(buffer: ByteArray) { + private fun handshakeWrite(buffer: ByteArray, handshakeSentBuff: OutputStream) { writeAll(buffer) handshakeSentBuff.write(buffer) } - private fun handshakeRead(buffer: ByteBuffer): Boolean { - return if (socket.read(buffer) == buffer.position()) { + private fun handshakeRead(buffer: ByteBuffer, handshakeRecvBuff: OutputStream): Boolean { + return if (readAll(buffer)) { handshakeRecvBuff.write(buffer.array()) true } else { false } } - private fun handshakeRead(buffer: ByteArray): Boolean { - return handshakeRead(ByteBuffer.wrap(buffer)) - } - private fun hashHandshake(iAmBob: Boolean): ByteArray { + private fun hashHandshake(iAmBob: Boolean, handshakeSentBuff: ByteArray, handshakeRecvBuff: ByteArray): ByteArray { MessageDigest.getInstance("SHA-384").apply { if (iAmBob) { - update(handshakeSentBuff.toByteArray()) - update(handshakeRecvBuff.toByteArray()) + update(handshakeSentBuff) + update(handshakeRecvBuff) } else { - update(handshakeRecvBuff.toByteArray()) - update(handshakeSentBuff.toByteArray()) + update(handshakeRecvBuff) + update(handshakeSentBuff) } return digest() } } - private fun amIBob(): Boolean { - val s = handshakeSentBuff.toByteArray() - val r = handshakeRecvBuff.toByteArray() - for (i in s.indices) { - if (s[i] != r[i]) { - return s[i].toInt() and 0xff < r[i].toInt() and 0xff + private fun amIBob(handshakeSentBuff: ByteArray, handshakeRecvBuff: ByteArray): Boolean { + for (i in handshakeSentBuff.indices) { + if (handshakeSentBuff[i] != handshakeRecvBuff[i]) { + return handshakeSentBuff[i].toInt() and 0xff < handshakeRecvBuff[i].toInt() and 0xff } } throw SecurityException("Handshake buffers are identical") @@ -102,64 +96,62 @@ class Session(private val socket: SocketChannel, val outgoing: Boolean): Selecta } fun doHandshake(): Boolean { + val handshakeSentBuff = ByteArrayOutputStream(handshakeBufferLen) + val handshakeRecvBuff = ByteArrayOutputStream(handshakeBufferLen) + val randomBuffer = ByteArray(RANDOM_LEN) prng.nextBytes(randomBuffer) val curve25519Cipher = Curve25519.getInstance(Curve25519.BEST) val keypair = curve25519Cipher.generateKeyPair() - handshakeWrite(randomBuffer+keypair.publicKey) + handshakeWrite(randomBuffer+keypair.publicKey, handshakeSentBuff) var recvBuffer = ByteBuffer.allocate(RANDOM_LEN+PUBLIC_KEY_LEN) - if (handshakeRead(recvBuffer)) { + if (handshakeRead(recvBuffer, handshakeRecvBuff)) { val peerEphemeralPublicKey = recvBuffer.array().sliceArray(RANDOM_LEN until recvBuffer.capacity()) val sharedSecret = curve25519Cipher.calculateAgreement(peerEphemeralPublicKey, keypair.privateKey) - val iAmBob = amIBob() //mutual consensus for keys attribution - var handshakeHash = hashHandshake(iAmBob) + val iAmBob = amIBob(handshakeSentBuff.toByteArray(), handshakeRecvBuff.toByteArray()) //mutual consensus for keys attribution + var handshakeHash = hashHandshake(iAmBob, handshakeSentBuff.toByteArray(), handshakeRecvBuff.toByteArray()) val handshakeKeys = deriveHandshakeKeys(sharedSecret, handshakeHash, iAmBob) prng.nextBytes(randomBuffer) - handshakeWrite(randomBuffer) - if (handshakeRead(randomBuffer)) { - val localCipher = Cipher.getInstance(CIPHER_TYPE) - localCipher.init(Cipher.ENCRYPT_MODE, SecretKeySpec(handshakeKeys.localKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, ivToNonce(handshakeKeys.localIv, 0))) - handshakeWrite(localCipher.doFinal(AIRADatabase.getIdentityPublicKey()+sign(keypair.publicKey))) + val localCipher = Cipher.getInstance(CIPHER_TYPE) + localCipher.init(Cipher.ENCRYPT_MODE, SecretKeySpec(handshakeKeys.localKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, handshakeKeys.localIv)) + handshakeWrite(localCipher.doFinal(randomBuffer+AIRADatabase.getIdentityPublicKey()+sign(keypair.publicKey)), handshakeSentBuff) - recvBuffer = ByteBuffer.allocate(PUBLIC_KEY_LEN+SIGNATURE_LEN+AES_TAG_LEN) - if (handshakeRead(recvBuffer)) { - val peerCipher = Cipher.getInstance(CIPHER_TYPE) - peerCipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(handshakeKeys.peerKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, ivToNonce(handshakeKeys.peerIv, 0))) - val plainText: ByteArray - try { - plainText = peerCipher.doFinal(recvBuffer.array()) - } catch (e: BadPaddingException) { - Log.w("BadPaddingException", ip) - return false - } catch (e: AEADBadTagException) { - Log.w("AEADBadTagException", ip) - return false - } - peerPublicKey = plainText.sliceArray(0 until PUBLIC_KEY_LEN) - val signature = plainText.sliceArray(PUBLIC_KEY_LEN until plainText.size) + recvBuffer = ByteBuffer.allocate(RANDOM_LEN+PUBLIC_KEY_LEN+SIGNATURE_LEN+AES_TAG_LEN) + if (handshakeRead(recvBuffer, handshakeRecvBuff)) { + val peerCipher = Cipher.getInstance(CIPHER_TYPE) + peerCipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(handshakeKeys.peerKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, handshakeKeys.peerIv)) + val plainText: ByteArray + try { + plainText = peerCipher.doFinal(recvBuffer.array()) + } catch (e: BadPaddingException) { + Log.w("BadPaddingException", ip) + return false + } catch (e: AEADBadTagException) { + Log.w("AEADBadTagException", ip) + return false + } + peerPublicKey = plainText.sliceArray(RANDOM_LEN until RANDOM_LEN+PUBLIC_KEY_LEN) + val signature = plainText.sliceArray(RANDOM_LEN+PUBLIC_KEY_LEN until plainText.size) - val edDSAEngine = EdDSAEngine().apply { - initVerify(EdDSAPublicKey(EdDSAPublicKeySpec(peerPublicKey, EdDSANamedCurveTable.ED_25519_CURVE_SPEC))) - } - if (edDSAEngine.verifyOneShot(peerEphemeralPublicKey, signature)) { - handshakeHash = hashHandshake(iAmBob) - val handshakeFinished = computeHandshakeFinished(handshakeKeys.localHandshakeTrafficSecret, handshakeHash) - writeAll(handshakeFinished) - val peerHandshakeFinished = ByteBuffer.allocate(HASH_OUTPUT_LEN) - socket.read(peerHandshakeFinished) - if (verifyHandshakeFinished(peerHandshakeFinished.array(), handshakeKeys.peerHandshakeTrafficSecret, handshakeHash)){ - applicationKeys = deriveApplicationKeys(handshakeKeys.handshakeSecret, handshakeHash, iAmBob) - handshakeSentBuff.reset() - handshakeRecvBuff.reset() - return true - } else { - Log.w("Handshake", "Final verification failed") - } + val edDSAEngine = EdDSAEngine().apply { + initVerify(EdDSAPublicKey(EdDSAPublicKeySpec(peerPublicKey, EdDSANamedCurveTable.ED_25519_CURVE_SPEC))) + } + if (edDSAEngine.verifyOneShot(peerEphemeralPublicKey, signature)) { + handshakeHash = hashHandshake(iAmBob, handshakeSentBuff.toByteArray(), handshakeRecvBuff.toByteArray()) + val handshakeFinished = computeHandshakeFinished(handshakeKeys.localHandshakeTrafficSecret, handshakeHash) + writeAll(handshakeFinished) + val peerHandshakeFinished = ByteBuffer.allocate(HASH_OUTPUT_LEN) + socket.read(peerHandshakeFinished) + if (verifyHandshakeFinished(peerHandshakeFinished.array(), handshakeKeys.peerHandshakeTrafficSecret, handshakeHash)){ + applicationKeys = deriveApplicationKeys(handshakeKeys.handshakeSecret, handshakeHash, iAmBob) + return true } else { - Log.w("Handshake", "Signature verification failed") + Log.w("Handshake", "Final verification failed") } + } else { + Log.w("Handshake", "Signature verification failed") } } }