Merge branch 'master' into 1.1
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011 Guus Sliepen <guus@tinc-vpn.org>,
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #include "cipher.h"
23 #include "crypto.h"
24 #include "digest.h"
25 #include "ecdh.h"
26 #include "ecdsa.h"
27 #include "prf.h"
28 #include "sptps.h"
29
30 /*
31    Nonce MUST be exchanged first (done)
32    Signatures MUST be done over both nonces, to guarantee the signature is fresh
33    Otherwise: if ECDHE key of one side is compromised, it can be reused!
34
35    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
36
37    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
38
39    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of ECDSA over the whole ECDHE exchange?)
40
41    Explicit close message needs to be added.
42
43    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
44
45    Use counter mode instead of OFB. (done)
46
47    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
48 */
49
50 // Log an error message.
51 static bool error(sptps_t *s, int s_errno, const char *msg) {
52         fprintf(stderr, "SPTPS error: %s\n", msg);
53         errno = s_errno;
54         return false;
55 }
56
57 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
58 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
59         char buffer[len + 23UL];
60
61         // Create header with sequence number, length and record type
62         uint32_t seqno = htonl(s->outseqno++);
63         uint16_t netlen = htons(len);
64
65         memcpy(buffer, &netlen, 2);
66         memcpy(buffer + 2, &seqno, 4);
67         buffer[6] = type;
68
69         // Add plaintext (TODO: avoid unnecessary copy)
70         memcpy(buffer + 7, data, len);
71
72         if(s->outstate) {
73                 // If first handshake has finished, encrypt and HMAC
74                 cipher_set_counter(&s->outcipher, &seqno, sizeof seqno);
75                 if(!cipher_counter_xor(&s->outcipher, buffer + 6, len + 1UL, buffer + 6))
76                         return false;
77
78                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
79                         return false;
80
81                 return s->send_data(s->handle, buffer + 2, len + 21UL);
82         } else {
83                 // Otherwise send as plaintext
84                 return s->send_data(s->handle, buffer + 2, len + 5UL);
85         }
86 }
87 // Send a record (private version, accepts all record types, handles encryption and authentication).
88 static bool send_record_priv(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
89         if(s->datagram)
90                 return send_record_priv_datagram(s, type, data, len);
91
92         char buffer[len + 23UL];
93
94         // Create header with sequence number, length and record type
95         uint32_t seqno = htonl(s->outseqno++);
96         uint16_t netlen = htons(len);
97
98         memcpy(buffer, &seqno, 4);
99         memcpy(buffer + 4, &netlen, 2);
100         buffer[6] = type;
101
102         // Add plaintext (TODO: avoid unnecessary copy)
103         memcpy(buffer + 7, data, len);
104
105         if(s->outstate) {
106                 // If first handshake has finished, encrypt and HMAC
107                 if(!cipher_counter_xor(&s->outcipher, buffer + 4, len + 3UL, buffer + 4))
108                         return false;
109
110                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
111                         return false;
112
113                 return s->send_data(s->handle, buffer + 4, len + 19UL);
114         } else {
115                 // Otherwise send as plaintext
116                 return s->send_data(s->handle, buffer + 4, len + 3UL);
117         }
118 }
119
120 // Send an application record.
121 bool sptps_send_record(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
122         // Sanity checks: application cannot send data before handshake is finished,
123         // and only record types 0..127 are allowed.
124         if(!s->outstate)
125                 return error(s, EINVAL, "Handshake phase not finished yet");
126
127         if(type >= SPTPS_HANDSHAKE)
128                 return error(s, EINVAL, "Invalid application record type");
129
130         return send_record_priv(s, type, data, len);
131 }
132
133 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
134 static bool send_kex(sptps_t *s) {
135         size_t keylen = ECDH_SIZE;
136
137         // Make room for our KEX message, which we will keep around since send_sig() needs it.
138         if(s->mykex)
139                 abort();
140         s->mykex = realloc(s->mykex, 1 + 32 + keylen);
141         if(!s->mykex)
142                 return error(s, errno, strerror(errno));
143
144         // Set version byte to zero.
145         s->mykex[0] = SPTPS_VERSION;
146
147         // Create a random nonce.
148         randomize(s->mykex + 1, 32);
149
150         // Create a new ECDH public key.
151         if(!ecdh_generate_public(&s->ecdh, s->mykex + 1 + 32))
152                 return false;
153
154         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, 1 + 32 + keylen);
155 }
156
157 // Send a SIGnature record, containing an ECDSA signature over both KEX records.
158 static bool send_sig(sptps_t *s) {
159         size_t keylen = ECDH_SIZE;
160         size_t siglen = ecdsa_size(&s->mykey);
161
162         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
163         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
164         char sig[siglen];
165
166         msg[0] = s->initiator;
167         memcpy(msg + 1, s->mykex, 1 + 32 + keylen);
168         memcpy(msg + 1 + 33 + keylen, s->hiskex, 1 + 32 + keylen);
169         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
170
171         // Sign the result.
172         if(!ecdsa_sign(&s->mykey, msg, sizeof msg, sig))
173                 return false;
174
175         // Send the SIG exchange record.
176         return send_record_priv(s, SPTPS_HANDSHAKE, sig, sizeof sig);
177 }
178
179 // Generate key material from the shared secret created from the ECDHE key exchange.
180 static bool generate_key_material(sptps_t *s, const char *shared, size_t len) {
181         // Initialise cipher and digest structures if necessary
182         if(!s->outstate) {
183                 bool result
184                         =  cipher_open_by_name(&s->incipher, "aes-256-ecb")
185                         && cipher_open_by_name(&s->outcipher, "aes-256-ecb")
186                         && digest_open_by_name(&s->indigest, "sha256", 16)
187                         && digest_open_by_name(&s->outdigest, "sha256", 16);
188                 if(!result)
189                         return false;
190         }
191
192         // Allocate memory for key material
193         size_t keylen = digest_keylength(&s->indigest) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher) + cipher_keylength(&s->outcipher);
194
195         s->key = realloc(s->key, keylen);
196         if(!s->key)
197                 return error(s, errno, strerror(errno));
198
199         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
200         char seed[s->labellen + 64 + 13];
201         strcpy(seed, "key expansion");
202         if(s->initiator) {
203                 memcpy(seed + 13, s->mykex + 1, 32);
204                 memcpy(seed + 45, s->hiskex + 1, 32);
205         } else {
206                 memcpy(seed + 13, s->hiskex + 1, 32);
207                 memcpy(seed + 45, s->mykex + 1, 32);
208         }
209         memcpy(seed + 78, s->label, s->labellen);
210
211         // Use PRF to generate the key material
212         if(!prf(shared, len, seed, s->labellen + 64 + 13, s->key, keylen))
213                 return false;
214
215         return true;
216 }
217
218 // Send an ACKnowledgement record.
219 static bool send_ack(sptps_t *s) {
220         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
221 }
222
223 // Receive an ACKnowledgement record.
224 static bool receive_ack(sptps_t *s, const char *data, uint16_t len) {
225         if(len)
226                 return error(s, EIO, "Invalid ACK record length");
227
228         if(s->initiator) {
229                 bool result
230                         = cipher_set_counter_key(&s->incipher, s->key)
231                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
232                 if(!result)
233                         return false;
234         } else {
235                 bool result
236                         = cipher_set_counter_key(&s->incipher, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest))
237                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
238                 if(!result)
239                         return false;
240         }
241
242         free(s->key);
243         s->key = NULL;
244         s->instate = true;
245
246         return true;
247 }
248
249 // Receive a Key EXchange record, respond by sending a SIG record.
250 static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
251         // Verify length of the HELLO record
252         if(len != 1 + 32 + ECDH_SIZE)
253                 return error(s, EIO, "Invalid KEX record length");
254
255         // Ignore version number for now.
256
257         // Make a copy of the KEX message, send_sig() and receive_sig() need it
258         if(s->hiskex)
259                 abort();
260         s->hiskex = realloc(s->hiskex, len);
261         if(!s->hiskex)
262                 return error(s, errno, strerror(errno));
263
264         memcpy(s->hiskex, data, len);
265
266         return send_sig(s);
267 }
268
269 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
270 static bool receive_sig(sptps_t *s, const char *data, uint16_t len) {
271         size_t keylen = ECDH_SIZE;
272         size_t siglen = ecdsa_size(&s->hiskey);
273
274         // Verify length of KEX record.
275         if(len != siglen)
276                 return error(s, EIO, "Invalid KEX record length");
277
278         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
279         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
280
281         msg[0] = !s->initiator;
282         memcpy(msg + 1, s->hiskex, 1 + 32 + keylen);
283         memcpy(msg + 1 + 33 + keylen, s->mykex, 1 + 32 + keylen);
284         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
285
286         // Verify signature.
287         if(!ecdsa_verify(&s->hiskey, msg, sizeof msg, data))
288                 return false;
289
290         // Compute shared secret.
291         char shared[ECDH_SHARED_SIZE];
292         if(!ecdh_compute_shared(&s->ecdh, s->hiskex + 1 + 32, shared))
293                 return false;
294
295         // Generate key material from shared secret.
296         if(!generate_key_material(s, shared, sizeof shared))
297                 return false;
298
299         free(s->mykex);
300         free(s->hiskex);
301
302         s->mykex = NULL;
303         s->hiskex = NULL;
304
305         // Send cipher change record
306         if(s->outstate && !send_ack(s))
307                 return false;
308
309         // TODO: only set new keys after ACK has been set/received
310         if(s->initiator) {
311                 bool result
312                         = cipher_set_counter_key(&s->outcipher, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest))
313                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest) + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
314                 if(!result)
315                         return false;
316         } else {
317                 bool result
318                         =  cipher_set_counter_key(&s->outcipher, s->key)
319                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
320                 if(!result)
321                         return false;
322         }
323
324         return true;
325 }
326
327 // Force another Key EXchange (for testing purposes).
328 bool sptps_force_kex(sptps_t *s) {
329         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX)
330                 return error(s, EINVAL, "Cannot force KEX in current state");
331
332         s->state = SPTPS_KEX;
333         return send_kex(s);
334 }
335
336 // Receive a handshake record.
337 static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
338         // Only a few states to deal with handshaking.
339         fprintf(stderr, "Received handshake message, current state %d\n", s->state);
340         switch(s->state) {
341                 case SPTPS_SECONDARY_KEX:
342                         // We receive a secondary KEX request, first respond by sending our own.
343                         if(!send_kex(s))
344                                 return false;
345                 case SPTPS_KEX:
346                         // We have sent our KEX request, we expect our peer to sent one as well.
347                         if(!receive_kex(s, data, len))
348                                 return false;
349                         s->state = SPTPS_SIG;
350                         return true;
351                 case SPTPS_SIG:
352                         // If we already sent our secondary public ECDH key, we expect the peer to send his.
353                         if(!receive_sig(s, data, len))
354                                 return false;
355                         if(s->outstate)
356                                 s->state = SPTPS_ACK;
357                         else {
358                                 s->outstate = true;
359                                 if(!receive_ack(s, NULL, 0))
360                                         return false;
361                                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
362                                 s->state = SPTPS_SECONDARY_KEX;
363                         }
364
365                         return true;
366                 case SPTPS_ACK:
367                         // We expect a handshake message to indicate transition to the new keys.
368                         if(!receive_ack(s, data, len))
369                                 return false;
370                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
371                         s->state = SPTPS_SECONDARY_KEX;
372                         return true;
373                 // TODO: split ACK into a VERify and ACK?
374                 default:
375                         return error(s, EIO, "Invalid session state");
376         }
377 }
378
379 // Receive incoming data, datagram version.
380 static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) {
381         if(len < (s->instate ? 21 : 5))
382                 return error(s, EIO, "Received short packet");
383
384         uint32_t seqno;
385         memcpy(&seqno, data, 4);
386         seqno = ntohl(seqno);
387
388         if(!s->instate) {
389                 if(seqno != s->inseqno) {
390                         fprintf(stderr, "Received invalid packet seqno: %d != %d\n", seqno, s->inseqno);
391                         return error(s, EIO, "Invalid packet seqno");
392                 }
393
394                 s->inseqno = seqno + 1;
395
396                 uint8_t type = data[4];
397
398                 if(type != SPTPS_HANDSHAKE)
399                         return error(s, EIO, "Application record received before handshake finished");
400
401                 return receive_handshake(s, data + 5, len - 5);
402         }
403
404         if(seqno < s->inseqno) {
405                 fprintf(stderr, "Received late or replayed packet: %d < %d\n", seqno, s->inseqno);
406                 return true;
407         }
408
409         if(seqno > s->inseqno)
410                 fprintf(stderr, "Missed %d packets\n", seqno - s->inseqno);
411
412         s->inseqno = seqno + 1;
413
414         uint16_t netlen = htons(len - 21);
415
416         char buffer[len + 23];
417
418         memcpy(buffer, &netlen, 2);
419         memcpy(buffer + 2, data, len);
420
421         memcpy(&seqno, buffer + 2, 4);
422
423         // Check HMAC and decrypt.
424         if(!digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14))
425                 return error(s, EIO, "Invalid HMAC");
426
427         cipher_set_counter(&s->incipher, &seqno, sizeof seqno);
428         if(!cipher_counter_xor(&s->incipher, buffer + 6, len - 4, buffer + 6))
429                 return false;
430
431         // Append a NULL byte for safety.
432         buffer[len - 14] = 0;
433
434         uint8_t type = buffer[6];
435
436         if(type < SPTPS_HANDSHAKE) {
437                 if(!s->instate)
438                         return error(s, EIO, "Application record received before handshake finished");
439                 if(!s->receive_record(s->handle, type, buffer + 7, len - 21))
440                         return false;
441         } else {
442                 return error(s, EIO, "Invalid record type");
443         }
444
445         return true;
446 }
447 // Receive incoming data. Check if it contains a complete record, if so, handle it.
448 bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
449         if(s->datagram)
450                 return sptps_receive_data_datagram(s, data, len);
451
452         while(len) {
453                 // First read the 2 length bytes.
454                 if(s->buflen < 6) {
455                         size_t toread = 6 - s->buflen;
456                         if(toread > len)
457                                 toread = len;
458
459                         memcpy(s->inbuf + s->buflen, data, toread);
460
461                         s->buflen += toread;
462                         len -= toread;
463                         data += toread;
464                 
465                         // Exit early if we don't have the full length.
466                         if(s->buflen < 6)
467                                 return true;
468
469                         // Decrypt the length bytes
470
471                         if(s->instate) {
472                                 if(!cipher_counter_xor(&s->incipher, s->inbuf + 4, 2, &s->reclen))
473                                         return false;
474                         } else {
475                                 memcpy(&s->reclen, s->inbuf + 4, 2);
476                         }
477
478                         s->reclen = ntohs(s->reclen);
479
480                         // If we have the length bytes, ensure our buffer can hold the whole request.
481                         s->inbuf = realloc(s->inbuf, s->reclen + 23UL);
482                         if(!s->inbuf)
483                                 return error(s, errno, strerror(errno));
484
485                         // Add sequence number.
486                         uint32_t seqno = htonl(s->inseqno++);
487                         memcpy(s->inbuf, &seqno, 4);
488
489                         // Exit early if we have no more data to process.
490                         if(!len)
491                                 return true;
492                 }
493
494                 // Read up to the end of the record.
495                 size_t toread = s->reclen + (s->instate ? 23UL : 7UL) - s->buflen;
496                 if(toread > len)
497                         toread = len;
498
499                 memcpy(s->inbuf + s->buflen, data, toread);
500                 s->buflen += toread;
501                 len -= toread;
502                 data += toread;
503
504                 // If we don't have a whole record, exit.
505                 if(s->buflen < s->reclen + (s->instate ? 23UL : 7UL))
506                         return true;
507
508                 // Check HMAC and decrypt.
509                 if(s->instate) {
510                         if(!digest_verify(&s->indigest, s->inbuf, s->reclen + 7UL, s->inbuf + s->reclen + 7UL))
511                                 return error(s, EIO, "Invalid HMAC");
512
513                         if(!cipher_counter_xor(&s->incipher, s->inbuf + 6UL, s->reclen + 1UL, s->inbuf + 6UL))
514                                 return false;
515                 }
516
517                 // Append a NULL byte for safety.
518                 s->inbuf[s->reclen + 7UL] = 0;
519
520                 uint8_t type = s->inbuf[6];
521
522                 if(type < SPTPS_HANDSHAKE) {
523                         if(!s->instate)
524                                 return error(s, EIO, "Application record received before handshake finished");
525                         if(!s->receive_record(s->handle, type, s->inbuf + 7, s->reclen))
526                                 return false;
527                 } else if(type == SPTPS_HANDSHAKE) {
528                         if(!receive_handshake(s, s->inbuf + 7, s->reclen))
529                                 return false;
530                 } else {
531                         return error(s, EIO, "Invalid record type");
532                 }
533
534                 s->buflen = 4;
535         }
536
537         return true;
538 }
539
540 // Start a SPTPS session.
541 bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
542         // Initialise struct sptps
543         memset(s, 0, sizeof *s);
544
545         s->handle = handle;
546         s->initiator = initiator;
547         s->datagram = datagram;
548         s->mykey = mykey;
549         s->hiskey = hiskey;
550
551         s->label = malloc(labellen);
552         if(!s->label)
553                 return error(s, errno, strerror(errno));
554
555         if(!datagram) {
556                 s->inbuf = malloc(7);
557                 if(!s->inbuf)
558                         return error(s, errno, strerror(errno));
559                 s->buflen = 4;
560                 memset(s->inbuf, 0, 4);
561         }
562
563         memcpy(s->label, label, labellen);
564         s->labellen = labellen;
565
566         s->send_data = send_data;
567         s->receive_record = receive_record;
568
569         // Do first KEX immediately
570         s->state = SPTPS_KEX;
571         return send_kex(s);
572 }
573
574 // Stop a SPTPS session.
575 bool sptps_stop(sptps_t *s) {
576         // Clean up any resources.
577         ecdh_free(&s->ecdh);
578         free(s->inbuf);
579         s->inbuf = NULL;
580         free(s->mykex);
581         s->mykex = NULL;
582         free(s->hiskex);
583         s->hiskex = NULL;
584         free(s->key);
585         s->key = NULL;
586         free(s->label);
587         s->label = NULL;
588         return true;
589 }