Updated DNS to use not-yet-implemented UDP data-xfer API.
[people/xl0/gpxe-arm.git] / src / net / udp / dns.c
1 /*
2  * Copyright (C) 2006 Michael Brown <mbrown@fensystems.co.uk>.
3  *
4  * Portions copyright (C) 2004 Anselm M. Hoffmeister
5  * <stockholm@users.sourceforge.net>.
6  *
7  * This program is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License as
9  * published by the Free Software Foundation; either version 2 of the
10  * License, or any later version.
11  *
12  * This program is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20  */
21
22 #include <stdint.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <errno.h>
26 #include <byteswap.h>
27 #include <gpxe/refcnt.h>
28 #include <gpxe/xfer.h>
29 #include <gpxe/open.h>
30 #include <gpxe/resolv.h>
31 #include <gpxe/retry.h>
32 #include <gpxe/tcpip.h>
33 #include <gpxe/dns.h>
34
35 /** @file
36  *
37  * DNS protocol
38  *
39  */
40
41 /** The DNS server */
42 struct sockaddr_tcpip nameserver = {
43         .st_port = htons ( DNS_PORT ),
44 };
45
46 /** A DNS request */
47 struct dns_request {
48         /** Reference counter */
49         struct refcnt refcnt;
50         /** Name resolution interface */
51         struct resolv_interface resolv;
52         /** Data transfer interface */
53         struct xfer_interface socket;
54         /** Retry timer */
55         struct retry_timer timer;
56
57         /** Socket address to fill in with resolved address */
58         struct sockaddr sa;
59         /** Current query packet */
60         struct dns_query query;
61         /** Location of query info structure within current packet
62          *
63          * The query info structure is located immediately after the
64          * compressed name.
65          */
66         struct dns_query_info *qinfo;
67         /** Recursion counter */
68         unsigned int recursion;
69 };
70
71 /**
72  * Mark DNS request as complete
73  *
74  * @v dns               DNS request
75  * @v rc                Return status code
76  */
77 static void dns_done ( struct dns_request *dns, int rc ) {
78
79         /* Stop the retry timer */
80         stop_timer ( &dns->timer );
81
82         /* Close data transfer interface */
83         xfer_nullify ( &dns->socket );
84         xfer_close ( &dns->socket, rc );
85
86         /* Mark name resolution as complete */
87         resolv_done ( &dns->resolv, &dns->sa, rc );
88 }
89
90 /**
91  * Compare DNS reply name against the query name from the original request
92  *
93  * @v dns               DNS request
94  * @v reply             DNS reply
95  * @v rname             Reply name
96  * @ret zero            Names match
97  * @ret non-zero        Names do not match
98  */
99 static int dns_name_cmp ( struct dns_request *dns,
100                           const struct dns_header *reply, 
101                           const char *rname ) {
102         const char *qname = dns->query.payload;
103         int i;
104
105         while ( 1 ) {
106                 /* Obtain next section of rname */
107                 while ( ( *rname ) & 0xc0 ) {
108                         rname = ( ( ( char * ) reply ) +
109                                   ( ntohs( *((uint16_t *)rname) ) & ~0xc000 ));
110                 }
111                 /* Check that lengths match */
112                 if ( *rname != *qname )
113                         return -1;
114                 /* If length is zero, we have reached the end */
115                 if ( ! *qname )
116                         return 0;
117                 /* Check that data matches */
118                 for ( i = *qname + 1; i > 0 ; i-- ) {
119                         if ( *(rname++) != *(qname++) )
120                                 return -1;
121                 }
122         }
123 }
124
125 /**
126  * Skip over a (possibly compressed) DNS name
127  *
128  * @v name              DNS name
129  * @ret name            Next DNS name
130  */
131 static const char * dns_skip_name ( const char *name ) {
132         while ( 1 ) {
133                 if ( ! *name ) {
134                         /* End of name */
135                         return ( name + 1);
136                 }
137                 if ( *name & 0xc0 ) {
138                         /* Start of a compressed name */
139                         return ( name + 2 );
140                 }
141                 /* Uncompressed name portion */
142                 name += *name + 1;
143         }
144 }
145
146 /**
147  * Find an RR in a reply packet corresponding to our query
148  *
149  * @v dns               DNS request
150  * @v reply             DNS reply
151  * @ret rr              DNS RR, or NULL if not found
152  */
153 static union dns_rr_info * dns_find_rr ( struct dns_request *dns,
154                                          const struct dns_header *reply ) {
155         int i, cmp;
156         const char *p = ( ( char * ) reply ) + sizeof ( struct dns_header );
157         union dns_rr_info *rr_info;
158
159         /* Skip over the questions section */
160         for ( i = ntohs ( reply->qdcount ) ; i > 0 ; i-- ) {
161                 p = dns_skip_name ( p ) + sizeof ( struct dns_query_info );
162         }
163
164         /* Process the answers section */
165         for ( i = ntohs ( reply->ancount ) ; i > 0 ; i-- ) {
166                 cmp = dns_name_cmp ( dns, reply, p );
167                 p = dns_skip_name ( p );
168                 rr_info = ( ( union dns_rr_info * ) p );
169                 if ( cmp == 0 )
170                         return rr_info;
171                 p += ( sizeof ( rr_info->common ) +
172                        ntohs ( rr_info->common.rdlength ) );
173         }
174
175         return NULL;
176 }
177
178 /**
179  * Convert a standard NUL-terminated string to a DNS name
180  *
181  * @v string            Name as a NUL-terminated string
182  * @v buf               Buffer in which to place DNS name
183  * @ret next            Byte following constructed DNS name
184  *
185  * DNS names consist of "<length>element" pairs.
186  */
187 static char * dns_make_name ( const char *string, char *buf ) {
188         char *length_byte = buf++;
189         char c;
190
191         while ( ( c = *(string++) ) ) {
192                 if ( c == '.' ) {
193                         *length_byte = buf - length_byte - 1;
194                         length_byte = buf;
195                 }
196                 *(buf++) = c;
197         }
198         *length_byte = buf - length_byte - 1;
199         *(buf++) = '\0';
200         return buf;
201 }
202
203 /**
204  * Convert an uncompressed DNS name to a NUL-terminated string
205  *
206  * @v name              DNS name
207  * @ret string          NUL-terminated string
208  *
209  * Produce a printable version of a DNS name.  Used only for debugging.
210  */
211 static inline char * dns_unmake_name ( char *name ) {
212         char *p;
213         unsigned int len;
214
215         p = name;
216         while ( ( len = *p ) ) {
217                 *(p++) = '.';
218                 p += len;
219         }
220
221         return name + 1;
222 }
223
224 /**
225  * Decompress a DNS name
226  *
227  * @v reply             DNS replay
228  * @v name              DNS name
229  * @v buf               Buffer into which to decompress DNS name
230  * @ret next            Byte following decompressed DNS name
231  */
232 static char * dns_decompress_name ( const struct dns_header *reply,
233                                     const char *name, char *buf ) {
234         int i, len;
235
236         do {
237                 /* Obtain next section of name */
238                 while ( ( *name ) & 0xc0 ) {
239                         name = ( ( char * ) reply +
240                                  ( ntohs ( *((uint16_t *)name) ) & ~0xc000 ) );
241                 }
242                 /* Copy data */
243                 len = *name;
244                 for ( i = len + 1 ; i > 0 ; i-- ) {
245                         *(buf++) = *(name++);
246                 }
247         } while ( len );
248         return buf;
249 }
250
251 /**
252  * Send next packet in DNS request
253  *
254  * @v dns               DNS request
255  */
256 static int dns_send_packet ( struct dns_request *dns ) {
257         static unsigned int qid = 0;
258         size_t qlen;
259
260         /* Increment query ID */
261         dns->query.dns.id = htons ( ++qid );
262
263         DBGC ( dns, "DNS %p sending query ID %d\n", dns, qid );
264
265         /* Start retransmission timer */
266         start_timer ( &dns->timer );
267
268         /* Send the data */
269         qlen = ( ( ( void * ) dns->qinfo ) - ( ( void * ) &dns->query )
270                  + sizeof ( dns->qinfo ) );
271         return xfer_deliver_raw ( &dns->socket, &dns->query, qlen );
272 }
273
274 /**
275  * Handle DNS retransmission timer expiry
276  *
277  * @v timer             Retry timer
278  * @v fail              Failure indicator
279  */
280 static void dns_timer_expired ( struct retry_timer *timer, int fail ) {
281         struct dns_request *dns =
282                 container_of ( timer, struct dns_request, timer );
283
284         if ( fail ) {
285                 dns_done ( dns, -ETIMEDOUT );
286         } else {
287                 dns_send_packet ( dns );
288         }
289 }
290
291 /**
292  * Receive new data
293  *
294  * @v socket            UDP socket
295  * @v data              DNS reply
296  * @v len               Length of DNS reply
297  * @ret rc              Return status code
298  */
299 static int dns_xfer_deliver_raw ( struct xfer_interface *socket,
300                                   const void *data, size_t len ) {
301         struct dns_request *dns =
302                 container_of ( socket, struct dns_request, socket );
303         const struct dns_header *reply = data;
304         union dns_rr_info *rr_info;
305         struct sockaddr_in *sin;
306         unsigned int qtype = dns->qinfo->qtype;
307
308         /* Sanity check */
309         if ( len < sizeof ( *reply ) ) {
310                 DBGC ( dns, "DNS %p received underlength packet length %zd\n",
311                        dns, len );
312                 return -EINVAL;
313         }
314
315         /* Check reply ID matches query ID */
316         if ( reply->id != dns->query.dns.id ) {
317                 DBGC ( dns, "DNS %p received unexpected reply ID %d "
318                        "(wanted %d)\n", dns, ntohs ( reply->id ),
319                        ntohs ( dns->query.dns.id ) );
320                 return -EINVAL;
321         }
322
323         DBGC ( dns, "DNS %p received reply ID %d\n", dns, ntohs ( reply->id ));
324
325         /* Stop the retry timer.  After this point, each code path
326          * must either restart the timer by calling dns_send_packet(),
327          * or mark the DNS operation as complete by calling
328          * dns_done()
329          */
330         stop_timer ( &dns->timer );
331
332         /* Search through response for useful answers.  Do this
333          * multiple times, to take advantage of useful nameservers
334          * which send us e.g. the CNAME *and* the A record for the
335          * pointed-to name.
336          */
337         while ( ( rr_info = dns_find_rr ( dns, reply ) ) ) {
338                 switch ( rr_info->common.type ) {
339
340                 case htons ( DNS_TYPE_A ):
341
342                         /* Found the target A record */
343                         DBGC ( dns, "DNS %p found address %s\n",
344                                dns, inet_ntoa ( rr_info->a.in_addr ) );
345                         sin = ( struct sockaddr_in * ) &dns->sa;
346                         sin->sin_family = AF_INET;
347                         sin->sin_addr = rr_info->a.in_addr;
348
349                         /* Mark operation as complete */
350                         dns_done ( dns, 0 );
351                         return 0;
352
353                 case htons ( DNS_TYPE_CNAME ):
354
355                         /* Found a CNAME record; update query and recurse */
356                         DBGC ( dns, "DNS %p found CNAME\n", dns );
357                         dns->qinfo = ( void * ) dns_decompress_name ( reply,
358                                                          rr_info->cname.cname,
359                                                          dns->query.payload );
360                         dns->qinfo->qtype = htons ( DNS_TYPE_A );
361                         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
362                         
363                         /* Terminate the operation if we recurse too far */
364                         if ( ++dns->recursion > DNS_MAX_CNAME_RECURSION ) {
365                                 DBGC ( dns, "DNS %p recursion exceeded\n",
366                                        dns );
367                                 dns_done ( dns, -ELOOP );
368                                 return 0;
369                         }
370                         break;
371
372                 default:
373                         DBGC ( dns, "DNS %p got unknown record type %d\n",
374                                dns, ntohs ( rr_info->common.type ) );
375                         break;
376                 }
377         }
378         
379         /* Determine what to do next based on the type of query we
380          * issued and the reponse we received
381          */
382         switch ( qtype ) {
383
384         case htons ( DNS_TYPE_A ):
385                 /* We asked for an A record and got nothing;
386                  * try the CNAME.
387                  */
388                 DBGC ( dns, "DNS %p found no A record; trying CNAME\n", dns );
389                 dns->qinfo->qtype = htons ( DNS_TYPE_CNAME );
390                 dns_send_packet ( dns );
391                 return 0;
392
393         case htons ( DNS_TYPE_CNAME ):
394                 /* We asked for a CNAME record.  If we got a response
395                  * (i.e. if the next A query is already set up), then
396                  * issue it, otherwise abort.
397                  */
398                 if ( dns->qinfo->qtype == htons ( DNS_TYPE_A ) ) {
399                         dns_send_packet ( dns );
400                         return 0;
401                 } else {
402                         DBGC ( dns, "DNS %p found no CNAME record\n", dns );
403                         dns_done ( dns, -ENXIO );
404                         return 0;
405                 }
406
407         default:
408                 assert ( 0 );
409                 dns_done ( dns, -EINVAL );
410                 return 0;
411         }
412 }
413
414 /**
415  * Receive new data
416  *
417  * @v socket            UDP socket
418  * @v rc                Reason for close
419  */
420 static void dns_xfer_close ( struct xfer_interface *socket, int rc ) {
421         struct dns_request *dns =
422                 container_of ( socket, struct dns_request, socket );
423
424         if ( ! rc )
425                 rc = -ECONNABORTED;
426
427         dns_done ( dns, rc );
428 }
429
430 /** DNS socket operations */
431 static struct xfer_interface_operations dns_socket_operations = {
432         .close          = dns_xfer_close,
433         .vredirect      = xfer_vopen,
434         .request        = ignore_xfer_request,
435         .seek           = ignore_xfer_seek,
436         .alloc_iob      = default_xfer_alloc_iob,
437         .deliver_iob    = xfer_deliver_as_raw,
438         .deliver_raw    = dns_xfer_deliver_raw,
439 };
440
441 /**
442  * Resolve name using DNS
443  *
444  * @v resolv            Name resolution interface
445  * @v name              Name to resolve
446  * @v sa                Socket address to fill in
447  * @ret rc              Return status code
448  */
449 static int dns_resolv ( struct resolv_interface *resolv,
450                         const char *name, struct sockaddr *sa ) {
451         struct dns_request *dns;
452         int rc;
453
454         /* Fail immediately if no DNS servers */
455         if ( ! nameserver.st_family ) {
456                 DBG ( "DNS not attempting to resolve \"%s\": "
457                       "no DNS servers\n", name );
458                 return -ENXIO;
459         }
460
461         /* Allocate DNS structure */
462         dns = malloc ( sizeof ( *dns ) );
463         if ( ! dns )
464                 return -ENOMEM;
465         memset ( dns, 0, sizeof ( *dns ) );
466         resolv_init ( &dns->resolv, &null_resolv_ops, &dns->refcnt );
467         xfer_init ( &dns->socket, &dns_socket_operations, &dns->refcnt );
468         dns->timer.expired = dns_timer_expired;
469         memcpy ( &dns->sa, sa, sizeof ( dns->sa ) );
470
471         /* Create query */
472         dns->query.dns.flags = htons ( DNS_FLAG_QUERY | DNS_FLAG_OPCODE_QUERY |
473                                        DNS_FLAG_RD );
474         dns->query.dns.qdcount = htons ( 1 );
475         dns->qinfo = ( void * ) dns_make_name ( name, dns->query.payload );
476         dns->qinfo->qtype = htons ( DNS_TYPE_A );
477         dns->qinfo->qclass = htons ( DNS_CLASS_IN );
478
479         /* Open UDP connection */
480         if ( ( rc = xfer_open_socket ( &dns->socket, SOCK_DGRAM,
481                                        ( struct sockaddr * ) &nameserver,
482                                        NULL ) ) != 0 ) {
483                 DBGC ( dns, "DNS %p could not open socket: %s\n",
484                        dns, strerror ( rc ) );
485                 goto err;
486         }
487
488         /* Send first DNS packet */
489         dns_send_packet ( dns );
490
491         /* Attach parent interface, mortalise self, and return */
492         resolv_plug_plug ( &dns->resolv, resolv );
493         ref_put ( &dns->refcnt );
494         return 0;       
495
496  err:
497         ref_put ( &dns->refcnt );
498         return rc;
499 }
500
501 /** DNS name resolver */
502 struct resolver dns_resolver __resolver ( RESOLV_NORMAL ) = {
503         .name = "DNS",
504         .resolv = dns_resolv,
505 };