Now capable of sending "Hello world!" via TCP.
[people/xl0/gpxe.git] / src / util / prototester.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <unistd.h>
4 #include <string.h>
5 #include <errno.h>
6 #include <time.h>
7 #include <sys/socket.h>
8 #include <sys/un.h>
9 #include <net/if.h>
10 #include <net/ethernet.h>
11 #include <netinet/in.h>
12 #include <arpa/inet.h>
13 #include <getopt.h>
14 #include <assert.h>
15
16 typedef int irq_action_t;
17
18 struct nic {
19         struct nic_operations   *nic_op;
20         unsigned char           *node_addr;
21         unsigned char           *packet;
22         unsigned int            packetlen;
23         void                    *priv_data;     /* driver private data */
24 };
25
26 struct nic_operations {
27         int ( *connect ) ( struct nic * );
28         int ( *poll ) ( struct nic *, int retrieve );
29         void ( *transmit ) ( struct nic *, const char *,
30                              unsigned int, unsigned int, const char * );
31         void ( *irq ) ( struct nic *, irq_action_t );
32 };
33
34 /*****************************************************************************
35  *
36  * Net device layer
37  *
38  */
39
40 #include "../proto/uip/uip_arp.h"
41
42 static unsigned char node_addr[ETH_ALEN];
43 static unsigned char packet[ETH_FRAME_LEN];
44 struct nic static_nic = {
45         .node_addr = node_addr,
46         .packet = packet,
47 };
48
49 /* Must be a macro because priv_data[] is of variable size */
50 #define alloc_netdevice( priv_size ) ( {        \
51         static char priv_data[priv_size];       \
52         static_nic.priv_data = priv_data;       \
53         &static_nic; } )
54
55 static int register_netdevice ( struct nic *nic ) {
56         struct uip_eth_addr hwaddr;
57
58         memcpy ( &hwaddr, nic->node_addr, sizeof ( hwaddr ) );
59         uip_setethaddr ( hwaddr );
60         return 0;
61 }
62
63 static inline void unregister_netdevice ( struct nic *nic ) {
64         /* Do nothing */
65 }
66
67 static inline void free_netdevice ( struct nic *nic ) {
68         /* Do nothing */
69 }
70
71 static int netdev_poll ( int retrieve, void **data, size_t *len ) {
72         int rc = static_nic.nic_op->poll ( &static_nic, retrieve );
73         *data = static_nic.packet;
74         *len = static_nic.packetlen;
75         return rc;
76 }
77
78 static void netdev_transmit ( const void *data, size_t len ) {
79         uint16_t type = ntohs ( *( ( uint16_t * ) ( data + 12 ) ) );
80         static_nic.nic_op->transmit ( &static_nic, data, type,
81                                       len - ETH_HLEN,
82                                       data + ETH_HLEN );
83 }
84
85 /*****************************************************************************
86  *
87  * Hijack device interface
88  *
89  * This requires a hijack daemon to be running
90  *
91  */
92
93 struct hijack {
94         int fd;
95 };
96
97 struct hijack_device {
98         char *name;
99         void *priv;
100 };
101
102 static inline void hijack_set_drvdata ( struct hijack_device *hijack_dev,
103                                         void *data ) {
104         hijack_dev->priv = data;
105 }
106
107 static inline void * hijack_get_drvdata ( struct hijack_device *hijack_dev ) {
108         return hijack_dev->priv;
109 }
110
111 static int hijack_poll ( struct nic *nic, int retrieve ) {
112         struct hijack *hijack = nic->priv_data;
113         fd_set fdset;
114         struct timeval tv;
115         int ready;
116         ssize_t len;
117
118         /* Poll for data */
119         FD_ZERO ( &fdset );
120         FD_SET ( hijack->fd, &fdset );
121         tv.tv_sec = 0;
122         tv.tv_usec = 500; /* 500us to avoid hogging CPU */
123         ready = select ( ( hijack->fd + 1 ), &fdset, NULL, NULL, &tv );
124         if ( ready < 0 ) {
125                 fprintf ( stderr, "select() failed: %s\n",
126                           strerror ( errno ) );
127                 return 0;
128         }
129         if ( ready == 0 )
130                 return 0;
131
132         /* Return if we're not retrieving data yet */
133         if ( ! retrieve )
134                 return 1;
135
136         /* Fetch data */
137         len = read ( hijack->fd, nic->packet, ETH_FRAME_LEN );
138         if ( len < 0 ) {
139                 fprintf ( stderr, "read() failed: %s\n",
140                           strerror ( errno ) );
141                 return 0;
142         }
143         nic->packetlen = len;
144
145         return 1;
146 }
147
148 static void hijack_transmit ( struct nic *nic, const char *dest,
149                               unsigned int type, unsigned int size,
150                               const char *packet ) {
151         struct hijack *hijack = nic->priv_data;
152         unsigned int nstype = htons ( type );
153         unsigned int total_size = ETH_HLEN + size;
154         char txbuf[ total_size ];
155
156         /* Build packet header */
157         memcpy ( txbuf, dest, ETH_ALEN );
158         memcpy ( txbuf + ETH_ALEN, nic->node_addr, ETH_ALEN );
159         memcpy ( txbuf + 2 * ETH_ALEN, &nstype, 2 );
160         memcpy ( txbuf + ETH_HLEN, packet, size );
161
162         /* Transmit data */
163         if ( write ( hijack->fd, txbuf, total_size ) != total_size ) {
164                 fprintf ( stderr, "write() failed: %s\n",
165                           strerror ( errno ) );
166         }
167 }
168
169 static int hijack_connect ( struct nic *nic ) {
170         return 1;
171 }
172
173 static void hijack_irq ( struct nic *nic, irq_action_t action ) {
174         /* Do nothing */
175 }
176
177 static struct nic_operations hijack_operations = {
178         .connect        = hijack_connect,
179         .transmit       = hijack_transmit,
180         .poll           = hijack_poll,
181         .irq            = hijack_irq,
182 };
183
184 int hijack_probe ( struct hijack_device *hijack_dev ) {
185         struct nic *nic;
186         struct hijack *hijack;
187         struct sockaddr_un sun;
188         int i;
189
190         nic = alloc_netdevice ( sizeof ( *hijack ) );
191         if ( ! nic ) {
192                 fprintf ( stderr, "alloc_netdevice() failed\n" );
193                 goto err_alloc;
194         }
195         hijack = nic->priv_data;
196         memset ( hijack, 0, sizeof ( *hijack ) );
197
198         /* Create socket */
199         hijack->fd = socket ( PF_UNIX, SOCK_SEQPACKET, 0 );
200         if ( hijack->fd < 0 ) {
201                 fprintf ( stderr, "socket() failed: %s\n",
202                           strerror ( errno ) );
203                 goto err;
204         }
205
206         /* Connect to hijack daemon */
207         sun.sun_family = AF_UNIX;
208         snprintf ( sun.sun_path, sizeof ( sun.sun_path ), "/var/run/hijack-%s",
209                    hijack_dev->name );
210         if ( connect ( hijack->fd, ( struct sockaddr * ) &sun,
211                        sizeof ( sun ) ) < 0 ) {
212                 fprintf ( stderr, "could not connect to %s: %s\n",
213                           sun.sun_path, strerror ( errno ) );
214                 goto err;
215         }
216
217         /* Generate MAC address */
218         srand ( time ( NULL ) );
219         for ( i = 0 ; i < ETH_ALEN ; i++ ) {
220                 nic->node_addr[i] = ( rand() & 0xff );
221         }
222         nic->node_addr[0] &= 0xfe; /* clear multicast bit */
223         nic->node_addr[0] |= 0x02; /* set "locally-assigned" bit */
224
225         nic->nic_op = &hijack_operations;
226         if ( register_netdevice ( nic ) < 0 )
227                 goto err;
228
229         hijack_set_drvdata ( hijack_dev, nic );
230         return 1;
231
232  err:
233         if ( hijack->fd >= 0 )
234                 close ( hijack->fd );
235         free_netdevice ( nic );
236  err_alloc:
237         return 0;
238 }
239
240 static void hijack_disable ( struct hijack_device *hijack_dev ) {
241         struct nic *nic = hijack_get_drvdata ( hijack_dev );
242         struct hijack *hijack = nic->priv_data;
243         
244         unregister_netdevice ( nic );
245         close ( hijack->fd );
246 }
247
248 /*****************************************************************************
249  *
250  * uIP wrapper layer
251  *
252  */
253
254 #include "../proto/uip/uip.h"
255 #include "../proto/uip/uip_arp.h"
256
257 struct tcp_connection;
258
259 struct tcp_operations {
260         void ( * aborted ) ( struct tcp_connection *conn );
261         void ( * timedout ) ( struct tcp_connection *conn );
262         void ( * closed ) ( struct tcp_connection *conn );
263         void ( * connected ) ( struct tcp_connection *conn );
264         void ( * acked ) ( struct tcp_connection *conn, size_t len );
265         void ( * newdata ) ( struct tcp_connection *conn );
266         void ( * senddata ) ( struct tcp_connection *conn );
267 };
268
269 struct tcp_connection {
270         struct sockaddr_in sin;
271         struct tcp_operations *tcp_op;
272 };
273
274 static int tcp_connect ( struct tcp_connection *conn ) {
275         struct uip_conn *uip_conn;
276         u16_t ipaddr[2];
277
278         assert ( conn->sin.sin_addr.s_addr != 0 );
279         assert ( conn->sin.sin_port != 0 );
280         assert ( conn->tcp_op != NULL );
281         assert ( sizeof ( uip_conn->appstate ) == sizeof ( conn ) );
282
283         * ( ( uint32_t * ) ipaddr ) = conn->sin.sin_addr.s_addr;
284         uip_conn = uip_connect ( ipaddr, conn->sin.sin_port );
285         if ( ! uip_conn )
286                 return -1;
287
288         *( ( void ** ) uip_conn->appstate ) = conn;
289         return 0;
290 }
291
292 static void tcp_send ( struct tcp_connection *conn, const void *data,
293                        size_t len ) {
294         assert ( conn = *( ( void ** ) uip_conn->appstate ) );
295         uip_send ( ( void * ) data, len );
296 }
297
298 static void tcp_close ( struct tcp_connection *conn ) {
299         assert ( conn = *( ( void ** ) uip_conn->appstate ) );
300         uip_close();
301 }
302
303 void uip_tcp_appcall ( void ) {
304         struct tcp_connection *conn = *( ( void ** ) uip_conn->appstate );
305         struct tcp_operations *op = conn->tcp_op;
306
307         assert ( conn->tcp_op->closed != NULL );
308         assert ( conn->tcp_op->connected != NULL );
309         assert ( conn->tcp_op->acked != NULL );
310         assert ( conn->tcp_op->newdata != NULL );
311         assert ( conn->tcp_op->senddata != NULL );
312
313         if ( uip_aborted() && op->aborted ) /* optional method */
314                 op->aborted ( conn );
315         if ( uip_timedout() && op->timedout ) /* optional method */
316                 op->timedout ( conn );
317         if ( uip_closed() && op->closed ) /* optional method */
318                 op->closed ( conn );
319         if ( uip_connected() )
320                 op->connected ( conn );
321         if ( uip_acked() )
322                 op->acked ( conn, uip_conn->len );
323         if ( uip_newdata() )
324                 op->newdata ( conn );
325         if ( uip_rexmit() || uip_newdata() || uip_acked() ||
326              uip_connected() || uip_poll() )
327                 op->senddata ( conn );
328 }
329
330 void uip_udp_appcall ( void ) {
331 }
332
333 static void init_tcpip ( void ) {
334         uip_init();
335         uip_arp_init();
336 }
337
338 #define UIP_HLEN ( 40 + UIP_LLH_LEN )
339
340 static void uip_transmit ( void ) {
341         uip_arp_out();
342         if ( uip_len > UIP_HLEN ) {
343                 memcpy ( uip_buf + UIP_HLEN, ( void * ) uip_appdata,
344                          uip_len - UIP_HLEN );
345         }
346         netdev_transmit ( uip_buf, uip_len );
347         uip_len = 0;
348 }
349
350 static void run_tcpip ( void ) {
351         void *data;
352         size_t len;
353         uint16_t type;
354         int i;
355         
356         if ( netdev_poll ( 1, &data, &len ) ) {
357                 /* We have data */
358                 memcpy ( uip_buf, data, len );
359                 uip_len = len;
360                 type = ntohs ( *( ( uint16_t * ) ( uip_buf + 12 ) ) );
361                 if ( type == ETHERTYPE_ARP ) {
362                         uip_arp_arpin();
363                 } else {
364                         uip_arp_ipin();
365                         uip_input();
366                 }
367                 if ( uip_len > 0 )
368                         uip_transmit();
369         } else {
370                 for ( i = 0 ; i < UIP_CONNS ; i++ ) {
371                         uip_periodic ( i );
372                         if ( uip_len > 0 )
373                                 uip_transmit();
374                 }
375         }
376 }
377
378 /*****************************************************************************
379  *
380  * "Hello world" protocol tester
381  *
382  */
383
384 #include <stddef.h>
385 #define container_of(ptr, type, member) ({                      \
386         const typeof( ((type *)0)->member ) *__mptr = (ptr);    \
387         (type *)( (char *)__mptr - offsetof(type,member) );})
388
389 enum hello_state {
390         HELLO_SENDING_MESSAGE = 0,
391         HELLO_SENDING_ENDL,
392 };
393
394 struct hello_request {
395         struct tcp_connection tcp;
396         const char *message;
397         enum hello_state state;
398         int remaining;
399         void ( *callback ) ( struct hello_request *hello );
400         int complete;
401 };
402
403 static inline struct hello_request *
404 tcp_to_hello ( struct tcp_connection *conn ) {
405         return container_of ( conn, struct hello_request, tcp );
406 }
407
408 static void hello_aborted ( struct tcp_connection *conn ) {
409         struct hello_request *hello = tcp_to_hello ( conn );
410
411         printf ( "Connection aborted\n" );
412         hello->complete = 1;
413 }
414
415 static void hello_timedout ( struct tcp_connection *conn ) {
416         struct hello_request *hello = tcp_to_hello ( conn );
417
418         printf ( "Connection timed out\n" );
419         hello->complete = 1;
420 }
421
422 static void hello_closed ( struct tcp_connection *conn ) {
423         struct hello_request *hello = tcp_to_hello ( conn );
424
425         hello->complete = 1;
426 }
427
428 static void hello_connected ( struct tcp_connection *conn ) {
429         struct hello_request *hello = tcp_to_hello ( conn );
430 }
431
432 static void hello_acked ( struct tcp_connection *conn, size_t len ) {
433         struct hello_request *hello = tcp_to_hello ( conn );
434
435         hello->message += len;
436         hello->remaining -= len;
437         if ( hello->remaining <= 0 ) {
438                 switch ( hello->state ) {
439                 case HELLO_SENDING_MESSAGE:
440                         hello->state = HELLO_SENDING_ENDL;
441                         hello->message = "\r\n";
442                         hello->remaining = 2;
443                         break;
444                 case HELLO_SENDING_ENDL:
445                         tcp_close ( conn );
446                         break;
447                 default:
448                         assert ( 0 );
449                 }
450         }
451 }
452
453 static void hello_newdata ( struct tcp_connection *conn ) {
454         struct hello_request *hello = tcp_to_hello ( conn );
455 }
456
457 static void hello_senddata ( struct tcp_connection *conn ) {
458         struct hello_request *hello = tcp_to_hello ( conn );
459
460         tcp_send ( conn, hello->message, hello->remaining );
461 }
462
463 static struct tcp_operations hello_tcp_operations = {
464         .aborted        = hello_aborted,
465         .timedout       = hello_timedout,
466         .closed         = hello_closed,
467         .connected      = hello_connected,
468         .acked          = hello_acked,
469         .newdata        = hello_newdata,
470         .senddata       = hello_senddata,
471 };
472
473 static int hello_connect ( struct hello_request *hello ) {
474         hello->tcp.tcp_op = &hello_tcp_operations;
475         hello->remaining = strlen ( hello->message );
476         return tcp_connect ( &hello->tcp );
477 }
478
479 struct hello_options {
480         struct sockaddr_in server;
481         const char *message;
482 };
483
484 static void hello_usage ( char **argv ) {
485         fprintf ( stderr,
486                   "Usage: %s [global options] hello [hello-specific options]\n"
487                   "\n"
488                   "hello-specific options:\n"
489                   "  -h|--host              Host IP address\n"
490                   "  -p|--port              Port number\n"
491                   "  -m|--message           Message to send\n",
492                   argv[0] );
493 }
494
495 static int hello_parse_options ( int argc, char **argv,
496                                  struct hello_options *options ) {
497         static struct option long_options[] = {
498                 { "host", 1, NULL, 'h' },
499                 { "port", 1, NULL, 'p' },
500                 { "message", 1, NULL, 'm' },
501                 { },
502         };
503         int c;
504         char *endptr;
505
506         /* Set default options */
507         memset ( options, 0, sizeof ( *options ) );
508         inet_aton ( "192.168.0.1", &options->server.sin_addr );
509         options->server.sin_port = htons ( 80 );
510         options->message = "Hello world!";
511
512         /* Parse command-line options */
513         while ( 1 ) {
514                 int option_index = 0;
515                 
516                 c = getopt_long ( argc, argv, "h:p:", long_options,
517                                   &option_index );
518                 if ( c < 0 )
519                         break;
520
521                 switch ( c ) {
522                 case 'h':
523                         if ( inet_aton ( optarg,
524                                          &options->server.sin_addr ) == 0 ) {
525                                 fprintf ( stderr, "Invalid IP address %s\n",
526                                           optarg );
527                                 return -1;
528                         }
529                         break;
530                 case 'p':
531                         options->server.sin_port =
532                                 htons ( strtoul ( optarg, &endptr, 0 ) );
533                         if ( *endptr != '\0' ) {
534                                 fprintf ( stderr, "Invalid port %s\n",
535                                           optarg );
536                                 return -1;
537                         }
538                         break;
539                 case 'm':
540                         options->message = optarg;
541                         break;
542                 case '?':
543                         /* Unrecognised option */
544                         return -1;
545                 default:
546                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
547                         return -1;
548                 }
549         }
550
551         /* Check there are no remaining arguments */
552         if ( optind != argc ) {
553                 hello_usage ( argv );
554                 return -1;
555         }
556         
557         return optind;
558 }
559
560 static void test_hello_callback ( struct hello_request *hello ) {
561         
562 }
563
564 static int test_hello ( int argc, char **argv ) {
565         struct hello_options options;
566         struct hello_request hello;
567
568         /* Parse hello-specific options */
569         if ( hello_parse_options ( argc, argv, &options ) < 0 )
570                 return -1;
571
572         /* Construct hello request */
573         memset ( &hello, 0, sizeof ( hello ) );
574         hello.tcp.sin = options.server;
575         hello.message = options.message;
576         hello.callback = test_hello_callback;
577         fprintf ( stderr, "Saying \"%s\" to %s:%d\n", hello.message,
578                   inet_ntoa ( hello.tcp.sin.sin_addr ),
579                   ntohs ( hello.tcp.sin.sin_port ) );
580
581         /* Issue hello request and run to completion */
582         hello_connect ( &hello );
583         while ( ! hello.complete ) {
584                 run_tcpip ();
585         }
586
587         return 0;
588 }
589
590 /*****************************************************************************
591  *
592  * Protocol tester
593  *
594  */
595
596 struct protocol_test {
597         const char *name;
598         int ( *exec ) ( int argc, char **argv );
599 };
600
601 static struct protocol_test tests[] = {
602         { "hello", test_hello },
603 };
604
605 #define NUM_TESTS ( sizeof ( tests ) / sizeof ( tests[0] ) )
606
607 static void list_tests ( void ) {
608         int i;
609
610         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
611                 printf ( "%s\n", tests[i].name );
612         }
613 }
614
615 static struct protocol_test * get_test_from_name ( const char *name ) {
616         int i;
617
618         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
619                 if ( strcmp ( name, tests[i].name ) == 0 )
620                         return &tests[i];
621         }
622
623         return NULL;
624 }
625
626 /*****************************************************************************
627  *
628  * Parse command-line options
629  *
630  */
631
632 struct tester_options {
633         char interface[IF_NAMESIZE];
634 };
635
636 static void usage ( char **argv ) {
637         fprintf ( stderr,
638                   "Usage: %s [global options] <test> [test-specific options]\n"
639                   "\n"
640                   "Global options:\n"
641                   "  -h|--help              Print this help message\n"
642                   "  -i|--interface intf    Use specified network interface\n"
643                   "  -l|--list              List available tests\n"
644                   "\n"
645                   "Use \"%s <test> -h\" to view test-specific options\n",
646                   argv[0], argv[0] );
647 }
648
649 static int parse_options ( int argc, char **argv,
650                            struct tester_options *options ) {
651         static struct option long_options[] = {
652                 { "interface", 1, NULL, 'i' },
653                 { "list", 0, NULL, 'l' },
654                 { "help", 0, NULL, 'h' },
655                 { },
656         };
657         int c;
658
659         /* Set default options */
660         memset ( options, 0, sizeof ( *options ) );
661         strncpy ( options->interface, "eth0", sizeof ( options->interface ) );
662
663         /* Parse command-line options */
664         while ( 1 ) {
665                 int option_index = 0;
666                 
667                 c = getopt_long ( argc, argv, "+i:hl", long_options,
668                                   &option_index );
669                 if ( c < 0 )
670                         break;
671
672                 switch ( c ) {
673                 case 'i':
674                         strncpy ( options->interface, optarg,
675                                   sizeof ( options->interface ) );
676                         break;
677                 case 'l':
678                         list_tests ();
679                         return -1;
680                 case 'h':
681                         usage ( argv );
682                         return -1;
683                 case '?':
684                         /* Unrecognised option */
685                         return -1;
686                 default:
687                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
688                         return -1;
689                 }
690         }
691
692         /* Check there is a test specified */
693         if ( optind == argc ) {
694                 usage ( argv );
695                 return -1;
696         }
697         
698         return optind;
699 }
700
701 /*****************************************************************************
702  *
703  * Main program
704  *
705  */
706
707 int main ( int argc, char **argv ) {
708         struct tester_options options;
709         struct protocol_test *test;
710         struct hijack_device hijack_dev;
711
712         /* Parse command-line options */
713         if ( parse_options ( argc, argv, &options ) < 0 )
714                 exit ( 1 );
715
716         /* Identify test */
717         test = get_test_from_name ( argv[optind] );
718         if ( ! test ) {
719                 fprintf ( stderr, "Unrecognised test \"%s\"\n", argv[optind] );
720                 exit ( 1 );
721         }
722         optind++;
723
724         /* Initialise the protocol stack */
725         init_tcpip();
726
727         /* Open the hijack device */
728         hijack_dev.name = options.interface;
729         if ( ! hijack_probe ( &hijack_dev ) )
730                 exit ( 1 );
731
732         /* Run the test */
733         if ( test->exec ( argc, argv ) < 0 )
734                 exit ( 1 );
735
736         /* Close the hijack device */
737         hijack_disable ( &hijack_dev );
738
739         return 0;
740 }