Add AES-256-GCM support to SPTPS.
[tinc] / src / sptps.c
index 35c68bc..7be8cd8 100644 (file)
 #include "random.h"
 #include "xalloc.h"
 
+#ifdef HAVE_OPENSSL
+#include <openssl/evp.h>
+#endif
+
 unsigned int sptps_replaywin = 16;
 
 /*
@@ -108,25 +112,160 @@ static void free_sptps_key(sptps_key_t *key) {
        xzfree(key, sizeof(sptps_key_t));
 }
 
+static bool cipher_init(uint8_t suite, void **ctx, const sptps_key_t *keys, bool key_half) {
+        const uint8_t *key = key_half ? keys->key1 : keys->key0;
+
+       switch(suite) {
+       case SPTPS_CHACHA_POLY1305:
+               *ctx = chacha_poly1305_init();
+               return ctx && chacha_poly1305_set_key(*ctx, key);
+
+       case SPTPS_AES256_GCM:
+#ifdef HAVE_OPENSSL
+               *ctx = EVP_CIPHER_CTX_new();
+
+               if(!ctx) {
+                       return false;
+               }
+
+               return EVP_EncryptInit_ex(*ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)
+                      && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_AEAD_SET_IVLEN, 4, NULL)
+                      && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
+#endif
+
+       default:
+               return false;
+       }
+}
+
+static void cipher_exit(uint8_t suite, void *ctx) {
+       switch(suite) {
+       case SPTPS_CHACHA_POLY1305:
+               chacha_poly1305_exit(ctx);
+               break;
+
+       case SPTPS_AES256_GCM:
+#ifdef HAVE_OPENSSL
+               EVP_CIPHER_CTX_free(ctx);
+               break;
+#endif
+
+       default:
+               break;
+       }
+}
+
+static bool cipher_encrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
+       switch(suite) {
+       case SPTPS_CHACHA_POLY1305:
+               chacha_poly1305_encrypt(ctx, seqno, in, inlen, out, outlen);
+               return true;
+
+       case SPTPS_AES256_GCM:
+#ifdef HAVE_OPENSSL
+               {
+                       if(!EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, (uint8_t *)&seqno)) {
+                               return false;
+                       }
+
+                       int outlen1 = 0, outlen2 = 0;
+
+                       if(!EVP_EncryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
+                               return false;
+                       }
+
+                       if(!EVP_EncryptFinal_ex(ctx, out + outlen1, &outlen2)) {
+                               return false;
+                       }
+
+                       outlen1 += outlen2;
+
+                       if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, out + outlen1)) {
+                               return false;
+                       }
+
+                       outlen1 += 16;
+
+                       if(outlen) {
+                               *outlen = outlen1;
+                       }
+
+                       return true;
+               }
+
+#endif
+
+       default:
+               return false;
+       }
+}
+
+static bool cipher_decrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
+       switch(suite) {
+       case SPTPS_CHACHA_POLY1305:
+               return chacha_poly1305_decrypt(ctx, seqno, in, inlen, out, outlen);
+
+       case SPTPS_AES256_GCM:
+#ifdef HAVE_OPENSSL
+               {
+                       if(inlen < 16) {
+                               return false;
+                       }
+
+                       inlen -= 16;
+
+                       if(!EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, (uint8_t *)&seqno)) {
+                               return false;
+                       }
+
+                       int outlen1 = 0, outlen2 = 0;
+
+                       if(!EVP_DecryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
+                               return false;
+                       }
+
+                       if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, (void *)(in + inlen))) {
+                               return false;
+                       }
+
+                       if(!EVP_DecryptFinal_ex(ctx, out + outlen1, &outlen2)) {
+                               return false;
+                       }
+
+                       if(outlen) {
+                               *outlen = outlen1 + outlen2;
+                       }
+
+                       return true;
+               }
+
+#endif
+
+       default:
+               return false;
+       }
+}
+
 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
-       uint8_t *buffer = alloca(len + 21UL);
-
+       uint8_t *buffer = alloca(len + SPTPS_DATAGRAM_OVERHEAD);
        // Create header with sequence number, length and record type
        uint32_t seqno = s->outseqno++;
-       uint32_t netseqno = ntohl(seqno);
 
-       memcpy(buffer, &netseqno, 4);
+       memcpy(buffer, &seqno, 4);
        buffer[4] = type;
        memcpy(buffer + 5, data, len);
 
        if(s->outstate) {
                // If first handshake has finished, encrypt and HMAC
-               chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL);
-               return s->send_data(s->handle, type, buffer, len + 21UL);
+               if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL)) {
+                       return error(s, EINVAL, "Failed to encrypt message");
+               }
+
+               return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_OVERHEAD);
        } else {
                // Otherwise send as plaintext
-               return s->send_data(s->handle, type, buffer, len + 5UL);
+               return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_HEADER);
        }
 }
 // Send a record (private version, accepts all record types, handles encryption and authentication).
@@ -135,11 +274,11 @@ static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_
                return send_record_priv_datagram(s, type, data, len);
        }
 
-       uint8_t *buffer = alloca(len + 19UL);
+       uint8_t *buffer = alloca(len + SPTPS_OVERHEAD);
 
        // Create header with sequence number, length and record type
        uint32_t seqno = s->outseqno++;
-       uint16_t netlen = htons(len);
+       uint16_t netlen = len;
 
        memcpy(buffer, &netlen, 2);
        buffer[2] = type;
@@ -147,11 +286,14 @@ static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_
 
        if(s->outstate) {
                // If first handshake has finished, encrypt and HMAC
-               chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL);
-               return s->send_data(s->handle, type, buffer, len + 19UL);
+               if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL)) {
+                       return error(s, EINVAL, "Failed to encrypt message");
+               }
+
+               return s->send_data(s->handle, type, buffer, len + SPTPS_OVERHEAD);
        } else {
                // Otherwise send as plaintext
-               return s->send_data(s->handle, type, buffer, len + 3UL);
+               return s->send_data(s->handle, type, buffer, len + SPTPS_HEADER);
        }
 }
 
@@ -181,6 +323,8 @@ static bool send_kex(sptps_t *s) {
 
        // Set version byte to zero.
        s->mykex->version = SPTPS_VERSION;
+       s->mykex->preferred_suite = s->preferred_suite;
+       s->mykex->cipher_suites = s->cipher_suites;
 
        // Create a random nonce.
        randomize(s->mykex->nonce, ECDH_SIZE);
@@ -225,16 +369,6 @@ static bool send_sig(sptps_t *s) {
 
 // Generate key material from the shared secret created from the ECDHE key exchange.
 static bool generate_key_material(sptps_t *s, const uint8_t *shared, size_t len) {
-       // Initialise cipher and digest structures if necessary
-       if(!s->outstate) {
-               s->incipher = chacha_poly1305_init();
-               s->outcipher = chacha_poly1305_init();
-
-               if(!s->incipher || !s->outcipher) {
-                       return error(s, EINVAL, "Failed to open cipher");
-               }
-       }
-
        // Allocate memory for key material
        s->key = new_sptps_key();
 
@@ -276,10 +410,8 @@ static bool receive_ack(sptps_t *s, const uint8_t *data, uint16_t len) {
                return error(s, EIO, "Invalid ACK record length");
        }
 
-       uint8_t *key = s->initiator ? s->key->key0 : s->key->key1;
-
-       if(!chacha_poly1305_set_key(s->incipher, key)) {
-               return error(s, EINVAL, "Failed to set counter");
+       if(!cipher_init(s->cipher_suite, &s->incipher, s->key, s->initiator)) {
+               return error(s, EINVAL, "Failed to initialize cipher");
        }
 
        free_sptps_key(s->key);
@@ -289,9 +421,35 @@ static bool receive_ack(sptps_t *s, const uint8_t *data, uint16_t len) {
        return true;
 }
 
+static uint8_t select_cipher_suite(uint16_t mask, uint8_t pref1, uint8_t pref2) {
+       // Check if there is a viable preference, if so select the lowest one
+       uint8_t selection = 255;
+
+       if(mask & (1U << pref1)) {
+               selection = pref1;
+       }
+
+       if(pref2 < selection && (mask & (1U << pref2))) {
+               selection = pref2;
+       }
+
+       // Otherwise, select the lowest cipher suite both sides support
+       if(selection == 255) {
+               selection = 0;
+
+               while(!(mask & 1U)) {
+                       selection++;
+                       mask >>= 1;
+               }
+       }
+
+       return selection;
+}
+
 // Receive a Key EXchange record, respond by sending a SIG record.
 static bool receive_kex(sptps_t *s, const uint8_t *data, uint16_t len) {
        // Verify length of the HELLO record
+
        if(len != sizeof(sptps_kex_t)) {
                return error(s, EIO, "Invalid KEX record length");
        }
@@ -300,6 +458,16 @@ static bool receive_kex(sptps_t *s, const uint8_t *data, uint16_t len) {
                return error(s, EINVAL, "Received incorrect version %d", *data);
        }
 
+       uint16_t suites;
+       memcpy(&suites, data + 2, 2);
+       suites &= s->cipher_suites;
+
+       if(!suites) {
+               return error(s, EIO, "No matching cipher suites");
+       }
+
+       s->cipher_suite = select_cipher_suite(suites, s->preferred_suite, data[1] & 0xf);
+
        // Make a copy of the KEX message, send_sig() and receive_sig() need it
        if(s->hiskex) {
                return error(s, EINVAL, "Received a second KEX message before first has been processed");
@@ -365,11 +533,8 @@ static bool receive_sig(sptps_t *s, const uint8_t *data, uint16_t len) {
                return false;
        }
 
-       // TODO: only set new keys after ACK has been set/received
-       uint8_t *key = s->initiator ? s->key->key1 : s->key->key0;
-
-       if(!chacha_poly1305_set_key(s->outcipher, key)) {
-               return error(s, EINVAL, "Failed to set key");
+       if(!cipher_init(s->cipher_suite, &s->outcipher, s->key, !s->initiator)) {
+               return error(s, EINVAL, "Failed to initialize cipher");
        }
 
        return true;
@@ -519,15 +684,13 @@ bool sptps_verify_datagram(sptps_t *s, const void *vdata, size_t len) {
        const uint8_t *data = vdata;
        uint32_t seqno;
        memcpy(&seqno, data, 4);
-       seqno = ntohl(seqno);
 
        if(!sptps_check_seqno(s, seqno, false)) {
                return false;
        }
 
        uint8_t *buffer = alloca(len);
-       size_t outlen;
-       return chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, buffer, &outlen);
+       return cipher_decrypt(s->cipher_suite, s->incipher, seqno, data + 4, len - 4, buffer, NULL);
 }
 
 // Receive incoming data, datagram version.
@@ -538,7 +701,6 @@ static bool sptps_receive_data_datagram(sptps_t *s, const uint8_t *data, size_t
 
        uint32_t seqno;
        memcpy(&seqno, data, 4);
-       seqno = ntohl(seqno);
        data += 4;
        len -= 4;
 
@@ -564,7 +726,7 @@ static bool sptps_receive_data_datagram(sptps_t *s, const uint8_t *data, size_t
        uint8_t *buffer = alloca(len);
        size_t outlen;
 
-       if(!chacha_poly1305_decrypt(s->incipher, seqno, data, len, buffer, &outlen)) {
+       if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, data, len, buffer, &outlen)) {
                return error(s, EIO, "Failed to decrypt and verify packet");
        }
 
@@ -636,10 +798,9 @@ size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
                // Get the length bytes
 
                memcpy(&s->reclen, s->inbuf, 2);
-               s->reclen = ntohs(s->reclen);
 
                // If we have the length bytes, ensure our buffer can hold the whole request.
-               s->inbuf = realloc(s->inbuf, s->reclen + 19UL);
+               s->inbuf = realloc(s->inbuf, s->reclen + SPTPS_OVERHEAD);
 
                if(!s->inbuf) {
                        return error(s, errno, "%s", strerror(errno));
@@ -652,7 +813,7 @@ size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
        }
 
        // Read up to the end of the record.
-       size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen;
+       size_t toread = s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER) - s->buflen;
 
        if(toread > len) {
                toread = len;
@@ -663,7 +824,7 @@ size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
        s->buflen += toread;
 
        // If we don't have a whole record, exit.
-       if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL)) {
+       if(s->buflen < s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER)) {
                return total_read;
        }
 
@@ -673,13 +834,13 @@ size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
 
        // Check HMAC and decrypt.
        if(s->instate) {
-               if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) {
+               if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) {
                        return error(s, EINVAL, "Failed to decrypt and verify record");
                }
        }
 
        // Append a NULL byte for safety.
-       s->inbuf[s->reclen + 3UL] = 0;
+       s->inbuf[s->reclen + SPTPS_HEADER] = 0;
 
        uint8_t type = s->inbuf[2];
 
@@ -705,16 +866,18 @@ size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
 }
 
 // Start a SPTPS session.
-bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t *mykey, ecdsa_t *hiskey, const void *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
+bool sptps_start(sptps_t *s, const sptps_params_t *params) {
        // Initialise struct sptps
        memset(s, 0, sizeof(*s));
 
-       s->handle = handle;
-       s->initiator = initiator;
-       s->datagram = datagram;
-       s->mykey = mykey;
-       s->hiskey = hiskey;
+       s->handle = params->handle;
+       s->initiator = params->initiator;
+       s->datagram = params->datagram;
+       s->mykey = params->mykey;
+       s->hiskey = params->hiskey;
        s->replaywin = sptps_replaywin;
+       s->cipher_suites = params->cipher_suites ? params->cipher_suites & SPTPS_ALL_CIPHER_SUITES : SPTPS_ALL_CIPHER_SUITES;
+       s->preferred_suite = params->preferred_suite;
 
        if(s->replaywin) {
                s->late = malloc(s->replaywin);
@@ -726,13 +889,16 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_
                memset(s->late, 0, s->replaywin);
        }
 
-       s->label = malloc(labellen);
+       s->labellen = params->labellen ? params->labellen : strlen(params->label);
+       s->label = malloc(s->labellen);
 
        if(!s->label) {
                return error(s, errno, "%s", strerror(errno));
        }
 
-       if(!datagram) {
+       memcpy(s->label, params->label, s->labellen);
+
+       if(!s->datagram) {
                s->inbuf = malloc(7);
 
                if(!s->inbuf) {
@@ -742,11 +908,9 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_
                s->buflen = 0;
        }
 
-       memcpy(s->label, label, labellen);
-       s->labellen = labellen;
 
-       s->send_data = send_data;
-       s->receive_record = receive_record;
+       s->send_data = params->send_data;
+       s->receive_record = params->receive_record;
 
        // Do first KEX immediately
        s->state = SPTPS_KEX;
@@ -756,8 +920,8 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_
 // Stop a SPTPS session.
 bool sptps_stop(sptps_t *s) {
        // Clean up any resources.
-       chacha_poly1305_exit(s->incipher);
-       chacha_poly1305_exit(s->outcipher);
+       cipher_exit(s->cipher_suite, s->incipher);
+       cipher_exit(s->cipher_suite, s->outcipher);
        ecdh_free(s->ecdh);
        free(s->inbuf);
        free_sptps_kex(s->mykex);