package sushi.hardcore.aira.background_service import android.util.Log import net.i2p.crypto.eddsa.EdDSAEngine import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec import org.whispersystems.curve25519.Curve25519 import sushi.hardcore.aira.AIRADatabase import java.io.ByteArrayOutputStream import java.io.OutputStream import java.nio.ByteBuffer import java.nio.channels.* import java.nio.channels.spi.SelectorProvider import java.security.MessageDigest import java.security.SecureRandom import javax.crypto.AEADBadTagException import javax.crypto.BadPaddingException import javax.crypto.Cipher import javax.crypto.spec.GCMParameterSpec import javax.crypto.spec.SecretKeySpec import kotlin.experimental.xor class Session(private val socket: SocketChannel, val outgoing: Boolean): SelectableChannel() { private external fun deriveHandshakeKeys(sharedSecret: ByteArray, handshakeHash: ByteArray, iAmBob: Boolean): HandshakeKeys private external fun sign(input: ByteArray): ByteArray private external fun computeHandshakeFinished(localHandshakeTrafficSecret: ByteArray, handshakeHash: ByteArray): ByteArray private external fun verifyHandshakeFinished(peerHandshakeFinished: ByteArray, peerHandshakeTrafficSecret: ByteArray, handshakeHash: ByteArray): Boolean private external fun deriveApplicationKeys(handshakeSecret: ByteArray, handshakeHash: ByteArray, iAmBob: Boolean): ApplicationKeys companion object { private const val RANDOM_LEN = 64 private const val PUBLIC_KEY_LEN = 32 private const val SIGNATURE_LEN = 64 private const val AES_TAG_LEN = 16 private const val IV_LEN = 12 private const val HASH_OUTPUT_LEN = 48 private const val handshakeBufferLen = (2*(RANDOM_LEN+PUBLIC_KEY_LEN))+SIGNATURE_LEN+AES_TAG_LEN private const val CIPHER_TYPE = "AES/GCM/NoPadding" private const val MESSAGE_LEN_LEN = 4 private const val PADDED_MAX_SIZE = 16384000 private const val MAX_RECV_SIZE = PADDED_MAX_SIZE + AES_TAG_LEN } private val prng = SecureRandom() private val peerCipher = Cipher.getInstance(CIPHER_TYPE) private val localCipher = Cipher.getInstance(CIPHER_TYPE) private var peerCounter = 0L private var localCounter = 0L private lateinit var applicationKeys: ApplicationKeys lateinit var peerPublicKey: ByteArray val ip: String = socket.socket().inetAddress.hostAddress private fun handshakeWrite(buffer: ByteArray, handshakeSentBuff: OutputStream) { writeAll(buffer) handshakeSentBuff.write(buffer) } private fun handshakeRead(buffer: ByteBuffer, handshakeRecvBuff: OutputStream): Boolean { return if (readAll(buffer)) { handshakeRecvBuff.write(buffer.array()) true } else { false } } private fun hashHandshake(iAmBob: Boolean, handshakeSentBuff: ByteArray, handshakeRecvBuff: ByteArray): ByteArray { MessageDigest.getInstance("SHA-384").apply { if (iAmBob) { update(handshakeSentBuff) update(handshakeRecvBuff) } else { update(handshakeRecvBuff) update(handshakeSentBuff) } return digest() } } private fun amIBob(handshakeSentBuff: ByteArray, handshakeRecvBuff: ByteArray): Boolean { for (i in handshakeSentBuff.indices) { if (handshakeSentBuff[i] != handshakeRecvBuff[i]) { return handshakeSentBuff[i].toInt() and 0xff < handshakeRecvBuff[i].toInt() and 0xff } } throw SecurityException("Handshake buffers are identical") } private fun ivToNonce(iv: ByteArray, counter: Long): ByteArray { val nonce = ByteArray(IV_LEN-Long.SIZE_BYTES)+ByteBuffer.allocate(Long.SIZE_BYTES).putLong(counter).array() for (i in nonce.indices) { nonce[i] = nonce[i] xor iv[i] } return nonce } fun doHandshake(): Boolean { val handshakeSentBuff = ByteArrayOutputStream(handshakeBufferLen) val handshakeRecvBuff = ByteArrayOutputStream(handshakeBufferLen) val randomBuffer = ByteArray(RANDOM_LEN) prng.nextBytes(randomBuffer) val curve25519Cipher = Curve25519.getInstance(Curve25519.BEST) val keypair = curve25519Cipher.generateKeyPair() handshakeWrite(randomBuffer+keypair.publicKey, handshakeSentBuff) var recvBuffer = ByteBuffer.allocate(RANDOM_LEN+PUBLIC_KEY_LEN) if (handshakeRead(recvBuffer, handshakeRecvBuff)) { val peerEphemeralPublicKey = recvBuffer.array().sliceArray(RANDOM_LEN until recvBuffer.capacity()) val sharedSecret = curve25519Cipher.calculateAgreement(peerEphemeralPublicKey, keypair.privateKey) val iAmBob = amIBob(handshakeSentBuff.toByteArray(), handshakeRecvBuff.toByteArray()) //mutual consensus for keys attribution var handshakeHash = hashHandshake(iAmBob, handshakeSentBuff.toByteArray(), handshakeRecvBuff.toByteArray()) val handshakeKeys = deriveHandshakeKeys(sharedSecret, handshakeHash, iAmBob) prng.nextBytes(randomBuffer) val localCipher = Cipher.getInstance(CIPHER_TYPE) localCipher.init(Cipher.ENCRYPT_MODE, SecretKeySpec(handshakeKeys.localKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, handshakeKeys.localIv)) handshakeWrite(localCipher.doFinal(randomBuffer+AIRADatabase.getIdentityPublicKey()+sign(keypair.publicKey)), handshakeSentBuff) recvBuffer = ByteBuffer.allocate(RANDOM_LEN+PUBLIC_KEY_LEN+SIGNATURE_LEN+AES_TAG_LEN) if (handshakeRead(recvBuffer, handshakeRecvBuff)) { val peerCipher = Cipher.getInstance(CIPHER_TYPE) peerCipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(handshakeKeys.peerKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, handshakeKeys.peerIv)) val plainText: ByteArray try { plainText = peerCipher.doFinal(recvBuffer.array()) } catch (e: BadPaddingException) { Log.w("BadPaddingException", ip) return false } catch (e: AEADBadTagException) { Log.w("AEADBadTagException", ip) return false } peerPublicKey = plainText.sliceArray(RANDOM_LEN until RANDOM_LEN+PUBLIC_KEY_LEN) val signature = plainText.sliceArray(RANDOM_LEN+PUBLIC_KEY_LEN until plainText.size) val edDSAEngine = EdDSAEngine().apply { initVerify(EdDSAPublicKey(EdDSAPublicKeySpec(peerPublicKey, EdDSANamedCurveTable.ED_25519_CURVE_SPEC))) } if (edDSAEngine.verifyOneShot(peerEphemeralPublicKey, signature)) { handshakeHash = hashHandshake(iAmBob, handshakeSentBuff.toByteArray(), handshakeRecvBuff.toByteArray()) val handshakeFinished = computeHandshakeFinished(handshakeKeys.localHandshakeTrafficSecret, handshakeHash) writeAll(handshakeFinished) val peerHandshakeFinished = ByteBuffer.allocate(HASH_OUTPUT_LEN) socket.read(peerHandshakeFinished) if (verifyHandshakeFinished(peerHandshakeFinished.array(), handshakeKeys.peerHandshakeTrafficSecret, handshakeHash)){ applicationKeys = deriveApplicationKeys(handshakeKeys.handshakeSecret, handshakeHash, iAmBob) return true } else { Log.w("Handshake", "Final verification failed") } } else { Log.w("Handshake", "Signature verification failed") } } } return false } private fun pad(input: ByteArray, usePadding: Boolean): ByteArray { val encodedLen = ByteBuffer.allocate(MESSAGE_LEN_LEN).putInt(input.size).array() return if (usePadding) { val msgLen = input.size + MESSAGE_LEN_LEN var len = 1000 while (len < msgLen) { len *= 2 } val padding = ByteArray(len-msgLen) prng.nextBytes(padding) encodedLen + input + padding } else { encodedLen + input } } private fun unpad(input: ByteArray): ByteArray { val messageLen = ByteBuffer.wrap(input.sliceArray(0..MESSAGE_LEN_LEN)).int return input.sliceArray(MESSAGE_LEN_LEN until MESSAGE_LEN_LEN+messageLen) } fun writeAll(buffer: ByteArray) { val byteBuffer = ByteBuffer.wrap(buffer) while (byteBuffer.remaining() > 0) { socket.write(byteBuffer) } } fun encrypt(plainText: ByteArray, usePadding: Boolean): ByteArray { val padded = pad(plainText, usePadding) val rawMsgLen = ByteBuffer.allocate(MESSAGE_LEN_LEN).putInt(padded.size+AES_TAG_LEN).array() val nonce = ivToNonce(applicationKeys.localIv, localCounter) localCounter++ localCipher.init(Cipher.ENCRYPT_MODE, SecretKeySpec(applicationKeys.localKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, nonce)) localCipher.updateAAD(rawMsgLen) return rawMsgLen+localCipher.doFinal(padded) } fun encryptAndSend(plainText: ByteArray, usePadding: Boolean) { writeAll(encrypt(plainText, usePadding)) } fun ByteArray.toHexString() = joinToString("") { "%02x".format(it) } private fun readAll(buffer: ByteBuffer): Boolean { while (buffer.position() != buffer.capacity()) { try { if (socket.read(buffer) < 0) { return false } } catch (e: ClosedChannelException) { return false } } return true } fun receiveAndDecrypt(): ByteArray? { val rawMessageLen = ByteBuffer.allocate(MESSAGE_LEN_LEN) if (readAll(rawMessageLen)) { rawMessageLen.position(0) val messageLen = rawMessageLen.int if (messageLen in 1..MAX_RECV_SIZE) { val cipherText = ByteBuffer.allocate(messageLen) if (readAll(cipherText)) { val nonce = ivToNonce(applicationKeys.peerIv, peerCounter) peerCounter++ peerCipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(applicationKeys.peerKey, "AES"), GCMParameterSpec(AES_TAG_LEN*8, nonce)) rawMessageLen.position(0) peerCipher.updateAAD(rawMessageLen) try { return unpad(peerCipher.doFinal(cipherText.array())) } catch (e: AEADBadTagException) { Log.w("AEADBadTagException", ip) } } } else { Log.w("Message too large", "$messageLen from $ip") } } return null } override fun implCloseChannel() { socket.close() } override fun provider(): SelectorProvider { return socket.provider() } override fun validOps(): Int { return socket.validOps() } override fun isRegistered(): Boolean { return socket.isRegistered } override fun keyFor(sel: Selector?): SelectionKey { return socket.keyFor(sel) } override fun register(sel: Selector?, ops: Int, att: Any?): SelectionKey { return socket.register(sel, ops, att) } override fun configureBlocking(block: Boolean): SelectableChannel { return socket.configureBlocking(block) } override fun isBlocking(): Boolean { return socket.isBlocking } override fun blockingLock(): Any { return socket.blockingLock() } }