Introduce name resolution interface and named socket opener.
[people/dverkamp/gpxe.git] / src / net / udp / dns.c
1 /*
2  * Copyright (C) 2006 Michael Brown <mbrown@fensystems.co.uk>.
3  *
4  * Portions copyright (C) 2004 Anselm M. Hoffmeister
5  * <stockholm@users.sourceforge.net>.
6  *
7  * This program is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License as
9  * published by the Free Software Foundation; either version 2 of the
10  * License, or any later version.
11  *
12  * This program is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20  */
21
22 #include <stdint.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <errno.h>
26 #include <byteswap.h>
27 #include <gpxe/async.h>
28 #include <gpxe/udp.h>
29 #include <gpxe/resolv.h>
30 #include <gpxe/dns.h>
31
32 /** @file
33  *
34  * DNS protocol
35  *
36  */
37
38 /* The DNS server */
39 struct in_addr nameserver = { INADDR_NONE };
40
41 /**
42  * Compare DNS reply name against the query name from the original request
43  *
44  * @v dns               DNS request
45  * @v reply             DNS reply
46  * @v rname             Reply name
47  * @ret zero            Names match
48  * @ret non-zero        Names do not match
49  */
50 static int dns_name_cmp ( struct dns_request *dns, struct dns_header *reply, 
51                           const char *rname ) {
52         const char *qname = dns->query.payload;
53         int i;
54
55         while ( 1 ) {
56                 /* Obtain next section of rname */
57                 while ( ( *rname ) & 0xc0 ) {
58                         rname = ( ( ( char * ) reply ) +
59                                   ( ntohs( *((uint16_t *)rname) ) & ~0xc000 ));
60                 }
61                 /* Check that lengths match */
62                 if ( *rname != *qname )
63                         return -1;
64                 /* If length is zero, we have reached the end */
65                 if ( ! *qname )
66                         return 0;
67                 /* Check that data matches */
68                 for ( i = *qname + 1; i > 0 ; i-- ) {
69                         if ( *(rname++) != *(qname++) )
70                                 return -1;
71                 }
72         }
73 }
74
75 /**
76  * Skip over a (possibly compressed) DNS name
77  *
78  * @v name              DNS name
79  * @ret name            Next DNS name
80  */
81 static const char * dns_skip_name ( const char *name ) {
82         while ( 1 ) {
83                 if ( ! *name ) {
84                         /* End of name */
85                         return ( name + 1);
86                 }
87                 if ( *name & 0xc0 ) {
88                         /* Start of a compressed name */
89                         return ( name + 2 );
90                 }
91                 /* Uncompressed name portion */
92                 name += *name + 1;
93         }
94 }
95
96 /**
97  * Find an RR in a reply packet corresponding to our query
98  *
99  * @v dns               DNS request
100  * @v reply             DNS reply
101  * @ret rr              DNS RR, or NULL if not found
102  */
103 static union dns_rr_info * dns_find_rr ( struct dns_request *dns,
104                                          struct dns_header *reply ) {
105         int i, cmp;
106         const char *p = ( ( char * ) reply ) + sizeof ( struct dns_header );
107         union dns_rr_info *rr_info;
108
109         /* Skip over the questions section */
110         for ( i = ntohs ( reply->qdcount ) ; i > 0 ; i-- ) {
111                 p = dns_skip_name ( p ) + sizeof ( struct dns_query_info );
112         }
113
114         /* Process the answers section */
115         for ( i = ntohs ( reply->ancount ) ; i > 0 ; i-- ) {
116                 cmp = dns_name_cmp ( dns, reply, p );
117                 p = dns_skip_name ( p );
118                 rr_info = ( ( union dns_rr_info * ) p );
119                 if ( cmp == 0 )
120                         return rr_info;
121                 p += ( sizeof ( rr_info->common ) +
122                        ntohs ( rr_info->common.rdlength ) );
123         }
124
125         return NULL;
126 }
127
128 /**
129  * Convert a standard NUL-terminated string to a DNS name
130  *
131  * @v string            Name as a NUL-terminated string
132  * @v buf               Buffer in which to place DNS name
133  * @ret next            Byte following constructed DNS name
134  *
135  * DNS names consist of "<length>element" pairs.
136  */
137 static char * dns_make_name ( const char *string, char *buf ) {
138         char *length_byte = buf++;
139         char c;
140
141         while ( ( c = *(string++) ) ) {
142                 if ( c == '.' ) {
143                         *length_byte = buf - length_byte - 1;
144                         length_byte = buf;
145                 }
146                 *(buf++) = c;
147         }
148         *length_byte = buf - length_byte - 1;
149         *(buf++) = '\0';
150         return buf;
151 }
152
153 /**
154  * Convert an uncompressed DNS name to a NUL-terminated string
155  *
156  * @v name              DNS name
157  * @ret string          NUL-terminated string
158  *
159  * Produce a printable version of a DNS name.  Used only for debugging.
160  */
161 static inline char * dns_unmake_name ( char *name ) {
162         char *p;
163         unsigned int len;
164
165         p = name;
166         while ( ( len = *p ) ) {
167                 *(p++) = '.';
168                 p += len;
169         }
170
171         return name + 1;
172 }
173
174 /**
175  * Decompress a DNS name
176  *
177  * @v reply             DNS replay
178  * @v name              DNS name
179  * @v buf               Buffer into which to decompress DNS name
180  * @ret next            Byte following decompressed DNS name
181  */
182 static char * dns_decompress_name ( struct dns_header *reply,
183                                     const char *name, char *buf ) {
184         int i, len;
185
186         do {
187                 /* Obtain next section of name */
188                 while ( ( *name ) & 0xc0 ) {
189                         name = ( ( char * ) reply +
190                                  ( ntohs ( *((uint16_t *)name) ) & ~0xc000 ) );
191                 }
192                 /* Copy data */
193                 len = *name;
194                 for ( i = len + 1 ; i > 0 ; i-- ) {
195                         *(buf++) = *(name++);
196                 }
197         } while ( len );
198         return buf;
199 }
200
201 /**
202  * Mark DNS request as complete
203  *
204  * @v dns               DNS request
205  * @v rc                Return status code
206  */
207 static void dns_done ( struct dns_request *dns, int rc ) {
208
209         /* Stop the retry timer */
210         stop_timer ( &dns->timer );
211
212         /* Close UDP connection */
213         udp_close ( &dns->udp );
214
215         /* Mark async operation as complete */
216         async_done ( &dns->async, rc );
217 }
218
219 /**
220  * Send next packet in DNS request
221  *
222  * @v dns               DNS request
223  */
224 static void dns_send_packet ( struct dns_request *dns ) {
225         static unsigned int qid = 0;
226
227         /* Increment query ID */
228         dns->query.dns.id = htons ( ++qid );
229
230         DBGC ( dns, "DNS %p sending query ID %d\n", dns, qid );
231
232         /* Start retransmission timer */
233         start_timer ( &dns->timer );
234
235         /* Send the data */
236         udp_send ( &dns->udp, &dns->query,
237                    ( ( ( void * ) dns->qinfo ) - ( ( void * ) &dns->query )
238                      + sizeof ( dns->qinfo ) ) );
239 }
240
241 /**
242  * Handle DNS retransmission timer expiry
243  *
244  * @v timer             Retry timer
245  * @v fail              Failure indicator
246  */
247 static void dns_timer_expired ( struct retry_timer *timer, int fail ) {
248         struct dns_request *dns =
249                 container_of ( timer, struct dns_request, timer );
250
251         if ( fail ) {
252                 dns_done ( dns, -ETIMEDOUT );
253         } else {
254                 dns_send_packet ( dns );
255         }
256 }
257
258 /**
259  * Receive new data
260  *
261  * @v udp               UDP connection
262  * @v data              Received data
263  * @v len               Length of received data
264  * @v st_src            Partially-filled source address
265  * @v st_dest           Partially-filled destination address
266  */
267 static int dns_newdata ( struct udp_connection *conn, void *data, size_t len,
268                          struct sockaddr_tcpip *st_src __unused,
269                          struct sockaddr_tcpip *st_dest __unused ) {
270         struct dns_request *dns =
271                 container_of ( conn, struct dns_request, udp );
272         struct dns_header *reply = data;
273         union dns_rr_info *rr_info;
274         struct sockaddr_in *sin;
275         unsigned int qtype = dns->qinfo->qtype;
276
277         /* Sanity check */
278         if ( len < sizeof ( *reply ) ) {
279                 DBGC ( dns, "DNS %p received underlength packet length %zd\n",
280                        dns, len );
281                 return -EINVAL;
282         }
283
284         /* Check reply ID matches query ID */
285         if ( reply->id != dns->query.dns.id ) {
286                 DBGC ( dns, "DNS %p received unexpected reply ID %d "
287                        "(wanted %d)\n", dns, ntohs ( reply->id ),
288                        ntohs ( dns->query.dns.id ) );
289                 return -EINVAL;
290         }
291
292         DBGC ( dns, "DNS %p received reply ID %d\n", dns, ntohs ( reply->id ));
293
294         /* Stop the retry timer.  After this point, each code path
295          * must either restart the timer by calling dns_send_packet(),
296          * or mark the DNS operation as complete by calling
297          * dns_done()
298          */
299         stop_timer ( &dns->timer );
300
301         /* Search through response for useful answers.  Do this
302          * multiple times, to take advantage of useful nameservers
303          * which send us e.g. the CNAME *and* the A record for the
304          * pointed-to name.
305          */
306         while ( ( rr_info = dns_find_rr ( dns, reply ) ) ) {
307                 switch ( rr_info->common.type ) {
308
309                 case htons ( DNS_TYPE_A ):
310
311                         /* Found the target A record */
312                         DBGC ( dns, "DNS %p found address %s\n",
313                                dns, inet_ntoa ( rr_info->a.in_addr ) );
314                         sin = ( struct sockaddr_in * ) dns->sa;
315                         sin->sin_family = AF_INET;
316                         sin->sin_addr = rr_info->a.in_addr;
317
318                         /* Mark operation as complete */
319                         dns_done ( dns, 0 );
320                         return 0;
321
322                 case htons ( DNS_TYPE_CNAME ):
323
324                         /* Found a CNAME record; update query and recurse */
325                         DBGC ( dns, "DNS %p found CNAME\n", dns );
326                         dns->qinfo = ( void * ) dns_decompress_name ( reply,
327                                                          rr_info->cname.cname,
328                                                          dns->query.payload );
329                         dns->qinfo->qtype = htons ( DNS_TYPE_A );
330                         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
331                         
332                         /* Terminate the operation if we recurse too far */
333                         if ( ++dns->recursion > DNS_MAX_CNAME_RECURSION ) {
334                                 DBGC ( dns, "DNS %p recursion exceeded\n",
335                                        dns );
336                                 dns_done ( dns, -ELOOP );
337                                 return 0;
338                         }
339                         break;
340
341                 default:
342                         DBGC ( dns, "DNS %p got unknown record type %d\n",
343                                dns, ntohs ( rr_info->common.type ) );
344                         break;
345                 }
346         }
347         
348         /* Determine what to do next based on the type of query we
349          * issued and the reponse we received
350          */
351         switch ( qtype ) {
352
353         case htons ( DNS_TYPE_A ):
354                 /* We asked for an A record and got nothing;
355                  * try the CNAME.
356                  */
357                 DBGC ( dns, "DNS %p found no A record; trying CNAME\n", dns );
358                 dns->qinfo->qtype = htons ( DNS_TYPE_CNAME );
359                 dns_send_packet ( dns );
360                 return 0;
361
362         case htons ( DNS_TYPE_CNAME ):
363                 /* We asked for a CNAME record.  If we got a response
364                  * (i.e. if the next A query is already set up), then
365                  * issue it, otherwise abort.
366                  */
367                 if ( dns->qinfo->qtype == htons ( DNS_TYPE_A ) ) {
368                         dns_send_packet ( dns );
369                         return 0;
370                 } else {
371                         DBGC ( dns, "DNS %p found no CNAME record\n", dns );
372                         dns_done ( dns, -ENXIO );
373                         return 0;
374                 }
375
376         default:
377                 assert ( 0 );
378                 dns_done ( dns, -EINVAL );
379                 return 0;
380         }
381 }
382
383 /** DNS UDP operations */
384 struct udp_operations dns_udp_operations = {
385         .newdata = dns_newdata,
386 };
387
388 /**
389  * Reap asynchronous operation
390  *
391  * @v async             Asynchronous operation
392  */
393 static void dns_reap ( struct async *async ) {
394         struct dns_request *dns =
395                 container_of ( async, struct dns_request, async );
396
397         free ( dns );
398 }
399
400 /**
401  * Handle SIGKILL
402  *
403  * @v async             Asynchronous operation
404  */
405 static void dns_sigkill ( struct async *async, enum signal signal __unused ) {
406         struct dns_request *dns =
407                 container_of ( async, struct dns_request, async );
408
409         dns_done ( dns, -ECANCELED );
410 }
411
412 /** DNS asynchronous operations */
413 static struct async_operations dns_async_operations = {
414         .reap = dns_reap,
415         .signal = {
416                 [SIGKILL] = dns_sigkill,
417         },
418 };
419
420 /**
421  * Resolve name using DNS
422  *
423  * @v name              Host name to resolve
424  * @v sa                Socket address to fill in
425  * @v parent            Parent asynchronous operation
426  * @ret rc              Return status code
427  */
428 int dns_resolv ( const char *name, struct sockaddr *sa,
429                  struct async *parent ) {
430         struct dns_request *dns;
431         union {
432                 struct sockaddr_tcpip st;
433                 struct sockaddr_in sin;
434         } server;
435         int rc;
436
437         /* Allocate DNS structure */
438         dns = malloc ( sizeof ( *dns ) );
439         if ( ! dns ) {
440                 rc = -ENOMEM;
441                 goto err;
442         }
443         memset ( dns, 0, sizeof ( *dns ) );
444         dns->sa = sa;
445         dns->timer.expired = dns_timer_expired;
446         dns->udp.udp_op = &dns_udp_operations;
447         async_init ( &dns->async, &dns_async_operations, parent );
448
449         /* Create query */
450         dns->query.dns.flags = htons ( DNS_FLAG_QUERY | DNS_FLAG_OPCODE_QUERY |
451                                        DNS_FLAG_RD );
452         dns->query.dns.qdcount = htons ( 1 );
453         dns->qinfo = ( void * ) dns_make_name ( name, dns->query.payload );
454         dns->qinfo->qtype = htons ( DNS_TYPE_A );
455         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
456
457         /* Identify nameserver */
458         memset ( &server, 0, sizeof ( server ) );
459         server.sin.sin_family = AF_INET;
460         server.sin.sin_port = htons ( DNS_PORT );
461         server.sin.sin_addr = nameserver;
462         if ( server.sin.sin_addr.s_addr == INADDR_NONE ) {
463                 DBGC ( dns, "DNS %p no name servers\n", dns );
464                 rc = -ENXIO;
465                 goto err;
466         }
467
468         /* Open UDP connection */
469         DBGC ( dns, "DNS %p using nameserver %s\n", dns, 
470                inet_ntoa ( server.sin.sin_addr ) );
471         udp_connect ( &dns->udp, &server.st );
472         if ( ( rc = udp_open ( &dns->udp, 0 ) ) != 0 )
473                 goto err;
474
475         /* Send first DNS packet */
476         dns_send_packet ( dns );
477
478         return 0;       
479
480  err:
481         DBGC ( dns, "DNS %p could not create request: %s\n", 
482                dns, strerror ( rc ) );
483         async_uninit ( &dns->async );
484         free ( dns );
485         return rc;
486 }
487
488 /** DNS name resolver */
489 struct resolver dns_resolver __resolver ( RESOLV_NORMAL ) = {
490         .name = "DNS",
491         .resolv = dns_resolv,
492 };