Add datagram mode to the SPTPS protocol.
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011 Guus Sliepen <guus@tinc-vpn.org>,
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #include "cipher.h"
23 #include "crypto.h"
24 #include "digest.h"
25 #include "ecdh.h"
26 #include "ecdsa.h"
27 #include "prf.h"
28 #include "sptps.h"
29
30 /*
31    Nonce MUST be exchanged first (done)
32    Signatures MUST be done over both nonces, to guarantee the signature is fresh
33    Otherwise: if ECDHE key of one side is compromised, it can be reused!
34
35    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
36
37    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
38
39    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of ECDSA over the whole ECDHE exchange?)
40
41    Explicit close message needs to be added.
42
43    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
44
45    Use counter mode instead of OFB. (done)
46
47    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
48 */
49
50 // Log an error message.
51 static bool error(sptps_t *s, int s_errno, const char *msg) {
52         fprintf(stderr, "SPTPS error: %s\n", msg);
53         errno = s_errno;
54         return false;
55 }
56
57 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
58 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
59         char buffer[len + 23UL];
60
61         // Create header with sequence number, length and record type
62         uint32_t seqno = htonl(s->outseqno++);
63         uint16_t netlen = htons(len);
64
65         memcpy(buffer, &netlen, 2);
66         memcpy(buffer + 2, &seqno, 4);
67         buffer[6] = type;
68
69         // Add plaintext (TODO: avoid unnecessary copy)
70         memcpy(buffer + 7, data, len);
71
72         if(s->outstate) {
73                 // If first handshake has finished, encrypt and HMAC
74                 cipher_set_counter(&s->outcipher, &seqno, sizeof seqno);
75                 if(!cipher_counter_xor(&s->outcipher, buffer + 6, len + 1UL, buffer + 6))
76                         return false;
77
78                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
79                         return false;
80
81                 return s->send_data(s->handle, buffer + 2, len + 21UL);
82         } else {
83                 // Otherwise send as plaintext
84                 return s->send_data(s->handle, buffer + 2, len + 5UL);
85         }
86 }
87 // Send a record (private version, accepts all record types, handles encryption and authentication).
88 static bool send_record_priv(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
89         if(s->datagram)
90                 return send_record_priv_datagram(s, type, data, len);
91
92         char buffer[len + 23UL];
93
94         // Create header with sequence number, length and record type
95         uint32_t seqno = htonl(s->outseqno++);
96         uint16_t netlen = htons(len);
97
98         memcpy(buffer, &seqno, 4);
99         memcpy(buffer + 4, &netlen, 2);
100         buffer[6] = type;
101
102         // Add plaintext (TODO: avoid unnecessary copy)
103         memcpy(buffer + 7, data, len);
104
105         if(s->outstate) {
106                 // If first handshake has finished, encrypt and HMAC
107                 if(!cipher_counter_xor(&s->outcipher, buffer + 4, len + 3UL, buffer + 4))
108                         return false;
109
110                 if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
111                         return false;
112
113                 return s->send_data(s->handle, buffer + 4, len + 19UL);
114         } else {
115                 // Otherwise send as plaintext
116                 return s->send_data(s->handle, buffer + 4, len + 3UL);
117         }
118 }
119
120 // Send an application record.
121 bool sptps_send_record(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
122         // Sanity checks: application cannot send data before handshake is finished,
123         // and only record types 0..127 are allowed.
124         if(!s->outstate)
125                 return error(s, EINVAL, "Handshake phase not finished yet");
126
127         if(type >= SPTPS_HANDSHAKE)
128                 return error(s, EINVAL, "Invalid application record type");
129
130         return send_record_priv(s, type, data, len);
131 }
132
133 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
134 static bool send_kex(sptps_t *s) {
135         size_t keylen = ECDH_SIZE;
136
137         // Make room for our KEX message, which we will keep around since send_sig() needs it.
138         if(s->mykex)
139                 abort();
140         s->mykex = realloc(s->mykex, 1 + 32 + keylen);
141         if(!s->mykex)
142                 return error(s, errno, strerror(errno));
143
144         // Set version byte to zero.
145         s->mykex[0] = SPTPS_VERSION;
146
147         // Create a random nonce.
148         randomize(s->mykex + 1, 32);
149
150         // Create a new ECDH public key.
151         if(!ecdh_generate_public(&s->ecdh, s->mykex + 1 + 32))
152                 return false;
153
154         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, 1 + 32 + keylen);
155 }
156
157 // Send a SIGnature record, containing an ECDSA signature over both KEX records.
158 static bool send_sig(sptps_t *s) {
159         size_t keylen = ECDH_SIZE;
160         size_t siglen = ecdsa_size(&s->mykey);
161
162         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
163         char msg[(1 + 32 + keylen) * 2 + 1];
164         char sig[siglen];
165
166         msg[0] = s->initiator;
167         memcpy(msg + 1, s->mykex, 1 + 32 + keylen);
168         memcpy(msg + 2 + 32 + keylen, s->hiskex, 1 + 32 + keylen);
169
170         // Sign the result.
171         if(!ecdsa_sign(&s->mykey, msg, sizeof msg, sig))
172                 return false;
173
174         // Send the SIG exchange record.
175         return send_record_priv(s, SPTPS_HANDSHAKE, sig, sizeof sig);
176 }
177
178 // Generate key material from the shared secret created from the ECDHE key exchange.
179 static bool generate_key_material(sptps_t *s, const char *shared, size_t len) {
180         // Initialise cipher and digest structures if necessary
181         if(!s->outstate) {
182                 bool result
183                         =  cipher_open_by_name(&s->incipher, "aes-256-ecb")
184                         && cipher_open_by_name(&s->outcipher, "aes-256-ecb")
185                         && digest_open_by_name(&s->indigest, "sha256", 16)
186                         && digest_open_by_name(&s->outdigest, "sha256", 16);
187                 if(!result)
188                         return false;
189         }
190
191         // Allocate memory for key material
192         size_t keylen = digest_keylength(&s->indigest) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher) + cipher_keylength(&s->outcipher);
193
194         s->key = realloc(s->key, keylen);
195         if(!s->key)
196                 return error(s, errno, strerror(errno));
197
198         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
199         char seed[s->labellen + 64 + 13];
200         strcpy(seed, "key expansion");
201         if(s->initiator) {
202                 memcpy(seed + 13, s->mykex + 1, 32);
203                 memcpy(seed + 45, s->hiskex + 1, 32);
204         } else {
205                 memcpy(seed + 13, s->hiskex + 1, 32);
206                 memcpy(seed + 45, s->mykex + 1, 32);
207         }
208         memcpy(seed + 78, s->label, s->labellen);
209
210         // Use PRF to generate the key material
211         if(!prf(shared, len, seed, s->labellen + 64 + 13, s->key, keylen))
212                 return false;
213
214         return true;
215 }
216
217 // Send an ACKnowledgement record.
218 static bool send_ack(sptps_t *s) {
219         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
220 }
221
222 // Receive an ACKnowledgement record.
223 static bool receive_ack(sptps_t *s, const char *data, uint16_t len) {
224         if(len)
225                 return error(s, EIO, "Invalid ACK record length");
226
227         if(s->initiator) {
228                 bool result
229                         = cipher_set_counter_key(&s->incipher, s->key)
230                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
231                 if(!result)
232                         return false;
233         } else {
234                 bool result
235                         = cipher_set_counter_key(&s->incipher, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest))
236                         && digest_set_key(&s->indigest, s->key + cipher_keylength(&s->outcipher) + digest_keylength(&s->outdigest) + cipher_keylength(&s->incipher), digest_keylength(&s->indigest));
237                 if(!result)
238                         return false;
239         }
240
241         free(s->key);
242         s->key = NULL;
243         s->instate = true;
244
245         return true;
246 }
247
248 // Receive a Key EXchange record, respond by sending a SIG record.
249 static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
250         // Verify length of the HELLO record
251         if(len != 1 + 32 + ECDH_SIZE)
252                 return error(s, EIO, "Invalid KEX record length");
253
254         // Ignore version number for now.
255
256         // Make a copy of the KEX message, send_sig() and receive_sig() need it
257         if(s->hiskex)
258                 abort();
259         s->hiskex = realloc(s->hiskex, len);
260         if(!s->hiskex)
261                 return error(s, errno, strerror(errno));
262
263         memcpy(s->hiskex, data, len);
264
265         return send_sig(s);
266 }
267
268 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
269 static bool receive_sig(sptps_t *s, const char *data, uint16_t len) {
270         size_t keylen = ECDH_SIZE;
271         size_t siglen = ecdsa_size(&s->hiskey);
272
273         // Verify length of KEX record.
274         if(len != siglen)
275                 return error(s, EIO, "Invalid KEX record length");
276
277         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
278         char msg[(1 + 32 + keylen) * 2 + 1];
279
280         msg[0] = !s->initiator;
281         memcpy(msg + 1, s->hiskex, 1 + 32 + keylen);
282         memcpy(msg + 2 + 32 + keylen, s->mykex, 1 + 32 + keylen);
283
284         // Verify signature.
285         if(!ecdsa_verify(&s->hiskey, msg, sizeof msg, data))
286                 return false;
287
288         // Compute shared secret.
289         char shared[ECDH_SHARED_SIZE];
290         if(!ecdh_compute_shared(&s->ecdh, s->hiskex + 1 + 32, shared))
291                 return false;
292
293         // Generate key material from shared secret.
294         if(!generate_key_material(s, shared, sizeof shared))
295                 return false;
296
297         free(s->mykex);
298         free(s->hiskex);
299
300         s->mykex = NULL;
301         s->hiskex = NULL;
302
303         // Send cipher change record
304         if(!send_ack(s))
305                 return false;
306
307         // TODO: only set new keys after ACK has been set/received
308         if(s->initiator) {
309                 bool result
310                         = cipher_set_counter_key(&s->outcipher, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest))
311                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->incipher) + digest_keylength(&s->indigest) + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
312                 if(!result)
313                         return false;
314         } else {
315                 bool result
316                         =  cipher_set_counter_key(&s->outcipher, s->key)
317                         && digest_set_key(&s->outdigest, s->key + cipher_keylength(&s->outcipher), digest_keylength(&s->outdigest));
318                 if(!result)
319                         return false;
320         }
321
322         s->outstate = true;
323
324         return true;
325 }
326
327 // Force another Key EXchange (for testing purposes).
328 bool sptps_force_kex(sptps_t *s) {
329         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX)
330                 return error(s, EINVAL, "Cannot force KEX in current state");
331
332         s->state = SPTPS_KEX;
333         return send_kex(s);
334 }
335
336 // Receive a handshake record.
337 static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
338         // Only a few states to deal with handshaking.
339         fprintf(stderr, "Received handshake message, current state %d\n", s->state);
340         switch(s->state) {
341                 case SPTPS_SECONDARY_KEX:
342                         // We receive a secondary KEX request, first respond by sending our own.
343                         if(!send_kex(s))
344                                 return false;
345                 case SPTPS_KEX:
346                         // We have sent our KEX request, we expect our peer to sent one as well.
347                         if(!receive_kex(s, data, len))
348                                 return false;
349                         s->state = SPTPS_SIG;
350                         return true;
351                 case SPTPS_SIG:
352                         // If we already sent our secondary public ECDH key, we expect the peer to send his.
353                         if(!receive_sig(s, data, len))
354                                 return false;
355                         s->state = SPTPS_ACK;
356                         return true;
357                 case SPTPS_ACK:
358                         // We expect a handshake message to indicate transition to the new keys.
359                         if(!receive_ack(s, data, len))
360                                 return false;
361                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
362                         s->state = SPTPS_SECONDARY_KEX;
363                         return true;
364                 // TODO: split ACK into a VERify and ACK?
365                 default:
366                         return error(s, EIO, "Invalid session state");
367         }
368 }
369
370 // Receive incoming data, datagram version.
371 static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) {
372         if(len < (s->instate ? 21 : 5))
373                 return error(s, EIO, "Received short packet");
374
375         uint32_t seqno;
376         memcpy(&seqno, data, 4);
377         seqno = ntohl(seqno);
378
379         if(!s->instate) {
380                 if(seqno != s->inseqno) {
381                         fprintf(stderr, "Received invalid packet seqno: %d != %d\n", seqno, s->inseqno);
382                         return error(s, EIO, "Invalid packet seqno");
383                 }
384
385                 s->inseqno = seqno + 1;
386
387                 uint8_t type = data[4];
388
389                 if(type != SPTPS_HANDSHAKE)
390                         return error(s, EIO, "Application record received before handshake finished");
391
392                 return receive_handshake(s, data + 5, len - 5);
393         }
394
395         if(seqno < s->inseqno) {
396                 fprintf(stderr, "Received late or replayed packet: %d < %d\n", seqno, s->inseqno);
397                 return true;
398         }
399
400         if(seqno > s->inseqno)
401                 fprintf(stderr, "Missed %d packets\n", seqno - s->inseqno);
402
403         s->inseqno = seqno + 1;
404
405         uint16_t netlen = htons(len - 21);
406
407         char buffer[len + 23];
408
409         memcpy(buffer, &netlen, 2);
410         memcpy(buffer + 2, data, len);
411
412         memcpy(&seqno, buffer + 2, 4);
413
414         // Check HMAC and decrypt.
415         if(!digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14))
416                 return error(s, EIO, "Invalid HMAC");
417
418         cipher_set_counter(&s->incipher, &seqno, sizeof seqno);
419         if(!cipher_counter_xor(&s->incipher, buffer + 6, len - 4, buffer + 6))
420                 return false;
421
422         // Append a NULL byte for safety.
423         buffer[len - 14] = 0;
424
425         uint8_t type = buffer[6];
426
427         if(type < SPTPS_HANDSHAKE) {
428                 if(!s->instate)
429                         return error(s, EIO, "Application record received before handshake finished");
430                 if(!s->receive_record(s->handle, type, buffer + 7, len - 21))
431                         return false;
432         } else {
433                 return error(s, EIO, "Invalid record type");
434         }
435
436         return true;
437 }
438 // Receive incoming data. Check if it contains a complete record, if so, handle it.
439 bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
440         if(s->datagram)
441                 return sptps_receive_data_datagram(s, data, len);
442
443         while(len) {
444                 // First read the 2 length bytes.
445                 if(s->buflen < 6) {
446                         size_t toread = 6 - s->buflen;
447                         if(toread > len)
448                                 toread = len;
449
450                         memcpy(s->inbuf + s->buflen, data, toread);
451
452                         s->buflen += toread;
453                         len -= toread;
454                         data += toread;
455                 
456                         // Exit early if we don't have the full length.
457                         if(s->buflen < 6)
458                                 return true;
459
460                         // Decrypt the length bytes
461
462                         if(s->instate) {
463                                 if(!cipher_counter_xor(&s->incipher, s->inbuf + 4, 2, &s->reclen))
464                                         return false;
465                         } else {
466                                 memcpy(&s->reclen, s->inbuf + 4, 2);
467                         }
468
469                         s->reclen = ntohs(s->reclen);
470
471                         // If we have the length bytes, ensure our buffer can hold the whole request.
472                         s->inbuf = realloc(s->inbuf, s->reclen + 23UL);
473                         if(!s->inbuf)
474                                 return error(s, errno, strerror(errno));
475
476                         // Add sequence number.
477                         uint32_t seqno = htonl(s->inseqno++);
478                         memcpy(s->inbuf, &seqno, 4);
479
480                         // Exit early if we have no more data to process.
481                         if(!len)
482                                 return true;
483                 }
484
485                 // Read up to the end of the record.
486                 size_t toread = s->reclen + (s->instate ? 23UL : 7UL) - s->buflen;
487                 if(toread > len)
488                         toread = len;
489
490                 memcpy(s->inbuf + s->buflen, data, toread);
491                 s->buflen += toread;
492                 len -= toread;
493                 data += toread;
494
495                 // If we don't have a whole record, exit.
496                 if(s->buflen < s->reclen + (s->instate ? 23UL : 7UL))
497                         return true;
498
499                 // Check HMAC and decrypt.
500                 if(s->instate) {
501                         if(!digest_verify(&s->indigest, s->inbuf, s->reclen + 7UL, s->inbuf + s->reclen + 7UL))
502                                 return error(s, EIO, "Invalid HMAC");
503
504                         if(!cipher_counter_xor(&s->incipher, s->inbuf + 6UL, s->reclen + 1UL, s->inbuf + 6UL))
505                                 return false;
506                 }
507
508                 // Append a NULL byte for safety.
509                 s->inbuf[s->reclen + 7UL] = 0;
510
511                 uint8_t type = s->inbuf[6];
512
513                 if(type < SPTPS_HANDSHAKE) {
514                         if(!s->instate)
515                                 return error(s, EIO, "Application record received before handshake finished");
516                         if(!s->receive_record(s->handle, type, s->inbuf + 7, s->reclen))
517                                 return false;
518                 } else if(type == SPTPS_HANDSHAKE) {
519                         if(!receive_handshake(s, s->inbuf + 7, s->reclen))
520                                 return false;
521                 } else {
522                         return error(s, EIO, "Invalid record type");
523                 }
524
525                 s->buflen = 4;
526         }
527
528         return true;
529 }
530
531 // Start a SPTPS session.
532 bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
533         // Initialise struct sptps
534         memset(s, 0, sizeof *s);
535
536         s->handle = handle;
537         s->initiator = initiator;
538         s->datagram = datagram;
539         s->mykey = mykey;
540         s->hiskey = hiskey;
541
542         s->label = malloc(labellen);
543         if(!s->label)
544                 return error(s, errno, strerror(errno));
545
546         if(!datagram) {
547                 s->inbuf = malloc(7);
548                 if(!s->inbuf)
549                         return error(s, errno, strerror(errno));
550                 s->buflen = 4;
551                 memset(s->inbuf, 0, 4);
552         }
553
554         memcpy(s->label, label, labellen);
555         s->labellen = labellen;
556
557         s->send_data = send_data;
558         s->receive_record = receive_record;
559
560         // Do first KEX immediately
561         s->state = SPTPS_KEX;
562         return send_kex(s);
563 }
564
565 // Stop a SPTPS session.
566 bool sptps_stop(sptps_t *s) {
567         // Clean up any resources.
568         ecdh_free(&s->ecdh);
569         free(s->inbuf);
570         free(s->mykex);
571         free(s->hiskex);
572         free(s->key);
573         free(s->label);
574         return true;
575 }