]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - testcode/streamtcp.c
Apply upstream fix 08968baec1122a58bb90d8f97ad948a75f8a5d69:
[FreeBSD/FreeBSD.git] / testcode / streamtcp.c
1 /*
2  * testcode/streamtcp.c - debug program perform multiple DNS queries on tcp.
3  *
4  * Copyright (c) 2008, NLnet Labs. All rights reserved.
5  *
6  * This software is open source.
7  * 
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 
12  * Redistributions of source code must retain the above copyright notice,
13  * this list of conditions and the following disclaimer.
14  * 
15  * Redistributions in binary form must reproduce the above copyright notice,
16  * this list of conditions and the following disclaimer in the documentation
17  * and/or other materials provided with the distribution.
18  * 
19  * Neither the name of the NLNET LABS nor the names of its contributors may
20  * be used to endorse or promote products derived from this software without
21  * specific prior written permission.
22  * 
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
29  * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
30  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
31  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
32  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
33  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  */
35
36 /**
37  * \file
38  *
39  * This program performs multiple DNS queries on a TCP stream.
40  */
41
42 #include "config.h"
43 #ifdef HAVE_GETOPT_H
44 #include <getopt.h>
45 #endif
46 #include <signal.h>
47 #include <stdlib.h>
48 #include <unistd.h>
49 #include "util/locks.h"
50 #include "util/log.h"
51 #include "util/net_help.h"
52 #include "util/data/msgencode.h"
53 #include "util/data/msgparse.h"
54 #include "util/data/msgreply.h"
55 #include "util/data/dname.h"
56 #include "sldns/sbuffer.h"
57 #include "sldns/str2wire.h"
58 #include "sldns/wire2str.h"
59 #include <openssl/ssl.h>
60 #include <openssl/rand.h>
61 #include <openssl/err.h>
62
63 #ifndef PF_INET6
64 /** define in case streamtcp is compiled on legacy systems */
65 #define PF_INET6 10
66 #endif
67
68 /** usage information for streamtcp */
69 static void usage(char* argv[])
70 {
71         printf("usage: %s [options] name type class ...\n", argv[0]);
72         printf("        sends the name-type-class queries over TCP.\n");
73         printf("-f server       what ipaddr@portnr to send the queries to\n");
74         printf("-u              use UDP. No retries are attempted.\n");
75         printf("-n              do not wait for an answer.\n");
76         printf("-a              print answers as they arrive.\n");
77         printf("-d secs         delay after connection before sending query\n");
78         printf("-s              use ssl\n");
79         printf("-h              this help text\n");
80         exit(1);
81 }
82
83 /** open TCP socket to svr */
84 static int
85 open_svr(const char* svr, int udp)
86 {
87         struct sockaddr_storage addr;
88         socklen_t addrlen;
89         int fd = -1;
90         /* svr can be ip@port */
91         memset(&addr, 0, sizeof(addr));
92         if(!extstrtoaddr(svr, &addr, &addrlen)) {
93                 printf("fatal: bad server specs '%s'\n", svr);
94                 exit(1);
95         }
96         fd = socket(addr_is_ip6(&addr, addrlen)?PF_INET6:PF_INET,
97                 udp?SOCK_DGRAM:SOCK_STREAM, 0);
98         if(fd == -1) {
99 #ifndef USE_WINSOCK
100                 perror("socket() error");
101 #else
102                 printf("socket: %s\n", wsa_strerror(WSAGetLastError()));
103 #endif
104                 exit(1);
105         }
106         if(connect(fd, (struct sockaddr*)&addr, addrlen) < 0) {
107 #ifndef USE_WINSOCK
108                 perror("connect() error");
109 #else
110                 printf("connect: %s\n", wsa_strerror(WSAGetLastError()));
111 #endif
112                 exit(1);
113         }
114         return fd;
115 }
116
117 /** write a query over the TCP fd */
118 static void
119 write_q(int fd, int udp, SSL* ssl, sldns_buffer* buf, uint16_t id, 
120         const char* strname, const char* strtype, const char* strclass)
121 {
122         struct query_info qinfo;
123         uint16_t len;
124         /* qname */
125         qinfo.qname = sldns_str2wire_dname(strname, &qinfo.qname_len);
126         if(!qinfo.qname) {
127                 printf("cannot parse query name: '%s'\n", strname);
128                 exit(1);
129         }
130
131         /* qtype and qclass */
132         qinfo.qtype = sldns_get_rr_type_by_name(strtype);
133         qinfo.qclass = sldns_get_rr_class_by_name(strclass);
134
135         /* clear local alias */
136         qinfo.local_alias = NULL;
137
138         /* make query */
139         qinfo_query_encode(buf, &qinfo);
140         sldns_buffer_write_u16_at(buf, 0, id);
141         sldns_buffer_write_u16_at(buf, 2, BIT_RD);
142
143         if(1) {
144                 /* add EDNS DO */
145                 struct edns_data edns;
146                 memset(&edns, 0, sizeof(edns));
147                 edns.edns_present = 1;
148                 edns.bits = EDNS_DO;
149                 edns.udp_size = 4096;
150                 if(sldns_buffer_capacity(buf) >=
151                         sldns_buffer_limit(buf)+calc_edns_field_size(&edns))
152                         attach_edns_record(buf, &edns);
153         }
154
155         /* send it */
156         if(!udp) {
157                 len = (uint16_t)sldns_buffer_limit(buf);
158                 len = htons(len);
159                 if(ssl) {
160                         if(SSL_write(ssl, (void*)&len, (int)sizeof(len)) <= 0) {
161                                 log_crypto_err("cannot SSL_write");
162                                 exit(1);
163                         }
164                 } else {
165                         if(send(fd, (void*)&len, sizeof(len), 0) <
166                                 (ssize_t)sizeof(len)){
167 #ifndef USE_WINSOCK
168                                 perror("send() len failed");
169 #else
170                                 printf("send len: %s\n", 
171                                         wsa_strerror(WSAGetLastError()));
172 #endif
173                                 exit(1);
174                         }
175                 }
176         }
177         if(ssl) {
178                 if(SSL_write(ssl, (void*)sldns_buffer_begin(buf),
179                         (int)sldns_buffer_limit(buf)) <= 0) {
180                         log_crypto_err("cannot SSL_write");
181                         exit(1);
182                 }
183         } else {
184                 if(send(fd, (void*)sldns_buffer_begin(buf),
185                         sldns_buffer_limit(buf), 0) < 
186                         (ssize_t)sldns_buffer_limit(buf)) {
187 #ifndef USE_WINSOCK
188                         perror("send() data failed");
189 #else
190                         printf("send data: %s\n", wsa_strerror(WSAGetLastError()));
191 #endif
192                         exit(1);
193                 }
194         }
195
196         free(qinfo.qname);
197 }
198
199 /** receive DNS datagram over TCP and print it */
200 static void
201 recv_one(int fd, int udp, SSL* ssl, sldns_buffer* buf)
202 {
203         size_t i;
204         char* pktstr;
205         uint16_t len;
206         if(!udp) {
207                 if(ssl) {
208                         int sr = SSL_read(ssl, (void*)&len, (int)sizeof(len));
209                         if(sr == 0) {
210                                 printf("ssl: stream closed\n");
211                                 exit(1);
212                         }
213                         if(sr < 0) {
214                                 log_crypto_err("could not SSL_read");
215                                 exit(1);
216                         }
217                 } else {
218                         ssize_t r = recv(fd, (void*)&len, sizeof(len), 0);
219                         if(r == 0) {
220                                 printf("recv: stream closed\n");
221                                 exit(1);
222                         }       
223                         if(r < (ssize_t)sizeof(len)) {
224 #ifndef USE_WINSOCK
225                                 perror("read() len failed");
226 #else
227                                 printf("read len: %s\n", 
228                                         wsa_strerror(WSAGetLastError()));
229 #endif
230                                 exit(1);
231                         }
232                 }
233                 len = ntohs(len);
234                 sldns_buffer_clear(buf);
235                 sldns_buffer_set_limit(buf, len);
236                 if(ssl) {
237                         int r = SSL_read(ssl, (void*)sldns_buffer_begin(buf),
238                                 (int)len);
239                         if(r <= 0) {
240                                 log_crypto_err("could not SSL_read");
241                                 exit(1);
242                         }
243                         if(r != (int)len)
244                                 fatal_exit("ssl_read %d of %d", r, len);
245                 } else {
246                         if(recv(fd, (void*)sldns_buffer_begin(buf), len, 0) < 
247                                 (ssize_t)len) {
248 #ifndef USE_WINSOCK
249                                 perror("read() data failed");
250 #else
251                                 printf("read data: %s\n", 
252                                         wsa_strerror(WSAGetLastError()));
253 #endif
254                                 exit(1);
255                         }
256                 }
257         } else {
258                 ssize_t l;
259                 sldns_buffer_clear(buf);
260                 if((l=recv(fd, (void*)sldns_buffer_begin(buf), 
261                         sldns_buffer_capacity(buf), 0)) < 0) {
262 #ifndef USE_WINSOCK
263                         perror("read() data failed");
264 #else
265                         printf("read data: %s\n", 
266                                 wsa_strerror(WSAGetLastError()));
267 #endif
268                         exit(1);
269                 }
270                 sldns_buffer_set_limit(buf, (size_t)l);
271                 len = (size_t)l;
272         }
273         printf("\nnext received packet\n");
274         printf("data[%d] ", (int)sldns_buffer_limit(buf));
275         for(i=0; i<sldns_buffer_limit(buf); i++) {
276                 const char* hex = "0123456789ABCDEF";
277                 printf("%c%c", hex[(sldns_buffer_read_u8_at(buf, i)&0xf0)>>4],
278                         hex[sldns_buffer_read_u8_at(buf, i)&0x0f]);
279         }
280         printf("\n");
281
282         pktstr = sldns_wire2str_pkt(sldns_buffer_begin(buf), len);
283         printf("%s", pktstr);
284         free(pktstr);
285 }
286
287 /** see if we can receive any results */
288 static void
289 print_any_answers(int fd, int udp, SSL* ssl, sldns_buffer* buf,
290         int* num_answers, int wait_all)
291 {
292         /* see if the fd can read, if so, print one answer, repeat */
293         int ret;
294         struct timeval tv, *waittv;
295         fd_set rfd;
296         while(*num_answers > 0) {
297                 memset(&rfd, 0, sizeof(rfd));
298                 memset(&tv, 0, sizeof(tv));
299                 FD_ZERO(&rfd);
300                 FD_SET(fd, &rfd);
301                 if(wait_all) waittv = NULL;
302                 else waittv = &tv;
303                 ret = select(fd+1, &rfd, NULL, NULL, waittv);
304                 if(ret < 0) {
305                         if(errno == EINTR || errno == EAGAIN) continue;
306                         perror("select() failed");
307                         exit(1);
308                 }
309                 if(ret == 0) {
310                         if(wait_all) continue;
311                         return;
312                 }
313                 (*num_answers) -= 1;
314                 recv_one(fd, udp, ssl, buf);
315         }
316 }
317
318 static int get_random(void)
319 {
320         int r;
321         if (RAND_bytes((unsigned char*)&r, (int)sizeof(r)) == 1) {
322                 return r;
323         }
324         return (int)arc4random();
325 }
326
327 /** send the TCP queries and print answers */
328 static void
329 send_em(const char* svr, int udp, int usessl, int noanswer, int onarrival,
330         int delay, int num, char** qs)
331 {
332         sldns_buffer* buf = sldns_buffer_new(65553);
333         int fd = open_svr(svr, udp);
334         int i, wait_results = 0;
335         SSL_CTX* ctx = NULL;
336         SSL* ssl = NULL;
337         if(!buf) fatal_exit("out of memory");
338         if(usessl) {
339                 ctx = connect_sslctx_create(NULL, NULL, NULL, 0);
340                 if(!ctx) fatal_exit("cannot create ssl ctx");
341                 ssl = outgoing_ssl_fd(ctx, fd);
342                 if(!ssl) fatal_exit("cannot create ssl");
343                 while(1) {
344                         int r;
345                         ERR_clear_error();
346                         if( (r=SSL_do_handshake(ssl)) == 1)
347                                 break;
348                         r = SSL_get_error(ssl, r);
349                         if(r != SSL_ERROR_WANT_READ &&
350                                 r != SSL_ERROR_WANT_WRITE) {
351                                 log_crypto_err("could not ssl_handshake");
352                                 exit(1);
353                         }
354                 }
355                 if(1) {
356                         X509* x = SSL_get_peer_certificate(ssl);
357                         if(!x) printf("SSL: no peer certificate\n");
358                         else {
359                                 X509_print_fp(stdout, x);
360                                 X509_free(x);
361                         }
362                 }
363         }
364         for(i=0; i<num; i+=3) {
365                 if (delay != 0) {
366 #ifdef HAVE_SLEEP
367                         sleep((unsigned)delay);
368 #else
369                         Sleep(delay*1000);
370 #endif
371                 }
372                 printf("\nNext query is %s %s %s\n", qs[i], qs[i+1], qs[i+2]);
373                 write_q(fd, udp, ssl, buf, (uint16_t)get_random(), qs[i],
374                         qs[i+1], qs[i+2]);
375                 /* print at least one result */
376                 if(onarrival) {
377                         wait_results += 1; /* one more answer to fetch */
378                         print_any_answers(fd, udp, ssl, buf, &wait_results, 0);
379                 } else if(!noanswer) {
380                         recv_one(fd, udp, ssl, buf);
381                 }
382         }
383         if(onarrival)
384                 print_any_answers(fd, udp, ssl, buf, &wait_results, 1);
385
386         if(usessl) {
387                 SSL_shutdown(ssl);
388                 SSL_free(ssl);
389                 SSL_CTX_free(ctx);
390         }
391         sock_close(fd);
392         sldns_buffer_free(buf);
393         printf("orderly exit\n");
394 }
395
396 #ifdef SIGPIPE
397 /** SIGPIPE handler */
398 static RETSIGTYPE sigh(int sig)
399 {
400         if(sig == SIGPIPE) {
401                 printf("got SIGPIPE, remote connection gone\n");
402                 exit(1);
403         }
404         printf("Got unhandled signal %d\n", sig);
405         exit(1);
406 }
407 #endif /* SIGPIPE */
408
409 /** getopt global, in case header files fail to declare it. */
410 extern int optind;
411 /** getopt global, in case header files fail to declare it. */
412 extern char* optarg;
413
414 /** main program for streamtcp */
415 int main(int argc, char** argv) 
416 {
417         int c;
418         const char* svr = "127.0.0.1";
419         int udp = 0;
420         int noanswer = 0;
421         int onarrival = 0;
422         int usessl = 0;
423         int delay = 0;
424
425 #ifdef USE_WINSOCK
426         WSADATA wsa_data;
427         if(WSAStartup(MAKEWORD(2,2), &wsa_data) != 0) {
428                 printf("WSAStartup failed\n");
429                 return 1;
430         }
431 #endif
432
433         /* lock debug start (if any) */
434         log_init(0, 0, 0);
435         checklock_start();
436
437 #ifdef SIGPIPE
438         if(signal(SIGPIPE, &sigh) == SIG_ERR) {
439                 perror("could not install signal handler");
440                 return 1;
441         }
442 #endif
443
444         /* command line options */
445         if(argc == 1) {
446                 usage(argv);
447         }
448         while( (c=getopt(argc, argv, "af:hnsud:")) != -1) {
449                 switch(c) {
450                         case 'f':
451                                 svr = optarg;
452                                 break;
453                         case 'a':
454                                 onarrival = 1;
455                                 break;
456                         case 'n':
457                                 noanswer = 1;
458                                 break;
459                         case 'u':
460                                 udp = 1;
461                                 break;
462                         case 's':
463                                 usessl = 1;
464                                 break;
465                         case 'd':
466                                 if(atoi(optarg)==0 && strcmp(optarg,"0")!=0) {
467                                         printf("error parsing delay, "
468                                             "number expected: %s\n", optarg);
469                                         return 1;
470                                 }
471                                 delay = atoi(optarg);
472                                 break;
473                         case 'h':
474                         case '?':
475                         default:
476                                 usage(argv);
477                 }
478         }
479         argc -= optind;
480         argv += optind;
481
482         if(argc % 3 != 0) {
483                 printf("queries must be multiples of name,type,class\n");
484                 return 1;
485         }
486         if(usessl) {
487 #if OPENSSL_VERSION_NUMBER < 0x10100000 || !defined(HAVE_OPENSSL_INIT_SSL)
488                 ERR_load_SSL_strings();
489 #endif
490 #if OPENSSL_VERSION_NUMBER < 0x10100000 || !defined(HAVE_OPENSSL_INIT_CRYPTO)
491 #  ifndef S_SPLINT_S
492                 OpenSSL_add_all_algorithms();
493 #  endif
494 #else
495                 OPENSSL_init_crypto(OPENSSL_INIT_ADD_ALL_CIPHERS
496                         | OPENSSL_INIT_ADD_ALL_DIGESTS
497                         | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL);
498 #endif
499 #if OPENSSL_VERSION_NUMBER < 0x10100000 || !defined(HAVE_OPENSSL_INIT_SSL)
500                 (void)SSL_library_init();
501 #else
502                 (void)OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS, NULL);
503 #endif
504         }
505         send_em(svr, udp, usessl, noanswer, onarrival, delay, argc, argv);
506         checklock_stop();
507 #ifdef USE_WINSOCK
508         WSACleanup();
509 #endif
510         return 0;
511 }