Add AES-256-GCM support to SPTPS.
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011-2015 Guus Sliepen <guus@tinc-vpn.org>,
4                   2010      Brandon L. Black <blblack@gmail.com>
5
6     This program is free software; you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation; either version 2 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License along
17     with this program; if not, write to the Free Software Foundation, Inc.,
18     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19 */
20
21 #include "system.h"
22
23 #include "chacha-poly1305/chacha-poly1305.h"
24 #include "ecdh.h"
25 #include "ecdsa.h"
26 #include "prf.h"
27 #include "sptps.h"
28 #include "random.h"
29 #include "xalloc.h"
30
31 #ifdef HAVE_OPENSSL
32 #include <openssl/evp.h>
33 #endif
34
35 unsigned int sptps_replaywin = 16;
36
37 /*
38    Nonce MUST be exchanged first (done)
39    Signatures MUST be done over both nonces, to guarantee the signature is fresh
40    Otherwise: if ECDHE key of one side is compromised, it can be reused!
41
42    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
43
44    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
45
46    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of Ed25519 over the whole ECDHE exchange?)
47
48    Explicit close message needs to be added.
49
50    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
51
52    Use counter mode instead of OFB. (done)
53
54    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
55 */
56
57 void sptps_log_quiet(sptps_t *s, int s_errno, const char *format, va_list ap) {
58         (void)s;
59         (void)s_errno;
60         (void)format;
61         (void)ap;
62 }
63
64 void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) {
65         (void)s;
66         (void)s_errno;
67
68         vfprintf(stderr, format, ap);
69         fputc('\n', stderr);
70 }
71
72 void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_stderr;
73
74 // Log an error message.
75 static bool error(sptps_t *s, int s_errno, const char *format, ...) ATTR_FORMAT(printf, 3, 4);
76 static bool error(sptps_t *s, int s_errno, const char *format, ...) {
77         (void)s;
78         (void)s_errno;
79
80         if(format) {
81                 va_list ap;
82                 va_start(ap, format);
83                 sptps_log(s, s_errno, format, ap);
84                 va_end(ap);
85         }
86
87         errno = s_errno;
88         return false;
89 }
90
91 static void warning(sptps_t *s, const char *format, ...) ATTR_FORMAT(printf, 2, 3);
92 static void warning(sptps_t *s, const char *format, ...) {
93         va_list ap;
94         va_start(ap, format);
95         sptps_log(s, 0, format, ap);
96         va_end(ap);
97 }
98
99 static sptps_kex_t *new_sptps_kex(void) {
100         return xzalloc(sizeof(sptps_kex_t));
101 }
102
103 static void free_sptps_kex(sptps_kex_t *kex) {
104         xzfree(kex, sizeof(sptps_kex_t));
105 }
106
107 static sptps_key_t *new_sptps_key(void) {
108         return xzalloc(sizeof(sptps_key_t));
109 }
110
111 static void free_sptps_key(sptps_key_t *key) {
112         xzfree(key, sizeof(sptps_key_t));
113 }
114
115 static bool cipher_init(uint8_t suite, void **ctx, const sptps_key_t *keys, bool key_half) {
116         const uint8_t *key = key_half ? keys->key1 : keys->key0;
117
118         switch(suite) {
119         case SPTPS_CHACHA_POLY1305:
120                 *ctx = chacha_poly1305_init();
121                 return ctx && chacha_poly1305_set_key(*ctx, key);
122
123         case SPTPS_AES256_GCM:
124 #ifdef HAVE_OPENSSL
125                 *ctx = EVP_CIPHER_CTX_new();
126
127                 if(!ctx) {
128                         return false;
129                 }
130
131                 return EVP_EncryptInit_ex(*ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)
132                        && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_AEAD_SET_IVLEN, 4, NULL)
133                        && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
134 #endif
135
136         default:
137                 return false;
138         }
139 }
140
141 static void cipher_exit(uint8_t suite, void *ctx) {
142         switch(suite) {
143         case SPTPS_CHACHA_POLY1305:
144                 chacha_poly1305_exit(ctx);
145                 break;
146
147         case SPTPS_AES256_GCM:
148 #ifdef HAVE_OPENSSL
149                 EVP_CIPHER_CTX_free(ctx);
150                 break;
151 #endif
152
153         default:
154                 break;
155         }
156 }
157
158 static bool cipher_encrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
159         switch(suite) {
160         case SPTPS_CHACHA_POLY1305:
161                 chacha_poly1305_encrypt(ctx, seqno, in, inlen, out, outlen);
162                 return true;
163
164         case SPTPS_AES256_GCM:
165 #ifdef HAVE_OPENSSL
166                 {
167                         if(!EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, (uint8_t *)&seqno)) {
168                                 return false;
169                         }
170
171                         int outlen1 = 0, outlen2 = 0;
172
173                         if(!EVP_EncryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
174                                 return false;
175                         }
176
177                         if(!EVP_EncryptFinal_ex(ctx, out + outlen1, &outlen2)) {
178                                 return false;
179                         }
180
181                         outlen1 += outlen2;
182
183                         if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, out + outlen1)) {
184                                 return false;
185                         }
186
187                         outlen1 += 16;
188
189                         if(outlen) {
190                                 *outlen = outlen1;
191                         }
192
193                         return true;
194                 }
195
196 #endif
197
198         default:
199                 return false;
200         }
201 }
202
203 static bool cipher_decrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
204         switch(suite) {
205         case SPTPS_CHACHA_POLY1305:
206                 return chacha_poly1305_decrypt(ctx, seqno, in, inlen, out, outlen);
207
208         case SPTPS_AES256_GCM:
209 #ifdef HAVE_OPENSSL
210                 {
211                         if(inlen < 16) {
212                                 return false;
213                         }
214
215                         inlen -= 16;
216
217                         if(!EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, (uint8_t *)&seqno)) {
218                                 return false;
219                         }
220
221                         int outlen1 = 0, outlen2 = 0;
222
223                         if(!EVP_DecryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
224                                 return false;
225                         }
226
227                         if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, (void *)(in + inlen))) {
228                                 return false;
229                         }
230
231                         if(!EVP_DecryptFinal_ex(ctx, out + outlen1, &outlen2)) {
232                                 return false;
233                         }
234
235                         if(outlen) {
236                                 *outlen = outlen1 + outlen2;
237                         }
238
239                         return true;
240                 }
241
242 #endif
243
244         default:
245                 return false;
246         }
247 }
248
249 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
250 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
251         uint8_t *buffer = alloca(len + SPTPS_DATAGRAM_OVERHEAD);
252         // Create header with sequence number, length and record type
253         uint32_t seqno = s->outseqno++;
254
255         memcpy(buffer, &seqno, 4);
256         buffer[4] = type;
257         memcpy(buffer + 5, data, len);
258
259         if(s->outstate) {
260                 // If first handshake has finished, encrypt and HMAC
261                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL)) {
262                         return error(s, EINVAL, "Failed to encrypt message");
263                 }
264
265                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_OVERHEAD);
266         } else {
267                 // Otherwise send as plaintext
268                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_HEADER);
269         }
270 }
271 // Send a record (private version, accepts all record types, handles encryption and authentication).
272 static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
273         if(s->datagram) {
274                 return send_record_priv_datagram(s, type, data, len);
275         }
276
277         uint8_t *buffer = alloca(len + SPTPS_OVERHEAD);
278
279         // Create header with sequence number, length and record type
280         uint32_t seqno = s->outseqno++;
281         uint16_t netlen = len;
282
283         memcpy(buffer, &netlen, 2);
284         buffer[2] = type;
285         memcpy(buffer + 3, data, len);
286
287         if(s->outstate) {
288                 // If first handshake has finished, encrypt and HMAC
289                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL)) {
290                         return error(s, EINVAL, "Failed to encrypt message");
291                 }
292
293                 return s->send_data(s->handle, type, buffer, len + SPTPS_OVERHEAD);
294         } else {
295                 // Otherwise send as plaintext
296                 return s->send_data(s->handle, type, buffer, len + SPTPS_HEADER);
297         }
298 }
299
300 // Send an application record.
301 bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
302         // Sanity checks: application cannot send data before handshake is finished,
303         // and only record types 0..127 are allowed.
304         if(!s->outstate) {
305                 return error(s, EINVAL, "Handshake phase not finished yet");
306         }
307
308         if(type >= SPTPS_HANDSHAKE) {
309                 return error(s, EINVAL, "Invalid application record type");
310         }
311
312         return send_record_priv(s, type, data, len);
313 }
314
315 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
316 static bool send_kex(sptps_t *s) {
317         // Make room for our KEX message, which we will keep around since send_sig() needs it.
318         if(s->mykex) {
319                 return false;
320         }
321
322         s->mykex = new_sptps_kex();
323
324         // Set version byte to zero.
325         s->mykex->version = SPTPS_VERSION;
326         s->mykex->preferred_suite = s->preferred_suite;
327         s->mykex->cipher_suites = s->cipher_suites;
328
329         // Create a random nonce.
330         randomize(s->mykex->nonce, ECDH_SIZE);
331
332         // Create a new ECDH public key.
333         if(!(s->ecdh = ecdh_generate_public(s->mykex->pubkey))) {
334                 return error(s, EINVAL, "Failed to generate ECDH public key");
335         }
336
337         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, sizeof(sptps_kex_t));
338 }
339
340 static size_t sigmsg_len(size_t labellen) {
341         return 1 + 2 * sizeof(sptps_kex_t) + labellen;
342 }
343
344 static void fill_msg(uint8_t *msg, bool initiator, const sptps_kex_t *kex0, const sptps_kex_t *kex1, const sptps_t *s) {
345         *msg = initiator, msg++;
346         memcpy(msg, kex0, sizeof(*kex0)), msg += sizeof(*kex0);
347         memcpy(msg, kex1, sizeof(*kex1)), msg += sizeof(*kex1);
348         memcpy(msg, s->label, s->labellen);
349 }
350
351 // Send a SIGnature record, containing an Ed25519 signature over both KEX records.
352 static bool send_sig(sptps_t *s) {
353         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
354         size_t msglen = sigmsg_len(s->labellen);
355         uint8_t *msg = alloca(msglen);
356         fill_msg(msg, s->initiator, s->mykex, s->hiskex, s);
357
358         // Sign the result.
359         size_t siglen = ecdsa_size(s->mykey);
360         uint8_t *sig = alloca(siglen);
361
362         if(!ecdsa_sign(s->mykey, msg, msglen, sig)) {
363                 return error(s, EINVAL, "Failed to sign SIG record");
364         }
365
366         // Send the SIG exchange record.
367         return send_record_priv(s, SPTPS_HANDSHAKE, sig, siglen);
368 }
369
370 // Generate key material from the shared secret created from the ECDHE key exchange.
371 static bool generate_key_material(sptps_t *s, const uint8_t *shared, size_t len) {
372         // Allocate memory for key material
373         s->key = new_sptps_key();
374
375         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
376         const size_t msglen = sizeof("key expansion") - 1;
377         const size_t seedlen = msglen + s->labellen + ECDH_SIZE * 2;
378         uint8_t *seed = alloca(seedlen);
379
380         uint8_t *ptr = seed;
381         memcpy(ptr, "key expansion", msglen);
382         ptr += msglen;
383
384         memcpy(ptr, (s->initiator ? s->mykex : s->hiskex)->nonce, ECDH_SIZE);
385         ptr += ECDH_SIZE;
386
387         memcpy(ptr, (s->initiator ? s->hiskex : s->mykex)->nonce, ECDH_SIZE);
388         ptr += ECDH_SIZE;
389
390         memcpy(ptr, s->label, s->labellen);
391
392         // Use PRF to generate the key material
393         if(!prf(shared, len, seed, seedlen, s->key->both, sizeof(sptps_key_t))) {
394                 return error(s, EINVAL, "Failed to generate key material");
395         }
396
397         return true;
398 }
399
400 // Send an ACKnowledgement record.
401 static bool send_ack(sptps_t *s) {
402         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
403 }
404
405 // Receive an ACKnowledgement record.
406 static bool receive_ack(sptps_t *s, const uint8_t *data, uint16_t len) {
407         (void)data;
408
409         if(len) {
410                 return error(s, EIO, "Invalid ACK record length");
411         }
412
413         if(!cipher_init(s->cipher_suite, &s->incipher, s->key, s->initiator)) {
414                 return error(s, EINVAL, "Failed to initialize cipher");
415         }
416
417         free_sptps_key(s->key);
418         s->key = NULL;
419         s->instate = true;
420
421         return true;
422 }
423
424 static uint8_t select_cipher_suite(uint16_t mask, uint8_t pref1, uint8_t pref2) {
425         // Check if there is a viable preference, if so select the lowest one
426         uint8_t selection = 255;
427
428         if(mask & (1U << pref1)) {
429                 selection = pref1;
430         }
431
432         if(pref2 < selection && (mask & (1U << pref2))) {
433                 selection = pref2;
434         }
435
436         // Otherwise, select the lowest cipher suite both sides support
437         if(selection == 255) {
438                 selection = 0;
439
440                 while(!(mask & 1U)) {
441                         selection++;
442                         mask >>= 1;
443                 }
444         }
445
446         return selection;
447 }
448
449 // Receive a Key EXchange record, respond by sending a SIG record.
450 static bool receive_kex(sptps_t *s, const uint8_t *data, uint16_t len) {
451         // Verify length of the HELLO record
452
453         if(len != sizeof(sptps_kex_t)) {
454                 return error(s, EIO, "Invalid KEX record length");
455         }
456
457         if(*data != SPTPS_VERSION) {
458                 return error(s, EINVAL, "Received incorrect version %d", *data);
459         }
460
461         uint16_t suites;
462         memcpy(&suites, data + 2, 2);
463         suites &= s->cipher_suites;
464
465         if(!suites) {
466                 return error(s, EIO, "No matching cipher suites");
467         }
468
469         s->cipher_suite = select_cipher_suite(suites, s->preferred_suite, data[1] & 0xf);
470
471         // Make a copy of the KEX message, send_sig() and receive_sig() need it
472         if(s->hiskex) {
473                 return error(s, EINVAL, "Received a second KEX message before first has been processed");
474         }
475
476         s->hiskex = new_sptps_kex();
477         memcpy(s->hiskex, data, sizeof(sptps_kex_t));
478
479         if(s->initiator) {
480                 return send_sig(s);
481         } else {
482                 return true;
483         }
484 }
485
486 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
487 static bool receive_sig(sptps_t *s, const uint8_t *data, uint16_t len) {
488         // Verify length of KEX record.
489         if(len != ecdsa_size(s->hiskey)) {
490                 return error(s, EIO, "Invalid KEX record length");
491         }
492
493         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
494         const size_t msglen = sigmsg_len(s->labellen);
495         uint8_t *msg = alloca(msglen);
496         fill_msg(msg, !s->initiator, s->hiskex, s->mykex, s);
497
498         // Verify signature.
499         if(!ecdsa_verify(s->hiskey, msg, msglen, data)) {
500                 return error(s, EIO, "Failed to verify SIG record");
501         }
502
503         // Compute shared secret.
504         uint8_t shared[ECDH_SHARED_SIZE];
505
506         if(!ecdh_compute_shared(s->ecdh, s->hiskex->pubkey, shared)) {
507                 memzero(shared, sizeof(shared));
508                 return error(s, EINVAL, "Failed to compute ECDH shared secret");
509         }
510
511         s->ecdh = NULL;
512
513         // Generate key material from shared secret.
514         bool generated = generate_key_material(s, shared, sizeof(shared));
515         memzero(shared, sizeof(shared));
516
517         if(!generated) {
518                 return false;
519         }
520
521         if(!s->initiator && !send_sig(s)) {
522                 return false;
523         }
524
525         free_sptps_kex(s->mykex);
526         s->mykex = NULL;
527
528         free_sptps_kex(s->hiskex);
529         s->hiskex = NULL;
530
531         // Send cipher change record
532         if(s->outstate && !send_ack(s)) {
533                 return false;
534         }
535
536         if(!cipher_init(s->cipher_suite, &s->outcipher, s->key, !s->initiator)) {
537                 return error(s, EINVAL, "Failed to initialize cipher");
538         }
539
540         return true;
541 }
542
543 // Force another Key EXchange (for testing purposes).
544 bool sptps_force_kex(sptps_t *s) {
545         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) {
546                 return error(s, EINVAL, "Cannot force KEX in current state");
547         }
548
549         s->state = SPTPS_KEX;
550         return send_kex(s);
551 }
552
553 // Receive a handshake record.
554 static bool receive_handshake(sptps_t *s, const uint8_t *data, uint16_t len) {
555         // Only a few states to deal with handshaking.
556         switch(s->state) {
557         case SPTPS_SECONDARY_KEX:
558
559                 // We receive a secondary KEX request, first respond by sending our own.
560                 if(!send_kex(s)) {
561                         return false;
562                 }
563
564         // Fall through
565         case SPTPS_KEX:
566
567                 // We have sent our KEX request, we expect our peer to sent one as well.
568                 if(!receive_kex(s, data, len)) {
569                         return false;
570                 }
571
572                 s->state = SPTPS_SIG;
573                 return true;
574
575         case SPTPS_SIG:
576
577                 // If we already sent our secondary public ECDH key, we expect the peer to send his.
578                 if(!receive_sig(s, data, len)) {
579                         return false;
580                 }
581
582                 if(s->outstate) {
583                         s->state = SPTPS_ACK;
584                 } else {
585                         s->outstate = true;
586
587                         if(!receive_ack(s, NULL, 0)) {
588                                 return false;
589                         }
590
591                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
592                         s->state = SPTPS_SECONDARY_KEX;
593                 }
594
595                 return true;
596
597         case SPTPS_ACK:
598
599                 // We expect a handshake message to indicate transition to the new keys.
600                 if(!receive_ack(s, data, len)) {
601                         return false;
602                 }
603
604                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
605                 s->state = SPTPS_SECONDARY_KEX;
606                 return true;
607
608         // TODO: split ACK into a VERify and ACK?
609         default:
610                 return error(s, EIO, "Invalid session state %d", s->state);
611         }
612 }
613
614 static bool sptps_check_seqno(sptps_t *s, uint32_t seqno, bool update_state) {
615         // Replay protection using a sliding window of configurable size.
616         // s->inseqno is expected sequence number
617         // seqno is received sequence number
618         // s->late[] is a circular buffer, a 1 bit means a packet has not been received yet
619         // The circular buffer contains bits for sequence numbers from s->inseqno - s->replaywin * 8 to (but excluding) s->inseqno.
620         if(s->replaywin) {
621                 if(seqno != s->inseqno) {
622                         if(seqno >= s->inseqno + s->replaywin * 8) {
623                                 // Prevent packets that jump far ahead of the queue from causing many others to be dropped.
624                                 bool farfuture = s->farfuture < s->replaywin >> 2;
625
626                                 if(update_state) {
627                                         s->farfuture++;
628                                 }
629
630                                 if(farfuture) {
631                                         return update_state ? error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture) : false;
632                                 }
633
634                                 // Unless we have seen lots of them, in which case we consider the others lost.
635                                 if(update_state) {
636                                         warning(s, "Lost %d packets\n", seqno - s->inseqno);
637                                 }
638
639                                 if(update_state) {
640                                         // Mark all packets in the replay window as being late.
641                                         memset(s->late, 255, s->replaywin);
642                                 }
643                         } else if(seqno < s->inseqno) {
644                                 // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it.
645                                 if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) {
646                                         return update_state ? error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno) : false;
647                                 }
648                         } else if(update_state) {
649                                 // We missed some packets. Mark them in the bitmap as being late.
650                                 for(uint32_t i = s->inseqno; i < seqno; i++) {
651                                         s->late[(i / 8) % s->replaywin] |= 1 << i % 8;
652                                 }
653                         }
654                 }
655
656                 if(update_state) {
657                         // Mark the current packet as not being late.
658                         s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8);
659                         s->farfuture = 0;
660                 }
661         }
662
663         if(update_state) {
664                 if(seqno >= s->inseqno) {
665                         s->inseqno = seqno + 1;
666                 }
667
668                 if(!s->inseqno) {
669                         s->received = 0;
670                 } else {
671                         s->received++;
672                 }
673         }
674
675         return true;
676 }
677
678 // Check datagram for valid HMAC
679 bool sptps_verify_datagram(sptps_t *s, const void *vdata, size_t len) {
680         if(!s->instate || len < 21) {
681                 return error(s, EIO, "Received short packet");
682         }
683
684         const uint8_t *data = vdata;
685         uint32_t seqno;
686         memcpy(&seqno, data, 4);
687
688         if(!sptps_check_seqno(s, seqno, false)) {
689                 return false;
690         }
691
692         uint8_t *buffer = alloca(len);
693         return cipher_decrypt(s->cipher_suite, s->incipher, seqno, data + 4, len - 4, buffer, NULL);
694 }
695
696 // Receive incoming data, datagram version.
697 static bool sptps_receive_data_datagram(sptps_t *s, const uint8_t *data, size_t len) {
698         if(len < (s->instate ? 21 : 5)) {
699                 return error(s, EIO, "Received short packet");
700         }
701
702         uint32_t seqno;
703         memcpy(&seqno, data, 4);
704         data += 4;
705         len -= 4;
706
707         if(!s->instate) {
708                 if(seqno != s->inseqno) {
709                         return error(s, EIO, "Invalid packet seqno: %d != %d", seqno, s->inseqno);
710                 }
711
712                 s->inseqno = seqno + 1;
713
714                 uint8_t type = *(data++);
715                 len--;
716
717                 if(type != SPTPS_HANDSHAKE) {
718                         return error(s, EIO, "Application record received before handshake finished");
719                 }
720
721                 return receive_handshake(s, data, len);
722         }
723
724         // Decrypt
725
726         uint8_t *buffer = alloca(len);
727         size_t outlen;
728
729         if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, data, len, buffer, &outlen)) {
730                 return error(s, EIO, "Failed to decrypt and verify packet");
731         }
732
733         if(!sptps_check_seqno(s, seqno, true)) {
734                 return false;
735         }
736
737         // Append a NULL byte for safety.
738         buffer[outlen] = 0;
739
740         data = buffer;
741         len = outlen;
742
743         uint8_t type = *(data++);
744         len--;
745
746         if(type < SPTPS_HANDSHAKE) {
747                 if(!s->instate) {
748                         return error(s, EIO, "Application record received before handshake finished");
749                 }
750
751                 if(!s->receive_record(s->handle, type, data, len)) {
752                         return false;
753                 }
754         } else if(type == SPTPS_HANDSHAKE) {
755                 if(!receive_handshake(s, data, len)) {
756                         return false;
757                 }
758         } else {
759                 return error(s, EIO, "Invalid record type %d", type);
760         }
761
762         return true;
763 }
764
765 // Receive incoming data. Check if it contains a complete record, if so, handle it.
766 size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
767         const uint8_t *data = vdata;
768         size_t total_read = 0;
769
770         if(!s->state) {
771                 return error(s, EIO, "Invalid session state zero");
772         }
773
774         if(s->datagram) {
775                 return sptps_receive_data_datagram(s, data, len) ? len : false;
776         }
777
778         // First read the 2 length bytes.
779         if(s->buflen < 2) {
780                 size_t toread = 2 - s->buflen;
781
782                 if(toread > len) {
783                         toread = len;
784                 }
785
786                 memcpy(s->inbuf + s->buflen, data, toread);
787
788                 total_read += toread;
789                 s->buflen += toread;
790                 len -= toread;
791                 data += toread;
792
793                 // Exit early if we don't have the full length.
794                 if(s->buflen < 2) {
795                         return total_read;
796                 }
797
798                 // Get the length bytes
799
800                 memcpy(&s->reclen, s->inbuf, 2);
801
802                 // If we have the length bytes, ensure our buffer can hold the whole request.
803                 s->inbuf = realloc(s->inbuf, s->reclen + SPTPS_OVERHEAD);
804
805                 if(!s->inbuf) {
806                         return error(s, errno, "%s", strerror(errno));
807                 }
808
809                 // Exit early if we have no more data to process.
810                 if(!len) {
811                         return total_read;
812                 }
813         }
814
815         // Read up to the end of the record.
816         size_t toread = s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER) - s->buflen;
817
818         if(toread > len) {
819                 toread = len;
820         }
821
822         memcpy(s->inbuf + s->buflen, data, toread);
823         total_read += toread;
824         s->buflen += toread;
825
826         // If we don't have a whole record, exit.
827         if(s->buflen < s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER)) {
828                 return total_read;
829         }
830
831         // Update sequence number.
832
833         uint32_t seqno = s->inseqno++;
834
835         // Check HMAC and decrypt.
836         if(s->instate) {
837                 if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) {
838                         return error(s, EINVAL, "Failed to decrypt and verify record");
839                 }
840         }
841
842         // Append a NULL byte for safety.
843         s->inbuf[s->reclen + SPTPS_HEADER] = 0;
844
845         uint8_t type = s->inbuf[2];
846
847         if(type < SPTPS_HANDSHAKE) {
848                 if(!s->instate) {
849                         return error(s, EIO, "Application record received before handshake finished");
850                 }
851
852                 if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) {
853                         return false;
854                 }
855         } else if(type == SPTPS_HANDSHAKE) {
856                 if(!receive_handshake(s, s->inbuf + 3, s->reclen)) {
857                         return false;
858                 }
859         } else {
860                 return error(s, EIO, "Invalid record type %d", type);
861         }
862
863         s->buflen = 0;
864
865         return total_read;
866 }
867
868 // Start a SPTPS session.
869 bool sptps_start(sptps_t *s, const sptps_params_t *params) {
870         // Initialise struct sptps
871         memset(s, 0, sizeof(*s));
872
873         s->handle = params->handle;
874         s->initiator = params->initiator;
875         s->datagram = params->datagram;
876         s->mykey = params->mykey;
877         s->hiskey = params->hiskey;
878         s->replaywin = sptps_replaywin;
879         s->cipher_suites = params->cipher_suites ? params->cipher_suites & SPTPS_ALL_CIPHER_SUITES : SPTPS_ALL_CIPHER_SUITES;
880         s->preferred_suite = params->preferred_suite;
881
882         if(s->replaywin) {
883                 s->late = malloc(s->replaywin);
884
885                 if(!s->late) {
886                         return error(s, errno, "%s", strerror(errno));
887                 }
888
889                 memset(s->late, 0, s->replaywin);
890         }
891
892         s->labellen = params->labellen ? params->labellen : strlen(params->label);
893         s->label = malloc(s->labellen);
894
895         if(!s->label) {
896                 return error(s, errno, "%s", strerror(errno));
897         }
898
899         memcpy(s->label, params->label, s->labellen);
900
901         if(!s->datagram) {
902                 s->inbuf = malloc(7);
903
904                 if(!s->inbuf) {
905                         return error(s, errno, "%s", strerror(errno));
906                 }
907
908                 s->buflen = 0;
909         }
910
911
912         s->send_data = params->send_data;
913         s->receive_record = params->receive_record;
914
915         // Do first KEX immediately
916         s->state = SPTPS_KEX;
917         return send_kex(s);
918 }
919
920 // Stop a SPTPS session.
921 bool sptps_stop(sptps_t *s) {
922         // Clean up any resources.
923         cipher_exit(s->cipher_suite, s->incipher);
924         cipher_exit(s->cipher_suite, s->outcipher);
925         ecdh_free(s->ecdh);
926         free(s->inbuf);
927         free_sptps_kex(s->mykex);
928         free_sptps_kex(s->hiskex);
929         free_sptps_key(s->key);
930         free(s->label);
931         free(s->late);
932         memset(s, 0, sizeof(*s));
933         return true;
934 }