Handle SPTPS datagrams in try_mac().
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011 Guus Sliepen <guus@tinc-vpn.org>,
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #include "cipher.h"
23 #include "crypto.h"
24 #include "digest.h"
25 #include "ecdh.h"
26 #include "ecdsa.h"
27 #include "prf.h"
28 #include "sptps.h"
29
30 /*
31    Nonce MUST be exchanged first (done)
32    Signatures MUST be done over both nonces, to guarantee the signature is fresh
33    Otherwise: if ECDHE key of one side is compromised, it can be reused!
34
35    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
36
37    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
38
39    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of ECDSA over the whole ECDHE exchange?)
40
41    Explicit close message needs to be added.
42
43    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
44
45    Use counter mode instead of OFB. (done)
46
47    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
48 */
49
50 // Log an error message.
51 static bool error(sptps_t *s, int s_errno, const char *msg) {
52         fprintf(stderr, "SPTPS error: %s\n", msg);
53         errno = s_errno;
54         return false;
55 }
56
57 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
58 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
59         char buffer[len + 23UL];
60
61         // Create header with sequence number, length and record type
62         uint32_t seqno = htonl(s->outseqno++);
63         uint16_t netlen = htons(len);
64
65         memcpy(buffer, &netlen, 2);
66         memcpy(buffer + 2, &seqno, 4);
67         buffer[6] = type;
68
69         // Add plaintext (TODO: avoid unnecessary copy)
70         memcpy(buffer + 7, data, len);
71
72         if(s->outstate) {
73                 // If first handshake has finished, encrypt and HMAC
74                 cipher_set_counter(&s->outcipher, &seqno, sizeof seqno);
75                 if(!cipher_counter_xor(&s->outcipher, buffer + 6, len + 1UL, buffer + 6))
76                         return false;
77
78                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
79                         return false;
80
81                 return s->send_data(s->handle, type, buffer + 2, len + 21UL);
82         } else {
83                 // Otherwise send as plaintext
84                 return s->send_data(s->handle, type, buffer + 2, len + 5UL);
85         }
86 }
87 // Send a record (private version, accepts all record types, handles encryption and authentication).
88 static bool send_record_priv(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
89         if(s->datagram)
90                 return send_record_priv_datagram(s, type, data, len);
91
92         char buffer[len + 23UL];
93
94         // Create header with sequence number, length and record type
95         uint32_t seqno = htonl(s->outseqno++);
96         uint16_t netlen = htons(len);
97
98         memcpy(buffer, &seqno, 4);
99         memcpy(buffer + 4, &netlen, 2);
100         buffer[6] = type;
101
102         // Add plaintext (TODO: avoid unnecessary copy)
103         memcpy(buffer + 7, data, len);
104
105         if(s->outstate) {
106                 // If first handshake has finished, encrypt and HMAC
107                 if(!cipher_counter_xor(&s->outcipher, buffer + 4, len + 3UL, buffer + 4))
108                         return false;
109
110                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
111                         return false;
112
113                 return s->send_data(s->handle, type, buffer + 4, len + 19UL);
114         } else {
115                 // Otherwise send as plaintext
116                 return s->send_data(s->handle, type, buffer + 4, len + 3UL);
117         }
118 }
119
120 // Send an application record.
121 bool sptps_send_record(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
122         // Sanity checks: application cannot send data before handshake is finished,
123         // and only record types 0..127 are allowed.
124         if(!s->outstate)
125                 return error(s, EINVAL, "Handshake phase not finished yet");
126
127         if(type >= SPTPS_HANDSHAKE)
128                 return error(s, EINVAL, "Invalid application record type");
129
130         return send_record_priv(s, type, data, len);
131 }
132
133 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
134 static bool send_kex(sptps_t *s) {
135         size_t keylen = ECDH_SIZE;
136
137         // Make room for our KEX message, which we will keep around since send_sig() needs it.
138         if(s->mykex)
139                 abort();
140         s->mykex = realloc(s->mykex, 1 + 32 + keylen);
141         if(!s->mykex)
142                 return error(s, errno, strerror(errno));
143
144         // Set version byte to zero.
145         s->mykex[0] = SPTPS_VERSION;
146
147         // Create a random nonce.
148         randomize(s->mykex + 1, 32);
149
150         // Create a new ECDH public key.
151         if(!ecdh_generate_public(&s->ecdh, s->mykex + 1 + 32))
152                 return false;
153
154         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, 1 + 32 + keylen);
155 }
156
157 // Send a SIGnature record, containing an ECDSA signature over both KEX records.
158 static bool send_sig(sptps_t *s) {
159         size_t keylen = ECDH_SIZE;
160         size_t siglen = ecdsa_size(&s->mykey);
161
162         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
163         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
164         char sig[siglen];
165
166         msg[0] = s->initiator;
167         memcpy(msg + 1, s->mykex, 1 + 32 + keylen);
168         memcpy(msg + 1 + 33 + keylen, s->hiskex, 1 + 32 + keylen);
169         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
170
171         // Sign the result.
172         if(!ecdsa_sign(&s->mykey, msg, sizeof msg, sig))
173                 return false;
174
175         // Send the SIG exchange record.
176         return send_record_priv(s, SPTPS_HANDSHAKE, sig, sizeof sig);
177 }
178
179 // Generate key material from the shared secret created from the ECDHE key exchange.
180 static bool generate_key_material(sptps_t *s, const char *shared, size_t len) {
181         // Initialise cipher and digest structures if necessary
182         if(!s->outstate) {
183                 bool result
184                         =  cipher_open_by_name(&s->incipher, "aes-256-ecb")
185                         && cipher_open_by_name(&s->outcipher, "aes-256-ecb")
186                         && digest_open_by_name(&s->indigest, "sha256", 16)
187                         && digest_open_by_name(&s->outdigest, "sha256", 16);
188                 if(!result)
189                         return false;
190         }
191
192         // Allocate memory for key material
193         size_t keylen = digest_keylength(&s->indigest) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher) + cipher_keylength(&s->outcipher);
194
195         s->key = realloc(s->key, keylen);
196         if(!s->key)
197                 return error(s, errno, strerror(errno));
198
199         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
200         char seed[s->labellen + 64 + 13];
201         strcpy(seed, "key expansion");
202         if(s->initiator) {
203                 memcpy(seed + 13, s->mykex + 1, 32);
204                 memcpy(seed + 45, s->hiskex + 1, 32);
205         } else {
206                 memcpy(seed + 13, s->hiskex + 1, 32);
207                 memcpy(seed + 45, s->mykex + 1, 32);
208         }
209         memcpy(seed + 78, s->label, s->labellen);
210
211         // Use PRF to generate the key material
212         if(!prf(shared, len, seed, s->labellen + 64 + 13, s->key, keylen))
213                 return false;
214
215         return true;
216 }
217
218 // Send an ACKnowledgement record.
219 static bool send_ack(sptps_t *s) {
220         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
221 }
222
223 // Receive an ACKnowledgement record.
224 static bool receive_ack(sptps_t *s, const char *data, uint16_t len) {
225         if(len)
226                 return error(s, EIO, "Invalid ACK record length");
227
228         if(s->initiator) {
229                 bool result
230                         = cipher_set_counter_key(&s->incipher, s->key)
231                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
232                 if(!result)
233                         return false;
234         } else {
235                 bool result
236                         = cipher_set_counter_key(&s->incipher, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest))
237                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
238                 if(!result)
239                         return false;
240         }
241
242         free(s->key);
243         s->key = NULL;
244         s->instate = true;
245
246         return true;
247 }
248
249 // Receive a Key EXchange record, respond by sending a SIG record.
250 static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
251         // Verify length of the HELLO record
252         if(len != 1 + 32 + ECDH_SIZE)
253                 return error(s, EIO, "Invalid KEX record length");
254
255         // Ignore version number for now.
256
257         // Make a copy of the KEX message, send_sig() and receive_sig() need it
258         if(s->hiskex)
259                 abort();
260         s->hiskex = realloc(s->hiskex, len);
261         if(!s->hiskex)
262                 return error(s, errno, strerror(errno));
263
264         memcpy(s->hiskex, data, len);
265
266         return send_sig(s);
267 }
268
269 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
270 static bool receive_sig(sptps_t *s, const char *data, uint16_t len) {
271         size_t keylen = ECDH_SIZE;
272         size_t siglen = ecdsa_size(&s->hiskey);
273
274         // Verify length of KEX record.
275         if(len != siglen)
276                 return error(s, EIO, "Invalid KEX record length");
277
278         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
279         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
280
281         msg[0] = !s->initiator;
282         memcpy(msg + 1, s->hiskex, 1 + 32 + keylen);
283         memcpy(msg + 1 + 33 + keylen, s->mykex, 1 + 32 + keylen);
284         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
285
286         // Verify signature.
287         if(!ecdsa_verify(&s->hiskey, msg, sizeof msg, data))
288                 return false;
289
290         // Compute shared secret.
291         char shared[ECDH_SHARED_SIZE];
292         if(!ecdh_compute_shared(&s->ecdh, s->hiskex + 1 + 32, shared))
293                 return false;
294
295         // Generate key material from shared secret.
296         if(!generate_key_material(s, shared, sizeof shared))
297                 return false;
298
299         free(s->mykex);
300         free(s->hiskex);
301
302         s->mykex = NULL;
303         s->hiskex = NULL;
304
305         // Send cipher change record
306         if(s->outstate && !send_ack(s))
307                 return false;
308
309         // TODO: only set new keys after ACK has been set/received
310         if(s->initiator) {
311                 bool result
312                         = cipher_set_counter_key(&s->outcipher, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest))
313                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest) + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
314                 if(!result)
315                         return false;
316         } else {
317                 bool result
318                         =  cipher_set_counter_key(&s->outcipher, s->key)
319                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
320                 if(!result)
321                         return false;
322         }
323
324         return true;
325 }
326
327 // Force another Key EXchange (for testing purposes).
328 bool sptps_force_kex(sptps_t *s) {
329         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX)
330                 return error(s, EINVAL, "Cannot force KEX in current state");
331
332         s->state = SPTPS_KEX;
333         return send_kex(s);
334 }
335
336 // Receive a handshake record.
337 static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
338         // Only a few states to deal with handshaking.
339         fprintf(stderr, "Received handshake message, current state %d\n", s->state);
340         switch(s->state) {
341                 case SPTPS_SECONDARY_KEX:
342                         // We receive a secondary KEX request, first respond by sending our own.
343                         if(!send_kex(s))
344                                 return false;
345                 case SPTPS_KEX:
346                         // We have sent our KEX request, we expect our peer to sent one as well.
347                         if(!receive_kex(s, data, len))
348                                 return false;
349                         s->state = SPTPS_SIG;
350                         return true;
351                 case SPTPS_SIG:
352                         // If we already sent our secondary public ECDH key, we expect the peer to send his.
353                         if(!receive_sig(s, data, len))
354                                 return false;
355                         if(s->outstate)
356                                 s->state = SPTPS_ACK;
357                         else {
358                                 s->outstate = true;
359                                 if(!receive_ack(s, NULL, 0))
360                                         return false;
361                                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
362                                 s->state = SPTPS_SECONDARY_KEX;
363                         }
364
365                         return true;
366                 case SPTPS_ACK:
367                         // We expect a handshake message to indicate transition to the new keys.
368                         if(!receive_ack(s, data, len))
369                                 return false;
370                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
371                         s->state = SPTPS_SECONDARY_KEX;
372                         return true;
373                 // TODO: split ACK into a VERify and ACK?
374                 default:
375                         return error(s, EIO, "Invalid session state");
376         }
377 }
378
379 // Check datagram for valid HMAC
380 bool sptps_verify_datagram(sptps_t *s, const char *data, size_t len) {
381         if(!s->instate || len < 21)
382                 return false;
383
384         char buffer[len + 23];
385         uint16_t netlen = htons(len - 21);
386
387         memcpy(buffer, &netlen, 2);
388         memcpy(buffer + 2, data, len);
389
390         return digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14);
391 }
392
393 // Receive incoming data, datagram version.
394 static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) {
395         if(len < (s->instate ? 21 : 5))
396                 return error(s, EIO, "Received short packet");
397
398         uint32_t seqno;
399         memcpy(&seqno, data, 4);
400         seqno = ntohl(seqno);
401
402         if(!s->instate) {
403                 if(seqno != s->inseqno) {
404                         fprintf(stderr, "Received invalid packet seqno: %d != %d\n", seqno, s->inseqno);
405                         return error(s, EIO, "Invalid packet seqno");
406                 }
407
408                 s->inseqno = seqno + 1;
409
410                 uint8_t type = data[4];
411
412                 if(type != SPTPS_HANDSHAKE)
413                         return error(s, EIO, "Application record received before handshake finished");
414
415                 return receive_handshake(s, data + 5, len - 5);
416         }
417
418         if(seqno < s->inseqno) {
419                 fprintf(stderr, "Received late or replayed packet: %d < %d\n", seqno, s->inseqno);
420                 return true;
421         }
422
423         if(seqno > s->inseqno)
424                 fprintf(stderr, "Missed %d packets\n", seqno - s->inseqno);
425
426         s->inseqno = seqno + 1;
427
428         uint16_t netlen = htons(len - 21);
429
430         char buffer[len + 23];
431
432         memcpy(buffer, &netlen, 2);
433         memcpy(buffer + 2, data, len);
434
435         memcpy(&seqno, buffer + 2, 4);
436
437         // Check HMAC and decrypt.
438         if(!digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14))
439                 return error(s, EIO, "Invalid HMAC");
440
441         cipher_set_counter(&s->incipher, &seqno, sizeof seqno);
442         if(!cipher_counter_xor(&s->incipher, buffer + 6, len - 4, buffer + 6))
443                 return false;
444
445         // Append a NULL byte for safety.
446         buffer[len - 14] = 0;
447
448         uint8_t type = buffer[6];
449
450         if(type < SPTPS_HANDSHAKE) {
451                 if(!s->instate)
452                         return error(s, EIO, "Application record received before handshake finished");
453                 if(!s->receive_record(s->handle, type, buffer + 7, len - 21))
454                         return false;
455         } else if(type == SPTPS_HANDSHAKE) {
456                 if(!receive_handshake(s, buffer + 7, len - 21))
457                         return false;
458         } else {
459                 return error(s, EIO, "Invalid record type");
460         }
461
462         return true;
463 }
464 // Receive incoming data. Check if it contains a complete record, if so, handle it.
465 bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
466         if(s->datagram)
467                 return sptps_receive_data_datagram(s, data, len);
468
469         while(len) {
470                 // First read the 2 length bytes.
471                 if(s->buflen < 6) {
472                         size_t toread = 6 - s->buflen;
473                         if(toread > len)
474                                 toread = len;
475
476                         memcpy(s->inbuf + s->buflen, data, toread);
477
478                         s->buflen += toread;
479                         len -= toread;
480                         data += toread;
481                 
482                         // Exit early if we don't have the full length.
483                         if(s->buflen < 6)
484                                 return true;
485
486                         // Decrypt the length bytes
487
488                         if(s->instate) {
489                                 if(!cipher_counter_xor(&s->incipher, s->inbuf + 4, 2, &s->reclen))
490                                         return false;
491                         } else {
492                                 memcpy(&s->reclen, s->inbuf + 4, 2);
493                         }
494
495                         s->reclen = ntohs(s->reclen);
496
497                         // If we have the length bytes, ensure our buffer can hold the whole request.
498                         s->inbuf = realloc(s->inbuf, s->reclen + 23UL);
499                         if(!s->inbuf)
500                                 return error(s, errno, strerror(errno));
501
502                         // Add sequence number.
503                         uint32_t seqno = htonl(s->inseqno++);
504                         memcpy(s->inbuf, &seqno, 4);
505
506                         // Exit early if we have no more data to process.
507                         if(!len)
508                                 return true;
509                 }
510
511                 // Read up to the end of the record.
512                 size_t toread = s->reclen + (s->instate ? 23UL : 7UL) - s->buflen;
513                 if(toread > len)
514                         toread = len;
515
516                 memcpy(s->inbuf + s->buflen, data, toread);
517                 s->buflen += toread;
518                 len -= toread;
519                 data += toread;
520
521                 // If we don't have a whole record, exit.
522                 if(s->buflen < s->reclen + (s->instate ? 23UL : 7UL))
523                         return true;
524
525                 // Check HMAC and decrypt.
526                 if(s->instate) {
527                         if(!digest_verify(&s->indigest, s->inbuf, s->reclen + 7UL, s->inbuf + s->reclen + 7UL))
528                                 return error(s, EIO, "Invalid HMAC");
529
530                         if(!cipher_counter_xor(&s->incipher, s->inbuf + 6UL, s->reclen + 1UL, s->inbuf + 6UL))
531                                 return false;
532                 }
533
534                 // Append a NULL byte for safety.
535                 s->inbuf[s->reclen + 7UL] = 0;
536
537                 uint8_t type = s->inbuf[6];
538
539                 if(type < SPTPS_HANDSHAKE) {
540                         if(!s->instate)
541                                 return error(s, EIO, "Application record received before handshake finished");
542                         if(!s->receive_record(s->handle, type, s->inbuf + 7, s->reclen))
543                                 return false;
544                 } else if(type == SPTPS_HANDSHAKE) {
545                         if(!receive_handshake(s, s->inbuf + 7, s->reclen))
546                                 return false;
547                 } else {
548                         return error(s, EIO, "Invalid record type");
549                 }
550
551                 s->buflen = 4;
552         }
553
554         return true;
555 }
556
557 // Start a SPTPS session.
558 bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
559         // Initialise struct sptps
560         memset(s, 0, sizeof *s);
561
562         s->handle = handle;
563         s->initiator = initiator;
564         s->datagram = datagram;
565         s->mykey = mykey;
566         s->hiskey = hiskey;
567
568         s->label = malloc(labellen);
569         if(!s->label)
570                 return error(s, errno, strerror(errno));
571
572         if(!datagram) {
573                 s->inbuf = malloc(7);
574                 if(!s->inbuf)
575                         return error(s, errno, strerror(errno));
576                 s->buflen = 4;
577                 memset(s->inbuf, 0, 4);
578         }
579
580         memcpy(s->label, label, labellen);
581         s->labellen = labellen;
582
583         s->send_data = send_data;
584         s->receive_record = receive_record;
585
586         // Do first KEX immediately
587         s->state = SPTPS_KEX;
588         return send_kex(s);
589 }
590
591 // Stop a SPTPS session.
592 bool sptps_stop(sptps_t *s) {
593         // Clean up any resources.
594         ecdh_free(&s->ecdh);
595         free(s->inbuf);
596         s->inbuf = NULL;
597         free(s->mykex);
598         s->mykex = NULL;
599         free(s->hiskex);
600         s->hiskex = NULL;
601         free(s->key);
602         s->key = NULL;
603         free(s->label);
604         s->label = NULL;
605         return true;
606 }