Don't send an ACK message after the first key exchange in the SPTPS protocol.
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011 Guus Sliepen <guus@tinc-vpn.org>,
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #include "cipher.h"
23 #include "crypto.h"
24 #include "digest.h"
25 #include "ecdh.h"
26 #include "ecdsa.h"
27 #include "prf.h"
28 #include "sptps.h"
29
30 /*
31    Nonce MUST be exchanged first (done)
32    Signatures MUST be done over both nonces, to guarantee the signature is fresh
33    Otherwise: if ECDHE key of one side is compromised, it can be reused!
34
35    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
36
37    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
38
39    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of ECDSA over the whole ECDHE exchange?)
40
41    Explicit close message needs to be added.
42
43    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
44
45    Use counter mode instead of OFB. (done)
46
47    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
48 */
49
50 // Log an error message.
51 static bool error(sptps_t *s, int s_errno, const char *msg) {
52         fprintf(stderr, "SPTPS error: %s\n", msg);
53         errno = s_errno;
54         return false;
55 }
56
57 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
58 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
59         char buffer[len + 23UL];
60
61         // Create header with sequence number, length and record type
62         uint32_t seqno = htonl(s->outseqno++);
63         uint16_t netlen = htons(len);
64
65         memcpy(buffer, &netlen, 2);
66         memcpy(buffer + 2, &seqno, 4);
67         buffer[6] = type;
68
69         // Add plaintext (TODO: avoid unnecessary copy)
70         memcpy(buffer + 7, data, len);
71
72         if(s->outstate) {
73                 // If first handshake has finished, encrypt and HMAC
74                 cipher_set_counter(&s->outcipher, &seqno, sizeof seqno);
75                 if(!cipher_counter_xor(&s->outcipher, buffer + 6, len + 1UL, buffer + 6))
76                         return false;
77
78                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
79                         return false;
80
81                 return s->send_data(s->handle, buffer + 2, len + 21UL);
82         } else {
83                 // Otherwise send as plaintext
84                 return s->send_data(s->handle, buffer + 2, len + 5UL);
85         }
86 }
87 // Send a record (private version, accepts all record types, handles encryption and authentication).
88 static bool send_record_priv(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
89         if(s->datagram)
90                 return send_record_priv_datagram(s, type, data, len);
91
92         char buffer[len + 23UL];
93
94         // Create header with sequence number, length and record type
95         uint32_t seqno = htonl(s->outseqno++);
96         uint16_t netlen = htons(len);
97
98         memcpy(buffer, &seqno, 4);
99         memcpy(buffer + 4, &netlen, 2);
100         buffer[6] = type;
101
102         // Add plaintext (TODO: avoid unnecessary copy)
103         memcpy(buffer + 7, data, len);
104
105         if(s->outstate) {
106                 // If first handshake has finished, encrypt and HMAC
107                 if(!cipher_counter_xor(&s->outcipher, buffer + 4, len + 3UL, buffer + 4))
108                         return false;
109
110                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
111                         return false;
112
113                 return s->send_data(s->handle, buffer + 4, len + 19UL);
114         } else {
115                 // Otherwise send as plaintext
116                 return s->send_data(s->handle, buffer + 4, len + 3UL);
117         }
118 }
119
120 // Send an application record.
121 bool sptps_send_record(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
122         // Sanity checks: application cannot send data before handshake is finished,
123         // and only record types 0..127 are allowed.
124         if(!s->outstate)
125                 return error(s, EINVAL, "Handshake phase not finished yet");
126
127         if(type >= SPTPS_HANDSHAKE)
128                 return error(s, EINVAL, "Invalid application record type");
129
130         return send_record_priv(s, type, data, len);
131 }
132
133 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
134 static bool send_kex(sptps_t *s) {
135         size_t keylen = ECDH_SIZE;
136
137         // Make room for our KEX message, which we will keep around since send_sig() needs it.
138         if(s->mykex)
139                 abort();
140         s->mykex = realloc(s->mykex, 1 + 32 + keylen);
141         if(!s->mykex)
142                 return error(s, errno, strerror(errno));
143
144         // Set version byte to zero.
145         s->mykex[0] = SPTPS_VERSION;
146
147         // Create a random nonce.
148         randomize(s->mykex + 1, 32);
149
150         // Create a new ECDH public key.
151         if(!ecdh_generate_public(&s->ecdh, s->mykex + 1 + 32))
152                 return false;
153
154         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, 1 + 32 + keylen);
155 }
156
157 // Send a SIGnature record, containing an ECDSA signature over both KEX records.
158 static bool send_sig(sptps_t *s) {
159         size_t keylen = ECDH_SIZE;
160         size_t siglen = ecdsa_size(&s->mykey);
161
162         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
163         char msg[(1 + 32 + keylen) * 2 + 1];
164         char sig[siglen];
165
166         msg[0] = s->initiator;
167         memcpy(msg + 1, s->mykex, 1 + 32 + keylen);
168         memcpy(msg + 2 + 32 + keylen, s->hiskex, 1 + 32 + keylen);
169
170         // Sign the result.
171         if(!ecdsa_sign(&s->mykey, msg, sizeof msg, sig))
172                 return false;
173
174         // Send the SIG exchange record.
175         return send_record_priv(s, SPTPS_HANDSHAKE, sig, sizeof sig);
176 }
177
178 // Generate key material from the shared secret created from the ECDHE key exchange.
179 static bool generate_key_material(sptps_t *s, const char *shared, size_t len) {
180         // Initialise cipher and digest structures if necessary
181         if(!s->outstate) {
182                 bool result
183                         =  cipher_open_by_name(&s->incipher, "aes-256-ecb")
184                         && cipher_open_by_name(&s->outcipher, "aes-256-ecb")
185                         && digest_open_by_name(&s->indigest, "sha256", 16)
186                         && digest_open_by_name(&s->outdigest, "sha256", 16);
187                 if(!result)
188                         return false;
189         }
190
191         // Allocate memory for key material
192         size_t keylen = digest_keylength(&s->indigest) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher) + cipher_keylength(&s->outcipher);
193
194         s->key = realloc(s->key, keylen);
195         if(!s->key)
196                 return error(s, errno, strerror(errno));
197
198         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
199         char seed[s->labellen + 64 + 13];
200         strcpy(seed, "key expansion");
201         if(s->initiator) {
202                 memcpy(seed + 13, s->mykex + 1, 32);
203                 memcpy(seed + 45, s->hiskex + 1, 32);
204         } else {
205                 memcpy(seed + 13, s->hiskex + 1, 32);
206                 memcpy(seed + 45, s->mykex + 1, 32);
207         }
208         memcpy(seed + 78, s->label, s->labellen);
209
210         // Use PRF to generate the key material
211         if(!prf(shared, len, seed, s->labellen + 64 + 13, s->key, keylen))
212                 return false;
213
214         return true;
215 }
216
217 // Send an ACKnowledgement record.
218 static bool send_ack(sptps_t *s) {
219         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
220 }
221
222 // Receive an ACKnowledgement record.
223 static bool receive_ack(sptps_t *s, const char *data, uint16_t len) {
224         if(len)
225                 return error(s, EIO, "Invalid ACK record length");
226
227         if(s->initiator) {
228                 bool result
229                         = cipher_set_counter_key(&s->incipher, s->key)
230                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
231                 if(!result)
232                         return false;
233         } else {
234                 bool result
235                         = cipher_set_counter_key(&s->incipher, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest))
236                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
237                 if(!result)
238                         return false;
239         }
240
241         free(s->key);
242         s->key = NULL;
243         s->instate = true;
244
245         return true;
246 }
247
248 // Receive a Key EXchange record, respond by sending a SIG record.
249 static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
250         // Verify length of the HELLO record
251         if(len != 1 + 32 + ECDH_SIZE)
252                 return error(s, EIO, "Invalid KEX record length");
253
254         // Ignore version number for now.
255
256         // Make a copy of the KEX message, send_sig() and receive_sig() need it
257         if(s->hiskex)
258                 abort();
259         s->hiskex = realloc(s->hiskex, len);
260         if(!s->hiskex)
261                 return error(s, errno, strerror(errno));
262
263         memcpy(s->hiskex, data, len);
264
265         return send_sig(s);
266 }
267
268 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
269 static bool receive_sig(sptps_t *s, const char *data, uint16_t len) {
270         size_t keylen = ECDH_SIZE;
271         size_t siglen = ecdsa_size(&s->hiskey);
272
273         // Verify length of KEX record.
274         if(len != siglen)
275                 return error(s, EIO, "Invalid KEX record length");
276
277         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
278         char msg[(1 + 32 + keylen) * 2 + 1];
279
280         msg[0] = !s->initiator;
281         memcpy(msg + 1, s->hiskex, 1 + 32 + keylen);
282         memcpy(msg + 2 + 32 + keylen, s->mykex, 1 + 32 + keylen);
283
284         // Verify signature.
285         if(!ecdsa_verify(&s->hiskey, msg, sizeof msg, data))
286                 return false;
287
288         // Compute shared secret.
289         char shared[ECDH_SHARED_SIZE];
290         if(!ecdh_compute_shared(&s->ecdh, s->hiskex + 1 + 32, shared))
291                 return false;
292
293         // Generate key material from shared secret.
294         if(!generate_key_material(s, shared, sizeof shared))
295                 return false;
296
297         free(s->mykex);
298         free(s->hiskex);
299
300         s->mykex = NULL;
301         s->hiskex = NULL;
302
303         // Send cipher change record
304         if(s->outstate && !send_ack(s))
305                 return false;
306
307         // TODO: only set new keys after ACK has been set/received
308         if(s->initiator) {
309                 bool result
310                         = cipher_set_counter_key(&s->outcipher, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest))
311                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest) + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
312                 if(!result)
313                         return false;
314         } else {
315                 bool result
316                         =  cipher_set_counter_key(&s->outcipher, s->key)
317                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
318                 if(!result)
319                         return false;
320         }
321
322         return true;
323 }
324
325 // Force another Key EXchange (for testing purposes).
326 bool sptps_force_kex(sptps_t *s) {
327         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX)
328                 return error(s, EINVAL, "Cannot force KEX in current state");
329
330         s->state = SPTPS_KEX;
331         return send_kex(s);
332 }
333
334 // Receive a handshake record.
335 static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
336         // Only a few states to deal with handshaking.
337         fprintf(stderr, "Received handshake message, current state %d\n", s->state);
338         switch(s->state) {
339                 case SPTPS_SECONDARY_KEX:
340                         // We receive a secondary KEX request, first respond by sending our own.
341                         if(!send_kex(s))
342                                 return false;
343                 case SPTPS_KEX:
344                         // We have sent our KEX request, we expect our peer to sent one as well.
345                         if(!receive_kex(s, data, len))
346                                 return false;
347                         s->state = SPTPS_SIG;
348                         return true;
349                 case SPTPS_SIG:
350                         // If we already sent our secondary public ECDH key, we expect the peer to send his.
351                         if(!receive_sig(s, data, len))
352                                 return false;
353                         if(s->outstate)
354                                 s->state = SPTPS_ACK;
355                         else {
356                                 s->outstate = true;
357                                 if(!receive_ack(s, NULL, 0))
358                                         return false;
359                                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
360                                 s->state = SPTPS_SECONDARY_KEX;
361                         }
362
363                         return true;
364                 case SPTPS_ACK:
365                         // We expect a handshake message to indicate transition to the new keys.
366                         if(!receive_ack(s, data, len))
367                                 return false;
368                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
369                         s->state = SPTPS_SECONDARY_KEX;
370                         return true;
371                 // TODO: split ACK into a VERify and ACK?
372                 default:
373                         return error(s, EIO, "Invalid session state");
374         }
375 }
376
377 // Receive incoming data, datagram version.
378 static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) {
379         if(len < (s->instate ? 21 : 5))
380                 return error(s, EIO, "Received short packet");
381
382         uint32_t seqno;
383         memcpy(&seqno, data, 4);
384         seqno = ntohl(seqno);
385
386         if(!s->instate) {
387                 if(seqno != s->inseqno) {
388                         fprintf(stderr, "Received invalid packet seqno: %d != %d\n", seqno, s->inseqno);
389                         return error(s, EIO, "Invalid packet seqno");
390                 }
391
392                 s->inseqno = seqno + 1;
393
394                 uint8_t type = data[4];
395
396                 if(type != SPTPS_HANDSHAKE)
397                         return error(s, EIO, "Application record received before handshake finished");
398
399                 return receive_handshake(s, data + 5, len - 5);
400         }
401
402         if(seqno < s->inseqno) {
403                 fprintf(stderr, "Received late or replayed packet: %d < %d\n", seqno, s->inseqno);
404                 return true;
405         }
406
407         if(seqno > s->inseqno)
408                 fprintf(stderr, "Missed %d packets\n", seqno - s->inseqno);
409
410         s->inseqno = seqno + 1;
411
412         uint16_t netlen = htons(len - 21);
413
414         char buffer[len + 23];
415
416         memcpy(buffer, &netlen, 2);
417         memcpy(buffer + 2, data, len);
418
419         memcpy(&seqno, buffer + 2, 4);
420
421         // Check HMAC and decrypt.
422         if(!digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14))
423                 return error(s, EIO, "Invalid HMAC");
424
425         cipher_set_counter(&s->incipher, &seqno, sizeof seqno);
426         if(!cipher_counter_xor(&s->incipher, buffer + 6, len - 4, buffer + 6))
427                 return false;
428
429         // Append a NULL byte for safety.
430         buffer[len - 14] = 0;
431
432         uint8_t type = buffer[6];
433
434         if(type < SPTPS_HANDSHAKE) {
435                 if(!s->instate)
436                         return error(s, EIO, "Application record received before handshake finished");
437                 if(!s->receive_record(s->handle, type, buffer + 7, len - 21))
438                         return false;
439         } else {
440                 return error(s, EIO, "Invalid record type");
441         }
442
443         return true;
444 }
445 // Receive incoming data. Check if it contains a complete record, if so, handle it.
446 bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
447         if(s->datagram)
448                 return sptps_receive_data_datagram(s, data, len);
449
450         while(len) {
451                 // First read the 2 length bytes.
452                 if(s->buflen < 6) {
453                         size_t toread = 6 - s->buflen;
454                         if(toread > len)
455                                 toread = len;
456
457                         memcpy(s->inbuf + s->buflen, data, toread);
458
459                         s->buflen += toread;
460                         len -= toread;
461                         data += toread;
462                 
463                         // Exit early if we don't have the full length.
464                         if(s->buflen < 6)
465                                 return true;
466
467                         // Decrypt the length bytes
468
469                         if(s->instate) {
470                                 if(!cipher_counter_xor(&s->incipher, s->inbuf + 4, 2, &s->reclen))
471                                         return false;
472                         } else {
473                                 memcpy(&s->reclen, s->inbuf + 4, 2);
474                         }
475
476                         s->reclen = ntohs(s->reclen);
477
478                         // If we have the length bytes, ensure our buffer can hold the whole request.
479                         s->inbuf = realloc(s->inbuf, s->reclen + 23UL);
480                         if(!s->inbuf)
481                                 return error(s, errno, strerror(errno));
482
483                         // Add sequence number.
484                         uint32_t seqno = htonl(s->inseqno++);
485                         memcpy(s->inbuf, &seqno, 4);
486
487                         // Exit early if we have no more data to process.
488                         if(!len)
489                                 return true;
490                 }
491
492                 // Read up to the end of the record.
493                 size_t toread = s->reclen + (s->instate ? 23UL : 7UL) - s->buflen;
494                 if(toread > len)
495                         toread = len;
496
497                 memcpy(s->inbuf + s->buflen, data, toread);
498                 s->buflen += toread;
499                 len -= toread;
500                 data += toread;
501
502                 // If we don't have a whole record, exit.
503                 if(s->buflen < s->reclen + (s->instate ? 23UL : 7UL))
504                         return true;
505
506                 // Check HMAC and decrypt.
507                 if(s->instate) {
508                         if(!digest_verify(&s->indigest, s->inbuf, s->reclen + 7UL, s->inbuf + s->reclen + 7UL))
509                                 return error(s, EIO, "Invalid HMAC");
510
511                         if(!cipher_counter_xor(&s->incipher, s->inbuf + 6UL, s->reclen + 1UL, s->inbuf + 6UL))
512                                 return false;
513                 }
514
515                 // Append a NULL byte for safety.
516                 s->inbuf[s->reclen + 7UL] = 0;
517
518                 uint8_t type = s->inbuf[6];
519
520                 if(type < SPTPS_HANDSHAKE) {
521                         if(!s->instate)
522                                 return error(s, EIO, "Application record received before handshake finished");
523                         if(!s->receive_record(s->handle, type, s->inbuf + 7, s->reclen))
524                                 return false;
525                 } else if(type == SPTPS_HANDSHAKE) {
526                         if(!receive_handshake(s, s->inbuf + 7, s->reclen))
527                                 return false;
528                 } else {
529                         return error(s, EIO, "Invalid record type");
530                 }
531
532                 s->buflen = 4;
533         }
534
535         return true;
536 }
537
538 // Start a SPTPS session.
539 bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
540         // Initialise struct sptps
541         memset(s, 0, sizeof *s);
542
543         s->handle = handle;
544         s->initiator = initiator;
545         s->datagram = datagram;
546         s->mykey = mykey;
547         s->hiskey = hiskey;
548
549         s->label = malloc(labellen);
550         if(!s->label)
551                 return error(s, errno, strerror(errno));
552
553         if(!datagram) {
554                 s->inbuf = malloc(7);
555                 if(!s->inbuf)
556                         return error(s, errno, strerror(errno));
557                 s->buflen = 4;
558                 memset(s->inbuf, 0, 4);
559         }
560
561         memcpy(s->label, label, labellen);
562         s->labellen = labellen;
563
564         s->send_data = send_data;
565         s->receive_record = receive_record;
566
567         // Do first KEX immediately
568         s->state = SPTPS_KEX;
569         return send_kex(s);
570 }
571
572 // Stop a SPTPS session.
573 bool sptps_stop(sptps_t *s) {
574         // Clean up any resources.
575         ecdh_free(&s->ecdh);
576         free(s->inbuf);
577         free(s->mykex);
578         free(s->hiskex);
579         free(s->key);
580         free(s->label);
581         return true;
582 }