Have DHCP set the nameserver, rather than DNS read the DHCP nameserver
[people/xl0/gpxe.git] / src / net / udp / dns.c
1 /*
2  * Copyright (C) 2006 Michael Brown <mbrown@fensystems.co.uk>.
3  *
4  * Portions copyright (C) 2004 Anselm M. Hoffmeister
5  * <stockholm@users.sourceforge.net>.
6  *
7  * This program is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License as
9  * published by the Free Software Foundation; either version 2 of the
10  * License, or any later version.
11  *
12  * This program is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20  */
21
22 #include <stdint.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <errno.h>
26 #include <byteswap.h>
27 #include <gpxe/async.h>
28 #include <gpxe/udp.h>
29 #include <gpxe/dns.h>
30
31 /** @file
32  *
33  * DNS protocol
34  *
35  */
36
37 /* The DNS server */
38 struct in_addr nameserver = { INADDR_NONE };
39
40 /**
41  * Compare DNS reply name against the query name from the original request
42  *
43  * @v dns               DNS request
44  * @v reply             DNS reply
45  * @v rname             Reply name
46  * @ret zero            Names match
47  * @ret non-zero        Names do not match
48  */
49 static int dns_name_cmp ( struct dns_request *dns, struct dns_header *reply, 
50                           const char *rname ) {
51         const char *qname = dns->query.payload;
52         int i;
53
54         while ( 1 ) {
55                 /* Obtain next section of rname */
56                 while ( ( *rname ) & 0xc0 ) {
57                         rname = ( ( ( char * ) reply ) +
58                                   ( ntohs( *((uint16_t *)rname) ) & ~0xc000 ));
59                 }
60                 /* Check that lengths match */
61                 if ( *rname != *qname )
62                         return -1;
63                 /* If length is zero, we have reached the end */
64                 if ( ! *qname )
65                         return 0;
66                 /* Check that data matches */
67                 for ( i = *qname + 1; i > 0 ; i-- ) {
68                         if ( *(rname++) != *(qname++) )
69                                 return -1;
70                 }
71         }
72 }
73
74 /**
75  * Skip over a (possibly compressed) DNS name
76  *
77  * @v name              DNS name
78  * @ret name            Next DNS name
79  */
80 static const char * dns_skip_name ( const char *name ) {
81         while ( 1 ) {
82                 if ( ! *name ) {
83                         /* End of name */
84                         return ( name + 1);
85                 }
86                 if ( *name & 0xc0 ) {
87                         /* Start of a compressed name */
88                         return ( name + 2 );
89                 }
90                 /* Uncompressed name portion */
91                 name += *name + 1;
92         }
93 }
94
95 /**
96  * Find an RR in a reply packet corresponding to our query
97  *
98  * @v dns               DNS request
99  * @v reply             DNS reply
100  * @ret rr              DNS RR, or NULL if not found
101  */
102 static union dns_rr_info * dns_find_rr ( struct dns_request *dns,
103                                          struct dns_header *reply ) {
104         int i, cmp;
105         const char *p = ( ( char * ) reply ) + sizeof ( struct dns_header );
106         union dns_rr_info *rr_info;
107
108         /* Skip over the questions section */
109         for ( i = ntohs ( reply->qdcount ) ; i > 0 ; i-- ) {
110                 p = dns_skip_name ( p ) + sizeof ( struct dns_query_info );
111         }
112
113         /* Process the answers section */
114         for ( i = ntohs ( reply->ancount ) ; i > 0 ; i-- ) {
115                 cmp = dns_name_cmp ( dns, reply, p );
116                 p = dns_skip_name ( p );
117                 rr_info = ( ( union dns_rr_info * ) p );
118                 if ( cmp == 0 )
119                         return rr_info;
120                 p += ( sizeof ( rr_info->common ) +
121                        ntohs ( rr_info->common.rdlength ) );
122         }
123
124         return NULL;
125 }
126
127 /**
128  * Convert a standard NUL-terminated string to a DNS name
129  *
130  * @v string            Name as a NUL-terminated string
131  * @v buf               Buffer in which to place DNS name
132  * @ret next            Byte following constructed DNS name
133  *
134  * DNS names consist of "<length>element" pairs.
135  */
136 static char * dns_make_name ( const char *string, char *buf ) {
137         char *length_byte = buf++;
138         char c;
139
140         while ( ( c = *(string++) ) ) {
141                 if ( c == '.' ) {
142                         *length_byte = buf - length_byte - 1;
143                         length_byte = buf;
144                 }
145                 *(buf++) = c;
146         }
147         *length_byte = buf - length_byte - 1;
148         *(buf++) = '\0';
149         return buf;
150 }
151
152 /**
153  * Convert an uncompressed DNS name to a NUL-terminated string
154  *
155  * @v name              DNS name
156  * @ret string          NUL-terminated string
157  *
158  * Produce a printable version of a DNS name.  Used only for debugging.
159  */
160 static inline char * dns_unmake_name ( char *name ) {
161         char *p;
162         unsigned int len;
163
164         p = name;
165         while ( ( len = *p ) ) {
166                 *(p++) = '.';
167                 p += len;
168         }
169
170         return name + 1;
171 }
172
173 /**
174  * Decompress a DNS name
175  *
176  * @v reply             DNS replay
177  * @v name              DNS name
178  * @v buf               Buffer into which to decompress DNS name
179  * @ret next            Byte following decompressed DNS name
180  */
181 static char * dns_decompress_name ( struct dns_header *reply,
182                                     const char *name, char *buf ) {
183         int i, len;
184
185         do {
186                 /* Obtain next section of name */
187                 while ( ( *name ) & 0xc0 ) {
188                         name = ( ( char * ) reply +
189                                  ( ntohs ( *((uint16_t *)name) ) & ~0xc000 ) );
190                 }
191                 /* Copy data */
192                 len = *name;
193                 for ( i = len + 1 ; i > 0 ; i-- ) {
194                         *(buf++) = *(name++);
195                 }
196         } while ( len );
197         return buf;
198 }
199
200 /**
201  * Mark DNS request as complete
202  *
203  * @v dns               DNS request
204  * @v rc                Return status code
205  */
206 static void dns_done ( struct dns_request *dns, int rc ) {
207
208         /* Stop the retry timer */
209         stop_timer ( &dns->timer );
210
211         /* Close UDP connection */
212         udp_close ( &dns->udp );
213
214         /* Mark async operation as complete */
215         async_done ( &dns->async, rc );
216 }
217
218 /**
219  * Send next packet in DNS request
220  *
221  * @v dns               DNS request
222  */
223 static void dns_send_packet ( struct dns_request *dns ) {
224         static unsigned int qid = 0;
225
226         /* Increment query ID */
227         dns->query.dns.id = htons ( ++qid );
228
229         DBGC ( dns, "DNS %p sending query ID %d\n", dns, qid );
230
231         /* Start retransmission timer */
232         start_timer ( &dns->timer );
233
234         /* Send the data */
235         udp_send ( &dns->udp, &dns->query,
236                    ( ( ( void * ) dns->qinfo ) - ( ( void * ) &dns->query )
237                      + sizeof ( dns->qinfo ) ) );
238 }
239
240 /**
241  * Handle DNS retransmission timer expiry
242  *
243  * @v timer             Retry timer
244  * @v fail              Failure indicator
245  */
246 static void dns_timer_expired ( struct retry_timer *timer, int fail ) {
247         struct dns_request *dns =
248                 container_of ( timer, struct dns_request, timer );
249
250         if ( fail ) {
251                 dns_done ( dns, -ETIMEDOUT );
252         } else {
253                 dns_send_packet ( dns );
254         }
255 }
256
257 /**
258  * Receive new data
259  *
260  * @v udp               UDP connection
261  * @v data              Received data
262  * @v len               Length of received data
263  * @v st_src            Partially-filled source address
264  * @v st_dest           Partially-filled destination address
265  */
266 static int dns_newdata ( struct udp_connection *conn, void *data, size_t len,
267                          struct sockaddr_tcpip *st_src __unused,
268                          struct sockaddr_tcpip *st_dest __unused ) {
269         struct dns_request *dns =
270                 container_of ( conn, struct dns_request, udp );
271         struct dns_header *reply = data;
272         union dns_rr_info *rr_info;
273         struct sockaddr_in *sin;
274         unsigned int qtype = dns->qinfo->qtype;
275
276         /* Sanity check */
277         if ( len < sizeof ( *reply ) ) {
278                 DBGC ( dns, "DNS %p received underlength packet length %zd\n",
279                        dns, len );
280                 return -EINVAL;
281         }
282
283         /* Check reply ID matches query ID */
284         if ( reply->id != dns->query.dns.id ) {
285                 DBGC ( dns, "DNS %p received unexpected reply ID %d "
286                        "(wanted %d)\n", dns, ntohs ( reply->id ),
287                        ntohs ( dns->query.dns.id ) );
288                 return -EINVAL;
289         }
290
291         DBGC ( dns, "DNS %p received reply ID %d\n", dns, ntohs ( reply->id ));
292
293         /* Stop the retry timer.  After this point, each code path
294          * must either restart the timer by calling dns_send_packet(),
295          * or mark the DNS operation as complete by calling
296          * dns_done()
297          */
298         stop_timer ( &dns->timer );
299
300         /* Search through response for useful answers.  Do this
301          * multiple times, to take advantage of useful nameservers
302          * which send us e.g. the CNAME *and* the A record for the
303          * pointed-to name.
304          */
305         while ( ( rr_info = dns_find_rr ( dns, reply ) ) ) {
306                 switch ( rr_info->common.type ) {
307
308                 case htons ( DNS_TYPE_A ):
309
310                         /* Found the target A record */
311                         DBGC ( dns, "DNS %p found address %s\n",
312                                dns, inet_ntoa ( rr_info->a.in_addr ) );
313                         sin = ( struct sockaddr_in * ) dns->sa;
314                         sin->sin_family = AF_INET;
315                         sin->sin_addr = rr_info->a.in_addr;
316
317                         /* Mark operation as complete */
318                         dns_done ( dns, 0 );
319                         return 0;
320
321                 case htons ( DNS_TYPE_CNAME ):
322
323                         /* Found a CNAME record; update query and recurse */
324                         DBGC ( dns, "DNS %p found CNAME\n", dns );
325                         dns->qinfo = ( void * ) dns_decompress_name ( reply,
326                                                          rr_info->cname.cname,
327                                                          dns->query.payload );
328                         dns->qinfo->qtype = htons ( DNS_TYPE_A );
329                         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
330                         
331                         /* Terminate the operation if we recurse too far */
332                         if ( ++dns->recursion > DNS_MAX_CNAME_RECURSION ) {
333                                 DBGC ( dns, "DNS %p recursion exceeded\n",
334                                        dns );
335                                 dns_done ( dns, -ELOOP );
336                                 return 0;
337                         }
338                         break;
339
340                 default:
341                         DBGC ( dns, "DNS %p got unknown record type %d\n",
342                                dns, ntohs ( rr_info->common.type ) );
343                         break;
344                 }
345         }
346         
347         /* Determine what to do next based on the type of query we
348          * issued and the reponse we received
349          */
350         switch ( qtype ) {
351
352         case htons ( DNS_TYPE_A ):
353                 /* We asked for an A record and got nothing;
354                  * try the CNAME.
355                  */
356                 DBGC ( dns, "DNS %p found no A record; trying CNAME\n", dns );
357                 dns->qinfo->qtype = htons ( DNS_TYPE_CNAME );
358                 dns_send_packet ( dns );
359                 return 0;
360
361         case htons ( DNS_TYPE_CNAME ):
362                 /* We asked for a CNAME record.  If we got a response
363                  * (i.e. if the next A query is already set up), then
364                  * issue it, otherwise abort.
365                  */
366                 if ( dns->qinfo->qtype == htons ( DNS_TYPE_A ) ) {
367                         dns_send_packet ( dns );
368                         return 0;
369                 } else {
370                         DBGC ( dns, "DNS %p found no CNAME record\n", dns );
371                         dns_done ( dns, -ENXIO );
372                         return 0;
373                 }
374
375         default:
376                 assert ( 0 );
377                 dns_done ( dns, -EINVAL );
378                 return 0;
379         }
380 }
381
382 /** DNS UDP operations */
383 struct udp_operations dns_udp_operations = {
384         .newdata = dns_newdata,
385 };
386
387 /**
388  * Reap asynchronous operation
389  *
390  * @v async             Asynchronous operation
391  */
392 static void dns_reap ( struct async *async ) {
393         struct dns_request *dns =
394                 container_of ( async, struct dns_request, async );
395
396         free ( dns );
397 }
398
399 /** DNS asynchronous operations */
400 static struct async_operations dns_async_operations = {
401         .reap = dns_reap,
402 };
403
404 /**
405  * Resolve name using DNS
406  *
407  * @v name              Host name to resolve
408  * @v sa                Socket address to fill in
409  * @v parent            Parent asynchronous operation
410  * @ret rc              Return status code
411  */
412 int dns_resolv ( const char *name, struct sockaddr *sa,
413                  struct async *parent ) {
414         struct dns_request *dns;
415         union {
416                 struct sockaddr_tcpip st;
417                 struct sockaddr_in sin;
418         } server;
419         int rc;
420
421         /* Allocate DNS structure */
422         dns = malloc ( sizeof ( *dns ) );
423         if ( ! dns ) {
424                 rc = -ENOMEM;
425                 goto err;
426         }
427         memset ( dns, 0, sizeof ( *dns ) );
428         dns->sa = sa;
429         dns->timer.expired = dns_timer_expired;
430         dns->udp.udp_op = &dns_udp_operations;
431         async_init ( &dns->async, &dns_async_operations, parent );
432
433         /* Create query */
434         dns->query.dns.flags = htons ( DNS_FLAG_QUERY | DNS_FLAG_OPCODE_QUERY |
435                                        DNS_FLAG_RD );
436         dns->query.dns.qdcount = htons ( 1 );
437         dns->qinfo = ( void * ) dns_make_name ( name, dns->query.payload );
438         dns->qinfo->qtype = htons ( DNS_TYPE_A );
439         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
440
441         /* Identify nameserver */
442         memset ( &server, 0, sizeof ( server ) );
443         server.sin.sin_family = AF_INET;
444         server.sin.sin_port = htons ( DNS_PORT );
445         server.sin.sin_addr = nameserver;
446         if ( server.sin.sin_addr.s_addr == INADDR_NONE ) {
447                 DBGC ( dns, "DNS %p no name servers\n", dns );
448                 rc = -ENXIO;
449                 goto err;
450         }
451
452         /* Open UDP connection */
453         DBGC ( dns, "DNS %p using nameserver %s\n", dns, 
454                inet_ntoa ( server.sin.sin_addr ) );
455         udp_connect ( &dns->udp, &server.st );
456         if ( ( rc = udp_open ( &dns->udp, 0 ) ) != 0 )
457                 goto err;
458
459         /* Send first DNS packet */
460         dns_send_packet ( dns );
461
462         return 0;       
463
464  err:
465         DBGC ( dns, "DNS %p could not create request: %s\n", 
466                dns, strerror ( rc ) );
467         async_uninit ( &dns->async );
468         free ( dns );
469         return rc;
470 }