Can now also print data sent by the remote side.
[people/xl0/gpxe.git] / src / util / prototester.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <unistd.h>
4 #include <string.h>
5 #include <errno.h>
6 #include <time.h>
7 #include <sys/socket.h>
8 #include <sys/un.h>
9 #include <net/if.h>
10 #include <net/ethernet.h>
11 #include <netinet/in.h>
12 #include <arpa/inet.h>
13 #include <getopt.h>
14 #include <assert.h>
15
16 typedef int irq_action_t;
17
18 struct nic {
19         struct nic_operations   *nic_op;
20         unsigned char           *node_addr;
21         unsigned char           *packet;
22         unsigned int            packetlen;
23         void                    *priv_data;     /* driver private data */
24 };
25
26 struct nic_operations {
27         int ( *connect ) ( struct nic * );
28         int ( *poll ) ( struct nic *, int retrieve );
29         void ( *transmit ) ( struct nic *, const char *,
30                              unsigned int, unsigned int, const char * );
31         void ( *irq ) ( struct nic *, irq_action_t );
32 };
33
34 /*****************************************************************************
35  *
36  * Net device layer
37  *
38  */
39
40 #include "../proto/uip/uip_arp.h"
41
42 static unsigned char node_addr[ETH_ALEN];
43 static unsigned char packet[ETH_FRAME_LEN];
44 struct nic static_nic = {
45         .node_addr = node_addr,
46         .packet = packet,
47 };
48
49 /* Must be a macro because priv_data[] is of variable size */
50 #define alloc_netdevice( priv_size ) ( {        \
51         static char priv_data[priv_size];       \
52         static_nic.priv_data = priv_data;       \
53         &static_nic; } )
54
55 static int register_netdevice ( struct nic *nic ) {
56         struct uip_eth_addr hwaddr;
57
58         memcpy ( &hwaddr, nic->node_addr, sizeof ( hwaddr ) );
59         uip_setethaddr ( hwaddr );
60         return 0;
61 }
62
63 static inline void unregister_netdevice ( struct nic *nic ) {
64         /* Do nothing */
65 }
66
67 static inline void free_netdevice ( struct nic *nic ) {
68         /* Do nothing */
69 }
70
71 static int netdev_poll ( int retrieve, void **data, size_t *len ) {
72         int rc = static_nic.nic_op->poll ( &static_nic, retrieve );
73         *data = static_nic.packet;
74         *len = static_nic.packetlen;
75         return rc;
76 }
77
78 static void netdev_transmit ( const void *data, size_t len ) {
79         uint16_t type = ntohs ( *( ( uint16_t * ) ( data + 12 ) ) );
80         static_nic.nic_op->transmit ( &static_nic, data, type,
81                                       len - ETH_HLEN,
82                                       data + ETH_HLEN );
83 }
84
85 /*****************************************************************************
86  *
87  * Hijack device interface
88  *
89  * This requires a hijack daemon to be running
90  *
91  */
92
93 struct hijack {
94         int fd;
95 };
96
97 struct hijack_device {
98         char *name;
99         void *priv;
100 };
101
102 static inline void hijack_set_drvdata ( struct hijack_device *hijack_dev,
103                                         void *data ) {
104         hijack_dev->priv = data;
105 }
106
107 static inline void * hijack_get_drvdata ( struct hijack_device *hijack_dev ) {
108         return hijack_dev->priv;
109 }
110
111 static int hijack_poll ( struct nic *nic, int retrieve ) {
112         struct hijack *hijack = nic->priv_data;
113         fd_set fdset;
114         struct timeval tv;
115         int ready;
116         ssize_t len;
117
118         /* Poll for data */
119         FD_ZERO ( &fdset );
120         FD_SET ( hijack->fd, &fdset );
121         tv.tv_sec = 0;
122         tv.tv_usec = 500; /* 500us to avoid hogging CPU */
123         ready = select ( ( hijack->fd + 1 ), &fdset, NULL, NULL, &tv );
124         if ( ready < 0 ) {
125                 fprintf ( stderr, "select() failed: %s\n",
126                           strerror ( errno ) );
127                 return 0;
128         }
129         if ( ready == 0 )
130                 return 0;
131
132         /* Return if we're not retrieving data yet */
133         if ( ! retrieve )
134                 return 1;
135
136         /* Fetch data */
137         len = read ( hijack->fd, nic->packet, ETH_FRAME_LEN );
138         if ( len < 0 ) {
139                 fprintf ( stderr, "read() failed: %s\n",
140                           strerror ( errno ) );
141                 return 0;
142         }
143         nic->packetlen = len;
144
145         return 1;
146 }
147
148 static void hijack_transmit ( struct nic *nic, const char *dest,
149                               unsigned int type, unsigned int size,
150                               const char *packet ) {
151         struct hijack *hijack = nic->priv_data;
152         unsigned int nstype = htons ( type );
153         unsigned int total_size = ETH_HLEN + size;
154         char txbuf[ total_size ];
155
156         /* Build packet header */
157         memcpy ( txbuf, dest, ETH_ALEN );
158         memcpy ( txbuf + ETH_ALEN, nic->node_addr, ETH_ALEN );
159         memcpy ( txbuf + 2 * ETH_ALEN, &nstype, 2 );
160         memcpy ( txbuf + ETH_HLEN, packet, size );
161
162         /* Transmit data */
163         if ( write ( hijack->fd, txbuf, total_size ) != total_size ) {
164                 fprintf ( stderr, "write() failed: %s\n",
165                           strerror ( errno ) );
166         }
167 }
168
169 static int hijack_connect ( struct nic *nic ) {
170         return 1;
171 }
172
173 static void hijack_irq ( struct nic *nic, irq_action_t action ) {
174         /* Do nothing */
175 }
176
177 static struct nic_operations hijack_operations = {
178         .connect        = hijack_connect,
179         .transmit       = hijack_transmit,
180         .poll           = hijack_poll,
181         .irq            = hijack_irq,
182 };
183
184 int hijack_probe ( struct hijack_device *hijack_dev ) {
185         struct nic *nic;
186         struct hijack *hijack;
187         struct sockaddr_un sun;
188         int i;
189
190         nic = alloc_netdevice ( sizeof ( *hijack ) );
191         if ( ! nic ) {
192                 fprintf ( stderr, "alloc_netdevice() failed\n" );
193                 goto err_alloc;
194         }
195         hijack = nic->priv_data;
196         memset ( hijack, 0, sizeof ( *hijack ) );
197
198         /* Create socket */
199         hijack->fd = socket ( PF_UNIX, SOCK_SEQPACKET, 0 );
200         if ( hijack->fd < 0 ) {
201                 fprintf ( stderr, "socket() failed: %s\n",
202                           strerror ( errno ) );
203                 goto err;
204         }
205
206         /* Connect to hijack daemon */
207         sun.sun_family = AF_UNIX;
208         snprintf ( sun.sun_path, sizeof ( sun.sun_path ), "/var/run/hijack-%s",
209                    hijack_dev->name );
210         if ( connect ( hijack->fd, ( struct sockaddr * ) &sun,
211                        sizeof ( sun ) ) < 0 ) {
212                 fprintf ( stderr, "could not connect to %s: %s\n",
213                           sun.sun_path, strerror ( errno ) );
214                 goto err;
215         }
216
217         /* Generate MAC address */
218         srand ( time ( NULL ) );
219         for ( i = 0 ; i < ETH_ALEN ; i++ ) {
220                 nic->node_addr[i] = ( rand() & 0xff );
221         }
222         nic->node_addr[0] &= 0xfe; /* clear multicast bit */
223         nic->node_addr[0] |= 0x02; /* set "locally-assigned" bit */
224
225         nic->nic_op = &hijack_operations;
226         if ( register_netdevice ( nic ) < 0 )
227                 goto err;
228
229         hijack_set_drvdata ( hijack_dev, nic );
230         return 1;
231
232  err:
233         if ( hijack->fd >= 0 )
234                 close ( hijack->fd );
235         free_netdevice ( nic );
236  err_alloc:
237         return 0;
238 }
239
240 static void hijack_disable ( struct hijack_device *hijack_dev ) {
241         struct nic *nic = hijack_get_drvdata ( hijack_dev );
242         struct hijack *hijack = nic->priv_data;
243         
244         unregister_netdevice ( nic );
245         close ( hijack->fd );
246 }
247
248 /*****************************************************************************
249  *
250  * uIP wrapper layer
251  *
252  */
253
254 #include "../proto/uip/uip.h"
255 #include "../proto/uip/uip_arp.h"
256
257 struct tcp_connection;
258
259 struct tcp_operations {
260         void ( * aborted ) ( struct tcp_connection *conn );
261         void ( * timedout ) ( struct tcp_connection *conn );
262         void ( * closed ) ( struct tcp_connection *conn );
263         void ( * connected ) ( struct tcp_connection *conn );
264         void ( * acked ) ( struct tcp_connection *conn, size_t len );
265         void ( * newdata ) ( struct tcp_connection *conn,
266                              void *data, size_t len );
267         void ( * senddata ) ( struct tcp_connection *conn );
268 };
269
270 struct tcp_connection {
271         struct sockaddr_in sin;
272         struct tcp_operations *tcp_op;
273 };
274
275 int tcp_connect ( struct tcp_connection *conn ) {
276         struct uip_conn *uip_conn;
277         u16_t ipaddr[2];
278
279         assert ( conn->sin.sin_addr.s_addr != 0 );
280         assert ( conn->sin.sin_port != 0 );
281         assert ( conn->tcp_op != NULL );
282         assert ( sizeof ( uip_conn->appstate ) == sizeof ( conn ) );
283
284         * ( ( uint32_t * ) ipaddr ) = conn->sin.sin_addr.s_addr;
285         uip_conn = uip_connect ( ipaddr, conn->sin.sin_port );
286         if ( ! uip_conn )
287                 return -1;
288
289         *( ( void ** ) uip_conn->appstate ) = conn;
290         return 0;
291 }
292
293 void tcp_send ( struct tcp_connection *conn, const void *data,
294                        size_t len ) {
295         assert ( conn = *( ( void ** ) uip_conn->appstate ) );
296         uip_send ( ( void * ) data, len );
297 }
298
299 void tcp_close ( struct tcp_connection *conn ) {
300         assert ( conn = *( ( void ** ) uip_conn->appstate ) );
301         uip_close();
302 }
303
304 void uip_tcp_appcall ( void ) {
305         struct tcp_connection *conn = *( ( void ** ) uip_conn->appstate );
306         struct tcp_operations *op = conn->tcp_op;
307
308         assert ( conn->tcp_op->closed != NULL );
309         assert ( conn->tcp_op->connected != NULL );
310         assert ( conn->tcp_op->acked != NULL );
311         assert ( conn->tcp_op->newdata != NULL );
312         assert ( conn->tcp_op->senddata != NULL );
313
314         if ( uip_aborted() && op->aborted ) /* optional method */
315                 op->aborted ( conn );
316         if ( uip_timedout() && op->timedout ) /* optional method */
317                 op->timedout ( conn );
318         if ( uip_closed() && op->closed ) /* optional method */
319                 op->closed ( conn );
320         if ( uip_connected() )
321                 op->connected ( conn );
322         if ( uip_acked() )
323                 op->acked ( conn, uip_conn->len );
324         if ( uip_newdata() )
325                 op->newdata ( conn, ( void * ) uip_appdata, uip_len );
326         if ( uip_rexmit() || uip_newdata() || uip_acked() ||
327              uip_connected() || uip_poll() )
328                 op->senddata ( conn );
329 }
330
331 void uip_udp_appcall ( void ) {
332 }
333
334 static void init_tcpip ( void ) {
335         uip_init();
336         uip_arp_init();
337 }
338
339 #define UIP_HLEN ( 40 + UIP_LLH_LEN )
340
341 static void uip_transmit ( void ) {
342         uip_arp_out();
343         if ( uip_len > UIP_HLEN ) {
344                 memcpy ( uip_buf + UIP_HLEN, ( void * ) uip_appdata,
345                          uip_len - UIP_HLEN );
346         }
347         netdev_transmit ( uip_buf, uip_len );
348         uip_len = 0;
349 }
350
351 static void run_tcpip ( void ) {
352         void *data;
353         size_t len;
354         uint16_t type;
355         int i;
356         
357         if ( netdev_poll ( 1, &data, &len ) ) {
358                 /* We have data */
359                 memcpy ( uip_buf, data, len );
360                 uip_len = len;
361                 type = ntohs ( *( ( uint16_t * ) ( uip_buf + 12 ) ) );
362                 if ( type == ETHERTYPE_ARP ) {
363                         uip_arp_arpin();
364                 } else {
365                         uip_arp_ipin();
366                         uip_input();
367                 }
368                 if ( uip_len > 0 )
369                         uip_transmit();
370         } else {
371                 for ( i = 0 ; i < UIP_CONNS ; i++ ) {
372                         uip_periodic ( i );
373                         if ( uip_len > 0 )
374                                 uip_transmit();
375                 }
376         }
377 }
378
379 /*****************************************************************************
380  *
381  * "Hello world" protocol tester
382  *
383  */
384
385 #include <stddef.h>
386 #define container_of(ptr, type, member) ({                      \
387         const typeof( ((type *)0)->member ) *__mptr = (ptr);    \
388         (type *)( (char *)__mptr - offsetof(type,member) );})
389
390 enum hello_state {
391         HELLO_SENDING_MESSAGE = 1,
392         HELLO_SENDING_ENDL,
393 };
394
395 struct hello_request {
396         struct tcp_connection tcp;
397         const char *message;
398         enum hello_state state;
399         size_t remaining;
400         void ( *callback ) ( char *data, size_t len );
401         int complete;
402 };
403
404 static inline struct hello_request *
405 tcp_to_hello ( struct tcp_connection *conn ) {
406         return container_of ( conn, struct hello_request, tcp );
407 }
408
409 static void hello_aborted ( struct tcp_connection *conn ) {
410         struct hello_request *hello = tcp_to_hello ( conn );
411
412         printf ( "Connection aborted\n" );
413         hello->complete = 1;
414 }
415
416 static void hello_timedout ( struct tcp_connection *conn ) {
417         struct hello_request *hello = tcp_to_hello ( conn );
418
419         printf ( "Connection timed out\n" );
420         hello->complete = 1;
421 }
422
423 static void hello_closed ( struct tcp_connection *conn ) {
424         struct hello_request *hello = tcp_to_hello ( conn );
425
426         hello->complete = 1;
427 }
428
429 static void hello_connected ( struct tcp_connection *conn ) {
430         struct hello_request *hello = tcp_to_hello ( conn );
431
432         printf ( "Connection established\n" );
433         hello->state = HELLO_SENDING_MESSAGE;
434 }
435
436 static void hello_acked ( struct tcp_connection *conn, size_t len ) {
437         struct hello_request *hello = tcp_to_hello ( conn );
438
439         hello->message += len;
440         hello->remaining -= len;
441         if ( hello->remaining == 0 ) {
442                 switch ( hello->state ) {
443                 case HELLO_SENDING_MESSAGE:
444                         hello->state = HELLO_SENDING_ENDL;
445                         hello->message = "\r\n";
446                         hello->remaining = 2;
447                         break;
448                 case HELLO_SENDING_ENDL:
449                         /* Nothing to do once we've finished sending
450                          * the end-of-line indicator.
451                          */
452                         break;
453                 default:
454                         assert ( 0 );
455                 }
456         }
457 }
458
459 static void hello_newdata ( struct tcp_connection *conn, void *data,
460                             size_t len ) {
461         struct hello_request *hello = tcp_to_hello ( conn );
462
463         hello->callback ( data, len );
464 }
465
466 static void hello_senddata ( struct tcp_connection *conn ) {
467         struct hello_request *hello = tcp_to_hello ( conn );
468
469         tcp_send ( conn, hello->message, hello->remaining );
470 }
471
472 static struct tcp_operations hello_tcp_operations = {
473         .aborted        = hello_aborted,
474         .timedout       = hello_timedout,
475         .closed         = hello_closed,
476         .connected      = hello_connected,
477         .acked          = hello_acked,
478         .newdata        = hello_newdata,
479         .senddata       = hello_senddata,
480 };
481
482 static int hello_connect ( struct hello_request *hello ) {
483         hello->tcp.tcp_op = &hello_tcp_operations;
484         hello->remaining = strlen ( hello->message );
485         return tcp_connect ( &hello->tcp );
486 }
487
488 struct hello_options {
489         struct sockaddr_in server;
490         const char *message;
491 };
492
493 static void hello_usage ( char **argv ) {
494         fprintf ( stderr,
495                   "Usage: %s [global options] hello [hello-specific options]\n"
496                   "\n"
497                   "hello-specific options:\n"
498                   "  -h|--host              Host IP address\n"
499                   "  -p|--port              Port number\n"
500                   "  -m|--message           Message to send\n",
501                   argv[0] );
502 }
503
504 static int hello_parse_options ( int argc, char **argv,
505                                  struct hello_options *options ) {
506         static struct option long_options[] = {
507                 { "host", 1, NULL, 'h' },
508                 { "port", 1, NULL, 'p' },
509                 { "message", 1, NULL, 'm' },
510                 { },
511         };
512         int c;
513         char *endptr;
514
515         /* Set default options */
516         memset ( options, 0, sizeof ( *options ) );
517         inet_aton ( "192.168.0.1", &options->server.sin_addr );
518         options->server.sin_port = htons ( 80 );
519         options->message = "Hello world!";
520
521         /* Parse command-line options */
522         while ( 1 ) {
523                 int option_index = 0;
524                 
525                 c = getopt_long ( argc, argv, "h:p:", long_options,
526                                   &option_index );
527                 if ( c < 0 )
528                         break;
529
530                 switch ( c ) {
531                 case 'h':
532                         if ( inet_aton ( optarg,
533                                          &options->server.sin_addr ) == 0 ) {
534                                 fprintf ( stderr, "Invalid IP address %s\n",
535                                           optarg );
536                                 return -1;
537                         }
538                         break;
539                 case 'p':
540                         options->server.sin_port =
541                                 htons ( strtoul ( optarg, &endptr, 0 ) );
542                         if ( *endptr != '\0' ) {
543                                 fprintf ( stderr, "Invalid port %s\n",
544                                           optarg );
545                                 return -1;
546                         }
547                         break;
548                 case 'm':
549                         options->message = optarg;
550                         break;
551                 case '?':
552                         /* Unrecognised option */
553                         return -1;
554                 default:
555                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
556                         return -1;
557                 }
558         }
559
560         /* Check there are no remaining arguments */
561         if ( optind != argc ) {
562                 hello_usage ( argv );
563                 return -1;
564         }
565         
566         return optind;
567 }
568
569 static void test_hello_callback ( char *data, size_t len ) {
570         int i;
571         char c;
572
573         for ( i = 0 ; i < len ; i++ ) {
574                 c = data[i];
575                 if ( c == '\r' ) {
576                         /* Print nothing */
577                 } else if ( ( c == '\n' ) || ( c >= 32 ) || ( c <= 126 ) ) {
578                         putchar ( c );
579                 } else {
580                         putchar ( '.' );
581                 }
582         }       
583 }
584
585 static int test_hello ( int argc, char **argv ) {
586         struct hello_options options;
587         struct hello_request hello;
588
589         /* Parse hello-specific options */
590         if ( hello_parse_options ( argc, argv, &options ) < 0 )
591                 return -1;
592
593         /* Construct hello request */
594         memset ( &hello, 0, sizeof ( hello ) );
595         hello.tcp.sin = options.server;
596         hello.message = options.message;
597         hello.callback = test_hello_callback;
598         fprintf ( stderr, "Saying \"%s\" to %s:%d\n", hello.message,
599                   inet_ntoa ( hello.tcp.sin.sin_addr ),
600                   ntohs ( hello.tcp.sin.sin_port ) );
601
602         /* Issue hello request and run to completion */
603         hello_connect ( &hello );
604         while ( ! hello.complete ) {
605                 run_tcpip ();
606         }
607
608         return 0;
609 }
610
611 /*****************************************************************************
612  *
613  * Protocol tester
614  *
615  */
616
617 struct protocol_test {
618         const char *name;
619         int ( *exec ) ( int argc, char **argv );
620 };
621
622 static struct protocol_test tests[] = {
623         { "hello", test_hello },
624 };
625
626 #define NUM_TESTS ( sizeof ( tests ) / sizeof ( tests[0] ) )
627
628 static void list_tests ( void ) {
629         int i;
630
631         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
632                 printf ( "%s\n", tests[i].name );
633         }
634 }
635
636 static struct protocol_test * get_test_from_name ( const char *name ) {
637         int i;
638
639         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
640                 if ( strcmp ( name, tests[i].name ) == 0 )
641                         return &tests[i];
642         }
643
644         return NULL;
645 }
646
647 /*****************************************************************************
648  *
649  * Parse command-line options
650  *
651  */
652
653 struct tester_options {
654         char interface[IF_NAMESIZE];
655 };
656
657 static void usage ( char **argv ) {
658         fprintf ( stderr,
659                   "Usage: %s [global options] <test> [test-specific options]\n"
660                   "\n"
661                   "Global options:\n"
662                   "  -h|--help              Print this help message\n"
663                   "  -i|--interface intf    Use specified network interface\n"
664                   "  -l|--list              List available tests\n"
665                   "\n"
666                   "Use \"%s <test> -h\" to view test-specific options\n",
667                   argv[0], argv[0] );
668 }
669
670 static int parse_options ( int argc, char **argv,
671                            struct tester_options *options ) {
672         static struct option long_options[] = {
673                 { "interface", 1, NULL, 'i' },
674                 { "list", 0, NULL, 'l' },
675                 { "help", 0, NULL, 'h' },
676                 { },
677         };
678         int c;
679
680         /* Set default options */
681         memset ( options, 0, sizeof ( *options ) );
682         strncpy ( options->interface, "eth0", sizeof ( options->interface ) );
683
684         /* Parse command-line options */
685         while ( 1 ) {
686                 int option_index = 0;
687                 
688                 c = getopt_long ( argc, argv, "+i:hl", long_options,
689                                   &option_index );
690                 if ( c < 0 )
691                         break;
692
693                 switch ( c ) {
694                 case 'i':
695                         strncpy ( options->interface, optarg,
696                                   sizeof ( options->interface ) );
697                         break;
698                 case 'l':
699                         list_tests ();
700                         return -1;
701                 case 'h':
702                         usage ( argv );
703                         return -1;
704                 case '?':
705                         /* Unrecognised option */
706                         return -1;
707                 default:
708                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
709                         return -1;
710                 }
711         }
712
713         /* Check there is a test specified */
714         if ( optind == argc ) {
715                 usage ( argv );
716                 return -1;
717         }
718         
719         return optind;
720 }
721
722 /*****************************************************************************
723  *
724  * Main program
725  *
726  */
727
728 int main ( int argc, char **argv ) {
729         struct tester_options options;
730         struct protocol_test *test;
731         struct hijack_device hijack_dev;
732
733         /* Parse command-line options */
734         if ( parse_options ( argc, argv, &options ) < 0 )
735                 exit ( 1 );
736
737         /* Identify test */
738         test = get_test_from_name ( argv[optind] );
739         if ( ! test ) {
740                 fprintf ( stderr, "Unrecognised test \"%s\"\n", argv[optind] );
741                 exit ( 1 );
742         }
743         optind++;
744
745         /* Initialise the protocol stack */
746         init_tcpip();
747
748         /* Open the hijack device */
749         hijack_dev.name = options.interface;
750         if ( ! hijack_probe ( &hijack_dev ) )
751                 exit ( 1 );
752
753         /* Run the test */
754         if ( test->exec ( argc, argv ) < 0 )
755                 exit ( 1 );
756
757         /* Close the hijack device */
758         hijack_disable ( &hijack_dev );
759
760         return 0;
761 }