Terminate cleanly on SIGINT or SIGHUP
[people/xl0/gpxe.git] / src / util / hijack.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <unistd.h>
4 #include <string.h>
5 #include <stdarg.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <libgen.h>
9 #include <signal.h>
10 #include <net/if.h>
11 #include <net/ethernet.h>
12 #include <sys/select.h>
13 #include <sys/socket.h>
14 #include <sys/stat.h>
15 #include <sys/un.h>
16 #include <syslog.h>
17 #include <getopt.h>
18 #include <pcap.h>
19
20 #define SNAPLEN 1600
21
22 struct hijack {
23         pcap_t *pcap;
24         int fd;
25         int datalink;
26         int filtered;
27         unsigned long rx_count;
28         unsigned long tx_count;
29 };
30
31 struct hijack_listener {
32         struct sockaddr_un sun;
33         int fd;
34 };
35
36 struct hijack_options {
37         char interface[IF_NAMESIZE];
38         int daemonise;
39 };
40
41 static int daemonised = 0;
42
43 static int signalled = 0;
44
45 static void flag_signalled ( int signal __attribute__ (( unused )) ) {
46         signalled = 1;
47 }
48
49 /**
50  * Log error message
51  *
52  */
53 static __attribute__ (( format ( printf, 2, 3 ) )) void
54 logmsg ( int level, const char *format, ... ) {
55         va_list ap;
56
57         va_start ( ap, format );
58         if ( daemonised ) {
59                 vsyslog ( ( LOG_DAEMON | level ), format, ap );
60         } else {
61                 vfprintf ( stderr, format, ap );
62         }
63         va_end ( ap );
64 }
65
66 /**
67  * Open pcap device
68  *
69  */
70 static int hijack_open ( const char *interface, struct hijack *hijack ) {
71         char errbuf[PCAP_ERRBUF_SIZE];
72
73         /* Open interface via pcap */
74         errbuf[0] = '\0';
75         hijack->pcap = pcap_open_live ( interface, SNAPLEN, 1, 0, errbuf );
76         if ( ! hijack->pcap ) {
77                 logmsg ( LOG_ERR, "Failed to open %s: %s\n",
78                          interface, errbuf );
79                 goto err;
80         }
81         if ( errbuf[0] )
82                 logmsg ( LOG_WARNING, "Warning: %s\n", errbuf );
83
84         /* Set capture interface to non-blocking mode */
85         if ( pcap_setnonblock ( hijack->pcap, 1, errbuf ) < 0 ) {
86                 logmsg ( LOG_ERR, "Could not make %s non-blocking: %s\n",
87                          interface, errbuf );
88                 goto err;
89         }
90
91         /* Get file descriptor for select() */
92         hijack->fd = pcap_get_selectable_fd ( hijack->pcap );
93         if ( hijack->fd < 0 ) {
94                 logmsg ( LOG_ERR, "Cannot get selectable file descriptor "
95                          "for %s\n", interface );
96                 goto err;
97         }
98
99         /* Get link layer type */
100         hijack->datalink = pcap_datalink ( hijack->pcap );
101
102         return 0;
103
104  err:
105         if ( hijack->pcap )
106                 pcap_close ( hijack->pcap );
107         return -1;
108 }
109
110 /**
111  * Close pcap device
112  *
113  */
114 static void hijack_close ( struct hijack *hijack ) {
115         pcap_close ( hijack->pcap );
116 }
117
118 /**
119  * Install filter for hijacked connection
120  *
121  */
122 static int hijack_install_filter ( struct hijack *hijack,
123                                    char *filter ) {
124         struct bpf_program program;
125
126         /* Compile filter */
127         if ( pcap_compile ( hijack->pcap, &program, filter, 1, 0 ) < 0 ) {
128                 logmsg ( LOG_ERR, "could not compile filter \"%s\": %s\n",
129                          filter, pcap_geterr ( hijack->pcap ) );
130                 goto err_nofree;
131         }
132
133         /* Install filter */
134         if ( pcap_setfilter ( hijack->pcap, &program ) < 0 ) {
135                 logmsg ( LOG_ERR, "could not install filter \"%s\": %s\n",
136                          filter, pcap_geterr ( hijack->pcap ) );
137                 goto err;
138         }
139         
140         logmsg ( LOG_INFO, "using filter \"%s\"\n", filter );
141
142         pcap_freecode ( &program );
143         return 0;
144
145  err:   
146         pcap_freecode ( &program );
147  err_nofree:
148         return -1;
149 }
150
151 /**
152  * Set up filter for hijacked ethernet connection
153  *
154  */
155 static int hijack_filter_ethernet ( struct hijack *hijack, const char *buf,
156                                     size_t len ) {
157         char filter[55]; /* see format string */
158         struct ether_header *ether_header = ( struct ether_header * ) buf;
159         unsigned char *hwaddr = ether_header->ether_shost;
160
161         if ( len < sizeof ( *ether_header ) )
162                 return -1;
163
164         snprintf ( filter, sizeof ( filter ), "broadcast or multicast or "
165                    "ether host %02x:%02x:%02x:%02x:%02x:%02x", hwaddr[0],
166                    hwaddr[1], hwaddr[2], hwaddr[3], hwaddr[4], hwaddr[5] );
167
168         return hijack_install_filter ( hijack, filter );
169 }
170
171 /**
172  * Set up filter for hijacked connection
173  *
174  */
175 static int hijack_filter ( struct hijack *hijack, const char *buf,
176                            size_t len ) {
177         switch ( hijack->datalink ) {
178         case DLT_EN10MB:
179                 return hijack_filter_ethernet ( hijack, buf, len );
180         default:
181                 logmsg ( LOG_ERR, "unsupported protocol %s: cannot filter\n",
182                          ( pcap_datalink_val_to_name ( hijack->datalink ) ?
183                            pcap_datalink_val_to_name ( hijack->datalink ) :
184                            "UNKNOWN" ) );
185                 /* Return success so we don't get called again */
186                 return 0;
187         }
188 }
189
190 /**
191  * Forward data from hijacker
192  *
193  */
194 static ssize_t forward_from_hijacker ( struct hijack *hijack, int fd ) {
195         char buf[SNAPLEN];
196         ssize_t len;
197
198         /* Read packet from hijacker */
199         len = read ( fd, buf, sizeof ( buf ) );
200         if ( len < 0 ) {
201                 logmsg ( LOG_ERR, "read from hijacker failed: %s\n",
202                          strerror ( errno ) );
203                 return -1;
204         }
205         if ( len == 0 )
206                 return 0;
207
208         /* Set up filter if not already in place */
209         if ( ! hijack->filtered ) {
210                 if ( hijack_filter ( hijack, buf, len ) == 0 )
211                         hijack->filtered = 1;
212         }
213
214         /* Transmit packet to network */
215         if ( pcap_inject ( hijack->pcap, buf, len ) != len ) {
216                 logmsg ( LOG_ERR, "write to hijacked port failed: %s\n",
217                          pcap_geterr ( hijack->pcap ) );
218                 return -1;
219         }
220
221         hijack->tx_count++;
222         return len;
223 };
224
225 /**
226  * Forward data to hijacker
227  *
228  */
229 static ssize_t forward_to_hijacker ( int fd, struct hijack *hijack ) {
230         struct pcap_pkthdr *pkt_header;
231         const unsigned char *pkt_data;
232         ssize_t len;
233
234         /* Receive packet from network */
235         if ( pcap_next_ex ( hijack->pcap, &pkt_header, &pkt_data ) < 0 ) {
236                 logmsg ( LOG_ERR, "read from hijacked port failed: %s\n",
237                          pcap_geterr ( hijack->pcap ) );
238                 return -1;
239         }
240         if ( pkt_header->caplen != pkt_header->len ) {
241                 logmsg ( LOG_ERR, "read partial packet (%d of %d bytes)\n",
242                          pkt_header->caplen, pkt_header->len );
243                 return -1;
244         }
245         if ( pkt_header->caplen == 0 )
246                 return 0;
247         len = pkt_header->caplen;
248
249         /* Write packet to hijacker */
250         if ( write ( fd, pkt_data, len ) != len ) {
251                 logmsg ( LOG_ERR, "write to hijacker failed: %s\n",
252                          strerror ( errno ) );
253                 return -1;
254         }
255
256         hijack->rx_count++;
257         return len;
258 };
259
260
261 /**
262  * Run hijacker
263  *
264  */
265 static int run_hijacker ( const char *interface, int fd ) {
266         struct hijack hijack;
267         fd_set fdset;
268         int max_fd;
269         ssize_t len;
270
271         logmsg ( LOG_INFO, "new connection for %s\n", interface );
272
273         /* Open connection to network */
274         memset ( &hijack, 0, sizeof ( hijack ) );
275         if ( hijack_open ( interface, &hijack ) < 0 )
276                 goto err;
277         
278         /* Do the forwarding */
279         max_fd = ( ( fd > hijack.fd ) ? fd : hijack.fd );
280         while ( 1 ) {
281                 /* Wait for available data */
282                 FD_ZERO ( &fdset );
283                 FD_SET ( fd, &fdset );
284                 FD_SET ( hijack.fd, &fdset );
285                 if ( select ( ( max_fd + 1 ), &fdset, NULL, NULL, 0 ) < 0 ) {
286                         logmsg ( LOG_ERR, "select failed: %s\n",
287                                  strerror ( errno ) );
288                         goto err;
289                 }
290                 if ( FD_ISSET ( fd, &fdset ) ) {
291                         len = forward_from_hijacker ( &hijack, fd );
292                         if ( len < 0 )
293                                 goto err;
294                         if ( len == 0 )
295                                 break;
296                 }
297                 if ( FD_ISSET ( hijack.fd, &fdset ) ) {
298                         len = forward_to_hijacker ( fd, &hijack );
299                         if ( len < 0 )
300                                 goto err;
301                         if ( len == 0 )
302                                 break;
303                 }
304         }
305
306         hijack_close ( &hijack );
307         logmsg ( LOG_INFO, "closed connection for %s\n", interface );
308         logmsg ( LOG_INFO, "received %ld packets, sent %ld packets\n",
309                  hijack.rx_count, hijack.tx_count );
310
311         return 0;
312
313  err:
314         if ( hijack.pcap )
315                 hijack_close ( &hijack );
316         return -1;
317 }
318
319 /**
320  * Open listener socket
321  *
322  */
323 static int open_listener ( const char *interface,
324                            struct hijack_listener *listener ) {
325         
326         /* Create socket */
327         listener->fd = socket ( PF_UNIX, SOCK_SEQPACKET, 0 );
328         if ( listener->fd < 0 ) {
329                 logmsg ( LOG_ERR, "Could not create socket: %s\n",
330                          strerror ( errno ) );
331                 goto err;
332         }
333
334         /* Bind to local filename */
335         listener->sun.sun_family = AF_UNIX,
336         snprintf ( listener->sun.sun_path, sizeof ( listener->sun.sun_path ),
337                    "/var/run/hijack-%s", interface );
338         if ( bind ( listener->fd, ( struct sockaddr * ) &listener->sun,
339                     sizeof ( listener->sun ) ) < 0 ) {
340                 logmsg ( LOG_ERR, "Could not bind socket to %s: %s\n",
341                          listener->sun.sun_path, strerror ( errno ) );
342                 goto err;
343         }
344
345         /* Set as a listening socket */
346         if ( listen ( listener->fd, 0 ) < 0 ) {
347                 logmsg ( LOG_ERR, "Could not listen to %s: %s\n",
348                          listener->sun.sun_path, strerror ( errno ) );
349                 goto err;
350         }
351
352         return 0;
353         
354  err:
355         if ( listener->fd >= 0 )
356                 close ( listener->fd );
357         return -1;
358 }
359
360 /**
361  * Listen on listener socket
362  *
363  */
364 static int listen_for_hijackers ( struct hijack_listener *listener,
365                                   const char *interface ) {
366         int fd;
367         pid_t child;
368         int rc;
369
370         logmsg ( LOG_INFO, "Listening on %s\n", listener->sun.sun_path );
371
372         while ( ! signalled ) {
373                 /* Accept new connection, interruptibly */
374                 siginterrupt ( SIGINT, 1 );
375                 siginterrupt ( SIGHUP, 1 );
376                 fd = accept ( listener->fd, NULL, 0 );
377                 siginterrupt ( SIGINT, 0 );
378                 siginterrupt ( SIGHUP, 0 );
379                 if ( fd < 0 ) {
380                         if ( errno == EINTR ) {
381                                 continue;
382                         } else {
383                                 logmsg ( LOG_ERR, "accept failed: %s\n",
384                                          strerror ( errno ) );
385                                 goto err;
386                         }
387                 }
388
389                 /* Fork child process */
390                 child = fork();
391                 if ( child < 0 ) {
392                         logmsg ( LOG_ERR, "fork failed: %s\n",
393                                  strerror ( errno ) );
394                         goto err;
395                 }
396                 if ( child == 0 ) {
397                         /* I am the child; run the hijacker */
398                         rc = run_hijacker ( interface, fd );
399                         close ( fd );
400                         exit ( rc );
401                 }
402                 
403                 close ( fd );
404         }
405
406         logmsg ( LOG_INFO, "Stopped listening on %s\n",
407                  listener->sun.sun_path );
408         return 0;
409
410  err:
411         if ( fd >= 0 )
412                 close ( fd );
413         return -1;
414 }
415
416 /**
417  * Close listener socket
418  *
419  */
420 static void close_listener ( struct hijack_listener *listener ) {
421         close ( listener->fd );
422         unlink ( listener->sun.sun_path );
423 }
424
425 /**
426  * Print usage
427  *
428  */
429 static void usage ( char **argv ) {
430         logmsg ( LOG_ERR,
431                  "Usage: %s [options]\n"
432                  "\n"
433                  "Options:\n"
434                  "  -h|--help               Print this help message\n"
435                  "  -i|--interface intf     Use specified network interface\n"
436                  "  -n|--nodaemon           Run in foreground\n",
437                  argv[0] );
438 }
439
440 /**
441  * Parse command-line options
442  *
443  */
444 static int parse_options ( int argc, char **argv,
445                            struct hijack_options *options ) {
446         static struct option long_options[] = {
447                 { "interface", 1, NULL, 'i' },
448                 { "nodaemon", 0, NULL, 'n' },
449                 { "help", 0, NULL, 'h' },
450                 { },
451         };
452         int c;
453
454         /* Set default options */
455         memset ( options, 0, sizeof ( *options ) );
456         strncpy ( options->interface, "eth0", sizeof ( options->interface ) );
457         options->daemonise = 1;
458
459         /* Parse command-line options */
460         while ( 1 ) {
461                 int option_index = 0;
462                 
463                 c = getopt_long ( argc, argv, "i:hn", long_options,
464                                   &option_index );
465                 if ( c < 0 )
466                         break;
467
468                 switch ( c ) {
469                 case 'i':
470                         strncpy ( options->interface, optarg,
471                                   sizeof ( options->interface ) );
472                         break;
473                 case 'n':
474                         options->daemonise = 0;
475                         break;
476                 case 'h':
477                         usage( argv );
478                         return -1;
479                 case '?':
480                         /* Unrecognised option */
481                         return -1;
482                 default:
483                         logmsg ( LOG_ERR, "Unrecognised option '-%c'\n", c );
484                         return -1;
485                 }
486         }
487
488         /* Check there's nothing left over on the command line */
489         if ( optind != argc ) {
490                 usage ( argv );
491                 return -1;
492         }
493
494         return 0;
495 }
496
497 /**
498  * Daemonise
499  *
500  */
501 static int daemonise ( const char *interface ) {
502         char pidfile[16 + IF_NAMESIZE + 4]; /* "/var/run/hijack-<intf>.pid" */
503         char pid[16];
504         int pidlen;
505         int fd = -1;
506
507         /* Daemonise */
508         if ( daemon ( 0, 0 ) < 0 ) {
509                 logmsg ( LOG_ERR, "Could not daemonise: %s\n",
510                          strerror ( errno ) );
511                 goto err;
512         }
513         daemonised = 1; /* Direct messages to syslog now */
514
515         /* Open pid file */
516         snprintf ( pidfile, sizeof ( pidfile ), "/var/run/hijack-%s.pid",
517                    interface );
518         fd = open ( pidfile, ( O_WRONLY | O_CREAT | O_TRUNC ),
519                     ( S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH ) );
520         if ( fd < 0 ) {
521                 logmsg ( LOG_ERR, "Could not open %s for writing: %s\n",
522                          pidfile, strerror ( errno ) );
523                 goto err;
524         }
525
526         /* Write pid to file */
527         pidlen = snprintf ( pid, sizeof ( pid ), "%d\n", getpid() );
528         if ( write ( fd, pid, pidlen ) != pidlen ) {
529                 logmsg ( LOG_ERR, "Could not write %s: %s\n",
530                          pidfile, strerror ( errno ) );
531                 goto err;
532         }
533
534         close ( fd );
535         return 0;
536
537  err:
538         if ( fd >= 0 )
539                 close ( fd );
540         return -1;
541 }
542
543 int main ( int argc, char **argv ) {
544         struct hijack_options options;
545         struct hijack_listener listener;
546         struct sigaction sa;
547
548         /* Parse command-line options */
549         if ( parse_options ( argc, argv, &options ) < 0 )
550                 exit ( 1 );
551
552         /* Set up syslog connection */
553         openlog ( basename ( argv[0] ), LOG_PID, LOG_DAEMON );
554
555         /* Set up listening socket */
556         if ( open_listener ( options.interface, &listener ) < 0 )
557                 exit ( 1 );
558
559         /* Daemonise on demand */
560         if ( options.daemonise ) {
561                 if ( daemonise ( options.interface ) < 0 )
562                         exit ( 1 );
563         }
564
565         /* Avoid creating zombies */
566         memset ( &sa, 0, sizeof ( sa ) );
567         sa.sa_handler = SIG_IGN;
568         sa.sa_flags = SA_RESTART | SA_NOCLDWAIT;
569         if ( sigaction ( SIGCHLD, &sa, NULL ) < 0 ) {
570                 logmsg ( LOG_ERR, "Could not set SIGCHLD handler: %s",
571                          strerror ( errno ) );
572                 exit ( 1 );
573         }
574
575         /* Set 'signalled' flag on SIGINT or SIGHUP */
576         sa.sa_handler = flag_signalled;
577         sa.sa_flags = SA_RESTART | SA_RESETHAND;
578         if ( sigaction ( SIGINT, &sa, NULL ) < 0 ) {
579                 logmsg ( LOG_ERR, "Could not set SIGINT handler: %s",
580                          strerror ( errno ) );
581                 exit ( 1 );
582         }
583         if ( sigaction ( SIGHUP, &sa, NULL ) < 0 ) {
584                 logmsg ( LOG_ERR, "Could not set SIGHUP handler: %s",
585                          strerror ( errno ) );
586                 exit ( 1 );
587         }
588
589         /* Listen for hijackers */
590         if ( listen_for_hijackers ( &listener, options.interface ) < 0 )
591                 exit ( 1 );
592
593         close_listener ( &listener );
594         
595         return 0;
596 }