Enable and fix many extra warnings supported by GCC and Clang.
[tinc] / src / sptps_test.c
1 /*
2     sptps_test.c -- Simple Peer-to-Peer Security test program
3     Copyright (C) 2011-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #ifdef HAVE_LINUX
23 #include <linux/if_tun.h>
24 #endif
25
26 #include <getopt.h>
27
28 #ifdef HAVE_MINGW
29 #include <pthread.h>
30 #endif
31
32 #include "crypto.h"
33 #include "ecdsa.h"
34 #include "meta.h"
35 #include "protocol.h"
36 #include "sptps.h"
37 #include "utils.h"
38
39 #ifndef HAVE_MINGW
40 #define closesocket(s) close(s)
41 #endif
42
43 // Symbols necessary to link with logger.o
44 bool send_request(struct connection_t *c, const char *msg, ...) {
45         (void)c;
46         (void)msg;
47         return false;
48 }
49
50 list_t connection_list;
51
52 bool send_meta(struct connection_t *c, const void *msg, size_t len) {
53         (void)c;
54         (void)msg;
55         (void)len;
56         return false;
57 }
58
59 char *logfilename = NULL;
60 bool do_detach = false;
61 struct timeval now;
62
63 static bool special;
64 static bool verbose;
65 static bool readonly;
66 static bool writeonly;
67 static int in = 0;
68 static int out = 1;
69 int addressfamily = AF_UNSPEC;
70
71 static bool send_data(void *handle, uint8_t type, const void *data, size_t len) {
72         (void)type;
73         char hex[len * 2 + 1];
74         bin2hex(data, hex, len);
75
76         if(verbose) {
77                 fprintf(stderr, "Sending %lu bytes of data:\n%s\n", (unsigned long)len, hex);
78         }
79
80         const int *sock = handle;
81         const char *p = data;
82
83         while(len) {
84                 ssize_t sent = send(*sock, p, len, 0);
85
86                 if(sent <= 0) {
87                         fprintf(stderr, "Error sending data: %s\n", strerror(errno));
88                         return false;
89                 }
90
91                 p += sent;
92                 len -= sent;
93         }
94
95         return true;
96 }
97
98 static bool receive_record(void *handle, uint8_t type, const void *data, uint16_t len) {
99         (void)handle;
100
101         if(verbose) {
102                 fprintf(stderr, "Received type %d record of %u bytes:\n", type, len);
103         }
104
105         if(writeonly) {
106                 return true;
107         }
108
109         const char *p = data;
110
111         while(len) {
112                 ssize_t written = write(out, p, len);
113
114                 if(written <= 0) {
115                         fprintf(stderr, "Error writing received data: %s\n", strerror(errno));
116                         return false;
117                 }
118
119                 p += written;
120                 len -= written;
121         }
122
123         return true;
124 }
125
126 static struct option const long_options[] = {
127         {"datagram", no_argument, NULL, 'd'},
128         {"quit", no_argument, NULL, 'q'},
129         {"readonly", no_argument, NULL, 'r'},
130         {"writeonly", no_argument, NULL, 'w'},
131         {"packet-loss", required_argument, NULL, 'L'},
132         {"replay-window", required_argument, NULL, 'W'},
133         {"special", no_argument, NULL, 's'},
134         {"verbose", required_argument, NULL, 'v'},
135         {"help", no_argument, NULL, 1},
136         {NULL, 0, NULL, 0}
137 };
138
139 const char *program_name;
140
141 static void usage(void) {
142         static const char *message =
143                 "Usage: %s [options] my_ed25519_key_file his_ed25519_key_file [host] port\n"
144                 "\n"
145                 "Valid options are:\n"
146                 "  -d, --datagram          Enable datagram mode.\n"
147                 "  -q, --quit              Quit when EOF occurs on stdin.\n"
148                 "  -r, --readonly          Only send data from the socket to stdout.\n"
149 #ifdef HAVE_LINUX
150                 "  -t, --tun               Use a tun device instead of stdio.\n"
151 #endif
152                 "  -w, --writeonly         Only send data from stdin to the socket.\n"
153                 "  -L, --packet-loss RATE  Fake packet loss of RATE percent.\n"
154                 "  -R, --replay-window N   Set replay window to N bytes.\n"
155                 "  -s, --special           Enable special handling of lines starting with #, ^ and $.\n"
156                 "  -v, --verbose           Display debug messages.\n"
157                 "  -4                      Use IPv4.\n"
158                 "  -6                      Use IPv6.\n"
159                 "\n"
160                 "Report bugs to tinc@tinc-vpn.org.\n";
161
162         fprintf(stderr, message, program_name);
163 }
164
165 #ifdef HAVE_MINGW
166
167 int stdin_sock_fd = -1;
168
169 // Windows does not allow calling select() on anything but sockets. Therefore,
170 // to keep the same code as on other operating systems, we have to put a
171 // separate thread between the stdin and the sptps loop way below. This thread
172 // reads stdin and sends its content to the main thread through a TCP socket,
173 // which can be properly select()'ed.
174 static void *stdin_reader_thread(void *arg) {
175         struct sockaddr_in sa;
176         socklen_t sa_size = sizeof(sa);
177
178         while(true) {
179                 int peer_fd = accept(stdin_sock_fd, (struct sockaddr *) &sa, &sa_size);
180
181                 if(peer_fd < 0) {
182                         fprintf(stderr, "accept() failed: %s\n", strerror(errno));
183                         continue;
184                 }
185
186                 if(verbose) {
187                         fprintf(stderr, "New connection received from :%d\n", ntohs(sa.sin_port));
188                 }
189
190                 char buf[1024];
191                 ssize_t nread;
192
193                 while((nread = read(STDIN_FILENO, buf, sizeof(buf))) > 0) {
194                         if(verbose) {
195                                 fprintf(stderr, "Read %lld bytes from input\n", nread);
196                         }
197
198                         char *start = buf;
199                         ssize_t nleft = nread;
200
201                         while(nleft) {
202                                 ssize_t nsend = send(peer_fd, start, nleft, 0);
203
204                                 if(nsend < 0) {
205                                         if(sockwouldblock(sockerrno)) {
206                                                 continue;
207                                         }
208
209                                         break;
210                                 }
211
212                                 start += nsend;
213                                 nleft -= nsend;
214                         }
215
216                         if(nleft) {
217                                 fprintf(stderr, "Could not send data: %s\n", strerror(errno));
218                                 break;
219                         }
220
221                         if(verbose) {
222                                 fprintf(stderr, "Sent %lld bytes to peer\n", nread);
223                         }
224                 }
225
226                 closesocket(peer_fd);
227         }
228
229         closesocket(stdin_sock_fd);
230         stdin_sock_fd = -1;
231         return NULL;
232 }
233
234 static int start_input_reader(void) {
235         if(stdin_sock_fd != -1) {
236                 fprintf(stderr, "stdin thread can only be started once.\n");
237                 return -1;
238         }
239
240         stdin_sock_fd = socket(AF_INET, SOCK_STREAM, 0);
241
242         if(stdin_sock_fd < 0) {
243                 fprintf(stderr, "Could not create server socket: %s\n", strerror(errno));
244                 return -1;
245         }
246
247         struct sockaddr_in serv_sa;
248
249         memset(&serv_sa, 0, sizeof(serv_sa));
250
251         serv_sa.sin_family = AF_INET;
252
253         serv_sa.sin_addr.s_addr = htonl(0x7f000001); // 127.0.0.1
254
255         int res = bind(stdin_sock_fd, (struct sockaddr *)&serv_sa, sizeof(serv_sa));
256
257         if(res < 0) {
258                 fprintf(stderr, "Could not bind socket: %s\n", strerror(errno));
259                 goto server_err;
260         }
261
262         if(listen(stdin_sock_fd, 1) < 0) {
263                 fprintf(stderr, "Could not listen: %s\n", strerror(errno));
264                 goto server_err;
265         }
266
267         struct sockaddr_in connect_sa;
268
269         socklen_t addr_len = sizeof(connect_sa);
270
271         if(getsockname(stdin_sock_fd, (struct sockaddr *)&connect_sa, &addr_len) < 0) {
272                 fprintf(stderr, "Could not determine the address of the stdin thread socket\n");
273                 goto server_err;
274         }
275
276         if(verbose) {
277                 fprintf(stderr, "stdin thread is listening on :%d\n", ntohs(connect_sa.sin_port));
278         }
279
280         pthread_t th;
281         int err = pthread_create(&th, NULL, stdin_reader_thread, NULL);
282
283         if(err) {
284                 fprintf(stderr, "Could not start reader thread: %s\n", strerror(err));
285                 goto server_err;
286         }
287
288         int client_fd = socket(AF_INET, SOCK_STREAM, 0);
289
290         if(client_fd < 0) {
291                 fprintf(stderr, "Could not create client socket: %s\n", strerror(errno));
292                 return -1;
293         }
294
295         if(connect(client_fd, (struct sockaddr *)&connect_sa, sizeof(connect_sa)) < 0) {
296                 fprintf(stderr, "Could not connect: %s\n", strerror(errno));
297                 closesocket(client_fd);
298                 return -1;
299         }
300
301         return client_fd;
302
303 server_err:
304
305         if(stdin_sock_fd != -1) {
306                 closesocket(stdin_sock_fd);
307                 stdin_sock_fd = -1;
308         }
309
310         return -1;
311 }
312
313 #endif // HAVE_MINGW
314
315 int main(int argc, char *argv[]) {
316         program_name = argv[0];
317         bool initiator = false;
318         bool datagram = false;
319 #ifdef HAVE_LINUX
320         bool tun = false;
321 #endif
322         int packetloss = 0;
323         int r;
324         int option_index = 0;
325         bool quit = false;
326
327         while((r = getopt_long(argc, argv, "dqrstwL:W:v46", long_options, &option_index)) != EOF) {
328                 switch(r) {
329                 case 0:   /* long option */
330                         break;
331
332                 case 'd': /* datagram mode */
333                         datagram = true;
334                         break;
335
336                 case 'q': /* close connection on EOF from stdin */
337                         quit = true;
338                         break;
339
340                 case 'r': /* read only */
341                         readonly = true;
342                         break;
343
344                 case 't': /* read only */
345 #ifdef HAVE_LINUX
346                         tun = true;
347 #else
348                         fprintf(stderr, "--tun is only supported on Linux.\n");
349                         usage();
350                         return 1;
351 #endif
352                         break;
353
354                 case 'w': /* write only */
355                         writeonly = true;
356                         break;
357
358                 case 'L': /* packet loss rate */
359                         packetloss = atoi(optarg);
360                         break;
361
362                 case 'W': /* replay window size */
363                         sptps_replaywin = atoi(optarg);
364                         break;
365
366                 case 'v': /* be verbose */
367                         verbose = true;
368                         break;
369
370                 case 's': /* special character handling */
371                         special = true;
372                         break;
373
374                 case '?': /* wrong options */
375                         usage();
376                         return 1;
377
378                 case '4': /* IPv4 */
379                         addressfamily = AF_INET;
380                         break;
381
382                 case '6': /* IPv6 */
383                         addressfamily = AF_INET6;
384                         break;
385
386                 case 1: /* help */
387                         usage();
388                         return 0;
389
390                 default:
391                         break;
392                 }
393         }
394
395         argc -= optind - 1;
396         argv += optind - 1;
397
398         if(argc < 4 || argc > 5) {
399                 fprintf(stderr, "Wrong number of arguments.\n");
400                 usage();
401                 return 1;
402         }
403
404         if(argc > 4) {
405                 initiator = true;
406         }
407
408 #ifdef HAVE_LINUX
409
410         if(tun) {
411                 in = out = open("/dev/net/tun", O_RDWR | O_NONBLOCK);
412
413                 if(in < 0) {
414                         fprintf(stderr, "Could not open tun device: %s\n", strerror(errno));
415                         return 1;
416                 }
417
418                 struct ifreq ifr = {
419                         .ifr_flags = IFF_TUN
420                 };
421
422                 if(ioctl(in, TUNSETIFF, &ifr)) {
423                         fprintf(stderr, "Could not configure tun interface: %s\n", strerror(errno));
424                         return 1;
425                 }
426
427                 ifr.ifr_name[IFNAMSIZ - 1] = 0;
428                 fprintf(stderr, "Using tun interface %s\n", ifr.ifr_name);
429         }
430
431 #endif
432
433 #ifdef HAVE_MINGW
434         static struct WSAData wsa_state;
435
436         if(WSAStartup(MAKEWORD(2, 2), &wsa_state)) {
437                 return 1;
438         }
439
440 #endif
441
442         struct addrinfo *ai, hint;
443         memset(&hint, 0, sizeof(hint));
444
445         hint.ai_family = addressfamily;
446         hint.ai_socktype = datagram ? SOCK_DGRAM : SOCK_STREAM;
447         hint.ai_protocol = datagram ? IPPROTO_UDP : IPPROTO_TCP;
448         hint.ai_flags = initiator ? 0 : AI_PASSIVE;
449
450         if(getaddrinfo(initiator ? argv[3] : NULL, initiator ? argv[4] : argv[3], &hint, &ai) || !ai) {
451                 fprintf(stderr, "getaddrinfo() failed: %s\n", sockstrerror(sockerrno));
452                 return 1;
453         }
454
455         int sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
456
457         if(sock < 0) {
458                 fprintf(stderr, "Could not create socket: %s\n", sockstrerror(sockerrno));
459                 freeaddrinfo(ai);
460                 return 1;
461         }
462
463         int one = 1;
464         setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (void *)&one, sizeof(one));
465
466         if(initiator) {
467                 int res = connect(sock, ai->ai_addr, ai->ai_addrlen);
468
469                 freeaddrinfo(ai);
470                 ai = NULL;
471
472                 if(res) {
473                         fprintf(stderr, "Could not connect to peer: %s\n", sockstrerror(sockerrno));
474                         return 1;
475                 }
476
477                 fprintf(stderr, "Connected\n");
478         } else {
479                 int res = bind(sock, ai->ai_addr, ai->ai_addrlen);
480
481                 freeaddrinfo(ai);
482                 ai = NULL;
483
484                 if(res) {
485                         fprintf(stderr, "Could not bind socket: %s\n", sockstrerror(sockerrno));
486                         return 1;
487                 }
488
489                 if(!datagram) {
490                         if(listen(sock, 1)) {
491                                 fprintf(stderr, "Could not listen on socket: %s\n", sockstrerror(sockerrno));
492                                 return 1;
493                         }
494
495                         fprintf(stderr, "Listening...\n");
496
497                         sock = accept(sock, NULL, NULL);
498
499                         if(sock < 0) {
500                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
501                                 return 1;
502                         }
503                 } else {
504                         fprintf(stderr, "Listening...\n");
505
506                         char buf[65536];
507                         struct sockaddr addr;
508                         socklen_t addrlen = sizeof(addr);
509
510                         if(recvfrom(sock, buf, sizeof(buf), MSG_PEEK, &addr, &addrlen) <= 0) {
511                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
512                                 return 1;
513                         }
514
515                         if(connect(sock, &addr, addrlen)) {
516                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
517                                 return 1;
518                         }
519                 }
520
521                 fprintf(stderr, "Connected\n");
522         }
523
524         crypto_init();
525         prng_init();
526
527         FILE *fp = fopen(argv[1], "r");
528
529         if(!fp) {
530                 fprintf(stderr, "Could not open %s: %s\n", argv[1], strerror(errno));
531                 return 1;
532         }
533
534         ecdsa_t *mykey = NULL;
535
536         if(!(mykey = ecdsa_read_pem_private_key(fp))) {
537                 return 1;
538         }
539
540         fclose(fp);
541
542         fp = fopen(argv[2], "r");
543
544         if(!fp) {
545                 fprintf(stderr, "Could not open %s: %s\n", argv[2], strerror(errno));
546                 free(mykey);
547                 return 1;
548         }
549
550         ecdsa_t *hiskey = NULL;
551
552         if(!(hiskey = ecdsa_read_pem_public_key(fp))) {
553                 free(mykey);
554                 return 1;
555         }
556
557         fclose(fp);
558
559         if(verbose) {
560                 fprintf(stderr, "Keys loaded\n");
561         }
562
563         sptps_t s;
564
565         if(!sptps_start(&s, &sock, initiator, datagram, mykey, hiskey, "sptps_test", 10, send_data, receive_record)) {
566                 free(mykey);
567                 free(hiskey);
568                 return 1;
569         }
570
571 #ifdef HAVE_MINGW
572
573         if(!readonly) {
574                 in = start_input_reader();
575
576                 if(in < 0) {
577                         fprintf(stderr, "Could not init stdin reader thread\n");
578                         free(mykey);
579                         free(hiskey);
580                         return 1;
581                 }
582         }
583
584 #endif
585
586         int max_fd = MAX(sock, in);
587
588         while(true) {
589                 if(writeonly && readonly) {
590                         break;
591                 }
592
593                 char buf[65535] = "";
594                 size_t readsize = datagram ? 1460u : sizeof(buf);
595
596                 fd_set fds;
597                 FD_ZERO(&fds);
598
599                 if(!readonly && s.instate) {
600                         FD_SET(in, &fds);
601                 }
602
603                 FD_SET(sock, &fds);
604
605                 if(select(max_fd + 1, &fds, NULL, NULL, NULL) <= 0) {
606                         free(mykey);
607                         free(hiskey);
608                         return 1;
609                 }
610
611                 if(FD_ISSET(in, &fds)) {
612 #ifdef HAVE_MINGW
613                         ssize_t len = recv(in, buf, readsize, 0);
614 #else
615                         ssize_t len = read(in, buf, readsize);
616 #endif
617
618                         if(len < 0) {
619                                 fprintf(stderr, "Could not read from stdin: %s\n", strerror(errno));
620                                 free(mykey);
621                                 free(hiskey);
622                                 return 1;
623                         }
624
625                         if(len == 0) {
626 #ifdef HAVE_MINGW
627                                 shutdown(in, SD_SEND);
628                                 closesocket(in);
629 #endif
630
631                                 if(quit) {
632                                         break;
633                                 }
634
635                                 readonly = true;
636                                 continue;
637                         }
638
639                         if(special && buf[0] == '#') {
640                                 s.outseqno = atoi(buf + 1);
641                         }
642
643                         if(special && buf[0] == '^') {
644                                 sptps_send_record(&s, SPTPS_HANDSHAKE, NULL, 0);
645                         } else if(special && buf[0] == '$') {
646                                 sptps_force_kex(&s);
647
648                                 if(len > 1) {
649                                         sptps_send_record(&s, 0, buf, len);
650                                 }
651                         } else if(!sptps_send_record(&s, buf[0] == '!' ? 1 : 0, buf, (len == 1 && buf[0] == '\n') ? 0 : buf[0] == '*' ? sizeof(buf) : (size_t)len)) {
652                                 free(mykey);
653                                 free(hiskey);
654                                 return 1;
655                         }
656                 }
657
658                 if(FD_ISSET(sock, &fds)) {
659                         ssize_t len = recv(sock, buf, sizeof(buf), 0);
660
661                         if(len < 0) {
662                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
663                                 free(mykey);
664                                 free(hiskey);
665                                 return 1;
666                         }
667
668                         if(len == 0) {
669                                 fprintf(stderr, "Connection terminated by peer.\n");
670                                 break;
671                         }
672
673                         if(verbose) {
674                                 char hex[len * 2 + 1];
675                                 bin2hex(buf, hex, len);
676                                 fprintf(stderr, "Received %ld bytes of data:\n%s\n", (long)len, hex);
677                         }
678
679                         if(packetloss && (int)prng(100) < packetloss) {
680                                 if(verbose) {
681                                         fprintf(stderr, "Dropped.\n");
682                                 }
683
684                                 continue;
685                         }
686
687                         char *bufp = buf;
688
689                         while(len) {
690                                 size_t done = sptps_receive_data(&s, bufp, len);
691
692                                 if(!done) {
693                                         if(!datagram) {
694                                                 free(mykey);
695                                                 free(hiskey);
696                                                 return 1;
697                                         }
698                                 }
699
700                                 bufp += done;
701                                 len -= (ssize_t) done;
702                         }
703                 }
704         }
705
706         bool stopped = sptps_stop(&s);
707
708         free(mykey);
709         free(hiskey);
710
711         if(!stopped) {
712                 return 1;
713         }
714
715         closesocket(sock);
716
717         return 0;
718 }