Make the ExperimentalProtocol option obsolete.
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011-2021 Guus Sliepen <guus@tinc-vpn.org>,
4                   2010      Brandon L. Black <blblack@gmail.com>
5
6     This program is free software; you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation; either version 2 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License along
17     with this program; if not, write to the Free Software Foundation, Inc.,
18     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19 */
20
21 #include "system.h"
22
23 #include "chacha-poly1305/chachapoly.h"
24 #include "ecdh.h"
25 #include "ecdsa.h"
26 #include "prf.h"
27 #include "sptps.h"
28 #include "random.h"
29 #include "xalloc.h"
30
31 #ifdef HAVE_OPENSSL
32 #include <openssl/evp.h>
33 #endif
34
35 #define CIPHER_KEYLEN 64
36
37 unsigned int sptps_replaywin = 16;
38
39 /*
40    Nonce MUST be exchanged first (done)
41    Signatures MUST be done over both nonces, to guarantee the signature is fresh
42    Otherwise: if ECDHE key of one side is compromised, it can be reused!
43
44    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
45
46    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
47
48    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of Ed25519 over the whole ECDHE exchange?)
49
50    Explicit close message needs to be added.
51
52    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
53
54    Use counter mode instead of OFB. (done)
55
56    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
57 */
58
59 void sptps_log_quiet(sptps_t *s, int s_errno, const char *format, va_list ap) {
60         (void)s;
61         (void)s_errno;
62         (void)format;
63         (void)ap;
64 }
65
66 void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) {
67         (void)s;
68         (void)s_errno;
69
70         vfprintf(stderr, format, ap);
71         fputc('\n', stderr);
72 }
73
74 void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_stderr;
75
76 // Log an error message.
77 static bool error(sptps_t *s, int s_errno, const char *format, ...) ATTR_FORMAT(printf, 3, 4);
78 static bool error(sptps_t *s, int s_errno, const char *format, ...) {
79         (void)s;
80         (void)s_errno;
81
82         if(format) {
83                 va_list ap;
84                 va_start(ap, format);
85                 sptps_log(s, s_errno, format, ap);
86                 va_end(ap);
87         }
88
89         errno = s_errno;
90         return false;
91 }
92
93 static void warning(sptps_t *s, const char *format, ...) ATTR_FORMAT(printf, 2, 3);
94 static void warning(sptps_t *s, const char *format, ...) {
95         va_list ap;
96         va_start(ap, format);
97         sptps_log(s, 0, format, ap);
98         va_end(ap);
99 }
100
101 static sptps_kex_t *new_sptps_kex(void) {
102         return xzalloc(sizeof(sptps_kex_t));
103 }
104
105 static void free_sptps_kex(sptps_kex_t *kex) {
106         xzfree(kex, sizeof(sptps_kex_t));
107 }
108
109 static sptps_key_t *new_sptps_key(void) {
110         return xzalloc(sizeof(sptps_key_t));
111 }
112
113 static void free_sptps_key(sptps_key_t *key) {
114         xzfree(key, sizeof(sptps_key_t));
115 }
116
117 static bool cipher_init(uint8_t suite, void **ctx, const sptps_key_t *keys, bool key_half) {
118         const uint8_t *key = key_half ? keys->key1 : keys->key0;
119
120         switch(suite) {
121 #ifndef HAVE_OPENSSL
122
123         case SPTPS_CHACHA_POLY1305:
124                 *ctx = malloc(sizeof(struct chachapoly_ctx));
125                 return *ctx && chachapoly_init(*ctx, key, 256) == CHACHAPOLY_OK;
126
127 #else
128
129         case SPTPS_CHACHA_POLY1305:
130 #ifdef EVP_F_EVP_AEAD_CTX_INIT
131                 *ctx = malloc(sizeof(EVP_AEAD_CTX));
132
133                 return *ctx && EVP_AEAD_CTX_init(*ctx, EVP_aead_chacha20_poly1305(), key + (key_half ? CIPHER_KEYLEN : 0), 32, 16, NULL);
134 #else
135                 *ctx = EVP_CIPHER_CTX_new();
136
137                 return *ctx
138                        && EVP_EncryptInit_ex(*ctx, EVP_chacha20_poly1305(), NULL, NULL, NULL)
139                        && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_GCM_SET_IVLEN, 12, NULL)
140                        && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
141 #endif
142
143         case SPTPS_AES256_GCM:
144                 *ctx = EVP_CIPHER_CTX_new();
145
146                 return *ctx
147                        && EVP_EncryptInit_ex(*ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)
148                        && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_GCM_SET_IVLEN, 12, NULL)
149                        && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
150 #endif
151
152         default:
153                 return false;
154         }
155 }
156
157 static void cipher_exit(uint8_t suite, void *ctx) {
158         switch(suite) {
159 #ifndef HAVE_OPENSSL
160
161         case SPTPS_CHACHA_POLY1305:
162                 free(ctx);
163                 break;
164
165 #else
166
167         case SPTPS_CHACHA_POLY1305:
168 #ifdef EVP_F_EVP_AEAD_CTX_INIT
169                 EVP_AEAD_CTX_cleanup(ctx);
170                 free(ctx);
171                 break;
172 #endif
173
174         case SPTPS_AES256_GCM:
175                 EVP_CIPHER_CTX_free(ctx);
176                 break;
177 #endif
178
179         default:
180                 break;
181         }
182 }
183
184 static bool cipher_encrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
185         switch(suite) {
186 #ifndef HAVE_OPENSSL
187
188         case SPTPS_CHACHA_POLY1305: {
189                 if(chachapoly_crypt(ctx, nonce, NULL, 0, (void *)in, inlen, out, out + inlen, 16, 1) != CHACHAPOLY_OK) {
190                         return false;
191                 }
192
193                 if(outlen) {
194                         *outlen = inlen + 16;
195                 }
196
197                 return true;
198         }
199
200 #else
201
202         case SPTPS_CHACHA_POLY1305:
203 #ifdef EVP_F_EVP_AEAD_CTX_INIT
204                 {
205                         size_t outlen1;
206
207                         if(!EVP_AEAD_CTX_seal(ctx, out, &outlen1, inlen + 16, nonce, sizeof(nonce), in, inlen, NULL, 0)) {
208                                 return false;
209                         }
210
211                         if(outlen) {
212                                 *outlen = outlen1;
213                         }
214
215                         return true;
216                 }
217
218 #endif
219
220         case SPTPS_AES256_GCM: {
221                 uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24};
222
223                 if(!EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, nonce)) {
224                         return false;
225                 }
226
227                 int outlen1 = 0, outlen2 = 0;
228
229                 if(!EVP_EncryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
230                         return false;
231                 }
232
233                 if(!EVP_EncryptFinal_ex(ctx, out + outlen1, &outlen2)) {
234                         return false;
235                 }
236
237                 outlen1 += outlen2;
238
239                 if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, out + outlen1)) {
240                         return false;
241                 }
242
243                 outlen1 += 16;
244
245                 if(outlen) {
246                         *outlen = outlen1;
247                 }
248
249                 return true;
250         }
251
252 #endif
253
254         default:
255                 return false;
256         }
257 }
258
259 static bool cipher_decrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
260         if(inlen < 16) {
261                 return false;
262         }
263
264         inlen -= 16;
265
266         switch(suite) {
267 #ifndef HAVE_OPENSSL
268
269         case SPTPS_CHACHA_POLY1305:
270                 if(chachapoly_crypt(ctx, nonce, NULL, 0, (void *)in, inlen, out, (void *)(in + inlen), 16, 0) != CHACHAPOLY_OK) {
271                         return false;
272                 }
273
274                 if(outlen) {
275                         *outlen = inlen;
276                 }
277
278                 return true;
279
280 #else
281
282         case SPTPS_CHACHA_POLY1305:
283 #ifdef EVP_F_EVP_AEAD_CTX_INIT
284                 {
285                         size_t outlen1;
286
287                         if(!EVP_AEAD_CTX_open(ctx, out, &outlen1, inlen, nonce, sizeof(nonce), in, inlen + 16, NULL, 0)) {
288                                 return false;
289                         }
290
291                         if(outlen) {
292                                 *outlen = outlen1;
293                         }
294
295                         return true;
296                 }
297
298 #endif
299
300         case SPTPS_AES256_GCM: {
301                 uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24};
302
303                 if(!EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, nonce)) {
304                         return false;
305                 }
306
307                 int outlen1 = 0, outlen2 = 0;
308
309                 if(!EVP_DecryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
310                         return false;
311                 }
312
313                 if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, (void *)(in + inlen))) {
314                         return false;
315                 }
316
317                 if(!EVP_DecryptFinal_ex(ctx, out + outlen1, &outlen2)) {
318                         return false;
319                 }
320
321                 if(outlen) {
322                         *outlen = outlen1 + outlen2;
323                 }
324
325                 return true;
326         }
327
328 #endif
329
330         default:
331                 return false;
332         }
333 }
334
335 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
336 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
337         uint8_t *buffer = alloca(len + SPTPS_DATAGRAM_OVERHEAD);
338         // Create header with sequence number, length and record type
339         uint32_t seqno = s->outseqno++;
340
341         memcpy(buffer, &seqno, 4);
342         buffer[4] = type;
343         memcpy(buffer + 5, data, len);
344
345         if(s->outstate) {
346                 // If first handshake has finished, encrypt and HMAC
347                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL)) {
348                         return error(s, EINVAL, "Failed to encrypt message");
349                 }
350
351                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_OVERHEAD);
352         } else {
353                 // Otherwise send as plaintext
354                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_HEADER);
355         }
356 }
357 // Send a record (private version, accepts all record types, handles encryption and authentication).
358 static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
359         if(s->datagram) {
360                 return send_record_priv_datagram(s, type, data, len);
361         }
362
363         uint8_t *buffer = alloca(len + SPTPS_OVERHEAD);
364
365         // Create header with sequence number, length and record type
366         uint32_t seqno = s->outseqno++;
367         uint16_t netlen = len;
368
369         memcpy(buffer, &netlen, 2);
370         buffer[2] = type;
371         memcpy(buffer + 3, data, len);
372
373         if(s->outstate) {
374                 // If first handshake has finished, encrypt and HMAC
375                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL)) {
376                         return error(s, EINVAL, "Failed to encrypt message");
377                 }
378
379                 return s->send_data(s->handle, type, buffer, len + SPTPS_OVERHEAD);
380         } else {
381                 // Otherwise send as plaintext
382                 return s->send_data(s->handle, type, buffer, len + SPTPS_HEADER);
383         }
384 }
385
386 // Send an application record.
387 bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
388         // Sanity checks: application cannot send data before handshake is finished,
389         // and only record types 0..127 are allowed.
390         if(!s->outstate) {
391                 return error(s, EINVAL, "Handshake phase not finished yet");
392         }
393
394         if(type >= SPTPS_HANDSHAKE) {
395                 return error(s, EINVAL, "Invalid application record type");
396         }
397
398         return send_record_priv(s, type, data, len);
399 }
400
401 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
402 static bool send_kex(sptps_t *s) {
403         // Make room for our KEX message, which we will keep around since send_sig() needs it.
404         if(s->mykex) {
405                 return false;
406         }
407
408         s->mykex = new_sptps_kex();
409
410         // Set version byte to zero.
411         s->mykex->version = SPTPS_VERSION;
412         s->mykex->preferred_suite = s->preferred_suite;
413         s->mykex->cipher_suites = s->cipher_suites;
414
415         // Create a random nonce.
416         randomize(s->mykex->nonce, ECDH_SIZE);
417
418         // Create a new ECDH public key.
419         if(!(s->ecdh = ecdh_generate_public(s->mykex->pubkey))) {
420                 return error(s, EINVAL, "Failed to generate ECDH public key");
421         }
422
423         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, sizeof(sptps_kex_t));
424 }
425
426 static size_t sigmsg_len(size_t labellen) {
427         return 1 + 2 * sizeof(sptps_kex_t) + labellen;
428 }
429
430 static void fill_msg(uint8_t *msg, bool initiator, const sptps_kex_t *kex0, const sptps_kex_t *kex1, const sptps_t *s) {
431         *msg = initiator, msg++;
432         memcpy(msg, kex0, sizeof(*kex0)), msg += sizeof(*kex0);
433         memcpy(msg, kex1, sizeof(*kex1)), msg += sizeof(*kex1);
434         memcpy(msg, s->label, s->labellen);
435 }
436
437 // Send a SIGnature record, containing an Ed25519 signature over both KEX records.
438 static bool send_sig(sptps_t *s) {
439         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
440         size_t msglen = sigmsg_len(s->labellen);
441         uint8_t *msg = alloca(msglen);
442         fill_msg(msg, s->initiator, s->mykex, s->hiskex, s);
443
444         // Sign the result.
445         size_t siglen = ecdsa_size(s->mykey);
446         uint8_t *sig = alloca(siglen);
447
448         if(!ecdsa_sign(s->mykey, msg, msglen, sig)) {
449                 return error(s, EINVAL, "Failed to sign SIG record");
450         }
451
452         // Send the SIG exchange record.
453         return send_record_priv(s, SPTPS_HANDSHAKE, sig, siglen);
454 }
455
456 // Generate key material from the shared secret created from the ECDHE key exchange.
457 static bool generate_key_material(sptps_t *s, const uint8_t *shared, size_t len) {
458         // Allocate memory for key material
459         s->key = new_sptps_key();
460
461         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
462         const size_t msglen = sizeof("key expansion") - 1;
463         const size_t seedlen = msglen + s->labellen + ECDH_SIZE * 2;
464         uint8_t *seed = alloca(seedlen);
465
466         uint8_t *ptr = seed;
467         memcpy(ptr, "key expansion", msglen);
468         ptr += msglen;
469
470         memcpy(ptr, (s->initiator ? s->mykex : s->hiskex)->nonce, ECDH_SIZE);
471         ptr += ECDH_SIZE;
472
473         memcpy(ptr, (s->initiator ? s->hiskex : s->mykex)->nonce, ECDH_SIZE);
474         ptr += ECDH_SIZE;
475
476         memcpy(ptr, s->label, s->labellen);
477
478         // Use PRF to generate the key material
479         if(!prf(shared, len, seed, seedlen, s->key->both, sizeof(sptps_key_t))) {
480                 return error(s, EINVAL, "Failed to generate key material");
481         }
482
483         return true;
484 }
485
486 // Send an ACKnowledgement record.
487 static bool send_ack(sptps_t *s) {
488         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
489 }
490
491 // Receive an ACKnowledgement record.
492 static bool receive_ack(sptps_t *s, const uint8_t *data, uint16_t len) {
493         (void)data;
494
495         if(len) {
496                 return error(s, EIO, "Invalid ACK record length");
497         }
498
499         if(!cipher_init(s->cipher_suite, &s->incipher, s->key, s->initiator)) {
500                 return error(s, EINVAL, "Failed to initialize cipher");
501         }
502
503         free_sptps_key(s->key);
504         s->key = NULL;
505         s->instate = true;
506
507         return true;
508 }
509
510 static uint8_t select_cipher_suite(uint16_t mask, uint8_t pref1, uint8_t pref2) {
511         // Check if there is a viable preference, if so select the lowest one
512         uint8_t selection = 255;
513
514         if(mask & (1U << pref1)) {
515                 selection = pref1;
516         }
517
518         if(pref2 < selection && (mask & (1U << pref2))) {
519                 selection = pref2;
520         }
521
522         // Otherwise, select the lowest cipher suite both sides support
523         if(selection == 255) {
524                 selection = 0;
525
526                 while(!(mask & 1U)) {
527                         selection++;
528                         mask >>= 1;
529                 }
530         }
531
532         return selection;
533 }
534
535 // Receive a Key EXchange record, respond by sending a SIG record.
536 static bool receive_kex(sptps_t *s, const uint8_t *data, uint16_t len) {
537         // Verify length of the HELLO record
538
539         if(len != sizeof(sptps_kex_t)) {
540                 return error(s, EIO, "Invalid KEX record length");
541         }
542
543         if(*data != SPTPS_VERSION) {
544                 return error(s, EINVAL, "Received incorrect version %d", *data);
545         }
546
547         uint16_t suites;
548         memcpy(&suites, data + 2, 2);
549         suites &= s->cipher_suites;
550
551         if(!suites) {
552                 return error(s, EIO, "No matching cipher suites");
553         }
554
555         s->cipher_suite = select_cipher_suite(suites, s->preferred_suite, data[1] & 0xf);
556
557         // Make a copy of the KEX message, send_sig() and receive_sig() need it
558         if(s->hiskex) {
559                 return error(s, EINVAL, "Received a second KEX message before first has been processed");
560         }
561
562         s->hiskex = new_sptps_kex();
563         memcpy(s->hiskex, data, sizeof(sptps_kex_t));
564
565         if(s->initiator) {
566                 return send_sig(s);
567         } else {
568                 return true;
569         }
570 }
571
572 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
573 static bool receive_sig(sptps_t *s, const uint8_t *data, uint16_t len) {
574         // Verify length of KEX record.
575         if(len != ecdsa_size(s->hiskey)) {
576                 return error(s, EIO, "Invalid KEX record length");
577         }
578
579         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
580         const size_t msglen = sigmsg_len(s->labellen);
581         uint8_t *msg = alloca(msglen);
582         fill_msg(msg, !s->initiator, s->hiskex, s->mykex, s);
583
584         // Verify signature.
585         if(!ecdsa_verify(s->hiskey, msg, msglen, data)) {
586                 return error(s, EIO, "Failed to verify SIG record");
587         }
588
589         // Compute shared secret.
590         uint8_t shared[ECDH_SHARED_SIZE];
591
592         if(!ecdh_compute_shared(s->ecdh, s->hiskex->pubkey, shared)) {
593                 memzero(shared, sizeof(shared));
594                 return error(s, EINVAL, "Failed to compute ECDH shared secret");
595         }
596
597         s->ecdh = NULL;
598
599         // Generate key material from shared secret.
600         bool generated = generate_key_material(s, shared, sizeof(shared));
601         memzero(shared, sizeof(shared));
602
603         if(!generated) {
604                 return false;
605         }
606
607         if(!s->initiator && !send_sig(s)) {
608                 return false;
609         }
610
611         free_sptps_kex(s->mykex);
612         s->mykex = NULL;
613
614         free_sptps_kex(s->hiskex);
615         s->hiskex = NULL;
616
617         // Send cipher change record
618         if(s->outstate && !send_ack(s)) {
619                 return false;
620         }
621
622         if(!cipher_init(s->cipher_suite, &s->outcipher, s->key, !s->initiator)) {
623                 return error(s, EINVAL, "Failed to initialize cipher");
624         }
625
626         return true;
627 }
628
629 // Force another Key EXchange (for testing purposes).
630 bool sptps_force_kex(sptps_t *s) {
631         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) {
632                 return error(s, EINVAL, "Cannot force KEX in current state");
633         }
634
635         s->state = SPTPS_KEX;
636         return send_kex(s);
637 }
638
639 // Receive a handshake record.
640 static bool receive_handshake(sptps_t *s, const uint8_t *data, uint16_t len) {
641         // Only a few states to deal with handshaking.
642         switch(s->state) {
643         case SPTPS_SECONDARY_KEX:
644
645                 // We receive a secondary KEX request, first respond by sending our own.
646                 if(!send_kex(s)) {
647                         return false;
648                 }
649
650         // Fall through
651         case SPTPS_KEX:
652
653                 // We have sent our KEX request, we expect our peer to sent one as well.
654                 if(!receive_kex(s, data, len)) {
655                         return false;
656                 }
657
658                 s->state = SPTPS_SIG;
659                 return true;
660
661         case SPTPS_SIG:
662
663                 // If we already sent our secondary public ECDH key, we expect the peer to send his.
664                 if(!receive_sig(s, data, len)) {
665                         return false;
666                 }
667
668                 if(s->outstate) {
669                         s->state = SPTPS_ACK;
670                 } else {
671                         s->outstate = true;
672
673                         if(!receive_ack(s, NULL, 0)) {
674                                 return false;
675                         }
676
677                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
678                         s->state = SPTPS_SECONDARY_KEX;
679                 }
680
681                 return true;
682
683         case SPTPS_ACK:
684
685                 // We expect a handshake message to indicate transition to the new keys.
686                 if(!receive_ack(s, data, len)) {
687                         return false;
688                 }
689
690                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
691                 s->state = SPTPS_SECONDARY_KEX;
692                 return true;
693
694         // TODO: split ACK into a VERify and ACK?
695         default:
696                 return error(s, EIO, "Invalid session state %d", s->state);
697         }
698 }
699
700 static bool sptps_check_seqno(sptps_t *s, uint32_t seqno, bool update_state) {
701         // Replay protection using a sliding window of configurable size.
702         // s->inseqno is expected sequence number
703         // seqno is received sequence number
704         // s->late[] is a circular buffer, a 1 bit means a packet has not been received yet
705         // The circular buffer contains bits for sequence numbers from s->inseqno - s->replaywin * 8 to (but excluding) s->inseqno.
706         if(s->replaywin) {
707                 if(seqno != s->inseqno) {
708                         if(seqno >= s->inseqno + s->replaywin * 8) {
709                                 // Prevent packets that jump far ahead of the queue from causing many others to be dropped.
710                                 bool farfuture = s->farfuture < s->replaywin >> 2;
711
712                                 if(update_state) {
713                                         s->farfuture++;
714                                 }
715
716                                 if(farfuture) {
717                                         return update_state ? error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture) : false;
718                                 }
719
720                                 // Unless we have seen lots of them, in which case we consider the others lost.
721                                 if(update_state) {
722                                         warning(s, "Lost %d packets\n", seqno - s->inseqno);
723                                 }
724
725                                 if(update_state) {
726                                         // Mark all packets in the replay window as being late.
727                                         memset(s->late, 255, s->replaywin);
728                                 }
729                         } else if(seqno < s->inseqno) {
730                                 // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it.
731                                 if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) {
732                                         return update_state ? error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno) : false;
733                                 }
734                         } else if(update_state) {
735                                 // We missed some packets. Mark them in the bitmap as being late.
736                                 for(uint32_t i = s->inseqno; i < seqno; i++) {
737                                         s->late[(i / 8) % s->replaywin] |= 1 << i % 8;
738                                 }
739                         }
740                 }
741
742                 if(update_state) {
743                         // Mark the current packet as not being late.
744                         s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8);
745                         s->farfuture = 0;
746                 }
747         }
748
749         if(update_state) {
750                 if(seqno >= s->inseqno) {
751                         s->inseqno = seqno + 1;
752                 }
753
754                 if(!s->inseqno) {
755                         s->received = 0;
756                 } else {
757                         s->received++;
758                 }
759         }
760
761         return true;
762 }
763
764 // Check datagram for valid HMAC
765 bool sptps_verify_datagram(sptps_t *s, const void *vdata, size_t len) {
766         if(!s->instate || len < 21) {
767                 return error(s, EIO, "Received short packet");
768         }
769
770         const uint8_t *data = vdata;
771         uint32_t seqno;
772         memcpy(&seqno, data, 4);
773
774         if(!sptps_check_seqno(s, seqno, false)) {
775                 return false;
776         }
777
778         uint8_t *buffer = alloca(len);
779         return cipher_decrypt(s->cipher_suite, s->incipher, seqno, data + 4, len - 4, buffer, NULL);
780 }
781
782 // Receive incoming data, datagram version.
783 static bool sptps_receive_data_datagram(sptps_t *s, const uint8_t *data, size_t len) {
784         if(len < (s->instate ? 21 : 5)) {
785                 return error(s, EIO, "Received short packet");
786         }
787
788         uint32_t seqno;
789         memcpy(&seqno, data, 4);
790         data += 4;
791         len -= 4;
792
793         if(!s->instate) {
794                 if(seqno != s->inseqno) {
795                         return error(s, EIO, "Invalid packet seqno: %d != %d", seqno, s->inseqno);
796                 }
797
798                 s->inseqno = seqno + 1;
799
800                 uint8_t type = *(data++);
801                 len--;
802
803                 if(type != SPTPS_HANDSHAKE) {
804                         return error(s, EIO, "Application record received before handshake finished");
805                 }
806
807                 return receive_handshake(s, data, len);
808         }
809
810         // Decrypt
811
812         uint8_t *buffer = alloca(len);
813         size_t outlen;
814
815         if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, data, len, buffer, &outlen)) {
816                 return error(s, EIO, "Failed to decrypt and verify packet");
817         }
818
819         if(!sptps_check_seqno(s, seqno, true)) {
820                 return false;
821         }
822
823         // Append a NULL byte for safety.
824         buffer[outlen] = 0;
825
826         data = buffer;
827         len = outlen;
828
829         uint8_t type = *(data++);
830         len--;
831
832         if(type < SPTPS_HANDSHAKE) {
833                 if(!s->instate) {
834                         return error(s, EIO, "Application record received before handshake finished");
835                 }
836
837                 if(!s->receive_record(s->handle, type, data, len)) {
838                         return false;
839                 }
840         } else if(type == SPTPS_HANDSHAKE) {
841                 if(!receive_handshake(s, data, len)) {
842                         return false;
843                 }
844         } else {
845                 return error(s, EIO, "Invalid record type %d", type);
846         }
847
848         return true;
849 }
850
851 // Receive incoming data. Check if it contains a complete record, if so, handle it.
852 size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
853         const uint8_t *data = vdata;
854         size_t total_read = 0;
855
856         if(!s->state) {
857                 return error(s, EIO, "Invalid session state zero");
858         }
859
860         if(s->datagram) {
861                 return sptps_receive_data_datagram(s, data, len) ? len : false;
862         }
863
864         // First read the 2 length bytes.
865         if(s->buflen < 2) {
866                 size_t toread = 2 - s->buflen;
867
868                 if(toread > len) {
869                         toread = len;
870                 }
871
872                 memcpy(s->inbuf + s->buflen, data, toread);
873
874                 total_read += toread;
875                 s->buflen += toread;
876                 len -= toread;
877                 data += toread;
878
879                 // Exit early if we don't have the full length.
880                 if(s->buflen < 2) {
881                         return total_read;
882                 }
883
884                 // Get the length bytes
885
886                 memcpy(&s->reclen, s->inbuf, 2);
887
888                 // If we have the length bytes, ensure our buffer can hold the whole request.
889                 s->inbuf = realloc(s->inbuf, s->reclen + SPTPS_OVERHEAD);
890
891                 if(!s->inbuf) {
892                         return error(s, errno, "%s", strerror(errno));
893                 }
894
895                 // Exit early if we have no more data to process.
896                 if(!len) {
897                         return total_read;
898                 }
899         }
900
901         // Read up to the end of the record.
902         size_t toread = s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER) - s->buflen;
903
904         if(toread > len) {
905                 toread = len;
906         }
907
908         memcpy(s->inbuf + s->buflen, data, toread);
909         total_read += toread;
910         s->buflen += toread;
911
912         // If we don't have a whole record, exit.
913         if(s->buflen < s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER)) {
914                 return total_read;
915         }
916
917         // Update sequence number.
918
919         uint32_t seqno = s->inseqno++;
920
921         // Check HMAC and decrypt.
922         if(s->instate) {
923                 if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) {
924                         return error(s, EINVAL, "Failed to decrypt and verify record");
925                 }
926         }
927
928         // Append a NULL byte for safety.
929         s->inbuf[s->reclen + SPTPS_HEADER] = 0;
930
931         uint8_t type = s->inbuf[2];
932
933         if(type < SPTPS_HANDSHAKE) {
934                 if(!s->instate) {
935                         return error(s, EIO, "Application record received before handshake finished");
936                 }
937
938                 if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) {
939                         return false;
940                 }
941         } else if(type == SPTPS_HANDSHAKE) {
942                 if(!receive_handshake(s, s->inbuf + 3, s->reclen)) {
943                         return false;
944                 }
945         } else {
946                 return error(s, EIO, "Invalid record type %d", type);
947         }
948
949         s->buflen = 0;
950
951         return total_read;
952 }
953
954 // Start a SPTPS session.
955 bool sptps_start(sptps_t *s, const sptps_params_t *params) {
956         // Initialise struct sptps
957         memset(s, 0, sizeof(*s));
958
959         s->handle = params->handle;
960         s->initiator = params->initiator;
961         s->datagram = params->datagram;
962         s->mykey = params->mykey;
963         s->hiskey = params->hiskey;
964         s->replaywin = sptps_replaywin;
965         s->cipher_suites = params->cipher_suites ? params->cipher_suites & SPTPS_ALL_CIPHER_SUITES : SPTPS_ALL_CIPHER_SUITES;
966         s->preferred_suite = params->preferred_suite;
967
968         if(s->replaywin) {
969                 s->late = malloc(s->replaywin);
970
971                 if(!s->late) {
972                         return error(s, errno, "%s", strerror(errno));
973                 }
974
975                 memset(s->late, 0, s->replaywin);
976         }
977
978         s->labellen = params->labellen ? params->labellen : strlen(params->label);
979         s->label = malloc(s->labellen);
980
981         if(!s->label) {
982                 return error(s, errno, "%s", strerror(errno));
983         }
984
985         memcpy(s->label, params->label, s->labellen);
986
987         if(!s->datagram) {
988                 s->inbuf = malloc(7);
989
990                 if(!s->inbuf) {
991                         return error(s, errno, "%s", strerror(errno));
992                 }
993
994                 s->buflen = 0;
995         }
996
997
998         s->send_data = params->send_data;
999         s->receive_record = params->receive_record;
1000
1001         // Do first KEX immediately
1002         s->state = SPTPS_KEX;
1003         return send_kex(s);
1004 }
1005
1006 // Stop a SPTPS session.
1007 bool sptps_stop(sptps_t *s) {
1008         // Clean up any resources.
1009         cipher_exit(s->cipher_suite, s->incipher);
1010         cipher_exit(s->cipher_suite, s->outcipher);
1011         ecdh_free(s->ecdh);
1012         free(s->inbuf);
1013         free_sptps_kex(s->mykex);
1014         free_sptps_kex(s->hiskex);
1015         free_sptps_key(s->key);
1016         free(s->label);
1017         free(s->late);
1018         memset(s, 0, sizeof(*s));
1019         return true;
1020 }