c636790449512694c88b678ae4125b1035021cb3
[people/xl0/gpxe.git] / src / net / udp / dns.c
1 /*
2  * Copyright (C) 2006 Michael Brown <mbrown@fensystems.co.uk>.
3  *
4  * Portions copyright (C) 2004 Anselm M. Hoffmeister
5  * <stockholm@users.sourceforge.net>.
6  *
7  * This program is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License as
9  * published by the Free Software Foundation; either version 2 of the
10  * License, or any later version.
11  *
12  * This program is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20  */
21
22 #include <stdint.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <errno.h>
26 #include <byteswap.h>
27 #include <gpxe/async.h>
28 #include <gpxe/udp.h>
29 #include <gpxe/dhcp.h>
30 #include <gpxe/dns.h>
31
32 /** @file
33  *
34  * DNS protocol
35  *
36  */
37
38 /**
39  * Compare DNS reply name against the query name from the original request
40  *
41  * @v dns               DNS request
42  * @v reply             DNS reply
43  * @v rname             Reply name
44  * @ret zero            Names match
45  * @ret non-zero        Names do not match
46  */
47 static int dns_name_cmp ( struct dns_request *dns, struct dns_header *reply, 
48                           const char *rname ) {
49         const char *qname = dns->query.payload;
50         int i;
51
52         while ( 1 ) {
53                 /* Obtain next section of rname */
54                 while ( ( *rname ) & 0xc0 ) {                   
55                         rname = ( ( ( char * ) reply ) +
56                                   ( ntohs( *((uint16_t *)rname) ) & ~0xc000 ));
57                 }
58                 /* Check that lengths match */
59                 if ( *rname != *qname )
60                         return -1;
61                 /* If length is zero, we have reached the end */
62                 if ( ! *qname )
63                         return 0;
64                 /* Check that data matches */
65                 for ( i = *qname + 1; i > 0 ; i-- ) {
66                         if ( *(rname++) != *(qname++) )
67                                 return -1;
68                 }
69         }
70 }
71
72 /**
73  * Skip over a (possibly compressed) DNS name
74  *
75  * @v name              DNS name
76  * @ret name            Next DNS name
77  */
78 static const char * dns_skip_name ( const char *name ) {
79         while ( 1 ) {
80                 if ( ! *name ) {
81                         /* End of name */
82                         return ( name + 1);
83                 }
84                 if ( *name & 0xc0 ) {
85                         /* Start of a compressed name */
86                         return ( name + 2 );
87                 }
88                 /* Uncompressed name portion */
89                 name += *name + 1;
90         }
91 }
92
93 /**
94  * Find an RR in a reply packet corresponding to our query
95  *
96  * @v dns               DNS request
97  * @v reply             DNS reply
98  * @ret rr              DNS RR, or NULL if not found
99  */
100 static union dns_rr_info * dns_find_rr ( struct dns_request *dns,
101                                          struct dns_header *reply ) {
102         int i, cmp;
103         const char *p = ( ( char * ) reply ) + sizeof ( struct dns_header );
104         union dns_rr_info *rr_info;
105
106         /* Skip over the questions section */
107         for ( i = ntohs ( reply->qdcount ) ; i > 0 ; i-- ) {
108                 p = dns_skip_name ( p ) + sizeof ( struct dns_query_info );
109         }
110
111         /* Process the answers section */
112         for ( i = ntohs ( reply->ancount ) ; i > 0 ; i-- ) {
113                 cmp = dns_name_cmp ( dns, reply, p );
114                 p = dns_skip_name ( p );
115                 rr_info = ( ( union dns_rr_info * ) p );
116                 if ( cmp == 0 )
117                         return rr_info;
118                 p += ( sizeof ( rr_info->common ) +
119                        ntohs ( rr_info->common.rdlength ) );
120         }
121
122         return NULL;
123 }
124
125 /**
126  * Convert a standard NUL-terminated string to a DNS name
127  *
128  * @v string            Name as a NUL-terminated string
129  * @v buf               Buffer in which to place DNS name
130  * @ret next            Byte following constructed DNS name
131  *
132  * DNS names consist of "<length>element" pairs.
133  */
134 static char * dns_make_name ( const char *string, char *buf ) {
135         char *length_byte = buf++;
136         char c;
137
138         while ( ( c = *(string++) ) ) {
139                 if ( c == '.' ) {
140                         *length_byte = buf - length_byte - 1;
141                         length_byte = buf;
142                 }
143                 *(buf++) = c;
144         }
145         *length_byte = buf - length_byte - 1;
146         *(buf++) = '\0';
147         return buf;
148 }
149
150 /**
151  * Convert an uncompressed DNS name to a NUL-terminated string
152  *
153  * @v name              DNS name
154  * @ret string          NUL-terminated string
155  *
156  * Produce a printable version of a DNS name.  Used only for debugging.
157  */
158 static inline char * dns_unmake_name ( char *name ) {
159         char *p;
160         unsigned int len;
161
162         p = name;
163         while ( ( len = *p ) ) {
164                 *(p++) = '.';
165                 p += len;
166         }
167
168         return name + 1;
169 }
170
171 /**
172  * Decompress a DNS name
173  *
174  * @v reply             DNS replay
175  * @v name              DNS name
176  * @v buf               Buffer into which to decompress DNS name
177  * @ret next            Byte following decompressed DNS name
178  */
179 static char * dns_decompress_name ( struct dns_header *reply,
180                                     const char *name, char *buf ) {
181         int i, len;
182
183         do {
184                 /* Obtain next section of name */
185                 while ( ( *name ) & 0xc0 ) {
186                         name = ( ( char * ) reply +
187                                  ( ntohs ( *((uint16_t *)name) ) & ~0xc000 ) );
188                 }
189                 /* Copy data */
190                 len = *name;
191                 for ( i = len + 1 ; i > 0 ; i-- ) {
192                         *(buf++) = *(name++);
193                 }
194         } while ( len );
195         return buf;
196 }
197
198 /**
199  * Mark DNS request as complete
200  *
201  * @v dns               DNS request
202  * @v rc                Return status code
203  */
204 static void dns_done ( struct dns_request *dns, int rc ) {
205
206         /* Stop the retry timer */
207         stop_timer ( &dns->timer );
208
209         /* Close UDP connection */
210         udp_close ( &dns->udp );
211
212         /* Mark async operation as complete */
213         async_done ( &dns->async, rc );
214 }
215
216 /**
217  * Send next packet in DNS request
218  *
219  * @v dns               DNS request
220  */
221 static void dns_send_packet ( struct dns_request *dns ) {
222         static unsigned int qid = 0;
223
224         /* Increment query ID */
225         dns->query.dns.id = htons ( ++qid );
226
227         DBGC ( dns, "DNS %p sending query ID %d\n", dns, qid );
228
229         /* Start retransmission timer */
230         start_timer ( &dns->timer );
231
232         /* Send the data */
233         udp_send ( &dns->udp, &dns->query,
234                    ( ( ( void * ) dns->qinfo ) - ( ( void * ) &dns->query )
235                      + sizeof ( dns->qinfo ) ) );
236 }
237
238 /**
239  * Handle DNS retransmission timer expiry
240  *
241  * @v timer             Retry timer
242  * @v fail              Failure indicator
243  */
244 static void dns_timer_expired ( struct retry_timer *timer, int fail ) {
245         struct dns_request *dns =
246                 container_of ( timer, struct dns_request, timer );
247
248         if ( fail ) {
249                 dns_done ( dns, -ETIMEDOUT );
250         } else {
251                 dns_send_packet ( dns );
252         }
253 }
254
255 /**
256  * Receive new data
257  *
258  * @v udp               UDP connection
259  * @v data              Received data
260  * @v len               Length of received data
261  * @v st_src            Partially-filled source address
262  * @v st_dest           Partially-filled destination address
263  */
264 static int dns_newdata ( struct udp_connection *conn, void *data, size_t len,
265                          struct sockaddr_tcpip *st_src __unused,
266                          struct sockaddr_tcpip *st_dest __unused ) {
267         struct dns_request *dns =
268                 container_of ( conn, struct dns_request, udp );
269         struct dns_header *reply = data;
270         union dns_rr_info *rr_info;
271         struct sockaddr_in *sin;
272         unsigned int qtype = dns->qinfo->qtype;
273
274         /* Sanity check */
275         if ( len < sizeof ( *reply ) ) {
276                 DBGC ( dns, "DNS %p received underlength packet length %zd\n",
277                        dns, len );
278                 return -EINVAL;
279         }
280
281         /* Check reply ID matches query ID */
282         if ( reply->id != dns->query.dns.id ) {
283                 DBGC ( dns, "DNS %p received unexpected reply ID %d "
284                        "(wanted %d)\n", dns, ntohs ( reply->id ),
285                        ntohs ( dns->query.dns.id ) );
286                 return -EINVAL;
287         }
288
289         DBGC ( dns, "DNS %p received reply ID %d\n", dns, ntohs ( reply->id ));
290
291         /* Stop the retry timer.  After this point, each code path
292          * must either restart the timer by calling dns_send_packet(),
293          * or mark the DNS operation as complete by calling
294          * dns_done()
295          */
296         stop_timer ( &dns->timer );
297
298         /* Search through response for useful answers.  Do this
299          * multiple times, to take advantage of useful nameservers
300          * which send us e.g. the CNAME *and* the A record for the
301          * pointed-to name.
302          */
303         while ( ( rr_info = dns_find_rr ( dns, reply ) ) ) {
304                 switch ( rr_info->common.type ) {
305
306                 case htons ( DNS_TYPE_A ):
307
308                         /* Found the target A record */
309                         DBGC ( dns, "DNS %p found address %s\n",
310                                dns, inet_ntoa ( rr_info->a.in_addr ) );
311                         sin = ( struct sockaddr_in * ) dns->sa;
312                         sin->sin_family = AF_INET;
313                         sin->sin_addr = rr_info->a.in_addr;
314
315                         /* Mark operation as complete */
316                         dns_done ( dns, 0 );
317                         return 0;
318
319                 case htons ( DNS_TYPE_CNAME ):
320
321                         /* Found a CNAME record; update query and recurse */
322                         DBGC ( dns, "DNS %p found CNAME\n", dns );
323                         dns->qinfo = ( void * ) dns_decompress_name ( reply,
324                                                          rr_info->cname.cname,
325                                                          dns->query.payload );
326                         dns->qinfo->qtype = htons ( DNS_TYPE_A );
327                         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
328                         
329                         /* Terminate the operation if we recurse too far */
330                         if ( ++dns->recursion > DNS_MAX_CNAME_RECURSION ) {
331                                 DBGC ( dns, "DNS %p recursion exceeded\n",
332                                        dns );
333                                 dns_done ( dns, -ELOOP );
334                                 return 0;
335                         }
336                         break;
337
338                 default:
339                         DBGC ( dns, "DNS %p got unknown record type %d\n",
340                                dns, ntohs ( rr_info->common.type ) );
341                         break;
342                 }
343         }
344         
345         /* Determine what to do next based on the type of query we
346          * issued and the reponse we received
347          */
348         switch ( qtype ) {
349
350         case htons ( DNS_TYPE_A ):
351                 /* We asked for an A record and got nothing;
352                  * try the CNAME.
353                  */
354                 DBGC ( dns, "DNS %p found no A record; trying CNAME\n", dns );
355                 dns->qinfo->qtype = htons ( DNS_TYPE_CNAME );
356                 dns_send_packet ( dns );
357                 return 0;
358
359         case htons ( DNS_TYPE_CNAME ):
360                 /* We asked for a CNAME record.  If we got a response
361                  * (i.e. if the next A query is already set up), then
362                  * issue it, otherwise abort.
363                  */
364                 if ( dns->qinfo->qtype == htons ( DNS_TYPE_A ) ) {
365                         dns_send_packet ( dns );
366                         return 0;
367                 } else {
368                         DBGC ( dns, "DNS %p found no CNAME record\n", dns );
369                         dns_done ( dns, -ENXIO );
370                         return 0;
371                 }
372
373         default:
374                 assert ( 0 );
375                 dns_done ( dns, -EINVAL );
376                 return 0;
377         }
378 }
379
380 /** DNS UDP operations */
381 struct udp_operations dns_udp_operations = {
382         .newdata = dns_newdata,
383 };
384
385 /**
386  * Reap asynchronous operation
387  *
388  * @v async             Asynchronous operation
389  */
390 static void dns_reap ( struct async *async ) {
391         struct dns_request *dns =
392                 container_of ( async, struct dns_request, async );
393
394         free ( dns );
395 }
396
397 /** DNS asynchronous operations */
398 static struct async_operations dns_async_operations = {
399         .reap = dns_reap,
400 };
401
402 /**
403  * Resolve name using DNS
404  *
405  */
406 int dns_resolv ( const char *name, struct sockaddr *sa,
407                  struct async *parent ) {
408         struct dns_request *dns;
409         struct dhcp_option *option;
410         union {
411                 struct sockaddr_tcpip st;
412                 struct sockaddr_in sin;
413         } nameserver;
414
415         int rc;
416
417         /* Allocate DNS structure */
418         dns = malloc ( sizeof ( *dns ) );
419         if ( ! dns ) {
420                 rc = -ENOMEM;
421                 goto err;
422         }
423         memset ( dns, 0, sizeof ( *dns ) );
424         dns->sa = sa;
425         dns->timer.expired = dns_timer_expired;
426         dns->udp.udp_op = &dns_udp_operations;
427         async_init ( &dns->async, &dns_async_operations, parent );
428
429         /* Create query */
430         dns->query.dns.flags = htons ( DNS_FLAG_QUERY | DNS_FLAG_OPCODE_QUERY |
431                                        DNS_FLAG_RD );
432         dns->query.dns.qdcount = htons ( 1 );
433         dns->qinfo = ( void * ) dns_make_name ( name, dns->query.payload );
434         dns->qinfo->qtype = htons ( DNS_TYPE_A );
435         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
436
437         /* Identify nameserver */
438         memset ( &nameserver, 0, sizeof ( nameserver ) );
439         nameserver.sin.sin_family = AF_INET;
440         nameserver.sin.sin_port = htons ( DNS_PORT );
441         if ( ! ( option = find_global_dhcp_option ( DHCP_DNS_SERVERS ) ) ) {
442                 DBGC ( dns, "DNS %p no name servers\n", dns );
443                 rc = -ENXIO;
444                 goto err;
445         }
446         dhcp_ipv4_option ( option, &nameserver.sin.sin_addr );
447
448         /* Open UDP connection */
449         DBGC ( dns, "DNS %p using nameserver %s\n", dns, 
450                inet_ntoa ( nameserver.sin.sin_addr ) );
451         udp_connect ( &dns->udp, &nameserver.st );
452         if ( ( rc = udp_open ( &dns->udp, 0 ) ) != 0 )
453                 goto err;
454
455         /* Send first DNS packet */
456         dns_send_packet ( dns );
457
458         return 0;       
459
460  err:
461         DBGC ( dns, "DNS %p could not create request: %s\n", 
462                dns, strerror ( rc ) );
463         async_uninit ( &dns->async );
464         free ( dns );
465         return rc;
466 }