Use datagram SPTPS for packet exchange between nodes.
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011 Guus Sliepen <guus@tinc-vpn.org>,
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #include "cipher.h"
23 #include "crypto.h"
24 #include "digest.h"
25 #include "ecdh.h"
26 #include "ecdsa.h"
27 #include "prf.h"
28 #include "sptps.h"
29
30 /*
31    Nonce MUST be exchanged first (done)
32    Signatures MUST be done over both nonces, to guarantee the signature is fresh
33    Otherwise: if ECDHE key of one side is compromised, it can be reused!
34
35    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
36
37    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
38
39    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of ECDSA over the whole ECDHE exchange?)
40
41    Explicit close message needs to be added.
42
43    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
44
45    Use counter mode instead of OFB. (done)
46
47    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
48 */
49
50 // Log an error message.
51 static bool error(sptps_t *s, int s_errno, const char *msg) {
52         fprintf(stderr, "SPTPS error: %s\n", msg);
53         errno = s_errno;
54         return false;
55 }
56
57 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
58 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
59         char buffer[len + 23UL];
60
61         // Create header with sequence number, length and record type
62         uint32_t seqno = htonl(s->outseqno++);
63         uint16_t netlen = htons(len);
64
65         memcpy(buffer, &netlen, 2);
66         memcpy(buffer + 2, &seqno, 4);
67         buffer[6] = type;
68
69         // Add plaintext (TODO: avoid unnecessary copy)
70         memcpy(buffer + 7, data, len);
71
72         if(s->outstate) {
73                 // If first handshake has finished, encrypt and HMAC
74                 cipher_set_counter(&s->outcipher, &seqno, sizeof seqno);
75                 if(!cipher_counter_xor(&s->outcipher, buffer + 6, len + 1UL, buffer + 6))
76                         return false;
77
78                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
79                         return false;
80
81                 return s->send_data(s->handle, type, buffer + 2, len + 21UL);
82         } else {
83                 // Otherwise send as plaintext
84                 return s->send_data(s->handle, type, buffer + 2, len + 5UL);
85         }
86 }
87 // Send a record (private version, accepts all record types, handles encryption and authentication).
88 static bool send_record_priv(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
89         if(s->datagram)
90                 return send_record_priv_datagram(s, type, data, len);
91
92         char buffer[len + 23UL];
93
94         // Create header with sequence number, length and record type
95         uint32_t seqno = htonl(s->outseqno++);
96         uint16_t netlen = htons(len);
97
98         memcpy(buffer, &seqno, 4);
99         memcpy(buffer + 4, &netlen, 2);
100         buffer[6] = type;
101
102         // Add plaintext (TODO: avoid unnecessary copy)
103         memcpy(buffer + 7, data, len);
104
105         if(s->outstate) {
106                 // If first handshake has finished, encrypt and HMAC
107                 if(!cipher_counter_xor(&s->outcipher, buffer + 4, len + 3UL, buffer + 4))
108                         return false;
109
110                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
111                         return false;
112
113                 return s->send_data(s->handle, type, buffer + 4, len + 19UL);
114         } else {
115                 // Otherwise send as plaintext
116                 return s->send_data(s->handle, type, buffer + 4, len + 3UL);
117         }
118 }
119
120 // Send an application record.
121 bool sptps_send_record(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
122         // Sanity checks: application cannot send data before handshake is finished,
123         // and only record types 0..127 are allowed.
124         if(!s->outstate)
125                 return error(s, EINVAL, "Handshake phase not finished yet");
126
127         if(type >= SPTPS_HANDSHAKE)
128                 return error(s, EINVAL, "Invalid application record type");
129
130         return send_record_priv(s, type, data, len);
131 }
132
133 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
134 static bool send_kex(sptps_t *s) {
135         size_t keylen = ECDH_SIZE;
136
137         // Make room for our KEX message, which we will keep around since send_sig() needs it.
138         if(s->mykex)
139                 abort();
140         s->mykex = realloc(s->mykex, 1 + 32 + keylen);
141         if(!s->mykex)
142                 return error(s, errno, strerror(errno));
143
144         // Set version byte to zero.
145         s->mykex[0] = SPTPS_VERSION;
146
147         // Create a random nonce.
148         randomize(s->mykex + 1, 32);
149
150         // Create a new ECDH public key.
151         if(!ecdh_generate_public(&s->ecdh, s->mykex + 1 + 32))
152                 return false;
153
154         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, 1 + 32 + keylen);
155 }
156
157 // Send a SIGnature record, containing an ECDSA signature over both KEX records.
158 static bool send_sig(sptps_t *s) {
159         size_t keylen = ECDH_SIZE;
160         size_t siglen = ecdsa_size(&s->mykey);
161
162         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
163         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
164         char sig[siglen];
165
166         msg[0] = s->initiator;
167         memcpy(msg + 1, s->mykex, 1 + 32 + keylen);
168         memcpy(msg + 1 + 33 + keylen, s->hiskex, 1 + 32 + keylen);
169         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
170
171         // Sign the result.
172         if(!ecdsa_sign(&s->mykey, msg, sizeof msg, sig))
173                 return false;
174
175         // Send the SIG exchange record.
176         return send_record_priv(s, SPTPS_HANDSHAKE, sig, sizeof sig);
177 }
178
179 // Generate key material from the shared secret created from the ECDHE key exchange.
180 static bool generate_key_material(sptps_t *s, const char *shared, size_t len) {
181         // Initialise cipher and digest structures if necessary
182         if(!s->outstate) {
183                 bool result
184                         =  cipher_open_by_name(&s->incipher, "aes-256-ecb")
185                         && cipher_open_by_name(&s->outcipher, "aes-256-ecb")
186                         && digest_open_by_name(&s->indigest, "sha256", 16)
187                         && digest_open_by_name(&s->outdigest, "sha256", 16);
188                 if(!result)
189                         return false;
190         }
191
192         // Allocate memory for key material
193         size_t keylen = digest_keylength(&s->indigest) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher) + cipher_keylength(&s->outcipher);
194
195         s->key = realloc(s->key, keylen);
196         if(!s->key)
197                 return error(s, errno, strerror(errno));
198
199         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
200         char seed[s->labellen + 64 + 13];
201         strcpy(seed, "key expansion");
202         if(s->initiator) {
203                 memcpy(seed + 13, s->mykex + 1, 32);
204                 memcpy(seed + 45, s->hiskex + 1, 32);
205         } else {
206                 memcpy(seed + 13, s->hiskex + 1, 32);
207                 memcpy(seed + 45, s->mykex + 1, 32);
208         }
209         memcpy(seed + 78, s->label, s->labellen);
210
211         // Use PRF to generate the key material
212         if(!prf(shared, len, seed, s->labellen + 64 + 13, s->key, keylen))
213                 return false;
214
215         return true;
216 }
217
218 // Send an ACKnowledgement record.
219 static bool send_ack(sptps_t *s) {
220         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
221 }
222
223 // Receive an ACKnowledgement record.
224 static bool receive_ack(sptps_t *s, const char *data, uint16_t len) {
225         if(len)
226                 return error(s, EIO, "Invalid ACK record length");
227
228         if(s->initiator) {
229                 bool result
230                         = cipher_set_counter_key(&s->incipher, s->key)
231                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
232                 if(!result)
233                         return false;
234         } else {
235                 bool result
236                         = cipher_set_counter_key(&s->incipher, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest))
237                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
238                 if(!result)
239                         return false;
240         }
241
242         free(s->key);
243         s->key = NULL;
244         s->instate = true;
245
246         return true;
247 }
248
249 // Receive a Key EXchange record, respond by sending a SIG record.
250 static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
251         // Verify length of the HELLO record
252         if(len != 1 + 32 + ECDH_SIZE)
253                 return error(s, EIO, "Invalid KEX record length");
254
255         // Ignore version number for now.
256
257         // Make a copy of the KEX message, send_sig() and receive_sig() need it
258         if(s->hiskex)
259                 abort();
260         s->hiskex = realloc(s->hiskex, len);
261         if(!s->hiskex)
262                 return error(s, errno, strerror(errno));
263
264         memcpy(s->hiskex, data, len);
265
266         return send_sig(s);
267 }
268
269 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
270 static bool receive_sig(sptps_t *s, const char *data, uint16_t len) {
271         size_t keylen = ECDH_SIZE;
272         size_t siglen = ecdsa_size(&s->hiskey);
273
274         // Verify length of KEX record.
275         if(len != siglen)
276                 return error(s, EIO, "Invalid KEX record length");
277
278         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
279         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
280
281         msg[0] = !s->initiator;
282         memcpy(msg + 1, s->hiskex, 1 + 32 + keylen);
283         memcpy(msg + 1 + 33 + keylen, s->mykex, 1 + 32 + keylen);
284         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
285
286         // Verify signature.
287         if(!ecdsa_verify(&s->hiskey, msg, sizeof msg, data))
288                 return false;
289
290         // Compute shared secret.
291         char shared[ECDH_SHARED_SIZE];
292         if(!ecdh_compute_shared(&s->ecdh, s->hiskex + 1 + 32, shared))
293                 return false;
294
295         // Generate key material from shared secret.
296         if(!generate_key_material(s, shared, sizeof shared))
297                 return false;
298
299         free(s->mykex);
300         free(s->hiskex);
301
302         s->mykex = NULL;
303         s->hiskex = NULL;
304
305         // Send cipher change record
306         if(s->outstate && !send_ack(s))
307                 return false;
308
309         // TODO: only set new keys after ACK has been set/received
310         if(s->initiator) {
311                 bool result
312                         = cipher_set_counter_key(&s->outcipher, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest))
313                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest) + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
314                 if(!result)
315                         return false;
316         } else {
317                 bool result
318                         =  cipher_set_counter_key(&s->outcipher, s->key)
319                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
320                 if(!result)
321                         return false;
322         }
323
324         return true;
325 }
326
327 // Force another Key EXchange (for testing purposes).
328 bool sptps_force_kex(sptps_t *s) {
329         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX)
330                 return error(s, EINVAL, "Cannot force KEX in current state");
331
332         s->state = SPTPS_KEX;
333         return send_kex(s);
334 }
335
336 // Receive a handshake record.
337 static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
338         // Only a few states to deal with handshaking.
339         fprintf(stderr, "Received handshake message, current state %d\n", s->state);
340         switch(s->state) {
341                 case SPTPS_SECONDARY_KEX:
342                         // We receive a secondary KEX request, first respond by sending our own.
343                         if(!send_kex(s))
344                                 return false;
345                 case SPTPS_KEX:
346                         // We have sent our KEX request, we expect our peer to sent one as well.
347                         if(!receive_kex(s, data, len))
348                                 return false;
349                         s->state = SPTPS_SIG;
350                         return true;
351                 case SPTPS_SIG:
352                         // If we already sent our secondary public ECDH key, we expect the peer to send his.
353                         if(!receive_sig(s, data, len))
354                                 return false;
355                         if(s->outstate)
356                                 s->state = SPTPS_ACK;
357                         else {
358                                 s->outstate = true;
359                                 if(!receive_ack(s, NULL, 0))
360                                         return false;
361                                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
362                                 s->state = SPTPS_SECONDARY_KEX;
363                         }
364
365                         return true;
366                 case SPTPS_ACK:
367                         // We expect a handshake message to indicate transition to the new keys.
368                         if(!receive_ack(s, data, len))
369                                 return false;
370                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
371                         s->state = SPTPS_SECONDARY_KEX;
372                         return true;
373                 // TODO: split ACK into a VERify and ACK?
374                 default:
375                         return error(s, EIO, "Invalid session state");
376         }
377 }
378
379 // Receive incoming data, datagram version.
380 static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) {
381         if(len < (s->instate ? 21 : 5))
382                 return error(s, EIO, "Received short packet");
383
384         uint32_t seqno;
385         memcpy(&seqno, data, 4);
386         seqno = ntohl(seqno);
387
388         if(!s->instate) {
389                 if(seqno != s->inseqno) {
390                         fprintf(stderr, "Received invalid packet seqno: %d != %d\n", seqno, s->inseqno);
391                         return error(s, EIO, "Invalid packet seqno");
392                 }
393
394                 s->inseqno = seqno + 1;
395
396                 uint8_t type = data[4];
397
398                 if(type != SPTPS_HANDSHAKE)
399                         return error(s, EIO, "Application record received before handshake finished");
400
401                 return receive_handshake(s, data + 5, len - 5);
402         }
403
404         if(seqno < s->inseqno) {
405                 fprintf(stderr, "Received late or replayed packet: %d < %d\n", seqno, s->inseqno);
406                 return true;
407         }
408
409         if(seqno > s->inseqno)
410                 fprintf(stderr, "Missed %d packets\n", seqno - s->inseqno);
411
412         s->inseqno = seqno + 1;
413
414         uint16_t netlen = htons(len - 21);
415
416         char buffer[len + 23];
417
418         memcpy(buffer, &netlen, 2);
419         memcpy(buffer + 2, data, len);
420
421         memcpy(&seqno, buffer + 2, 4);
422
423         // Check HMAC and decrypt.
424         if(!digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14))
425                 return error(s, EIO, "Invalid HMAC");
426
427         cipher_set_counter(&s->incipher, &seqno, sizeof seqno);
428         if(!cipher_counter_xor(&s->incipher, buffer + 6, len - 4, buffer + 6))
429                 return false;
430
431         // Append a NULL byte for safety.
432         buffer[len - 14] = 0;
433
434         uint8_t type = buffer[6];
435
436         if(type < SPTPS_HANDSHAKE) {
437                 if(!s->instate)
438                         return error(s, EIO, "Application record received before handshake finished");
439                 if(!s->receive_record(s->handle, type, buffer + 7, len - 21))
440                         return false;
441         } else if(type == SPTPS_HANDSHAKE) {
442                 if(!receive_handshake(s, buffer + 7, len - 21))
443                         return false;
444         } else {
445                 return error(s, EIO, "Invalid record type");
446         }
447
448         return true;
449 }
450 // Receive incoming data. Check if it contains a complete record, if so, handle it.
451 bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
452         if(s->datagram)
453                 return sptps_receive_data_datagram(s, data, len);
454
455         while(len) {
456                 // First read the 2 length bytes.
457                 if(s->buflen < 6) {
458                         size_t toread = 6 - s->buflen;
459                         if(toread > len)
460                                 toread = len;
461
462                         memcpy(s->inbuf + s->buflen, data, toread);
463
464                         s->buflen += toread;
465                         len -= toread;
466                         data += toread;
467                 
468                         // Exit early if we don't have the full length.
469                         if(s->buflen < 6)
470                                 return true;
471
472                         // Decrypt the length bytes
473
474                         if(s->instate) {
475                                 if(!cipher_counter_xor(&s->incipher, s->inbuf + 4, 2, &s->reclen))
476                                         return false;
477                         } else {
478                                 memcpy(&s->reclen, s->inbuf + 4, 2);
479                         }
480
481                         s->reclen = ntohs(s->reclen);
482
483                         // If we have the length bytes, ensure our buffer can hold the whole request.
484                         s->inbuf = realloc(s->inbuf, s->reclen + 23UL);
485                         if(!s->inbuf)
486                                 return error(s, errno, strerror(errno));
487
488                         // Add sequence number.
489                         uint32_t seqno = htonl(s->inseqno++);
490                         memcpy(s->inbuf, &seqno, 4);
491
492                         // Exit early if we have no more data to process.
493                         if(!len)
494                                 return true;
495                 }
496
497                 // Read up to the end of the record.
498                 size_t toread = s->reclen + (s->instate ? 23UL : 7UL) - s->buflen;
499                 if(toread > len)
500                         toread = len;
501
502                 memcpy(s->inbuf + s->buflen, data, toread);
503                 s->buflen += toread;
504                 len -= toread;
505                 data += toread;
506
507                 // If we don't have a whole record, exit.
508                 if(s->buflen < s->reclen + (s->instate ? 23UL : 7UL))
509                         return true;
510
511                 // Check HMAC and decrypt.
512                 if(s->instate) {
513                         if(!digest_verify(&s->indigest, s->inbuf, s->reclen + 7UL, s->inbuf + s->reclen + 7UL))
514                                 return error(s, EIO, "Invalid HMAC");
515
516                         if(!cipher_counter_xor(&s->incipher, s->inbuf + 6UL, s->reclen + 1UL, s->inbuf + 6UL))
517                                 return false;
518                 }
519
520                 // Append a NULL byte for safety.
521                 s->inbuf[s->reclen + 7UL] = 0;
522
523                 uint8_t type = s->inbuf[6];
524
525                 if(type < SPTPS_HANDSHAKE) {
526                         if(!s->instate)
527                                 return error(s, EIO, "Application record received before handshake finished");
528                         if(!s->receive_record(s->handle, type, s->inbuf + 7, s->reclen))
529                                 return false;
530                 } else if(type == SPTPS_HANDSHAKE) {
531                         if(!receive_handshake(s, s->inbuf + 7, s->reclen))
532                                 return false;
533                 } else {
534                         return error(s, EIO, "Invalid record type");
535                 }
536
537                 s->buflen = 4;
538         }
539
540         return true;
541 }
542
543 // Start a SPTPS session.
544 bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
545         // Initialise struct sptps
546         memset(s, 0, sizeof *s);
547
548         s->handle = handle;
549         s->initiator = initiator;
550         s->datagram = datagram;
551         s->mykey = mykey;
552         s->hiskey = hiskey;
553
554         s->label = malloc(labellen);
555         if(!s->label)
556                 return error(s, errno, strerror(errno));
557
558         if(!datagram) {
559                 s->inbuf = malloc(7);
560                 if(!s->inbuf)
561                         return error(s, errno, strerror(errno));
562                 s->buflen = 4;
563                 memset(s->inbuf, 0, 4);
564         }
565
566         memcpy(s->label, label, labellen);
567         s->labellen = labellen;
568
569         s->send_data = send_data;
570         s->receive_record = receive_record;
571
572         // Do first KEX immediately
573         s->state = SPTPS_KEX;
574         return send_kex(s);
575 }
576
577 // Stop a SPTPS session.
578 bool sptps_stop(sptps_t *s) {
579         // Clean up any resources.
580         ecdh_free(&s->ecdh);
581         free(s->inbuf);
582         s->inbuf = NULL;
583         free(s->mykex);
584         s->mykex = NULL;
585         free(s->hiskex);
586         s->hiskex = NULL;
587         free(s->key);
588         s->key = NULL;
589         free(s->label);
590         s->label = NULL;
591         return true;
592 }