X-Git-Url: http://tinc-vpn.org/git/browse?a=blobdiff_plain;f=src%2Fsptps.c;h=b36079d3d73b33472c4294ee9d5bda394cc0cc19;hb=refs%2Fheads%2Ffeature%2Falt-ciphersuite;hp=a5987682695b76c04b2761b5b496c73ad6779f6d;hpb=266afc6c63d3d02584feb24b69063f97057daac8;p=tinc diff --git a/src/sptps.c b/src/sptps.c index a5987682..b36079d3 100644 --- a/src/sptps.c +++ b/src/sptps.c @@ -1,6 +1,6 @@ /* sptps.c -- Simple Peer-to-Peer Security - Copyright (C) 2011-2014 Guus Sliepen , + Copyright (C) 2011-2021 Guus Sliepen , 2010 Brandon L. Black This program is free software; you can redistribute it and/or modify @@ -20,13 +20,19 @@ #include "system.h" -#include "chacha-poly1305/chacha-poly1305.h" -#include "crypto.h" +#include "chacha-poly1305/chachapoly.h" #include "ecdh.h" #include "ecdsa.h" -#include "logger.h" #include "prf.h" #include "sptps.h" +#include "random.h" +#include "xalloc.h" + +#ifdef HAVE_OPENSSL +#include +#endif + +#define CIPHER_KEYLEN 64 unsigned int sptps_replaywin = 16; @@ -51,9 +57,16 @@ unsigned int sptps_replaywin = 16; */ void sptps_log_quiet(sptps_t *s, int s_errno, const char *format, va_list ap) { + (void)s; + (void)s_errno; + (void)format; + (void)ap; } void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) { + (void)s; + (void)s_errno; + vfprintf(stderr, format, ap); fputc('\n', stderr); } @@ -61,7 +74,11 @@ void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) { void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_stderr; // Log an error message. +static bool error(sptps_t *s, int s_errno, const char *format, ...) ATTR_FORMAT(printf, 3, 4); static bool error(sptps_t *s, int s_errno, const char *format, ...) { + (void)s; + (void)s_errno; + if(format) { va_list ap; va_start(ap, format); @@ -73,6 +90,7 @@ static bool error(sptps_t *s, int s_errno, const char *format, ...) { return false; } +static void warning(sptps_t *s, const char *format, ...) ATTR_FORMAT(printf, 2, 3); static void warning(sptps_t *s, const char *format, ...) { va_list ap; va_start(ap, format); @@ -80,37 +98,273 @@ static void warning(sptps_t *s, const char *format, ...) { va_end(ap); } +static sptps_kex_t *new_sptps_kex(void) { + return xzalloc(sizeof(sptps_kex_t)); +} + +static void free_sptps_kex(sptps_kex_t *kex) { + xzfree(kex, sizeof(sptps_kex_t)); +} + +static sptps_key_t *new_sptps_key(void) { + return xzalloc(sizeof(sptps_key_t)); +} + +static void free_sptps_key(sptps_key_t *key) { + xzfree(key, sizeof(sptps_key_t)); +} + +static bool cipher_init(uint8_t suite, void **ctx, const sptps_key_t *keys, bool key_half) { + const uint8_t *key = key_half ? keys->key1 : keys->key0; + + switch(suite) { +#ifndef HAVE_OPENSSL + + case SPTPS_CHACHA_POLY1305: + *ctx = malloc(sizeof(struct chachapoly_ctx)); + return *ctx && chachapoly_init(*ctx, key, 256) == CHACHAPOLY_OK; + +#else + + case SPTPS_CHACHA_POLY1305: +#ifdef EVP_F_EVP_AEAD_CTX_INIT + *ctx = malloc(sizeof(EVP_AEAD_CTX)); + + return *ctx && EVP_AEAD_CTX_init(*ctx, EVP_aead_chacha20_poly1305(), key + (key_half ? CIPHER_KEYLEN : 0), 32, 16, NULL); +#else + *ctx = EVP_CIPHER_CTX_new(); + + return *ctx + && EVP_EncryptInit_ex(*ctx, EVP_chacha20_poly1305(), NULL, NULL, NULL) + && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_GCM_SET_IVLEN, 12, NULL) + && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32); +#endif + + case SPTPS_AES256_GCM: + *ctx = EVP_CIPHER_CTX_new(); + + return *ctx + && EVP_EncryptInit_ex(*ctx, EVP_aes_256_gcm(), NULL, NULL, NULL) + && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_GCM_SET_IVLEN, 12, NULL) + && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32); +#endif + + default: + return false; + } +} + +static void cipher_exit(uint8_t suite, void *ctx) { + switch(suite) { +#ifndef HAVE_OPENSSL + + case SPTPS_CHACHA_POLY1305: + free(ctx); + break; + +#else + + case SPTPS_CHACHA_POLY1305: +#ifdef EVP_F_EVP_AEAD_CTX_INIT + EVP_AEAD_CTX_cleanup(ctx); + free(ctx); + break; +#endif + + case SPTPS_AES256_GCM: + EVP_CIPHER_CTX_free(ctx); + break; +#endif + + default: + break; + } +} + +static bool cipher_encrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { + switch(suite) { +#ifndef HAVE_OPENSSL + + case SPTPS_CHACHA_POLY1305: { + if(chachapoly_crypt(ctx, nonce, (void *)in, inlen, out, out + inlen, 16, 1) != CHACHAPOLY_OK) { + return false; + } + + if(outlen) { + *outlen = inlen + 16; + } + + return true; + } + +#else + + case SPTPS_CHACHA_POLY1305: +#ifdef EVP_F_EVP_AEAD_CTX_INIT + { + size_t outlen1; + + if(!EVP_AEAD_CTX_seal(ctx, out, &outlen1, inlen + 16, nonce, sizeof(nonce), in, inlen, NULL, 0)) { + return false; + } + + if(outlen) { + *outlen = outlen1; + } + + return true; + } + +#endif + + case SPTPS_AES256_GCM: { + uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24}; + + if(!EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, nonce)) { + return false; + } + + int outlen1 = 0, outlen2 = 0; + + if(!EVP_EncryptUpdate(ctx, out, &outlen1, in, (int)inlen)) { + return false; + } + + if(!EVP_EncryptFinal_ex(ctx, out + outlen1, &outlen2)) { + return false; + } + + outlen1 += outlen2; + + if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, out + outlen1)) { + return false; + } + + outlen1 += 16; + + if(outlen) { + *outlen = outlen1; + } + + return true; + } + +#endif + + default: + return false; + } +} + +static bool cipher_decrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { + if(inlen < 16) { + return false; + } + + inlen -= 16; + + switch(suite) { +#ifndef HAVE_OPENSSL + + case SPTPS_CHACHA_POLY1305: + if(chachapoly_crypt(ctx, nonce, (void *)in, inlen, out, (void *)(in + inlen), 16, 0) != CHACHAPOLY_OK) { + return false; + } + + if(outlen) { + *outlen = inlen; + } + + return true; + +#else + + case SPTPS_CHACHA_POLY1305: +#ifdef EVP_F_EVP_AEAD_CTX_INIT + { + size_t outlen1; + + if(!EVP_AEAD_CTX_open(ctx, out, &outlen1, inlen, nonce, sizeof(nonce), in, inlen + 16, NULL, 0)) { + return false; + } + + if(outlen) { + *outlen = outlen1; + } + + return true; + } + +#endif + + case SPTPS_AES256_GCM: { + uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24}; + + if(!EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, nonce)) { + return false; + } + + int outlen1 = 0, outlen2 = 0; + + if(!EVP_DecryptUpdate(ctx, out, &outlen1, in, (int)inlen)) { + return false; + } + + if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, (void *)(in + inlen))) { + return false; + } + + if(!EVP_DecryptFinal_ex(ctx, out + outlen1, &outlen2)) { + return false; + } + + if(outlen) { + *outlen = outlen1 + outlen2; + } + + return true; + } + +#endif + + default: + return false; + } +} + // Send a record (datagram version, accepts all record types, handles encryption and authentication). static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) { - char buffer[len + 21UL]; - + uint8_t *buffer = alloca(len + SPTPS_DATAGRAM_OVERHEAD); // Create header with sequence number, length and record type uint32_t seqno = s->outseqno++; - uint32_t netseqno = ntohl(seqno); - memcpy(buffer, &netseqno, 4); + memcpy(buffer, &seqno, 4); buffer[4] = type; memcpy(buffer + 5, data, len); if(s->outstate) { // If first handshake has finished, encrypt and HMAC - chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL); - return s->send_data(s->handle, type, buffer, len + 21UL); + if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL)) { + return error(s, EINVAL, "Failed to encrypt message"); + } + + return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_OVERHEAD); } else { // Otherwise send as plaintext - return s->send_data(s->handle, type, buffer, len + 5UL); + return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_HEADER); } } // Send a record (private version, accepts all record types, handles encryption and authentication). static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_t len) { - if(s->datagram) + if(s->datagram) { return send_record_priv_datagram(s, type, data, len); + } - char buffer[len + 19UL]; + uint8_t *buffer = alloca(len + SPTPS_OVERHEAD); // Create header with sequence number, length and record type uint32_t seqno = s->outseqno++; - uint16_t netlen = htons(len); + uint16_t netlen = len; memcpy(buffer, &netlen, 2); buffer[2] = type; @@ -118,11 +372,14 @@ static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_ if(s->outstate) { // If first handshake has finished, encrypt and HMAC - chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL); - return s->send_data(s->handle, type, buffer, len + 19UL); + if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL)) { + return error(s, EINVAL, "Failed to encrypt message"); + } + + return s->send_data(s->handle, type, buffer, len + SPTPS_OVERHEAD); } else { // Otherwise send as plaintext - return s->send_data(s->handle, type, buffer, len + 3UL); + return s->send_data(s->handle, type, buffer, len + SPTPS_HEADER); } } @@ -130,93 +387,98 @@ static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_ bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len) { // Sanity checks: application cannot send data before handshake is finished, // and only record types 0..127 are allowed. - if(!s->outstate) + if(!s->outstate) { return error(s, EINVAL, "Handshake phase not finished yet"); + } - if(type >= SPTPS_HANDSHAKE) + if(type >= SPTPS_HANDSHAKE) { return error(s, EINVAL, "Invalid application record type"); + } return send_record_priv(s, type, data, len); } // Send a Key EXchange record, containing a random nonce and an ECDHE public key. static bool send_kex(sptps_t *s) { - size_t keylen = ECDH_SIZE; - // Make room for our KEX message, which we will keep around since send_sig() needs it. - if(s->mykex) + if(s->mykex) { return false; - s->mykex = realloc(s->mykex, 1 + 32 + keylen); - if(!s->mykex) - return error(s, errno, strerror(errno)); + } + + s->mykex = new_sptps_kex(); // Set version byte to zero. - s->mykex[0] = SPTPS_VERSION; + s->mykex->version = SPTPS_VERSION; + s->mykex->preferred_suite = s->preferred_suite; + s->mykex->cipher_suites = s->cipher_suites; // Create a random nonce. - randomize(s->mykex + 1, 32); + randomize(s->mykex->nonce, ECDH_SIZE); // Create a new ECDH public key. - if(!(s->ecdh = ecdh_generate_public(s->mykex + 1 + 32))) + if(!(s->ecdh = ecdh_generate_public(s->mykex->pubkey))) { return error(s, EINVAL, "Failed to generate ECDH public key"); + } + + return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, sizeof(sptps_kex_t)); +} + +static size_t sigmsg_len(size_t labellen) { + return 1 + 2 * sizeof(sptps_kex_t) + labellen; +} - return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, 1 + 32 + keylen); +static void fill_msg(uint8_t *msg, bool initiator, const sptps_kex_t *kex0, const sptps_kex_t *kex1, const sptps_t *s) { + *msg = initiator, msg++; + memcpy(msg, kex0, sizeof(*kex0)), msg += sizeof(*kex0); + memcpy(msg, kex1, sizeof(*kex1)), msg += sizeof(*kex1); + memcpy(msg, s->label, s->labellen); } // Send a SIGnature record, containing an Ed25519 signature over both KEX records. static bool send_sig(sptps_t *s) { - size_t keylen = ECDH_SIZE; - size_t siglen = ecdsa_size(s->mykey); - // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label - char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen]; - char sig[siglen]; - - msg[0] = s->initiator; - memcpy(msg + 1, s->mykex, 1 + 32 + keylen); - memcpy(msg + 1 + 33 + keylen, s->hiskex, 1 + 32 + keylen); - memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen); + size_t msglen = sigmsg_len(s->labellen); + uint8_t *msg = alloca(msglen); + fill_msg(msg, s->initiator, s->mykex, s->hiskex, s); // Sign the result. - if(!ecdsa_sign(s->mykey, msg, sizeof msg, sig)) + size_t siglen = ecdsa_size(s->mykey); + uint8_t *sig = alloca(siglen); + + if(!ecdsa_sign(s->mykey, msg, msglen, sig)) { return error(s, EINVAL, "Failed to sign SIG record"); + } // Send the SIG exchange record. - return send_record_priv(s, SPTPS_HANDSHAKE, sig, sizeof sig); + return send_record_priv(s, SPTPS_HANDSHAKE, sig, siglen); } // Generate key material from the shared secret created from the ECDHE key exchange. -static bool generate_key_material(sptps_t *s, const char *shared, size_t len) { - // Initialise cipher and digest structures if necessary - if(!s->outstate) { - s->incipher = chacha_poly1305_init(); - s->outcipher = chacha_poly1305_init(); - if(!s->incipher || !s->outcipher) - return error(s, EINVAL, "Failed to open cipher"); - } - +static bool generate_key_material(sptps_t *s, const uint8_t *shared, size_t len) { // Allocate memory for key material - size_t keylen = 2 * CHACHA_POLY1305_KEYLEN; - - s->key = realloc(s->key, keylen); - if(!s->key) - return error(s, errno, strerror(errno)); + s->key = new_sptps_key(); // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce - char seed[s->labellen + 64 + 13]; - strcpy(seed, "key expansion"); - if(s->initiator) { - memcpy(seed + 13, s->mykex + 1, 32); - memcpy(seed + 45, s->hiskex + 1, 32); - } else { - memcpy(seed + 13, s->hiskex + 1, 32); - memcpy(seed + 45, s->mykex + 1, 32); - } - memcpy(seed + 77, s->label, s->labellen); + const size_t msglen = sizeof("key expansion") - 1; + const size_t seedlen = msglen + s->labellen + ECDH_SIZE * 2; + uint8_t *seed = alloca(seedlen); + + uint8_t *ptr = seed; + memcpy(ptr, "key expansion", msglen); + ptr += msglen; + + memcpy(ptr, (s->initiator ? s->mykex : s->hiskex)->nonce, ECDH_SIZE); + ptr += ECDH_SIZE; + + memcpy(ptr, (s->initiator ? s->hiskex : s->mykex)->nonce, ECDH_SIZE); + ptr += ECDH_SIZE; + + memcpy(ptr, s->label, s->labellen); // Use PRF to generate the key material - if(!prf(shared, len, seed, s->labellen + 64 + 13, s->key, keylen)) + if(!prf(shared, len, seed, seedlen, s->key->both, sizeof(sptps_key_t))) { return error(s, EINVAL, "Failed to generate key material"); + } return true; } @@ -227,93 +489,138 @@ static bool send_ack(sptps_t *s) { } // Receive an ACKnowledgement record. -static bool receive_ack(sptps_t *s, const char *data, uint16_t len) { - if(len) +static bool receive_ack(sptps_t *s, const uint8_t *data, uint16_t len) { + (void)data; + + if(len) { return error(s, EIO, "Invalid ACK record length"); + } - if(s->initiator) { - if(!chacha_poly1305_set_key(s->incipher, s->key)) - return error(s, EINVAL, "Failed to set counter"); - } else { - if(!chacha_poly1305_set_key(s->incipher, s->key + CHACHA_POLY1305_KEYLEN)) - return error(s, EINVAL, "Failed to set counter"); + if(!cipher_init(s->cipher_suite, &s->incipher, s->key, s->initiator)) { + return error(s, EINVAL, "Failed to initialize cipher"); } - free(s->key); + free_sptps_key(s->key); s->key = NULL; s->instate = true; return true; } +static uint8_t select_cipher_suite(uint16_t mask, uint8_t pref1, uint8_t pref2) { + // Check if there is a viable preference, if so select the lowest one + uint8_t selection = 255; + + if(mask & (1U << pref1)) { + selection = pref1; + } + + if(pref2 < selection && (mask & (1U << pref2))) { + selection = pref2; + } + + // Otherwise, select the lowest cipher suite both sides support + if(selection == 255) { + selection = 0; + + while(!(mask & 1U)) { + selection++; + mask >>= 1; + } + } + + return selection; +} + // Receive a Key EXchange record, respond by sending a SIG record. -static bool receive_kex(sptps_t *s, const char *data, uint16_t len) { +static bool receive_kex(sptps_t *s, const uint8_t *data, uint16_t len) { // Verify length of the HELLO record - if(len != 1 + 32 + ECDH_SIZE) + + if(len != sizeof(sptps_kex_t)) { return error(s, EIO, "Invalid KEX record length"); + } - // Ignore version number for now. + if(*data != SPTPS_VERSION) { + return error(s, EINVAL, "Received incorrect version %d", *data); + } + + uint16_t suites; + memcpy(&suites, data + 2, 2); + suites &= s->cipher_suites; + + if(!suites) { + return error(s, EIO, "No matching cipher suites"); + } + + s->cipher_suite = select_cipher_suite(suites, s->preferred_suite, data[1] & 0xf); // Make a copy of the KEX message, send_sig() and receive_sig() need it - if(s->hiskex) + if(s->hiskex) { return error(s, EINVAL, "Received a second KEX message before first has been processed"); - s->hiskex = realloc(s->hiskex, len); - if(!s->hiskex) - return error(s, errno, strerror(errno)); + } - memcpy(s->hiskex, data, len); + s->hiskex = new_sptps_kex(); + memcpy(s->hiskex, data, sizeof(sptps_kex_t)); - return send_sig(s); + if(s->initiator) { + return send_sig(s); + } else { + return true; + } } // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys. -static bool receive_sig(sptps_t *s, const char *data, uint16_t len) { - size_t keylen = ECDH_SIZE; - size_t siglen = ecdsa_size(s->hiskey); - +static bool receive_sig(sptps_t *s, const uint8_t *data, uint16_t len) { // Verify length of KEX record. - if(len != siglen) + if(len != ecdsa_size(s->hiskey)) { return error(s, EIO, "Invalid KEX record length"); + } // Concatenate both KEX messages, plus tag indicating if it is from the connection originator - char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen]; - - msg[0] = !s->initiator; - memcpy(msg + 1, s->hiskex, 1 + 32 + keylen); - memcpy(msg + 1 + 33 + keylen, s->mykex, 1 + 32 + keylen); - memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen); + const size_t msglen = sigmsg_len(s->labellen); + uint8_t *msg = alloca(msglen); + fill_msg(msg, !s->initiator, s->hiskex, s->mykex, s); // Verify signature. - if(!ecdsa_verify(s->hiskey, msg, sizeof msg, data)) + if(!ecdsa_verify(s->hiskey, msg, msglen, data)) { return error(s, EIO, "Failed to verify SIG record"); + } // Compute shared secret. - char shared[ECDH_SHARED_SIZE]; - if(!ecdh_compute_shared(s->ecdh, s->hiskex + 1 + 32, shared)) + uint8_t shared[ECDH_SHARED_SIZE]; + + if(!ecdh_compute_shared(s->ecdh, s->hiskex->pubkey, shared)) { + memzero(shared, sizeof(shared)); return error(s, EINVAL, "Failed to compute ECDH shared secret"); + } + s->ecdh = NULL; // Generate key material from shared secret. - if(!generate_key_material(s, shared, sizeof shared)) + bool generated = generate_key_material(s, shared, sizeof(shared)); + memzero(shared, sizeof(shared)); + + if(!generated) { return false; + } - free(s->mykex); - free(s->hiskex); + if(!s->initiator && !send_sig(s)) { + return false; + } + free_sptps_kex(s->mykex); s->mykex = NULL; + + free_sptps_kex(s->hiskex); s->hiskex = NULL; // Send cipher change record - if(s->outstate && !send_ack(s)) + if(s->outstate && !send_ack(s)) { return false; + } - // TODO: only set new keys after ACK has been set/received - if(s->initiator) { - if(!chacha_poly1305_set_key(s->outcipher, s->key + CHACHA_POLY1305_KEYLEN)) - return error(s, EINVAL, "Failed to set key"); - } else { - if(!chacha_poly1305_set_key(s->outcipher, s->key)) - return error(s, EINVAL, "Failed to set key"); + if(!cipher_init(s->cipher_suite, &s->outcipher, s->key, !s->initiator)) { + return error(s, EINVAL, "Failed to initialize cipher"); } return true; @@ -321,52 +628,72 @@ static bool receive_sig(sptps_t *s, const char *data, uint16_t len) { // Force another Key EXchange (for testing purposes). bool sptps_force_kex(sptps_t *s) { - if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) + if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) { return error(s, EINVAL, "Cannot force KEX in current state"); + } s->state = SPTPS_KEX; return send_kex(s); } // Receive a handshake record. -static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) { +static bool receive_handshake(sptps_t *s, const uint8_t *data, uint16_t len) { // Only a few states to deal with handshaking. switch(s->state) { - case SPTPS_SECONDARY_KEX: - // We receive a secondary KEX request, first respond by sending our own. - if(!send_kex(s)) - return false; - case SPTPS_KEX: - // We have sent our KEX request, we expect our peer to sent one as well. - if(!receive_kex(s, data, len)) - return false; - s->state = SPTPS_SIG; - return true; - case SPTPS_SIG: - // If we already sent our secondary public ECDH key, we expect the peer to send his. - if(!receive_sig(s, data, len)) + case SPTPS_SECONDARY_KEX: + + // We receive a secondary KEX request, first respond by sending our own. + if(!send_kex(s)) { + return false; + } + + // Fall through + case SPTPS_KEX: + + // We have sent our KEX request, we expect our peer to sent one as well. + if(!receive_kex(s, data, len)) { + return false; + } + + s->state = SPTPS_SIG; + return true; + + case SPTPS_SIG: + + // If we already sent our secondary public ECDH key, we expect the peer to send his. + if(!receive_sig(s, data, len)) { + return false; + } + + if(s->outstate) { + s->state = SPTPS_ACK; + } else { + s->outstate = true; + + if(!receive_ack(s, NULL, 0)) { return false; - if(s->outstate) - s->state = SPTPS_ACK; - else { - s->outstate = true; - if(!receive_ack(s, NULL, 0)) - return false; - s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0); - s->state = SPTPS_SECONDARY_KEX; } - return true; - case SPTPS_ACK: - // We expect a handshake message to indicate transition to the new keys. - if(!receive_ack(s, data, len)) - return false; s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0); s->state = SPTPS_SECONDARY_KEX; - return true; - // TODO: split ACK into a VERify and ACK? - default: - return error(s, EIO, "Invalid session state %d", s->state); + } + + return true; + + case SPTPS_ACK: + + // We expect a handshake message to indicate transition to the new keys. + if(!receive_ack(s, data, len)) { + return false; + } + + s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0); + s->state = SPTPS_SECONDARY_KEX; + return true; + + // TODO: split ACK into a VERify and ACK? + default: + return error(s, EIO, "Invalid session state %d", s->state); } } @@ -381,112 +708,139 @@ static bool sptps_check_seqno(sptps_t *s, uint32_t seqno, bool update_state) { if(seqno >= s->inseqno + s->replaywin * 8) { // Prevent packets that jump far ahead of the queue from causing many others to be dropped. bool farfuture = s->farfuture < s->replaywin >> 2; - if (update_state) + + if(update_state) { s->farfuture++; - if(farfuture) - return error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture); + } + + if(farfuture) { + return update_state ? error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture) : false; + } // Unless we have seen lots of them, in which case we consider the others lost. - warning(s, "Lost %d packets\n", seqno - s->inseqno); - if (update_state) { + if(update_state) { + warning(s, "Lost %d packets\n", seqno - s->inseqno); + } + + if(update_state) { // Mark all packets in the replay window as being late. memset(s->late, 255, s->replaywin); } - } else if (seqno < s->inseqno) { + } else if(seqno < s->inseqno) { // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it. - if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) - return error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno); - } else if (update_state) { + if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) { + return update_state ? error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno) : false; + } + } else if(update_state) { // We missed some packets. Mark them in the bitmap as being late. - for(int i = s->inseqno; i < seqno; i++) + for(uint32_t i = s->inseqno; i < seqno; i++) { s->late[(i / 8) % s->replaywin] |= 1 << i % 8; + } } } - if (update_state) { + if(update_state) { // Mark the current packet as not being late. s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8); s->farfuture = 0; } } - if (update_state) { - if(seqno >= s->inseqno) + if(update_state) { + if(seqno >= s->inseqno) { s->inseqno = seqno + 1; + } - if(!s->inseqno) + if(!s->inseqno) { s->received = 0; - else + } else { s->received++; + } } return true; } // Check datagram for valid HMAC -bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len) { - if(!s->instate || len < 21) +bool sptps_verify_datagram(sptps_t *s, const void *vdata, size_t len) { + if(!s->instate || len < 21) { return error(s, EIO, "Received short packet"); + } + const uint8_t *data = vdata; uint32_t seqno; memcpy(&seqno, data, 4); - seqno = ntohl(seqno); - if (!sptps_check_seqno(s, seqno, false)) + + if(!sptps_check_seqno(s, seqno, false)) { return false; + } - char buffer[len]; - size_t outlen; - return chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, buffer, &outlen); + uint8_t *buffer = alloca(len); + return cipher_decrypt(s->cipher_suite, s->incipher, seqno, data + 4, len - 4, buffer, NULL); } // Receive incoming data, datagram version. -static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) { - if(len < (s->instate ? 21 : 5)) +static bool sptps_receive_data_datagram(sptps_t *s, const uint8_t *data, size_t len) { + if(len < (s->instate ? 21 : 5)) { return error(s, EIO, "Received short packet"); + } uint32_t seqno; memcpy(&seqno, data, 4); - seqno = ntohl(seqno); + data += 4; + len -= 4; if(!s->instate) { - if(seqno != s->inseqno) + if(seqno != s->inseqno) { return error(s, EIO, "Invalid packet seqno: %d != %d", seqno, s->inseqno); + } s->inseqno = seqno + 1; - uint8_t type = data[4]; + uint8_t type = *(data++); + len--; - if(type != SPTPS_HANDSHAKE) + if(type != SPTPS_HANDSHAKE) { return error(s, EIO, "Application record received before handshake finished"); + } - return receive_handshake(s, data + 5, len - 5); + return receive_handshake(s, data, len); } // Decrypt - char buffer[len]; - + uint8_t *buffer = alloca(len); size_t outlen; - if(!chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, buffer, &outlen)) + if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, data, len, buffer, &outlen)) { return error(s, EIO, "Failed to decrypt and verify packet"); + } - if(!sptps_check_seqno(s, seqno, true)) + if(!sptps_check_seqno(s, seqno, true)) { return false; + } // Append a NULL byte for safety. - buffer[len - 20] = 0; + buffer[outlen] = 0; + + data = buffer; + len = outlen; - uint8_t type = buffer[0]; + uint8_t type = *(data++); + len--; if(type < SPTPS_HANDSHAKE) { - if(!s->instate) + if(!s->instate) { return error(s, EIO, "Application record received before handshake finished"); - if(!s->receive_record(s->handle, type, buffer + 1, len - 21)) - abort(); + } + + if(!s->receive_record(s->handle, type, data, len)) { + return false; + } } else if(type == SPTPS_HANDSHAKE) { - if(!receive_handshake(s, buffer + 1, len - 21)) - abort(); + if(!receive_handshake(s, data, len)) { + return false; + } } else { return error(s, EIO, "Invalid record type %d", type); } @@ -495,126 +849,154 @@ static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len } // Receive incoming data. Check if it contains a complete record, if so, handle it. -bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { - if(!s->state) +size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) { + const uint8_t *data = vdata; + size_t total_read = 0; + + if(!s->state) { return error(s, EIO, "Invalid session state zero"); + } + + if(s->datagram) { + return sptps_receive_data_datagram(s, data, len) ? len : false; + } - if(s->datagram) - return sptps_receive_data_datagram(s, data, len); + // First read the 2 length bytes. + if(s->buflen < 2) { + size_t toread = 2 - s->buflen; - while(len) { - // First read the 2 length bytes. - if(s->buflen < 2) { - size_t toread = 2 - s->buflen; - if(toread > len) - toread = len; + if(toread > len) { + toread = len; + } - memcpy(s->inbuf + s->buflen, data, toread); + memcpy(s->inbuf + s->buflen, data, toread); - s->buflen += toread; - len -= toread; - data += toread; + total_read += toread; + s->buflen += toread; + len -= toread; + data += toread; - // Exit early if we don't have the full length. - if(s->buflen < 2) - return true; + // Exit early if we don't have the full length. + if(s->buflen < 2) { + return total_read; + } - // Get the length bytes + // Get the length bytes - memcpy(&s->reclen, s->inbuf, 2); - s->reclen = ntohs(s->reclen); + memcpy(&s->reclen, s->inbuf, 2); - // If we have the length bytes, ensure our buffer can hold the whole request. - s->inbuf = realloc(s->inbuf, s->reclen + 19UL); - if(!s->inbuf) - return error(s, errno, strerror(errno)); + // If we have the length bytes, ensure our buffer can hold the whole request. + s->inbuf = realloc(s->inbuf, s->reclen + SPTPS_OVERHEAD); - // Exit early if we have no more data to process. - if(!len) - return true; + if(!s->inbuf) { + return error(s, errno, "%s", strerror(errno)); } - // Read up to the end of the record. - size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen; - if(toread > len) - toread = len; + // Exit early if we have no more data to process. + if(!len) { + return total_read; + } + } - memcpy(s->inbuf + s->buflen, data, toread); - s->buflen += toread; - len -= toread; - data += toread; + // Read up to the end of the record. + size_t toread = s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER) - s->buflen; - // If we don't have a whole record, exit. - if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL)) - return true; + if(toread > len) { + toread = len; + } + + memcpy(s->inbuf + s->buflen, data, toread); + total_read += toread; + s->buflen += toread; - // Update sequence number. + // If we don't have a whole record, exit. + if(s->buflen < s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER)) { + return total_read; + } + + // Update sequence number. - uint32_t seqno = s->inseqno++; + uint32_t seqno = s->inseqno++; - // Check HMAC and decrypt. - if(s->instate) { - if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) - return error(s, EINVAL, "Failed to decrypt and verify record"); + // Check HMAC and decrypt. + if(s->instate) { + if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) { + return error(s, EINVAL, "Failed to decrypt and verify record"); } + } - // Append a NULL byte for safety. - s->inbuf[s->reclen + 3UL] = 0; + // Append a NULL byte for safety. + s->inbuf[s->reclen + SPTPS_HEADER] = 0; - uint8_t type = s->inbuf[2]; + uint8_t type = s->inbuf[2]; - if(type < SPTPS_HANDSHAKE) { - if(!s->instate) - return error(s, EIO, "Application record received before handshake finished"); - if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) - return false; - } else if(type == SPTPS_HANDSHAKE) { - if(!receive_handshake(s, s->inbuf + 3, s->reclen)) - return false; - } else { - return error(s, EIO, "Invalid record type %d", type); + if(type < SPTPS_HANDSHAKE) { + if(!s->instate) { + return error(s, EIO, "Application record received before handshake finished"); } - s->buflen = 0; + if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) { + return false; + } + } else if(type == SPTPS_HANDSHAKE) { + if(!receive_handshake(s, s->inbuf + 3, s->reclen)) { + return false; + } + } else { + return error(s, EIO, "Invalid record type %d", type); } - return true; + s->buflen = 0; + + return total_read; } // Start a SPTPS session. -bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t *mykey, ecdsa_t *hiskey, const void *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) { +bool sptps_start(sptps_t *s, const sptps_params_t *params) { // Initialise struct sptps - memset(s, 0, sizeof *s); + memset(s, 0, sizeof(*s)); - s->handle = handle; - s->initiator = initiator; - s->datagram = datagram; - s->mykey = mykey; - s->hiskey = hiskey; + s->handle = params->handle; + s->initiator = params->initiator; + s->datagram = params->datagram; + s->mykey = params->mykey; + s->hiskey = params->hiskey; s->replaywin = sptps_replaywin; + s->cipher_suites = params->cipher_suites ? params->cipher_suites & SPTPS_ALL_CIPHER_SUITES : SPTPS_ALL_CIPHER_SUITES; + s->preferred_suite = params->preferred_suite; + if(s->replaywin) { s->late = malloc(s->replaywin); - if(!s->late) - return error(s, errno, strerror(errno)); + + if(!s->late) { + return error(s, errno, "%s", strerror(errno)); + } + memset(s->late, 0, s->replaywin); } - s->label = malloc(labellen); - if(!s->label) - return error(s, errno, strerror(errno)); + s->labellen = params->labellen ? params->labellen : strlen(params->label); + s->label = malloc(s->labellen); + + if(!s->label) { + return error(s, errno, "%s", strerror(errno)); + } - if(!datagram) { + memcpy(s->label, params->label, s->labellen); + + if(!s->datagram) { s->inbuf = malloc(7); - if(!s->inbuf) - return error(s, errno, strerror(errno)); + + if(!s->inbuf) { + return error(s, errno, "%s", strerror(errno)); + } + s->buflen = 0; } - memcpy(s->label, label, labellen); - s->labellen = labellen; - s->send_data = send_data; - s->receive_record = receive_record; + s->send_data = params->send_data; + s->receive_record = params->receive_record; // Do first KEX immediately s->state = SPTPS_KEX; @@ -624,15 +1006,15 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_ // Stop a SPTPS session. bool sptps_stop(sptps_t *s) { // Clean up any resources. - chacha_poly1305_exit(s->incipher); - chacha_poly1305_exit(s->outcipher); + cipher_exit(s->cipher_suite, s->incipher); + cipher_exit(s->cipher_suite, s->outcipher); ecdh_free(s->ecdh); free(s->inbuf); - free(s->mykex); - free(s->hiskex); - free(s->key); + free_sptps_kex(s->mykex); + free_sptps_kex(s->hiskex); + free_sptps_key(s->key); free(s->label); free(s->late); - memset(s, 0, sizeof *s); + memset(s, 0, sizeof(*s)); return true; }