Moved "hello world" protocol implementation out of prototester.c and into
[people/xl0/gpxe.git] / src / util / prototester.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <unistd.h>
4 #include <string.h>
5 #include <errno.h>
6 #include <time.h>
7 #include <sys/socket.h>
8 #include <sys/un.h>
9 #include <net/if.h>
10 #include <net/ethernet.h>
11 #include <getopt.h>
12 #include <assert.h>
13
14 #include <gpxe/tcp.h>
15 #include <gpxe/hello.h>
16
17 typedef int irq_action_t;
18
19 struct nic {
20         struct nic_operations   *nic_op;
21         unsigned char           *node_addr;
22         unsigned char           *packet;
23         unsigned int            packetlen;
24         void                    *priv_data;     /* driver private data */
25 };
26
27 struct nic_operations {
28         int ( *connect ) ( struct nic * );
29         int ( *poll ) ( struct nic *, int retrieve );
30         void ( *transmit ) ( struct nic *, const char *,
31                              unsigned int, unsigned int, const char * );
32         void ( *irq ) ( struct nic *, irq_action_t );
33 };
34
35 /*****************************************************************************
36  *
37  * Net device layer
38  *
39  */
40
41 #include "../proto/uip/uip_arp.h"
42
43 static unsigned char node_addr[ETH_ALEN];
44 static unsigned char packet[ETH_FRAME_LEN];
45 struct nic static_nic = {
46         .node_addr = node_addr,
47         .packet = packet,
48 };
49
50 /* Must be a macro because priv_data[] is of variable size */
51 #define alloc_netdevice( priv_size ) ( {        \
52         static char priv_data[priv_size];       \
53         static_nic.priv_data = priv_data;       \
54         &static_nic; } )
55
56 static int register_netdevice ( struct nic *nic ) {
57         struct uip_eth_addr hwaddr;
58
59         memcpy ( &hwaddr, nic->node_addr, sizeof ( hwaddr ) );
60         uip_setethaddr ( hwaddr );
61         return 0;
62 }
63
64 static inline void unregister_netdevice ( struct nic *nic ) {
65         /* Do nothing */
66 }
67
68 static inline void free_netdevice ( struct nic *nic ) {
69         /* Do nothing */
70 }
71
72 int netdev_poll ( int retrieve, void **data, size_t *len ) {
73         int rc = static_nic.nic_op->poll ( &static_nic, retrieve );
74         *data = static_nic.packet;
75         *len = static_nic.packetlen;
76         return rc;
77 }
78
79 void netdev_transmit ( const void *data, size_t len ) {
80         uint16_t type = ntohs ( *( ( uint16_t * ) ( data + 12 ) ) );
81         static_nic.nic_op->transmit ( &static_nic, data, type,
82                                       len - ETH_HLEN,
83                                       data + ETH_HLEN );
84 }
85
86 /*****************************************************************************
87  *
88  * Hijack device interface
89  *
90  * This requires a hijack daemon to be running
91  *
92  */
93
94 struct hijack {
95         int fd;
96 };
97
98 struct hijack_device {
99         char *name;
100         void *priv;
101 };
102
103 static inline void hijack_set_drvdata ( struct hijack_device *hijack_dev,
104                                         void *data ) {
105         hijack_dev->priv = data;
106 }
107
108 static inline void * hijack_get_drvdata ( struct hijack_device *hijack_dev ) {
109         return hijack_dev->priv;
110 }
111
112 static int hijack_poll ( struct nic *nic, int retrieve ) {
113         struct hijack *hijack = nic->priv_data;
114         fd_set fdset;
115         struct timeval tv;
116         int ready;
117         ssize_t len;
118
119         /* Poll for data */
120         FD_ZERO ( &fdset );
121         FD_SET ( hijack->fd, &fdset );
122         tv.tv_sec = 0;
123         tv.tv_usec = 500; /* 500us to avoid hogging CPU */
124         ready = select ( ( hijack->fd + 1 ), &fdset, NULL, NULL, &tv );
125         if ( ready < 0 ) {
126                 fprintf ( stderr, "select() failed: %s\n",
127                           strerror ( errno ) );
128                 return 0;
129         }
130         if ( ready == 0 )
131                 return 0;
132
133         /* Return if we're not retrieving data yet */
134         if ( ! retrieve )
135                 return 1;
136
137         /* Fetch data */
138         len = read ( hijack->fd, nic->packet, ETH_FRAME_LEN );
139         if ( len < 0 ) {
140                 fprintf ( stderr, "read() failed: %s\n",
141                           strerror ( errno ) );
142                 return 0;
143         }
144         nic->packetlen = len;
145
146         return 1;
147 }
148
149 static void hijack_transmit ( struct nic *nic, const char *dest,
150                               unsigned int type, unsigned int size,
151                               const char *packet ) {
152         struct hijack *hijack = nic->priv_data;
153         unsigned int nstype = htons ( type );
154         unsigned int total_size = ETH_HLEN + size;
155         char txbuf[ total_size ];
156
157         /* Build packet header */
158         memcpy ( txbuf, dest, ETH_ALEN );
159         memcpy ( txbuf + ETH_ALEN, nic->node_addr, ETH_ALEN );
160         memcpy ( txbuf + 2 * ETH_ALEN, &nstype, 2 );
161         memcpy ( txbuf + ETH_HLEN, packet, size );
162
163         /* Transmit data */
164         if ( write ( hijack->fd, txbuf, total_size ) != total_size ) {
165                 fprintf ( stderr, "write() failed: %s\n",
166                           strerror ( errno ) );
167         }
168 }
169
170 static int hijack_connect ( struct nic *nic ) {
171         return 1;
172 }
173
174 static void hijack_irq ( struct nic *nic, irq_action_t action ) {
175         /* Do nothing */
176 }
177
178 static struct nic_operations hijack_operations = {
179         .connect        = hijack_connect,
180         .transmit       = hijack_transmit,
181         .poll           = hijack_poll,
182         .irq            = hijack_irq,
183 };
184
185 int hijack_probe ( struct hijack_device *hijack_dev ) {
186         struct nic *nic;
187         struct hijack *hijack;
188         struct sockaddr_un sun;
189         int i;
190
191         nic = alloc_netdevice ( sizeof ( *hijack ) );
192         if ( ! nic ) {
193                 fprintf ( stderr, "alloc_netdevice() failed\n" );
194                 goto err_alloc;
195         }
196         hijack = nic->priv_data;
197         memset ( hijack, 0, sizeof ( *hijack ) );
198
199         /* Create socket */
200         hijack->fd = socket ( PF_UNIX, SOCK_SEQPACKET, 0 );
201         if ( hijack->fd < 0 ) {
202                 fprintf ( stderr, "socket() failed: %s\n",
203                           strerror ( errno ) );
204                 goto err;
205         }
206
207         /* Connect to hijack daemon */
208         sun.sun_family = AF_UNIX;
209         snprintf ( sun.sun_path, sizeof ( sun.sun_path ), "/var/run/hijack-%s",
210                    hijack_dev->name );
211         if ( connect ( hijack->fd, ( struct sockaddr * ) &sun,
212                        sizeof ( sun ) ) < 0 ) {
213                 fprintf ( stderr, "could not connect to %s: %s\n",
214                           sun.sun_path, strerror ( errno ) );
215                 goto err;
216         }
217
218         /* Generate MAC address */
219         srand ( time ( NULL ) );
220         for ( i = 0 ; i < ETH_ALEN ; i++ ) {
221                 nic->node_addr[i] = ( rand() & 0xff );
222         }
223         nic->node_addr[0] &= 0xfe; /* clear multicast bit */
224         nic->node_addr[0] |= 0x02; /* set "locally-assigned" bit */
225
226         nic->nic_op = &hijack_operations;
227         if ( register_netdevice ( nic ) < 0 )
228                 goto err;
229
230         hijack_set_drvdata ( hijack_dev, nic );
231         return 1;
232
233  err:
234         if ( hijack->fd >= 0 )
235                 close ( hijack->fd );
236         free_netdevice ( nic );
237  err_alloc:
238         return 0;
239 }
240
241 static void hijack_disable ( struct hijack_device *hijack_dev ) {
242         struct nic *nic = hijack_get_drvdata ( hijack_dev );
243         struct hijack *hijack = nic->priv_data;
244         
245         unregister_netdevice ( nic );
246         close ( hijack->fd );
247 }
248
249 /*****************************************************************************
250  *
251  * "Hello world" protocol tester
252  *
253  */
254
255 struct hello_options {
256         struct sockaddr_in server;
257         const char *message;
258 };
259
260 static void hello_usage ( char **argv ) {
261         fprintf ( stderr,
262                   "Usage: %s [global options] hello [hello-specific options]\n"
263                   "\n"
264                   "hello-specific options:\n"
265                   "  -h|--host              Host IP address\n"
266                   "  -p|--port              Port number\n"
267                   "  -m|--message           Message to send\n",
268                   argv[0] );
269 }
270
271 static int hello_parse_options ( int argc, char **argv,
272                                  struct hello_options *options ) {
273         static struct option long_options[] = {
274                 { "host", 1, NULL, 'h' },
275                 { "port", 1, NULL, 'p' },
276                 { "message", 1, NULL, 'm' },
277                 { },
278         };
279         int c;
280         char *endptr;
281
282         /* Set default options */
283         memset ( options, 0, sizeof ( *options ) );
284         inet_aton ( "192.168.0.1", &options->server.sin_addr );
285         options->server.sin_port = htons ( 80 );
286         options->message = "Hello world!";
287
288         /* Parse command-line options */
289         while ( 1 ) {
290                 int option_index = 0;
291                 
292                 c = getopt_long ( argc, argv, "h:p:", long_options,
293                                   &option_index );
294                 if ( c < 0 )
295                         break;
296
297                 switch ( c ) {
298                 case 'h':
299                         if ( inet_aton ( optarg,
300                                          &options->server.sin_addr ) == 0 ) {
301                                 fprintf ( stderr, "Invalid IP address %s\n",
302                                           optarg );
303                                 return -1;
304                         }
305                         break;
306                 case 'p':
307                         options->server.sin_port =
308                                 htons ( strtoul ( optarg, &endptr, 0 ) );
309                         if ( *endptr != '\0' ) {
310                                 fprintf ( stderr, "Invalid port %s\n",
311                                           optarg );
312                                 return -1;
313                         }
314                         break;
315                 case 'm':
316                         options->message = optarg;
317                         break;
318                 case '?':
319                         /* Unrecognised option */
320                         return -1;
321                 default:
322                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
323                         return -1;
324                 }
325         }
326
327         /* Check there are no remaining arguments */
328         if ( optind != argc ) {
329                 hello_usage ( argv );
330                 return -1;
331         }
332         
333         return optind;
334 }
335
336 static void test_hello_callback ( char *data, size_t len ) {
337         int i;
338         char c;
339
340         for ( i = 0 ; i < len ; i++ ) {
341                 c = data[i];
342                 if ( c == '\r' ) {
343                         /* Print nothing */
344                 } else if ( ( c == '\n' ) || ( c >= 32 ) || ( c <= 126 ) ) {
345                         putchar ( c );
346                 } else {
347                         putchar ( '.' );
348                 }
349         }       
350 }
351
352 static int test_hello ( int argc, char **argv ) {
353         struct hello_options options;
354         struct hello_request hello;
355
356         /* Parse hello-specific options */
357         if ( hello_parse_options ( argc, argv, &options ) < 0 )
358                 return -1;
359
360         /* Construct hello request */
361         memset ( &hello, 0, sizeof ( hello ) );
362         hello.tcp.sin = options.server;
363         hello.message = options.message;
364         hello.callback = test_hello_callback;
365         fprintf ( stderr, "Saying \"%s\" to %s:%d\n", hello.message,
366                   inet_ntoa ( hello.tcp.sin.sin_addr ),
367                   ntohs ( hello.tcp.sin.sin_port ) );
368
369         /* Issue hello request and run to completion */
370         hello_connect ( &hello );
371         while ( ! hello.complete ) {
372                 run_tcpip ();
373         }
374
375         return 0;
376 }
377
378 /*****************************************************************************
379  *
380  * Protocol tester
381  *
382  */
383
384 struct protocol_test {
385         const char *name;
386         int ( *exec ) ( int argc, char **argv );
387 };
388
389 static struct protocol_test tests[] = {
390         { "hello", test_hello },
391 };
392
393 #define NUM_TESTS ( sizeof ( tests ) / sizeof ( tests[0] ) )
394
395 static void list_tests ( void ) {
396         int i;
397
398         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
399                 printf ( "%s\n", tests[i].name );
400         }
401 }
402
403 static struct protocol_test * get_test_from_name ( const char *name ) {
404         int i;
405
406         for ( i = 0 ; i < NUM_TESTS ; i++ ) {
407                 if ( strcmp ( name, tests[i].name ) == 0 )
408                         return &tests[i];
409         }
410
411         return NULL;
412 }
413
414 /*****************************************************************************
415  *
416  * Parse command-line options
417  *
418  */
419
420 struct tester_options {
421         char interface[IF_NAMESIZE];
422 };
423
424 static void usage ( char **argv ) {
425         fprintf ( stderr,
426                   "Usage: %s [global options] <test> [test-specific options]\n"
427                   "\n"
428                   "Global options:\n"
429                   "  -h|--help              Print this help message\n"
430                   "  -i|--interface intf    Use specified network interface\n"
431                   "  -l|--list              List available tests\n"
432                   "\n"
433                   "Use \"%s <test> -h\" to view test-specific options\n",
434                   argv[0], argv[0] );
435 }
436
437 static int parse_options ( int argc, char **argv,
438                            struct tester_options *options ) {
439         static struct option long_options[] = {
440                 { "interface", 1, NULL, 'i' },
441                 { "list", 0, NULL, 'l' },
442                 { "help", 0, NULL, 'h' },
443                 { },
444         };
445         int c;
446
447         /* Set default options */
448         memset ( options, 0, sizeof ( *options ) );
449         strncpy ( options->interface, "eth0", sizeof ( options->interface ) );
450
451         /* Parse command-line options */
452         while ( 1 ) {
453                 int option_index = 0;
454                 
455                 c = getopt_long ( argc, argv, "+i:hl", long_options,
456                                   &option_index );
457                 if ( c < 0 )
458                         break;
459
460                 switch ( c ) {
461                 case 'i':
462                         strncpy ( options->interface, optarg,
463                                   sizeof ( options->interface ) );
464                         break;
465                 case 'l':
466                         list_tests ();
467                         return -1;
468                 case 'h':
469                         usage ( argv );
470                         return -1;
471                 case '?':
472                         /* Unrecognised option */
473                         return -1;
474                 default:
475                         fprintf ( stderr, "Unrecognised option '-%c'\n", c );
476                         return -1;
477                 }
478         }
479
480         /* Check there is a test specified */
481         if ( optind == argc ) {
482                 usage ( argv );
483                 return -1;
484         }
485         
486         return optind;
487 }
488
489 /*****************************************************************************
490  *
491  * Main program
492  *
493  */
494
495 int main ( int argc, char **argv ) {
496         struct tester_options options;
497         struct protocol_test *test;
498         struct hijack_device hijack_dev;
499
500         /* Parse command-line options */
501         if ( parse_options ( argc, argv, &options ) < 0 )
502                 exit ( 1 );
503
504         /* Identify test */
505         test = get_test_from_name ( argv[optind] );
506         if ( ! test ) {
507                 fprintf ( stderr, "Unrecognised test \"%s\"\n", argv[optind] );
508                 exit ( 1 );
509         }
510         optind++;
511
512         /* Initialise the protocol stack */
513         init_tcpip();
514
515         /* Open the hijack device */
516         hijack_dev.name = options.interface;
517         if ( ! hijack_probe ( &hijack_dev ) )
518                 exit ( 1 );
519
520         /* Run the test */
521         if ( test->exec ( argc, argv ) < 0 )
522                 exit ( 1 );
523
524         /* Close the hijack device */
525         hijack_disable ( &hijack_dev );
526
527         return 0;
528 }