de4ef6ee68f51e0d94ac976618e63fdcd8015919
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011-2021 Guus Sliepen <guus@tinc-vpn.org>,
4                   2010      Brandon L. Black <blblack@gmail.com>
5
6     This program is free software; you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation; either version 2 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License along
17     with this program; if not, write to the Free Software Foundation, Inc.,
18     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19 */
20
21 #include "system.h"
22
23 #include "chacha-poly1305/chachapoly.h"
24 #include "ecdh.h"
25 #include "ecdsa.h"
26 #include "prf.h"
27 #include "sptps.h"
28 #include "random.h"
29 #include "xalloc.h"
30
31 #ifdef HAVE_OPENSSL
32 #include <openssl/evp.h>
33 #endif
34
35 #define CIPHER_KEYLEN 64
36
37 unsigned int sptps_replaywin = 16;
38
39 /*
40    Nonce MUST be exchanged first (done)
41    Signatures MUST be done over both nonces, to guarantee the signature is fresh
42    Otherwise: if ECDHE key of one side is compromised, it can be reused!
43
44    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
45
46    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
47
48    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of Ed25519 over the whole ECDHE exchange?)
49
50    Explicit close message needs to be added.
51
52    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
53
54    Use counter mode instead of OFB. (done)
55
56    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
57 */
58
59 void sptps_log_quiet(sptps_t *s, int s_errno, const char *format, va_list ap) {
60         (void)s;
61         (void)s_errno;
62         (void)format;
63         (void)ap;
64 }
65
66 void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) {
67         (void)s;
68         (void)s_errno;
69
70         vfprintf(stderr, format, ap);
71         fputc('\n', stderr);
72 }
73
74 void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_stderr;
75
76 // Log an error message.
77 static bool error(sptps_t *s, int s_errno, const char *format, ...) ATTR_FORMAT(printf, 3, 4);
78 static bool error(sptps_t *s, int s_errno, const char *format, ...) {
79         (void)s;
80         (void)s_errno;
81
82         if(format) {
83                 va_list ap;
84                 va_start(ap, format);
85                 sptps_log(s, s_errno, format, ap);
86                 va_end(ap);
87         }
88
89         errno = s_errno;
90         return false;
91 }
92
93 static void warning(sptps_t *s, const char *format, ...) ATTR_FORMAT(printf, 2, 3);
94 static void warning(sptps_t *s, const char *format, ...) {
95         va_list ap;
96         va_start(ap, format);
97         sptps_log(s, 0, format, ap);
98         va_end(ap);
99 }
100
101 static sptps_kex_t *new_sptps_kex(void) {
102         return xzalloc(sizeof(sptps_kex_t));
103 }
104
105 static void free_sptps_kex(sptps_kex_t *kex) {
106         xzfree(kex, sizeof(sptps_kex_t));
107 }
108
109 static sptps_key_t *new_sptps_key(void) {
110         return xzalloc(sizeof(sptps_key_t));
111 }
112
113 static void free_sptps_key(sptps_key_t *key) {
114         xzfree(key, sizeof(sptps_key_t));
115 }
116
117 static bool cipher_init(uint8_t suite, void **ctx, const sptps_key_t *keys, bool key_half) {
118         const uint8_t *key = key_half ? keys->key1 : keys->key0;
119
120         switch(suite) {
121 #ifndef HAVE_OPENSSL
122
123         case SPTPS_CHACHA_POLY1305:
124                 *ctx = malloc(sizeof(struct chachapoly_ctx));
125                 return *ctx && chachapoly_init(*ctx, key, 256) == CHACHAPOLY_OK;
126
127 #else
128
129         case SPTPS_CHACHA_POLY1305:
130                 *ctx = EVP_CIPHER_CTX_new();
131
132                 if(!ctx) {
133                         return false;
134                 }
135
136                 return EVP_EncryptInit_ex(*ctx, EVP_chacha20_poly1305(), NULL, NULL, NULL)
137                        && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_AEAD_SET_IVLEN, 12, NULL)
138                        && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
139
140         case SPTPS_AES256_GCM:
141                 *ctx = EVP_CIPHER_CTX_new();
142
143                 if(!ctx) {
144                         return false;
145                 }
146
147                 return EVP_EncryptInit_ex(*ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)
148                        && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_AEAD_SET_IVLEN, 12, NULL)
149                        && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
150 #endif
151
152         default:
153                 return false;
154         }
155 }
156
157 static void cipher_exit(uint8_t suite, void *ctx) {
158         switch(suite) {
159 #ifndef HAVE_OPENSSL
160
161         case SPTPS_CHACHA_POLY1305:
162                 free(ctx);
163                 break;
164
165 #else
166
167         case SPTPS_CHACHA_POLY1305:
168         case SPTPS_AES256_GCM:
169                 EVP_CIPHER_CTX_free(ctx);
170                 break;
171 #endif
172
173         default:
174                 break;
175         }
176 }
177
178 static bool cipher_encrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
179         switch(suite) {
180 #ifndef HAVE_OPENSSL
181
182         case SPTPS_CHACHA_POLY1305: {
183                 if(chachapoly_crypt(ctx, nonce, NULL, 0, (void *)in, inlen, out, out + inlen, 16, 1) != CHACHAPOLY_OK) {
184                         return false;
185                 }
186
187                 if(outlen) {
188                         *outlen = inlen + 16;
189                 }
190
191                 return true;
192         }
193
194 #else
195
196         case SPTPS_CHACHA_POLY1305:
197         case SPTPS_AES256_GCM: {
198                 uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24};
199
200                 if(!EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, nonce)) {
201                         return false;
202                 }
203
204                 int outlen1 = 0, outlen2 = 0;
205
206                 if(!EVP_EncryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
207                         return false;
208                 }
209
210                 if(!EVP_EncryptFinal_ex(ctx, out + outlen1, &outlen2)) {
211                         return false;
212                 }
213
214                 outlen1 += outlen2;
215
216                 if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, out + outlen1)) {
217                         return false;
218                 }
219
220                 outlen1 += 16;
221
222                 if(outlen) {
223                         *outlen = outlen1;
224                 }
225
226                 return true;
227         }
228
229 #endif
230
231         default:
232                 return false;
233         }
234 }
235
236 static bool cipher_decrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
237         if(inlen < 16) {
238                 return false;
239         }
240
241         inlen -= 16;
242
243         switch(suite) {
244 #ifndef HAVE_OPENSSL
245
246         case SPTPS_CHACHA_POLY1305:
247                 if(chachapoly_crypt(ctx, nonce, NULL, 0, (void *)in, inlen, out, (void *)(in + inlen), 16, 0) != CHACHAPOLY_OK) {
248                         return false;
249                 }
250
251                 if(outlen) {
252                         *outlen = inlen;
253                 }
254
255                 return true;
256
257 #else
258
259         case SPTPS_CHACHA_POLY1305:
260         case SPTPS_AES256_GCM: {
261                 uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24};
262
263                 if(!EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, nonce)) {
264                         return false;
265                 }
266
267                 int outlen1 = 0, outlen2 = 0;
268
269                 if(!EVP_DecryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
270                         return false;
271                 }
272
273                 if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, (void *)(in + inlen))) {
274                         return false;
275                 }
276
277                 if(!EVP_DecryptFinal_ex(ctx, out + outlen1, &outlen2)) {
278                         return false;
279                 }
280
281                 if(outlen) {
282                         *outlen = outlen1 + outlen2;
283                 }
284
285                 return true;
286         }
287
288 #endif
289
290         default:
291                 return false;
292         }
293 }
294
295 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
296 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
297         uint8_t *buffer = alloca(len + SPTPS_DATAGRAM_OVERHEAD);
298         // Create header with sequence number, length and record type
299         uint32_t seqno = s->outseqno++;
300
301         memcpy(buffer, &seqno, 4);
302         buffer[4] = type;
303         memcpy(buffer + 5, data, len);
304
305         if(s->outstate) {
306                 // If first handshake has finished, encrypt and HMAC
307                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL)) {
308                         return error(s, EINVAL, "Failed to encrypt message");
309                 }
310
311                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_OVERHEAD);
312         } else {
313                 // Otherwise send as plaintext
314                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_HEADER);
315         }
316 }
317 // Send a record (private version, accepts all record types, handles encryption and authentication).
318 static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
319         if(s->datagram) {
320                 return send_record_priv_datagram(s, type, data, len);
321         }
322
323         uint8_t *buffer = alloca(len + SPTPS_OVERHEAD);
324
325         // Create header with sequence number, length and record type
326         uint32_t seqno = s->outseqno++;
327         uint16_t netlen = len;
328
329         memcpy(buffer, &netlen, 2);
330         buffer[2] = type;
331         memcpy(buffer + 3, data, len);
332
333         if(s->outstate) {
334                 // If first handshake has finished, encrypt and HMAC
335                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL)) {
336                         return error(s, EINVAL, "Failed to encrypt message");
337                 }
338
339                 return s->send_data(s->handle, type, buffer, len + SPTPS_OVERHEAD);
340         } else {
341                 // Otherwise send as plaintext
342                 return s->send_data(s->handle, type, buffer, len + SPTPS_HEADER);
343         }
344 }
345
346 // Send an application record.
347 bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
348         // Sanity checks: application cannot send data before handshake is finished,
349         // and only record types 0..127 are allowed.
350         if(!s->outstate) {
351                 return error(s, EINVAL, "Handshake phase not finished yet");
352         }
353
354         if(type >= SPTPS_HANDSHAKE) {
355                 return error(s, EINVAL, "Invalid application record type");
356         }
357
358         return send_record_priv(s, type, data, len);
359 }
360
361 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
362 static bool send_kex(sptps_t *s) {
363         // Make room for our KEX message, which we will keep around since send_sig() needs it.
364         if(s->mykex) {
365                 return false;
366         }
367
368         s->mykex = new_sptps_kex();
369
370         // Set version byte to zero.
371         s->mykex->version = SPTPS_VERSION;
372         s->mykex->preferred_suite = s->preferred_suite;
373         s->mykex->cipher_suites = s->cipher_suites;
374
375         // Create a random nonce.
376         randomize(s->mykex->nonce, ECDH_SIZE);
377
378         // Create a new ECDH public key.
379         if(!(s->ecdh = ecdh_generate_public(s->mykex->pubkey))) {
380                 return error(s, EINVAL, "Failed to generate ECDH public key");
381         }
382
383         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, sizeof(sptps_kex_t));
384 }
385
386 static size_t sigmsg_len(size_t labellen) {
387         return 1 + 2 * sizeof(sptps_kex_t) + labellen;
388 }
389
390 static void fill_msg(uint8_t *msg, bool initiator, const sptps_kex_t *kex0, const sptps_kex_t *kex1, const sptps_t *s) {
391         *msg = initiator, msg++;
392         memcpy(msg, kex0, sizeof(*kex0)), msg += sizeof(*kex0);
393         memcpy(msg, kex1, sizeof(*kex1)), msg += sizeof(*kex1);
394         memcpy(msg, s->label, s->labellen);
395 }
396
397 // Send a SIGnature record, containing an Ed25519 signature over both KEX records.
398 static bool send_sig(sptps_t *s) {
399         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
400         size_t msglen = sigmsg_len(s->labellen);
401         uint8_t *msg = alloca(msglen);
402         fill_msg(msg, s->initiator, s->mykex, s->hiskex, s);
403
404         // Sign the result.
405         size_t siglen = ecdsa_size(s->mykey);
406         uint8_t *sig = alloca(siglen);
407
408         if(!ecdsa_sign(s->mykey, msg, msglen, sig)) {
409                 return error(s, EINVAL, "Failed to sign SIG record");
410         }
411
412         // Send the SIG exchange record.
413         return send_record_priv(s, SPTPS_HANDSHAKE, sig, siglen);
414 }
415
416 // Generate key material from the shared secret created from the ECDHE key exchange.
417 static bool generate_key_material(sptps_t *s, const uint8_t *shared, size_t len) {
418         // Allocate memory for key material
419         s->key = new_sptps_key();
420
421         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
422         const size_t msglen = sizeof("key expansion") - 1;
423         const size_t seedlen = msglen + s->labellen + ECDH_SIZE * 2;
424         uint8_t *seed = alloca(seedlen);
425
426         uint8_t *ptr = seed;
427         memcpy(ptr, "key expansion", msglen);
428         ptr += msglen;
429
430         memcpy(ptr, (s->initiator ? s->mykex : s->hiskex)->nonce, ECDH_SIZE);
431         ptr += ECDH_SIZE;
432
433         memcpy(ptr, (s->initiator ? s->hiskex : s->mykex)->nonce, ECDH_SIZE);
434         ptr += ECDH_SIZE;
435
436         memcpy(ptr, s->label, s->labellen);
437
438         // Use PRF to generate the key material
439         if(!prf(shared, len, seed, seedlen, s->key->both, sizeof(sptps_key_t))) {
440                 return error(s, EINVAL, "Failed to generate key material");
441         }
442
443         return true;
444 }
445
446 // Send an ACKnowledgement record.
447 static bool send_ack(sptps_t *s) {
448         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
449 }
450
451 // Receive an ACKnowledgement record.
452 static bool receive_ack(sptps_t *s, const uint8_t *data, uint16_t len) {
453         (void)data;
454
455         if(len) {
456                 return error(s, EIO, "Invalid ACK record length");
457         }
458
459         if(!cipher_init(s->cipher_suite, &s->incipher, s->key, s->initiator)) {
460                 return error(s, EINVAL, "Failed to initialize cipher");
461         }
462
463         free_sptps_key(s->key);
464         s->key = NULL;
465         s->instate = true;
466
467         return true;
468 }
469
470 static uint8_t select_cipher_suite(uint16_t mask, uint8_t pref1, uint8_t pref2) {
471         // Check if there is a viable preference, if so select the lowest one
472         uint8_t selection = 255;
473
474         if(mask & (1U << pref1)) {
475                 selection = pref1;
476         }
477
478         if(pref2 < selection && (mask & (1U << pref2))) {
479                 selection = pref2;
480         }
481
482         // Otherwise, select the lowest cipher suite both sides support
483         if(selection == 255) {
484                 selection = 0;
485
486                 while(!(mask & 1U)) {
487                         selection++;
488                         mask >>= 1;
489                 }
490         }
491
492         return selection;
493 }
494
495 // Receive a Key EXchange record, respond by sending a SIG record.
496 static bool receive_kex(sptps_t *s, const uint8_t *data, uint16_t len) {
497         // Verify length of the HELLO record
498
499         if(len != sizeof(sptps_kex_t)) {
500                 return error(s, EIO, "Invalid KEX record length");
501         }
502
503         if(*data != SPTPS_VERSION) {
504                 return error(s, EINVAL, "Received incorrect version %d", *data);
505         }
506
507         uint16_t suites;
508         memcpy(&suites, data + 2, 2);
509         suites &= s->cipher_suites;
510
511         if(!suites) {
512                 return error(s, EIO, "No matching cipher suites");
513         }
514
515         s->cipher_suite = select_cipher_suite(suites, s->preferred_suite, data[1] & 0xf);
516
517         // Make a copy of the KEX message, send_sig() and receive_sig() need it
518         if(s->hiskex) {
519                 return error(s, EINVAL, "Received a second KEX message before first has been processed");
520         }
521
522         s->hiskex = new_sptps_kex();
523         memcpy(s->hiskex, data, sizeof(sptps_kex_t));
524
525         if(s->initiator) {
526                 return send_sig(s);
527         } else {
528                 return true;
529         }
530 }
531
532 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
533 static bool receive_sig(sptps_t *s, const uint8_t *data, uint16_t len) {
534         // Verify length of KEX record.
535         if(len != ecdsa_size(s->hiskey)) {
536                 return error(s, EIO, "Invalid KEX record length");
537         }
538
539         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
540         const size_t msglen = sigmsg_len(s->labellen);
541         uint8_t *msg = alloca(msglen);
542         fill_msg(msg, !s->initiator, s->hiskex, s->mykex, s);
543
544         // Verify signature.
545         if(!ecdsa_verify(s->hiskey, msg, msglen, data)) {
546                 return error(s, EIO, "Failed to verify SIG record");
547         }
548
549         // Compute shared secret.
550         uint8_t shared[ECDH_SHARED_SIZE];
551
552         if(!ecdh_compute_shared(s->ecdh, s->hiskex->pubkey, shared)) {
553                 memzero(shared, sizeof(shared));
554                 return error(s, EINVAL, "Failed to compute ECDH shared secret");
555         }
556
557         s->ecdh = NULL;
558
559         // Generate key material from shared secret.
560         bool generated = generate_key_material(s, shared, sizeof(shared));
561         memzero(shared, sizeof(shared));
562
563         if(!generated) {
564                 return false;
565         }
566
567         if(!s->initiator && !send_sig(s)) {
568                 return false;
569         }
570
571         free_sptps_kex(s->mykex);
572         s->mykex = NULL;
573
574         free_sptps_kex(s->hiskex);
575         s->hiskex = NULL;
576
577         // Send cipher change record
578         if(s->outstate && !send_ack(s)) {
579                 return false;
580         }
581
582         if(!cipher_init(s->cipher_suite, &s->outcipher, s->key, !s->initiator)) {
583                 return error(s, EINVAL, "Failed to initialize cipher");
584         }
585
586         return true;
587 }
588
589 // Force another Key EXchange (for testing purposes).
590 bool sptps_force_kex(sptps_t *s) {
591         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) {
592                 return error(s, EINVAL, "Cannot force KEX in current state");
593         }
594
595         s->state = SPTPS_KEX;
596         return send_kex(s);
597 }
598
599 // Receive a handshake record.
600 static bool receive_handshake(sptps_t *s, const uint8_t *data, uint16_t len) {
601         // Only a few states to deal with handshaking.
602         switch(s->state) {
603         case SPTPS_SECONDARY_KEX:
604
605                 // We receive a secondary KEX request, first respond by sending our own.
606                 if(!send_kex(s)) {
607                         return false;
608                 }
609
610         // Fall through
611         case SPTPS_KEX:
612
613                 // We have sent our KEX request, we expect our peer to sent one as well.
614                 if(!receive_kex(s, data, len)) {
615                         return false;
616                 }
617
618                 s->state = SPTPS_SIG;
619                 return true;
620
621         case SPTPS_SIG:
622
623                 // If we already sent our secondary public ECDH key, we expect the peer to send his.
624                 if(!receive_sig(s, data, len)) {
625                         return false;
626                 }
627
628                 if(s->outstate) {
629                         s->state = SPTPS_ACK;
630                 } else {
631                         s->outstate = true;
632
633                         if(!receive_ack(s, NULL, 0)) {
634                                 return false;
635                         }
636
637                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
638                         s->state = SPTPS_SECONDARY_KEX;
639                 }
640
641                 return true;
642
643         case SPTPS_ACK:
644
645                 // We expect a handshake message to indicate transition to the new keys.
646                 if(!receive_ack(s, data, len)) {
647                         return false;
648                 }
649
650                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
651                 s->state = SPTPS_SECONDARY_KEX;
652                 return true;
653
654         // TODO: split ACK into a VERify and ACK?
655         default:
656                 return error(s, EIO, "Invalid session state %d", s->state);
657         }
658 }
659
660 static bool sptps_check_seqno(sptps_t *s, uint32_t seqno, bool update_state) {
661         // Replay protection using a sliding window of configurable size.
662         // s->inseqno is expected sequence number
663         // seqno is received sequence number
664         // s->late[] is a circular buffer, a 1 bit means a packet has not been received yet
665         // The circular buffer contains bits for sequence numbers from s->inseqno - s->replaywin * 8 to (but excluding) s->inseqno.
666         if(s->replaywin) {
667                 if(seqno != s->inseqno) {
668                         if(seqno >= s->inseqno + s->replaywin * 8) {
669                                 // Prevent packets that jump far ahead of the queue from causing many others to be dropped.
670                                 bool farfuture = s->farfuture < s->replaywin >> 2;
671
672                                 if(update_state) {
673                                         s->farfuture++;
674                                 }
675
676                                 if(farfuture) {
677                                         return update_state ? error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture) : false;
678                                 }
679
680                                 // Unless we have seen lots of them, in which case we consider the others lost.
681                                 if(update_state) {
682                                         warning(s, "Lost %d packets\n", seqno - s->inseqno);
683                                 }
684
685                                 if(update_state) {
686                                         // Mark all packets in the replay window as being late.
687                                         memset(s->late, 255, s->replaywin);
688                                 }
689                         } else if(seqno < s->inseqno) {
690                                 // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it.
691                                 if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) {
692                                         return update_state ? error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno) : false;
693                                 }
694                         } else if(update_state) {
695                                 // We missed some packets. Mark them in the bitmap as being late.
696                                 for(uint32_t i = s->inseqno; i < seqno; i++) {
697                                         s->late[(i / 8) % s->replaywin] |= 1 << i % 8;
698                                 }
699                         }
700                 }
701
702                 if(update_state) {
703                         // Mark the current packet as not being late.
704                         s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8);
705                         s->farfuture = 0;
706                 }
707         }
708
709         if(update_state) {
710                 if(seqno >= s->inseqno) {
711                         s->inseqno = seqno + 1;
712                 }
713
714                 if(!s->inseqno) {
715                         s->received = 0;
716                 } else {
717                         s->received++;
718                 }
719         }
720
721         return true;
722 }
723
724 // Check datagram for valid HMAC
725 bool sptps_verify_datagram(sptps_t *s, const void *vdata, size_t len) {
726         if(!s->instate || len < 21) {
727                 return error(s, EIO, "Received short packet");
728         }
729
730         const uint8_t *data = vdata;
731         uint32_t seqno;
732         memcpy(&seqno, data, 4);
733
734         if(!sptps_check_seqno(s, seqno, false)) {
735                 return false;
736         }
737
738         uint8_t *buffer = alloca(len);
739         return cipher_decrypt(s->cipher_suite, s->incipher, seqno, data + 4, len - 4, buffer, NULL);
740 }
741
742 // Receive incoming data, datagram version.
743 static bool sptps_receive_data_datagram(sptps_t *s, const uint8_t *data, size_t len) {
744         if(len < (s->instate ? 21 : 5)) {
745                 return error(s, EIO, "Received short packet");
746         }
747
748         uint32_t seqno;
749         memcpy(&seqno, data, 4);
750         data += 4;
751         len -= 4;
752
753         if(!s->instate) {
754                 if(seqno != s->inseqno) {
755                         return error(s, EIO, "Invalid packet seqno: %d != %d", seqno, s->inseqno);
756                 }
757
758                 s->inseqno = seqno + 1;
759
760                 uint8_t type = *(data++);
761                 len--;
762
763                 if(type != SPTPS_HANDSHAKE) {
764                         return error(s, EIO, "Application record received before handshake finished");
765                 }
766
767                 return receive_handshake(s, data, len);
768         }
769
770         // Decrypt
771
772         uint8_t *buffer = alloca(len);
773         size_t outlen;
774
775         if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, data, len, buffer, &outlen)) {
776                 return error(s, EIO, "Failed to decrypt and verify packet");
777         }
778
779         if(!sptps_check_seqno(s, seqno, true)) {
780                 return false;
781         }
782
783         // Append a NULL byte for safety.
784         buffer[outlen] = 0;
785
786         data = buffer;
787         len = outlen;
788
789         uint8_t type = *(data++);
790         len--;
791
792         if(type < SPTPS_HANDSHAKE) {
793                 if(!s->instate) {
794                         return error(s, EIO, "Application record received before handshake finished");
795                 }
796
797                 if(!s->receive_record(s->handle, type, data, len)) {
798                         return false;
799                 }
800         } else if(type == SPTPS_HANDSHAKE) {
801                 if(!receive_handshake(s, data, len)) {
802                         return false;
803                 }
804         } else {
805                 return error(s, EIO, "Invalid record type %d", type);
806         }
807
808         return true;
809 }
810
811 // Receive incoming data. Check if it contains a complete record, if so, handle it.
812 size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
813         const uint8_t *data = vdata;
814         size_t total_read = 0;
815
816         if(!s->state) {
817                 return error(s, EIO, "Invalid session state zero");
818         }
819
820         if(s->datagram) {
821                 return sptps_receive_data_datagram(s, data, len) ? len : false;
822         }
823
824         // First read the 2 length bytes.
825         if(s->buflen < 2) {
826                 size_t toread = 2 - s->buflen;
827
828                 if(toread > len) {
829                         toread = len;
830                 }
831
832                 memcpy(s->inbuf + s->buflen, data, toread);
833
834                 total_read += toread;
835                 s->buflen += toread;
836                 len -= toread;
837                 data += toread;
838
839                 // Exit early if we don't have the full length.
840                 if(s->buflen < 2) {
841                         return total_read;
842                 }
843
844                 // Get the length bytes
845
846                 memcpy(&s->reclen, s->inbuf, 2);
847
848                 // If we have the length bytes, ensure our buffer can hold the whole request.
849                 s->inbuf = realloc(s->inbuf, s->reclen + SPTPS_OVERHEAD);
850
851                 if(!s->inbuf) {
852                         return error(s, errno, "%s", strerror(errno));
853                 }
854
855                 // Exit early if we have no more data to process.
856                 if(!len) {
857                         return total_read;
858                 }
859         }
860
861         // Read up to the end of the record.
862         size_t toread = s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER) - s->buflen;
863
864         if(toread > len) {
865                 toread = len;
866         }
867
868         memcpy(s->inbuf + s->buflen, data, toread);
869         total_read += toread;
870         s->buflen += toread;
871
872         // If we don't have a whole record, exit.
873         if(s->buflen < s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER)) {
874                 return total_read;
875         }
876
877         // Update sequence number.
878
879         uint32_t seqno = s->inseqno++;
880
881         // Check HMAC and decrypt.
882         if(s->instate) {
883                 if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) {
884                         return error(s, EINVAL, "Failed to decrypt and verify record");
885                 }
886         }
887
888         // Append a NULL byte for safety.
889         s->inbuf[s->reclen + SPTPS_HEADER] = 0;
890
891         uint8_t type = s->inbuf[2];
892
893         if(type < SPTPS_HANDSHAKE) {
894                 if(!s->instate) {
895                         return error(s, EIO, "Application record received before handshake finished");
896                 }
897
898                 if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) {
899                         return false;
900                 }
901         } else if(type == SPTPS_HANDSHAKE) {
902                 if(!receive_handshake(s, s->inbuf + 3, s->reclen)) {
903                         return false;
904                 }
905         } else {
906                 return error(s, EIO, "Invalid record type %d", type);
907         }
908
909         s->buflen = 0;
910
911         return total_read;
912 }
913
914 // Start a SPTPS session.
915 bool sptps_start(sptps_t *s, const sptps_params_t *params) {
916         // Initialise struct sptps
917         memset(s, 0, sizeof(*s));
918
919         s->handle = params->handle;
920         s->initiator = params->initiator;
921         s->datagram = params->datagram;
922         s->mykey = params->mykey;
923         s->hiskey = params->hiskey;
924         s->replaywin = sptps_replaywin;
925         s->cipher_suites = params->cipher_suites ? params->cipher_suites & SPTPS_ALL_CIPHER_SUITES : SPTPS_ALL_CIPHER_SUITES;
926         s->preferred_suite = params->preferred_suite;
927
928         if(s->replaywin) {
929                 s->late = malloc(s->replaywin);
930
931                 if(!s->late) {
932                         return error(s, errno, "%s", strerror(errno));
933                 }
934
935                 memset(s->late, 0, s->replaywin);
936         }
937
938         s->labellen = params->labellen ? params->labellen : strlen(params->label);
939         s->label = malloc(s->labellen);
940
941         if(!s->label) {
942                 return error(s, errno, "%s", strerror(errno));
943         }
944
945         memcpy(s->label, params->label, s->labellen);
946
947         if(!s->datagram) {
948                 s->inbuf = malloc(7);
949
950                 if(!s->inbuf) {
951                         return error(s, errno, "%s", strerror(errno));
952                 }
953
954                 s->buflen = 0;
955         }
956
957
958         s->send_data = params->send_data;
959         s->receive_record = params->receive_record;
960
961         // Do first KEX immediately
962         s->state = SPTPS_KEX;
963         return send_kex(s);
964 }
965
966 // Stop a SPTPS session.
967 bool sptps_stop(sptps_t *s) {
968         // Clean up any resources.
969         cipher_exit(s->cipher_suite, s->incipher);
970         cipher_exit(s->cipher_suite, s->outcipher);
971         ecdh_free(s->ecdh);
972         free(s->inbuf);
973         free_sptps_kex(s->mykex);
974         free_sptps_kex(s->hiskex);
975         free_sptps_key(s->key);
976         free(s->label);
977         free(s->late);
978         memset(s, 0, sizeof(*s));
979         return true;
980 }