Allow specifying the local IP address via --from.
[people/xl0/gpxe.git] / src / util / prototester.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <unistd.h>
4 #include <string.h>
5 #include <errno.h>
6 #include <time.h>
7 #include <sys/socket.h>
8 #include <sys/un.h>
9 #include <net/if.h>
10 #include <net/ethernet.h>
11 #include <getopt.h>
12 #include <assert.h>
13
14 #include <gpxe/ip.h>
15 #include <gpxe/tcp.h>
16 #include <gpxe/hello.h>
17
18 typedef int irq_action_t;
19
20 struct nic {
21         struct nic_operations   *nic_op;
22         unsigned char           *node_addr;
23         unsigned char           *packet;
24         unsigned int            packetlen;
25         void                    *priv_data;     /* driver private data */
26 };
27
28 struct nic_operations {
29         int ( *connect ) ( struct nic * );
30         int ( *poll ) ( struct nic *, int retrieve );
31         void ( *transmit ) ( struct nic *, const char *,
32                              unsigned int, unsigned int, const char * );
33         void ( *irq ) ( struct nic *, irq_action_t );
34 };
35
36 /*****************************************************************************
37  *
38  * Net device layer
39  *
40  */
41
42 #include "../proto/uip/uip_arp.h"
43
44 static unsigned char node_addr[ETH_ALEN];
45 static unsigned char packet[ETH_FRAME_LEN];
46 struct nic static_nic = {
47         .node_addr = node_addr,
48         .packet = packet,
49 };
50
51 /* Must be a macro because priv_data[] is of variable size */
52 #define alloc_netdevice( priv_size ) ( {        \
53         static char priv_data[priv_size];       \
54         static_nic.priv_data = priv_data;       \
55         &static_nic; } )
56
57 static int register_netdevice ( struct nic *nic ) {
58         struct uip_eth_addr hwaddr;
59
60         memcpy ( &hwaddr, nic->node_addr, sizeof ( hwaddr ) );
61         uip_setethaddr ( hwaddr );
62         return 0;
63 }
64
65 static inline void unregister_netdevice ( struct nic *nic ) {
66         /* Do nothing */
67 }
68
69 static inline void free_netdevice ( struct nic *nic ) {
70         /* Do nothing */
71 }
72
73 int netdev_poll ( int retrieve, void **data, size_t *len ) {
74         int rc = static_nic.nic_op->poll ( &static_nic, retrieve );
75         *data = static_nic.packet;
76         *len = static_nic.packetlen;
77         return rc;
78 }
79
80 void netdev_transmit ( const void *data, size_t len ) {
81         uint16_t type = ntohs ( *( ( uint16_t * ) ( data + 12 ) ) );
82         static_nic.nic_op->transmit ( &static_nic, data, type,
83                                       len - ETH_HLEN,
84                                       data + ETH_HLEN );
85 }
86
87 /*****************************************************************************
88  *
89  * Hijack device interface
90  *
91  * This requires a hijack daemon to be running
92  *
93  */
94
95 struct hijack {
96         int fd;
97 };
98
99 struct hijack_device {
100         char *name;
101         void *priv;
102 };
103
104 static inline void hijack_set_drvdata ( struct hijack_device *hijack_dev,
105                                         void *data ) {
106         hijack_dev->priv = data;
107 }
108
109 static inline void * hijack_get_drvdata ( struct hijack_device *hijack_dev ) {
110         return hijack_dev->priv;
111 }
112
113 static int hijack_poll ( struct nic *nic, int retrieve ) {
114         struct hijack *hijack = nic->priv_data;
115         fd_set fdset;
116         struct timeval tv;
117         int ready;
118         ssize_t len;
119
120         /* Poll for data */
121         FD_ZERO ( &fdset );
122         FD_SET ( hijack->fd, &fdset );
123         tv.tv_sec = 0;
124         tv.tv_usec = 500; /* 500us to avoid hogging CPU */
125         ready = select ( ( hijack->fd + 1 ), &fdset, NULL, NULL, &tv );
126         if ( ready < 0 ) {
127                 fprintf ( stderr, "select() failed: %s\n",
128                           strerror ( errno ) );
129                 return 0;
130         }
131         if ( ready == 0 )
132                 return 0;
133
134         /* Return if we're not retrieving data yet */
135         if ( ! retrieve )
136                 return 1;
137
138         /* Fetch data */
139         len = read ( hijack->fd, nic->packet, ETH_FRAME_LEN );
140         if ( len < 0 ) {
141                 fprintf ( stderr, "read() failed: %s\n",
142                           strerror ( errno ) );
143                 return 0;
144         }
145         nic->packetlen = len;
146
147         return 1;
148 }
149
150 static void hijack_transmit ( struct nic *nic, const char *dest,
151                               unsigned int type, unsigned int size,
152                               const char *packet ) {
153         struct hijack *hijack = nic->priv_data;
154         unsigned int nstype = htons ( type );
155         unsigned int total_size = ETH_HLEN + size;
156         char txbuf[ total_size ];
157
158         /* Build packet header */
159         memcpy ( txbuf, dest, ETH_ALEN );
160         memcpy ( txbuf + ETH_ALEN, nic->node_addr, ETH_ALEN );
161         memcpy ( txbuf + 2 * ETH_ALEN, &nstype, 2 );
162         memcpy ( txbuf + ETH_HLEN, packet, size );
163
164         /* Transmit data */
165         if ( write ( hijack->fd, txbuf, total_size ) != total_size ) {
166                 fprintf ( stderr, "write() failed: %s\n",
167                           strerror ( errno ) );
168         }
169 }
170
171 static int hijack_connect ( struct nic *nic ) {
172         return 1;
173 }
174
175 static void hijack_irq ( struct nic *nic, irq_action_t action ) {
176         /* Do nothing */
177 }
178
179 static struct nic_operations hijack_operations = {
180         .connect        = hijack_connect,
181         .transmit       = hijack_transmit,
182         .poll           = hijack_poll,
183         .irq            = hijack_irq,
184 };
185
186 int hijack_probe ( struct hijack_device *hijack_dev ) {
187         struct nic *nic;
188         struct hijack *hijack;
189         struct sockaddr_un sun;
190         int i;
191
192         nic = alloc_netdevice ( sizeof ( *hijack ) );
193         if ( ! nic ) {
194                 fprintf ( stderr, "alloc_netdevice() failed\n" );
195                 goto err_alloc;
196         }
197         hijack = nic->priv_data;
198         memset ( hijack, 0, sizeof ( *hijack ) );
199
200         /* Create socket */
201         hijack->fd = socket ( PF_UNIX, SOCK_SEQPACKET, 0 );
202         if ( hijack->fd < 0 ) {
203                 fprintf ( stderr, "socket() failed: %s\n",
204                           strerror ( errno ) );
205                 goto err;
206         }
207
208         /* Connect to hijack daemon */
209         sun.sun_family = AF_UNIX;
210         snprintf ( sun.sun_path, sizeof ( sun.sun_path ), "/var/run/hijack-%s",
211                    hijack_dev->name );
212         if ( connect ( hijack->fd, ( struct sockaddr * ) &sun,
213                        sizeof ( sun ) ) < 0 ) {
214                 fprintf ( stderr, "could not connect to %s: %s\n",
215                           sun.sun_path, strerror ( errno ) );
216                 goto err;
217         }
218
219         /* Generate MAC address */
220         srand ( time ( NULL ) );
221         for ( i = 0 ; i < ETH_ALEN ; i++ ) {
222                 nic->node_addr[i] = ( rand() & 0xff );
223         }
224         nic->node_addr[0] &= 0xfe; /* clear multicast bit */
225         nic->node_addr[0] |= 0x02; /* set "locally-assigned" bit */
226
227         nic->nic_op = &hijack_operations;
228         if ( register_netdevice ( nic ) < 0 )
229                 goto err;
230
231         hijack_set_drvdata ( hijack_dev, nic );
232         return 1;
233
234  err:
235         if ( hijack->fd >= 0 )
236                 close ( hijack->fd );
237         free_netdevice ( nic );
238  err_alloc:
239         return 0;
240 }
241
242 static void hijack_disable ( struct hijack_device *hijack_dev ) {
243         struct nic *nic = hijack_get_drvdata ( hijack_dev );
244         struct hijack *hijack = nic->priv_data;
245         
246         unregister_netdevice ( nic );
247         close ( hijack->fd );
248 }
249
250 /*****************************************************************************
251  *
252  * "Hello world" protocol tester
253  *
254  */
255
256 struct hello_options {
257         struct sockaddr_in server;
258         const char *message;
259 };
260
261 static void hello_usage ( char **argv ) {
262         fprintf ( stderr,
263                   "Usage: %s [global options] hello [hello-specific options]\n"
264                   "\n"
265                   "hello-specific options:\n"
266                   "  -h|--host              Host IP address\n"
267                   "  -p|--port              Port number\n"
268                   "  -m|--message           Message to send\n",
269                   argv[0] );
270 }
271
272 static int hello_parse_options ( int argc, char **argv,
273                                  struct hello_options *options ) {
274         static struct option long_options[] = {
275                 { "host", 1, NULL, 'h' },
276                 { "port", 1, NULL, 'p' },
277                 { "message", 1, NULL, 'm' },
278                 { },
279         };
280         int c;
281         char *endptr;
282
283         /* Set default options */
284         memset ( options, 0, sizeof ( *options ) );
285         inet_aton ( "192.168.0.1", &options->server.sin_addr );
286         options->server.sin_port = htons ( 80 );
287         options->message = "Hello world!";
288
289         /* Parse command-line options */
290         while ( 1 ) {
291                 int option_index = 0;
292                 
293                 c = getopt_long ( argc, argv, "h:p:", long_options,
294                                   &option_index );
295                 if ( c < 0 )
296                         break;
297
298                 switch ( c ) {
299                 case 'h':
300                         if ( inet_aton ( optarg,
301                                          &options->server.sin_addr ) == 0 ) {
302                                 fprintf ( stderr, "Invalid IP address %s\n",
303                                           optarg );
304                                 return -1;
305                         }
306                         break;
307                 case 'p':
308                         options->server.sin_port =
309                                 htons ( strtoul ( optarg, &endptr, 0 ) );
310                         if ( *endptr != '\0' ) {
311                                 fprintf ( stderr, "Invalid port %s\n",
312                                           optarg );
313                                 return -1;
314                         }
315                         break;
316                 case 'm':
317                         options->message = optarg;
318                         break;
319                 case '?':
320                         /* Unrecognised option */
321                         return -1;
322                 default:
323                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
324                         return -1;
325                 }
326         }
327
328         /* Check there are no remaining arguments */
329         if ( optind != argc ) {
330                 hello_usage ( argv );
331                 return -1;
332         }
333         
334         return optind;
335 }
336
337 static void test_hello_callback ( char *data, size_t len ) {
338         int i;
339         char c;
340
341         for ( i = 0 ; i < len ; i++ ) {
342                 c = data[i];
343                 if ( c == '\r' ) {
344                         /* Print nothing */
345                 } else if ( ( c == '\n' ) || ( c >= 32 ) || ( c <= 126 ) ) {
346                         putchar ( c );
347                 } else {
348                         putchar ( '.' );
349                 }
350         }       
351 }
352
353 static int test_hello ( int argc, char **argv ) {
354         struct hello_options options;
355         struct hello_request hello;
356
357         /* Parse hello-specific options */
358         if ( hello_parse_options ( argc, argv, &options ) < 0 )
359                 return -1;
360
361         /* Construct hello request */
362         memset ( &hello, 0, sizeof ( hello ) );
363         hello.tcp.sin = options.server;
364         hello.message = options.message;
365         hello.callback = test_hello_callback;
366         fprintf ( stderr, "Saying \"%s\" to %s:%d\n", hello.message,
367                   inet_ntoa ( hello.tcp.sin.sin_addr ),
368                   ntohs ( hello.tcp.sin.sin_port ) );
369
370         /* Issue hello request and run to completion */
371         hello_connect ( &hello );
372         while ( ! hello.complete ) {
373                 run_tcpip ();
374         }
375
376         return 0;
377 }
378
379 /*****************************************************************************
380  *
381  * Protocol tester
382  *
383  */
384
385 struct protocol_test {
386         const char *name;
387         int ( *exec ) ( int argc, char **argv );
388 };
389
390 static struct protocol_test tests[] = {
391         { "hello", test_hello },
392 };
393
394 #define NUM_TESTS ( sizeof ( tests ) / sizeof ( tests[0] ) )
395
396 static void list_tests ( void ) {
397         int i;
398
399         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
400                 printf ( "%s\n", tests[i].name );
401         }
402 }
403
404 static struct protocol_test * get_test_from_name ( const char *name ) {
405         int i;
406
407         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
408                 if ( strcmp ( name, tests[i].name ) == 0 )
409                         return &tests[i];
410         }
411
412         return NULL;
413 }
414
415 /*****************************************************************************
416  *
417  * Parse command-line options
418  *
419  */
420
421 struct tester_options {
422         char interface[IF_NAMESIZE];
423         struct in_addr in_addr;
424 };
425
426 static void usage ( char **argv ) {
427         fprintf ( stderr,
428                   "Usage: %s [global options] <test> [test-specific options]\n"
429                   "\n"
430                   "Global options:\n"
431                   "  -h|--help              Print this help message\n"
432                   "  -i|--interface intf    Use specified network interface\n"
433                   "  -f|--from ip-address   Use specified local IP address\n"
434                   "  -l|--list              List available tests\n"
435                   "\n"
436                   "Use \"%s <test> -h\" to view test-specific options\n",
437                   argv[0], argv[0] );
438 }
439
440 static int parse_options ( int argc, char **argv,
441                            struct tester_options *options ) {
442         static struct option long_options[] = {
443                 { "interface", 1, NULL, 'i' },
444                 { "from", 1, NULL, 'f' },
445                 { "list", 0, NULL, 'l' },
446                 { "help", 0, NULL, 'h' },
447                 { },
448         };
449         int c;
450
451         /* Set default options */
452         memset ( options, 0, sizeof ( *options ) );
453         strncpy ( options->interface, "eth0", sizeof ( options->interface ) );
454         inet_aton ( "192.168.0.2", &options->in_addr );
455
456         /* Parse command-line options */
457         while ( 1 ) {
458                 int option_index = 0;
459                 
460                 c = getopt_long ( argc, argv, "+i:f:hl", long_options,
461                                   &option_index );
462                 if ( c < 0 )
463                         break;
464
465                 switch ( c ) {
466                 case 'i':
467                         strncpy ( options->interface, optarg,
468                                   sizeof ( options->interface ) );
469                         break;
470                 case 'f':
471                         if ( inet_aton ( optarg, &options->in_addr ) == 0 ) {
472                                 fprintf ( stderr, "Invalid IP address %s\n",
473                                           optarg );
474                                 return -1;
475                         }
476                         break;
477                 case 'l':
478                         list_tests ();
479                         return -1;
480                 case 'h':
481                         usage ( argv );
482                         return -1;
483                 case '?':
484                         /* Unrecognised option */
485                         return -1;
486                 default:
487                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
488                         return -1;
489                 }
490         }
491
492         /* Check there is a test specified */
493         if ( optind == argc ) {
494                 usage ( argv );
495                 return -1;
496         }
497         
498         return optind;
499 }
500
501 /*****************************************************************************
502  *
503  * Main program
504  *
505  */
506
507 int main ( int argc, char **argv ) {
508         struct tester_options options;
509         struct protocol_test *test;
510         struct hijack_device hijack_dev;
511
512         /* Parse command-line options */
513         if ( parse_options ( argc, argv, &options ) < 0 )
514                 exit ( 1 );
515
516         /* Identify test */
517         test = get_test_from_name ( argv[optind] );
518         if ( ! test ) {
519                 fprintf ( stderr, "Unrecognised test \"%s\"\n", argv[optind] );
520                 exit ( 1 );
521         }
522         optind++;
523
524         /* Initialise the protocol stack */
525         init_tcpip();
526         set_ipaddr ( options.in_addr );
527
528         /* Open the hijack device */
529         hijack_dev.name = options.interface;
530         if ( ! hijack_probe ( &hijack_dev ) )
531                 exit ( 1 );
532
533         /* Run the test */
534         if ( test->exec ( argc, argv ) < 0 )
535                 exit ( 1 );
536
537         /* Close the hijack device */
538         hijack_disable ( &hijack_dev );
539
540         return 0;
541 }