Add cipher suite selection options to sptps_test.
[tinc] / src / sptps_test.c
1 /*
2     sptps_test.c -- Simple Peer-to-Peer Security test program
3     Copyright (C) 2011-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #ifdef HAVE_LINUX
23 #include <linux/if_tun.h>
24 #endif
25
26 #include "crypto.h"
27 #include "ecdsa.h"
28 #include "meta.h"
29 #include "protocol.h"
30 #include "sptps.h"
31 #include "utils.h"
32 #include "names.h"
33 #include "random.h"
34
35 #ifndef HAVE_WINDOWS
36 #define closesocket(s) close(s)
37 #endif
38
39 // Symbols necessary to link with logger.o
40 bool send_request(struct connection_t *c, const char *msg, ...) {
41         (void)c;
42         (void)msg;
43         return false;
44 }
45
46 list_t connection_list;
47
48 bool send_meta(struct connection_t *c, const void *msg, size_t len) {
49         (void)c;
50         (void)msg;
51         (void)len;
52         return false;
53 }
54
55 bool do_detach = false;
56 struct timeval now;
57
58 static bool special;
59 static bool verbose;
60 static bool readonly;
61 static bool writeonly;
62 static int in = 0;
63 static int out = 1;
64 int addressfamily = AF_UNSPEC;
65
66 static bool send_data(void *handle, uint8_t type, const void *data, size_t len) {
67         (void)type;
68         char *hex = alloca(len * 2 + 1);
69         bin2hex(data, hex, len);
70
71         if(verbose) {
72                 fprintf(stderr, "Sending %lu bytes of data:\n%s\n", (unsigned long)len, hex);
73         }
74
75         const int *sock = handle;
76         const char *p = data;
77
78         while(len) {
79                 ssize_t sent = send(*sock, p, len, 0);
80
81                 if(sent <= 0) {
82                         fprintf(stderr, "Error sending data: %s\n", strerror(errno));
83                         return false;
84                 }
85
86                 p += sent;
87                 len -= sent;
88         }
89
90         return true;
91 }
92
93 static bool receive_record(void *handle, uint8_t type, const void *data, uint16_t len) {
94         (void)handle;
95
96         if(verbose) {
97                 fprintf(stderr, "Received type %d record of %u bytes:\n", type, len);
98         }
99
100         if(writeonly) {
101                 return true;
102         }
103
104         const char *p = data;
105
106         while(len) {
107                 ssize_t written = write(out, p, len);
108
109                 if(written <= 0) {
110                         fprintf(stderr, "Error writing received data: %s\n", strerror(errno));
111                         return false;
112                 }
113
114                 p += written;
115                 len -= written;
116         }
117
118         return true;
119 }
120
121 typedef enum option_t {
122         OPT_BAD_OPTION      = '?',
123         OPT_LONG_OPTION     =  0,
124
125         // Short options
126         OPT_DATAGRAM        = 'd',
127         OPT_QUIT_ON_EOF     = 'q',
128         OPT_READONLY        = 'r',
129         OPT_WRITEONLY       = 'w',
130         OPT_PACKET_LOSS     = 'L',
131         OPT_REPLAY_WINDOW   = 'W',
132         OPT_SPECIAL_CHAR    = 's',
133         OPT_TUN             = 't',
134         OPT_VERBOSE         = 'v',
135         OPT_CIPHER_SUITES   = 'M',
136         OPT_PREFERRED_SUITE = 'P',
137         OPT_IPV4            = '4',
138         OPT_IPV6            = '6',
139
140         // Long options
141         OPT_HELP            = 255,
142 } option_t;
143
144 static struct option const long_options[] = {
145         {"datagram",        no_argument,       NULL, OPT_DATAGRAM},
146         {"quit",            no_argument,       NULL, OPT_QUIT_ON_EOF},
147         {"readonly",        no_argument,       NULL, OPT_READONLY},
148         {"writeonly",       no_argument,       NULL, OPT_WRITEONLY},
149         {"packet-loss",     required_argument, NULL, OPT_PACKET_LOSS},
150         {"replay-window",   required_argument, NULL, OPT_REPLAY_WINDOW},
151         {"special",         no_argument,       NULL, OPT_SPECIAL_CHAR},
152         {"tun",             no_argument,       NULL, OPT_TUN},
153         {"verbose",         required_argument, NULL, OPT_VERBOSE},
154         {"cipher-suites",   required_argument, NULL, OPT_CIPHER_SUITES},
155         {"preferred-suite", required_argument, NULL, OPT_PREFERRED_SUITE},
156         {"help",            no_argument,       NULL, OPT_HELP},
157         {NULL,              0,                 NULL, 0}
158 };
159
160 static void usage(void) {
161         fprintf(stderr,
162                 "Usage: %s [options] my_ed25519_key_file his_ed25519_key_file [host] port\n"
163                 "\n"
164                 "Valid options are:\n"
165                 "  -d, --datagram            Enable datagram mode.\n"
166                 "  -q, --quit                Quit when EOF occurs on stdin.\n"
167                 "  -r, --readonly            Only send data from the socket to stdout.\n"
168 #ifdef HAVE_LINUX
169                 "  -t, --tun                 Use a tun device instead of stdio.\n"
170 #endif
171                 "  -w, --writeonly           Only send data from stdin to the socket.\n"
172                 "  -L, --packet-loss RATE    Fake packet loss of RATE percent.\n"
173                 "  -R, --replay-window N     Set replay window to N bytes.\n"
174                 "  -M, --cipher-suites MASK  Set the mask of allowed cipher suites.\n"
175                 "  -P, --preferred-suite N   Set the preferred cipher suite.\n"
176                 "  -s, --special             Enable special handling of lines starting with #, ^ and $.\n"
177                 "  -v, --verbose             Display debug messages.\n"
178                 "  -4                        Use IPv4.\n"
179                 "  -6                        Use IPv6.\n"
180                 "\n"
181                 "Report bugs to tinc@tinc-vpn.org.\n",
182                 program_name);
183 }
184
185 #ifdef HAVE_WINDOWS
186
187 int stdin_sock_fd = -1;
188
189 // Windows does not allow calling select() on anything but sockets. Therefore,
190 // to keep the same code as on other operating systems, we have to put a
191 // separate thread between the stdin and the sptps loop way below. This thread
192 // reads stdin and sends its content to the main thread through a TCP socket,
193 // which can be properly select()'ed.
194 static DWORD WINAPI stdin_reader_thread(LPVOID arg) {
195         struct sockaddr_in sa;
196         socklen_t sa_size = sizeof(sa);
197
198         while(true) {
199                 int peer_fd = accept(stdin_sock_fd, (struct sockaddr *) &sa, &sa_size);
200
201                 if(peer_fd < 0) {
202                         fprintf(stderr, "accept() failed: %s\n", strerror(errno));
203                         continue;
204                 }
205
206                 if(verbose) {
207                         fprintf(stderr, "New connection received from :%d\n", ntohs(sa.sin_port));
208                 }
209
210                 char buf[1024];
211                 ssize_t nread;
212
213                 while((nread = read(STDIN_FILENO, buf, sizeof(buf))) > 0) {
214                         if(verbose) {
215                                 fprintf(stderr, "Read %lld bytes from input\n", nread);
216                         }
217
218                         char *start = buf;
219                         ssize_t nleft = nread;
220
221                         while(nleft) {
222                                 ssize_t nsend = send(peer_fd, start, nleft, 0);
223
224                                 if(nsend < 0) {
225                                         if(sockwouldblock(sockerrno)) {
226                                                 continue;
227                                         }
228
229                                         break;
230                                 }
231
232                                 start += nsend;
233                                 nleft -= nsend;
234                         }
235
236                         if(nleft) {
237                                 fprintf(stderr, "Could not send data: %s\n", strerror(errno));
238                                 break;
239                         }
240
241                         if(verbose) {
242                                 fprintf(stderr, "Sent %lld bytes to peer\n", nread);
243                         }
244                 }
245
246                 closesocket(peer_fd);
247         }
248
249         closesocket(stdin_sock_fd);
250         stdin_sock_fd = -1;
251         return 0;
252 }
253
254 static int start_input_reader(void) {
255         if(stdin_sock_fd != -1) {
256                 fprintf(stderr, "stdin thread can only be started once.\n");
257                 return -1;
258         }
259
260         stdin_sock_fd = socket(AF_INET, SOCK_STREAM, 0);
261
262         if(stdin_sock_fd < 0) {
263                 fprintf(stderr, "Could not create server socket: %s\n", strerror(errno));
264                 return -1;
265         }
266
267         struct sockaddr_in serv_sa;
268
269         memset(&serv_sa, 0, sizeof(serv_sa));
270
271         serv_sa.sin_family = AF_INET;
272
273         serv_sa.sin_addr.s_addr = htonl(0x7f000001); // 127.0.0.1
274
275         int res = bind(stdin_sock_fd, (struct sockaddr *)&serv_sa, sizeof(serv_sa));
276
277         if(res < 0) {
278                 fprintf(stderr, "Could not bind socket: %s\n", strerror(errno));
279                 goto server_err;
280         }
281
282         if(listen(stdin_sock_fd, 1) < 0) {
283                 fprintf(stderr, "Could not listen: %s\n", strerror(errno));
284                 goto server_err;
285         }
286
287         struct sockaddr_in connect_sa;
288
289         socklen_t addr_len = sizeof(connect_sa);
290
291         if(getsockname(stdin_sock_fd, (struct sockaddr *)&connect_sa, &addr_len) < 0) {
292                 fprintf(stderr, "Could not determine the address of the stdin thread socket\n");
293                 goto server_err;
294         }
295
296         if(verbose) {
297                 fprintf(stderr, "stdin thread is listening on :%d\n", ntohs(connect_sa.sin_port));
298         }
299
300         if(!CreateThread(NULL, 0, stdin_reader_thread, NULL, 0, NULL)) {
301                 fprintf(stderr, "Could not start reader thread: %d\n", GetLastError());
302                 goto server_err;
303         }
304
305         int client_fd = socket(AF_INET, SOCK_STREAM, 0);
306
307         if(client_fd < 0) {
308                 fprintf(stderr, "Could not create client socket: %s\n", strerror(errno));
309                 return -1;
310         }
311
312         if(connect(client_fd, (struct sockaddr *)&connect_sa, sizeof(connect_sa)) < 0) {
313                 fprintf(stderr, "Could not connect: %s\n", strerror(errno));
314                 closesocket(client_fd);
315                 return -1;
316         }
317
318         return client_fd;
319
320 server_err:
321
322         if(stdin_sock_fd != -1) {
323                 closesocket(stdin_sock_fd);
324                 stdin_sock_fd = -1;
325         }
326
327         return -1;
328 }
329
330 #endif // HAVE_WINDOWS
331
332 static void print_listening_msg(int sock) {
333         sockaddr_t sa = {0};
334         socklen_t salen = sizeof(sa);
335         int port = 0;
336
337         if(!getsockname(sock, &sa.sa, &salen)) {
338                 port = ntohs(sa.in.sin_port);
339         }
340
341         fprintf(stderr, "Listening on %d...\n", port);
342         fflush(stderr);
343 }
344
345 static int run_test(int argc, char *argv[]) {
346         program_name = argv[0];
347         bool initiator = false;
348         bool datagram = false;
349 #ifdef HAVE_LINUX
350         bool tun = false;
351 #endif
352         int packetloss = 0;
353         int r;
354         int option_index = 0;
355         bool quit = false;
356         unsigned long cipher_suites = SPTPS_ALL_CIPHER_SUITES;
357         unsigned long preferred_suite = 0;
358
359         while((r = getopt_long(argc, argv, "dqrstwL:W:v46", long_options, &option_index)) != EOF) {
360                 switch((option_t) r) {
361                 case OPT_LONG_OPTION:
362                         break;
363
364                 case OPT_BAD_OPTION:
365                         usage();
366                         return 1;
367
368                 case OPT_DATAGRAM:
369                         datagram = true;
370                         break;
371
372                 case OPT_QUIT_ON_EOF:
373                         quit = true;
374                         break;
375
376                 case OPT_READONLY:
377                         readonly = true;
378                         break;
379
380                 case OPT_TUN:
381 #ifdef HAVE_LINUX
382                         tun = true;
383 #else
384                         fprintf(stderr, "--tun is only supported on Linux.\n");
385                         usage();
386                         return 1;
387 #endif
388                         break;
389
390                 case OPT_WRITEONLY:
391                         writeonly = true;
392                         break;
393
394                 case OPT_PACKET_LOSS:
395                         packetloss = atoi(optarg);
396                         break;
397
398                 case OPT_REPLAY_WINDOW:
399                         sptps_replaywin = atoi(optarg);
400                         break;
401
402                 case OPT_CIPHER_SUITES:
403                         cipher_suites = strtoul(optarg, NULL, 0);
404                         break;
405
406                 case OPT_PREFERRED_SUITE:
407                         preferred_suite = strtoul(optarg, NULL, 0);
408                         break;
409
410                 case OPT_VERBOSE:
411                         verbose = true;
412                         break;
413
414                 case OPT_SPECIAL_CHAR:
415                         special = true;
416                         break;
417
418                 case OPT_IPV4:
419                         addressfamily = AF_INET;
420                         break;
421
422                 case OPT_IPV6:
423                         addressfamily = AF_INET6;
424                         break;
425
426                 case OPT_HELP:
427                         usage();
428                         return 0;
429
430                 default:
431                         break;
432                 }
433         }
434
435         argc -= optind - 1;
436         argv += optind - 1;
437
438         if(argc < 4 || argc > 5) {
439                 fprintf(stderr, "Wrong number of arguments.\n");
440                 usage();
441                 return 1;
442         }
443
444         if(argc > 4) {
445                 initiator = true;
446         }
447
448 #ifdef HAVE_LINUX
449
450         if(tun) {
451                 in = out = open("/dev/net/tun", O_RDWR | O_NONBLOCK);
452
453                 if(in < 0) {
454                         fprintf(stderr, "Could not open tun device: %s\n", strerror(errno));
455                         return 1;
456                 }
457
458                 struct ifreq ifr = {
459                         .ifr_flags = IFF_TUN
460                 };
461
462                 if(ioctl(in, TUNSETIFF, &ifr)) {
463                         fprintf(stderr, "Could not configure tun interface: %s\n", strerror(errno));
464                         return 1;
465                 }
466
467                 ifr.ifr_name[IFNAMSIZ - 1] = 0;
468                 fprintf(stderr, "Using tun interface %s\n", ifr.ifr_name);
469         }
470
471 #endif
472
473 #ifdef HAVE_WINDOWS
474         static struct WSAData wsa_state;
475
476         if(WSAStartup(MAKEWORD(2, 2), &wsa_state)) {
477                 return 1;
478         }
479
480 #endif
481
482         struct addrinfo *ai, hint;
483         memset(&hint, 0, sizeof(hint));
484
485         hint.ai_family = addressfamily;
486         hint.ai_socktype = datagram ? SOCK_DGRAM : SOCK_STREAM;
487         hint.ai_protocol = datagram ? IPPROTO_UDP : IPPROTO_TCP;
488         hint.ai_flags = initiator ? 0 : AI_PASSIVE;
489
490         if(getaddrinfo(initiator ? argv[3] : NULL, initiator ? argv[4] : argv[3], &hint, &ai) || !ai) {
491                 fprintf(stderr, "getaddrinfo() failed: %s\n", sockstrerror(sockerrno));
492                 return 1;
493         }
494
495         int sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
496
497         if(sock < 0) {
498                 fprintf(stderr, "Could not create socket: %s\n", sockstrerror(sockerrno));
499                 freeaddrinfo(ai);
500                 return 1;
501         }
502
503         int one = 1;
504         setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (void *)&one, sizeof(one));
505
506         if(initiator) {
507                 int res = connect(sock, ai->ai_addr, ai->ai_addrlen);
508
509                 freeaddrinfo(ai);
510                 ai = NULL;
511
512                 if(res) {
513                         fprintf(stderr, "Could not connect to peer: %s\n", sockstrerror(sockerrno));
514                         return 1;
515                 }
516
517                 fprintf(stderr, "Connected\n");
518         } else {
519                 int res = bind(sock, ai->ai_addr, ai->ai_addrlen);
520
521                 freeaddrinfo(ai);
522                 ai = NULL;
523
524                 if(res) {
525                         fprintf(stderr, "Could not bind socket: %s\n", sockstrerror(sockerrno));
526                         return 1;
527                 }
528
529                 if(!datagram) {
530                         if(listen(sock, 1)) {
531                                 fprintf(stderr, "Could not listen on socket: %s\n", sockstrerror(sockerrno));
532                                 return 1;
533                         }
534
535                         print_listening_msg(sock);
536
537                         sock = accept(sock, NULL, NULL);
538
539                         if(sock < 0) {
540                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
541                                 return 1;
542                         }
543                 } else {
544                         print_listening_msg(sock);
545
546                         char buf[65536];
547                         struct sockaddr addr;
548                         socklen_t addrlen = sizeof(addr);
549
550                         if(recvfrom(sock, buf, sizeof(buf), MSG_PEEK, &addr, &addrlen) <= 0) {
551                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
552                                 return 1;
553                         }
554
555                         if(connect(sock, &addr, addrlen)) {
556                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
557                                 return 1;
558                         }
559                 }
560
561                 fprintf(stderr, "Connected\n");
562         }
563
564         FILE *fp = fopen(argv[1], "r");
565
566         if(!fp) {
567                 fprintf(stderr, "Could not open %s: %s\n", argv[1], strerror(errno));
568                 return 1;
569         }
570
571         ecdsa_t *mykey = NULL;
572
573         if(!(mykey = ecdsa_read_pem_private_key(fp))) {
574                 return 1;
575         }
576
577         fclose(fp);
578
579         fp = fopen(argv[2], "r");
580
581         if(!fp) {
582                 fprintf(stderr, "Could not open %s: %s\n", argv[2], strerror(errno));
583                 ecdsa_free(mykey);
584                 return 1;
585         }
586
587         ecdsa_t *hiskey = NULL;
588
589         if(!(hiskey = ecdsa_read_pem_public_key(fp))) {
590                 ecdsa_free(mykey);
591                 return 1;
592         }
593
594         fclose(fp);
595
596         if(verbose) {
597                 fprintf(stderr, "Keys loaded\n");
598         }
599
600         sptps_t s;
601
602         sptps_params_t params = {
603                 .handle = &sock,
604                 .initiator = initiator,
605                 .datagram = datagram,
606                 .mykey = mykey,
607                 .hiskey = hiskey,
608                 .label = "sptps_test",
609                 .send_data = send_data,
610                 .receive_record = receive_record,
611                 .cipher_suites = cipher_suites,
612                 .preferred_suite = preferred_suite,
613         };
614
615         if(!sptps_start(&s, &params)) {
616                 ecdsa_free(mykey);
617                 ecdsa_free(hiskey);
618                 return 1;
619         }
620
621 #ifdef HAVE_WINDOWS
622
623         if(!readonly) {
624                 in = start_input_reader();
625
626                 if(in < 0) {
627                         fprintf(stderr, "Could not init stdin reader thread\n");
628                         ecdsa_free(mykey);
629                         ecdsa_free(hiskey);
630                         return 1;
631                 }
632         }
633
634 #endif
635
636         int max_fd = MAX(sock, in);
637
638         while(true) {
639                 if(writeonly && readonly) {
640                         break;
641                 }
642
643                 char buf[65535] = "";
644                 size_t readsize = datagram ? 1460u : sizeof(buf);
645
646                 fd_set fds;
647                 FD_ZERO(&fds);
648
649                 if(!readonly && s.instate) {
650                         FD_SET(in, &fds);
651                 }
652
653                 FD_SET(sock, &fds);
654
655                 if(select(max_fd + 1, &fds, NULL, NULL, NULL) <= 0) {
656                         ecdsa_free(mykey);
657                         ecdsa_free(hiskey);
658                         return 1;
659                 }
660
661                 if(FD_ISSET(in, &fds)) {
662 #ifdef HAVE_WINDOWS
663                         ssize_t len = recv(in, buf, readsize, 0);
664 #else
665                         ssize_t len = read(in, buf, readsize);
666 #endif
667
668                         if(len < 0) {
669                                 fprintf(stderr, "Could not read from stdin: %s\n", strerror(errno));
670                                 ecdsa_free(mykey);
671                                 ecdsa_free(hiskey);
672                                 return 1;
673                         }
674
675                         if(len == 0) {
676 #ifdef HAVE_WINDOWS
677                                 shutdown(in, SD_SEND);
678                                 closesocket(in);
679 #endif
680
681                                 if(quit) {
682                                         break;
683                                 }
684
685                                 readonly = true;
686                                 continue;
687                         }
688
689                         if(special && buf[0] == '#') {
690                                 s.outseqno = atoi(buf + 1);
691                         }
692
693                         if(special && buf[0] == '^') {
694                                 sptps_send_record(&s, SPTPS_HANDSHAKE, NULL, 0);
695                         } else if(special && buf[0] == '$') {
696                                 sptps_force_kex(&s);
697
698                                 if(len > 1) {
699                                         sptps_send_record(&s, 0, buf, len);
700                                 }
701                         } else if(!sptps_send_record(&s, buf[0] == '!' ? 1 : 0, buf, (len == 1 && buf[0] == '\n') ? 0 : buf[0] == '*' ? sizeof(buf) : (size_t)len)) {
702                                 ecdsa_free(mykey);
703                                 ecdsa_free(hiskey);
704                                 return 1;
705                         }
706                 }
707
708                 if(FD_ISSET(sock, &fds)) {
709                         ssize_t len = recv(sock, buf, sizeof(buf), 0);
710
711                         if(len < 0) {
712                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
713                                 ecdsa_free(mykey);
714                                 ecdsa_free(hiskey);
715                                 return 1;
716                         }
717
718                         if(len == 0) {
719                                 fprintf(stderr, "Connection terminated by peer.\n");
720                                 break;
721                         }
722
723                         if(verbose) {
724                                 char *hex = alloca(len * 2 + 1);
725                                 bin2hex(buf, hex, len);
726                                 fprintf(stderr, "Received %ld bytes of data:\n%s\n", (long)len, hex);
727                         }
728
729                         if(packetloss && (int)prng(100) < packetloss) {
730                                 if(verbose) {
731                                         fprintf(stderr, "Dropped.\n");
732                                 }
733
734                                 continue;
735                         }
736
737                         char *bufp = buf;
738
739                         while(len) {
740                                 size_t done = sptps_receive_data(&s, bufp, len);
741
742                                 if(!done) {
743                                         if(!datagram) {
744                                                 ecdsa_free(mykey);
745                                                 ecdsa_free(hiskey);
746                                                 return 1;
747                                         }
748                                 }
749
750                                 bufp += done;
751                                 len -= (ssize_t) done;
752                         }
753                 }
754         }
755
756         bool stopped = sptps_stop(&s);
757
758         ecdsa_free(mykey);
759         ecdsa_free(hiskey);
760         closesocket(sock);
761
762         return !stopped;
763 }
764
765 int main(int argc, char *argv[]) {
766         random_init();
767         crypto_init();
768         prng_init();
769
770         int result = run_test(argc, argv);
771
772         random_exit();
773
774         return result;
775 }