From 5db9e834f2c371b7711f105b524ce682280cdd9d Mon Sep 17 00:00:00 2001 From: Rhys Weatherley Date: Mon, 18 Jun 2018 05:24:27 +1000 Subject: [PATCH] Test cases for the transport phase of Noise sessions --- host/Crypto/noise/test-vector.cpp | 66 +++++++------------ .../src/NoiseCipherState_AESGCM.cpp | 8 +-- .../src/NoiseCipherState_ChaChaPoly.cpp | 7 +- .../NoiseProtocol/src/NoiseHandshakeState.cpp | 2 +- .../src/NoiseSymmetricState_AESGCM_SHA256.cpp | 7 +- ...NoiseSymmetricState_ChaChaPoly_BLAKE2s.cpp | 7 +- .../NoiseSymmetricState_ChaChaPoly_SHA256.cpp | 7 +- 7 files changed, 42 insertions(+), 62 deletions(-) diff --git a/host/Crypto/noise/test-vector.cpp b/host/Crypto/noise/test-vector.cpp index 920dd4b3..ebb5ad7c 100644 --- a/host/Crypto/noise/test-vector.cpp +++ b/host/Crypto/noise/test-vector.cpp @@ -226,23 +226,21 @@ static NoiseHandshakeState *create_handshake(const char *protocol) */ static void test_connection(const TestVector *vec) { - NoiseHandshakeState *initiator = 0; - NoiseHandshakeState *responder = 0; + NoiseHandshakeState *initiator; + NoiseHandshakeState *responder; NoiseHandshakeState *send; NoiseHandshakeState *recv; -#if 0 NoiseCipherState *c1init; NoiseCipherState *c2init; NoiseCipherState *c1resp; NoiseCipherState *c2resp; NoiseCipherState *csend; NoiseCipherState *crecv; -#endif uint8_t message[MAX_MESSAGE_SIZE]; uint8_t payload[MAX_MESSAGE_SIZE]; int result; size_t index; - //size_t mac_len; + size_t mac_len; Noise::Party role; /* Create the two ends of the connection */ @@ -338,72 +336,58 @@ static void test_connection(const TestVector *vec) vec->messages[index].payload_len); } -#if 0 /* Handshake finished. Check the handshake hash values */ -#if 0 if (vec->handshake_hash_len) { memset(payload, 0xAA, sizeof(payload)); - compare(noise_handshakestate_get_handshake_hash - (initiator, payload, vec->handshake_hash_len), - NOISE_ERROR_NONE); + verify(initiator->getHandshakeHash(payload, vec->handshake_hash_len)); compare_blocks("handshake_hash", payload, vec->handshake_hash_len, vec->handshake_hash, vec->handshake_hash_len); memset(payload, 0xAA, sizeof(payload)); - compare(noise_handshakestate_get_handshake_hash - (responder, payload, vec->handshake_hash_len), - NOISE_ERROR_NONE); + verify(responder->getHandshakeHash(payload, vec->handshake_hash_len)); compare_blocks("handshake_hash", payload, vec->handshake_hash_len, vec->handshake_hash, vec->handshake_hash_len); } -#endif /* Now handle the data transport */ - compare(noise_handshakestate_split(initiator, &c1init, &c2init), - NOISE_ERROR_NONE); - compare(noise_handshakestate_split(responder, &c2resp, &c1resp), - NOISE_ERROR_NONE); - mac_len = noise_cipherstate_get_mac_length(c1init); + verify(initiator->split(&c1init, &c2init)); + verify(responder->split(&c2resp, &c1resp)); + mac_len = 16; for (; index < vec->num_messages; ++index) { - if (role == NOISE_ROLE_INITIATOR) { + if (role == Noise::Initiator) { /* Send on the initiator, receive on the responder */ csend = c1init; crecv = c1resp; - if (!is_one_way) - role = NOISE_ROLE_RESPONDER; + role = Noise::Responder; } else { /* Send on the responder, receive on the initiator */ csend = c2resp; crecv = c2init; - role = NOISE_ROLE_INITIATOR; + role = Noise::Initiator; } verify(sizeof(message) >= (vec->messages[index].payload_len + mac_len)); - memcpy(message, vec->messages[index].payload, - vec->messages[index].payload_len); - noise_buffer_set_inout(mbuf, message, vec->messages[index].payload_len, - sizeof(message)); - compare(noise_cipherstate_encrypt(csend, &mbuf), - NOISE_ERROR_NONE); - compare_blocks("ciphertext", mbuf.data, mbuf.size, + verify(sizeof(payload) >= (vec->messages[index].payload_len + mac_len)); + result = csend->encryptPacket + (message, sizeof(message), + vec->messages[index].payload, vec->messages[index].payload_len); + verify(result >= 0); + compare_blocks("ciphertext", message, (size_t)result, vec->messages[index].ciphertext, vec->messages[index].ciphertext_len); - compare(noise_cipherstate_decrypt(crecv, &mbuf), - NOISE_ERROR_NONE); - compare_blocks("plaintext", mbuf.data, mbuf.size, + result = crecv->decryptPacket + (payload, sizeof(payload), message, result); + verify(result >= 0); + compare_blocks("plaintext", payload, (size_t)result, vec->messages[index].payload, vec->messages[index].payload_len); } -#endif /* Clean up */ delete initiator; delete responder; - -#if 0 - compare(noise_cipherstate_free(c1init), NOISE_ERROR_NONE); - compare(noise_cipherstate_free(c2init), NOISE_ERROR_NONE); - compare(noise_cipherstate_free(c1resp), NOISE_ERROR_NONE); - compare(noise_cipherstate_free(c2resp), NOISE_ERROR_NONE); -#endif + delete c1init; + delete c2init; + delete c1resp; + delete c2resp; } /** diff --git a/libraries/NoiseProtocol/src/NoiseCipherState_AESGCM.cpp b/libraries/NoiseProtocol/src/NoiseCipherState_AESGCM.cpp index 5344ef00..8f0402e2 100644 --- a/libraries/NoiseProtocol/src/NoiseCipherState_AESGCM.cpp +++ b/libraries/NoiseProtocol/src/NoiseCipherState_AESGCM.cpp @@ -77,11 +77,11 @@ int NoiseCipherState_AESGCM::decryptPacket return -1; uint8_t iv[12]; noiseAESGCMFormatIV(iv, n); - cipher.setIV((const uint8_t *)&iv, sizeof(iv)); - cipher.decrypt((uint8_t *)output, (const uint8_t *)input, outputSize); - if (cipher.checkTag(((const uint8_t *)input) + outputSize, 16)) { + cipher.setIV(iv, sizeof(iv)); + cipher.decrypt((uint8_t *)output, (const uint8_t *)input, inputSize - 16); + if (cipher.checkTag(((const uint8_t *)input) + inputSize - 16, 16)) { ++n; - return outputSize; + return inputSize - 16; } memset(output, 0, outputSize); // Destroy the output if the tag is invalid. return -1; diff --git a/libraries/NoiseProtocol/src/NoiseCipherState_ChaChaPoly.cpp b/libraries/NoiseProtocol/src/NoiseCipherState_ChaChaPoly.cpp index 2a600a6b..07e2476f 100644 --- a/libraries/NoiseProtocol/src/NoiseCipherState_ChaChaPoly.cpp +++ b/libraries/NoiseProtocol/src/NoiseCipherState_ChaChaPoly.cpp @@ -73,12 +73,11 @@ int NoiseCipherState_ChaChaPoly::decryptPacket if (inputSize < 16 || outputSize < (inputSize - 16)) return -1; uint64_t iv = htole64(n); - outputSize = inputSize - 16; cipher.setIV((const uint8_t *)&iv, sizeof(iv)); - cipher.decrypt((uint8_t *)output, (const uint8_t *)input, outputSize); - if (cipher.checkTag(((const uint8_t *)input) + outputSize, 16)) { + cipher.decrypt((uint8_t *)output, (const uint8_t *)input, inputSize - 16); + if (cipher.checkTag(((const uint8_t *)input) + inputSize - 16, 16)) { ++n; - return outputSize; + return inputSize - 16; } memset(output, 0, outputSize); // Destroy the output if the tag is invalid. return -1; diff --git a/libraries/NoiseProtocol/src/NoiseHandshakeState.cpp b/libraries/NoiseProtocol/src/NoiseHandshakeState.cpp index e0c05ec2..e7423214 100644 --- a/libraries/NoiseProtocol/src/NoiseHandshakeState.cpp +++ b/libraries/NoiseProtocol/src/NoiseHandshakeState.cpp @@ -353,7 +353,7 @@ int NoiseHandshakeState::read * \return Returns true if the cipher objects were split out, or false if * state() is not NoiseHandshakeState::Split. * - * If \a tx or \a rx are NULL, the the respective cipher object will not + * If \a tx or \a rx are NULL, then the respective cipher object will not * be created. This is useful for one-way patterns. * * The application is responsible for destroying the \a tx and \a rx diff --git a/libraries/NoiseProtocol/src/NoiseSymmetricState_AESGCM_SHA256.cpp b/libraries/NoiseProtocol/src/NoiseSymmetricState_AESGCM_SHA256.cpp index bc03b421..f451f95e 100644 --- a/libraries/NoiseProtocol/src/NoiseSymmetricState_AESGCM_SHA256.cpp +++ b/libraries/NoiseProtocol/src/NoiseSymmetricState_AESGCM_SHA256.cpp @@ -174,16 +174,15 @@ int NoiseSymmetricState_AESGCM_SHA256::decryptAndHash if (st.hasKey) { if (inputSize < 16 || outputSize < (inputSize - 16)) return -1; - outputSize = inputSize - 16; uint8_t iv[12]; noiseAESGCMFormatIV(iv, st.n); cipher.setIV(iv, sizeof(iv)); cipher.addAuthData(st.h, sizeof(st.h)); mixHash(input, inputSize); - cipher.decrypt(output, input, outputSize); - if (cipher.checkTag(input + outputSize, 16)) { + cipher.decrypt(output, input, inputSize -16); + if (cipher.checkTag(input + inputSize - 16, 16)) { ++st.n; - return outputSize; + return inputSize -16; } memset(output, 0, outputSize); // Destroy output if tag is incorrect. return -1; diff --git a/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_BLAKE2s.cpp b/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_BLAKE2s.cpp index f3231a09..f3f15d79 100644 --- a/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_BLAKE2s.cpp +++ b/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_BLAKE2s.cpp @@ -148,17 +148,16 @@ int NoiseSymmetricState_ChaChaPoly_BLAKE2s::decryptAndHash if (st.hasKey) { if (inputSize < 16 || outputSize < (inputSize - 16)) return -1; - outputSize = inputSize - 16; ChaChaPoly cipher; uint64_t iv = htole64(st.n); cipher.setKey(st.key, 32); cipher.setIV((const uint8_t *)&iv, sizeof(iv)); cipher.addAuthData(st.h, sizeof(st.h)); mixHash(input, inputSize); - cipher.decrypt(output, input, outputSize); - if (cipher.checkTag(input + outputSize, 16)) { + cipher.decrypt(output, input, inputSize - 16); + if (cipher.checkTag(input + inputSize - 16, 16)) { ++st.n; - return outputSize; + return inputSize - 16; } memset(output, 0, outputSize); // Destroy output if tag is incorrect. return -1; diff --git a/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_SHA256.cpp b/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_SHA256.cpp index 1c926ad9..f19bd3e9 100644 --- a/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_SHA256.cpp +++ b/libraries/NoiseProtocol/src/NoiseSymmetricState_ChaChaPoly_SHA256.cpp @@ -148,17 +148,16 @@ int NoiseSymmetricState_ChaChaPoly_SHA256::decryptAndHash if (st.hasKey) { if (inputSize < 16 || outputSize < (inputSize - 16)) return -1; - outputSize = inputSize - 16; ChaChaPoly cipher; uint64_t iv = htole64(st.n); cipher.setKey(st.key, 32); cipher.setIV((const uint8_t *)&iv, sizeof(iv)); cipher.addAuthData(st.h, sizeof(st.h)); mixHash(input, inputSize); - cipher.decrypt(output, input, outputSize); - if (cipher.checkTag(input + outputSize, 16)) { + cipher.decrypt(output, input, inputSize - 16); + if (cipher.checkTag(input + inputSize - 16, 16)) { ++st.n; - return outputSize; + return inputSize - 16; } memset(output, 0, outputSize); // Destroy output if tag is incorrect. return -1;