Quickly hacked to use a buffer rather than a processor.
[people/xl0/gpxe.git] / src / proto / tftm.c
1 /**************************************************************************
2 *
3 *    proto_tftm.c -- Etherboot Multicast TFTP 
4 *    Written 2003-2003 by Timothy Legge <tlegge@rogers.com>
5 *
6 *    This program is free software; you can redistribute it and/or modify
7 *    it under the terms of the GNU General Public License as published by
8 *    the Free Software Foundation; either version 2 of the License, or
9 *    (at your option) any later version.
10 *
11 *    This program is distributed in the hope that it will be useful,
12 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *    GNU General Public License for more details.
15 *
16 *    You should have received a copy of the GNU General Public License
17 *    along with this program; if not, write to the Free Software
18 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19 *
20 *    This code is based on the DOWNLOAD_PROTO_TFTM section of 
21 *    Etherboot 5.3 core/nic.c and:
22 *    
23 *    Anselm Martin Hoffmeister's previous proto_tftm.c multicast work
24 *    Eric Biederman's proto_slam.c
25 *
26 *    $Revision$
27 *    $Author$
28 *    $Date$
29 *
30 *    REVISION HISTORY:
31 *    ================
32 *    09-07-2003 timlegge        Release Version, Capable of Multicast Booting
33 *    08-30-2003 timlegge        Initial version, Assumes consecutive blocks
34 *
35 *    Indent Options: indent -kr -i8
36 ***************************************************************************/
37
38 /*
39  * IMPORTANT
40  *
41  * This file should be rewritten to avoid the use of a bitmap.  Our
42  * buffer routines can cope with being handed blocks in an arbitrary
43  * order, duplicate blocks, etc.  This code could be substantially
44  * simplified by taking advantage of these features.
45  *
46  */
47
48 #include "etherboot.h"
49 #include "proto.h"
50 #include "nic.h"
51
52 struct tftm_info {
53         struct sockaddr_in server;
54         struct sockaddr_in local;
55         struct sockaddr_in multicast;
56         int sent_nack;
57         const char *name;       /* Filename */
58 };
59
60 struct tftm_state {
61         unsigned long block_size;
62         unsigned long total_bytes;
63         unsigned long total_packets;
64         char ismaster;
65         unsigned long received_packets;
66         struct buffer *buffer;
67         unsigned char *image;
68         unsigned char *bitmap;
69         char recvd_oack;
70 } state;
71
72 #define TFTM_PORT 1758
73 #define TFTM_MIN_PACKET 1024
74
75
76 static int opt_get_multicast(struct tftp_t *tr, unsigned short *len,
77                              unsigned long *filesize, struct tftm_info *info);
78
79 static int await_tftm(int ival, void *ptr, unsigned short ptype __unused,
80                       struct iphdr *ip, struct udphdr *udp,
81                       struct tcphdr *tcp __unused)
82 {
83         struct tftm_info *info = ptr;
84
85         /* Check for Unicast data being received */
86         if (ip->dest.s_addr == arptable[ARP_CLIENT].ipaddr.s_addr) {
87                 if (!udp) {
88                         return 0;
89                 }
90                 if (arptable[ARP_CLIENT].ipaddr.s_addr != ip->dest.s_addr)
91                         return 0;
92                 if (ntohs(udp->dest) != ival)
93                         return 0;
94
95                 return 1;       /* Unicast Data Received */
96         }
97
98         /* Also check for Multicast data being received */
99         if ((ip->dest.s_addr == info->multicast.sin_addr.s_addr) &&
100             (ntohs(udp->dest) == info->multicast.sin_port) &&
101             (nic.packetlen >= ETH_HLEN + sizeof(struct iphdr) +
102              sizeof(struct udphdr))) {
103                 return 1;       /* Multicast data received */
104         }
105         return 0;
106 }
107
108 static int proto_tftm(struct tftm_info *info)
109 {
110         int retry = 0;
111         static unsigned short iport = 2000;
112         unsigned short oport = 0;
113         unsigned short len, block = 0, prevblock = 0;
114         struct tftp_t *tr;
115         struct tftpreq_t tp;
116         unsigned long filesize = 0;
117
118         state.image = 0;
119         state.bitmap = 0;
120
121         rx_qdrain();
122
123         /* Warning: the following assumes the layout of bootp_t.
124            But that's fixed by the IP, UDP and BOOTP specs. */
125
126         /* Send a tftm-request to the server */
127         tp.opcode = htons(TFTP_RRQ);    /* Const for "\0x0" "\0x1" =^= ReadReQuest */
128         len =
129             sizeof(tp.ip) + sizeof(tp.udp) + sizeof(tp.opcode) +
130             sprintf((char *) tp.u.rrq,
131                     "%s%coctet%cmulticast%c%cblksize%c%d%ctsize%c",
132                     info->name, 0, 0, 0, 0, 0, TFTM_MIN_PACKET, 0, 0) + 1;
133
134         if (!udp_transmit(info->server.sin_addr.s_addr, ++iport,
135                           info->server.sin_port, len, &tp))
136                 return (0);
137
138         /* loop to listen for packets and to receive the file */
139         for (;;) {
140                 long timeout;
141 #ifdef  CONGESTED
142                 timeout =
143                     rfc2131_sleep_interval(block ? TFTP_REXMT : TIMEOUT,
144                                            retry);
145 #else
146                 timeout = rfc2131_sleep_interval(TIMEOUT, retry);
147 #endif
148                 /* Calls the await_reply function in nic.c which in turn calls
149                    await_tftm (1st parameter) as above */
150                 if (!await_reply(await_tftm, iport, info, timeout)) {
151                         if (!block && retry++ < MAX_TFTP_RETRIES) {     /* maybe initial request was lost */
152                                 if (!udp_transmit
153                                     (info->server.sin_addr.s_addr, ++iport,
154                                      info->server.sin_port, len, &tp))
155                                         return (0);
156                                 continue;
157                         }
158 #ifdef  CONGESTED
159                         if (block && ((retry += TFTP_REXMT) < TFTP_TIMEOUT)) {  /* we resend our last ack */
160                                 DBG("Timed out receiving file");
161                                 len =
162                                     sizeof(tp.ip) + sizeof(tp.udp) +
163                                     sizeof(tp.opcode) +
164                                     sprintf((char *) tp.u.rrq,
165                                             "%s%coctet%cmulticast%c%cblksize%c%d%ctsize%c",
166                                             info->name, 0, 0, 0, 0, 0,
167                                             TFTM_MIN_PACKET, 0, 0) + 1;
168
169                                 udp_transmit
170                                         (info->server.sin_addr.s_addr,
171                                          ++iport, info->server.sin_port,
172                                          len, &tp);
173                                         continue;
174                         }
175 #endif
176                         break;  /* timeout */
177                 }
178
179                 tr = (struct tftp_t *) &nic.packet[ETH_HLEN];
180
181                 if (tr->opcode == ntohs(TFTP_ERROR)) {
182                         printf("TFTP error %d (%s)\n",
183                                ntohs(tr->u.err.errcode), tr->u.err.errmsg);
184                         break;
185                 }
186
187                 if (tr->opcode == ntohs(TFTP_OACK)) {
188                         int i =
189                             opt_get_multicast(tr, &len, &filesize, info);
190
191                         if (i == 0 || (i != 7 && !state.recvd_oack)) {  /* Multicast unsupported */
192                                 /* Transmit an error message to the server to end the transmission */
193                                 printf
194                                     ("TFTM-Server doesn't understand options [blksize tsize multicast]\n");
195                                 tp.opcode = htons(TFTP_ERROR);
196                                 tp.u.err.errcode = 8;
197                                 /*
198                                  *      Warning: the following assumes the layout of bootp_t.
199                                  *      But that's fixed by the IP, UDP and BOOTP specs.
200                                  */
201                                 len =
202                                     sizeof(tp.ip) + sizeof(tp.udp) +
203                                     sizeof(tp.opcode) +
204                                     sizeof(tp.u.err.errcode) +
205                                     /*
206                                      *      Normally bad form to omit the format string, but in this case
207                                      *      the string we are copying from is fixed. sprintf is just being
208                                      *      used as a strcpy and strlen.
209                                      */
210                                     sprintf((char *) tp.u.err.errmsg,
211                                             "RFC2090 error") + 1;
212                                 udp_transmit(info->server.sin_addr.s_addr,
213                                              iport, ntohs(tr->udp.src),
214                                              len, &tp);
215                                 block = tp.u.ack.block = 0;     /* this ensures, that */
216                                 /* the packet does not get */
217                                 /* processed as data! */
218                                 return (0);
219                         } else {
220                                 unsigned long bitmap_len;
221                                 /* */
222                                 if (!state.recvd_oack) {
223
224                                         state.total_packets =
225                                             1 + (filesize -
226                                                  (filesize %
227                                                   state.block_size)) /
228                                             state.block_size;
229                                         bitmap_len =
230                                             (state.total_packets + 7) / 8;
231                                         if (!state.image) {
232                                                 state.image = phys_to_virt ( state.buffer->start );
233                                                 state.bitmap = state.image + filesize;
234                                                 /* We don't yet use the buffer routines; fake it */
235                                                 state.buffer->fill = filesize;
236
237                                                 memset(state.bitmap, 0,
238                                                        bitmap_len);
239                                         }
240                                         /* If I'm running over multicast join the multicast group */
241                                         join_group(IGMP_SERVER,
242                                               info->multicast.sin_addr.s_addr);
243                                 }
244                                 state.recvd_oack = 1;
245                         }
246
247
248
249                 } else if (tr->opcode == htons(TFTP_DATA)) {
250                         unsigned long data_len;
251                         unsigned char *data;
252                         struct udphdr *udp;
253                         udp =
254                             (struct udphdr *) &nic.packet[ETH_HLEN +
255                                                           sizeof(struct
256                                                                  iphdr)];
257                         len =
258                             ntohs(tr->udp.len) - sizeof(struct udphdr) - 4;
259                         data =
260                             nic.packet + ETH_HLEN + sizeof(struct iphdr) +
261                             sizeof(struct udphdr) + 4;
262
263                         if (len > TFTM_MIN_PACKET)      /* shouldn't happen */
264                                 continue;       /* ignore it */
265
266                         block = ntohs(tp.u.ack.block = tr->u.data.block);
267
268                         if (block > state.total_packets) {
269                                 printf("ALERT: Invalid packet number\n");
270                                 continue;
271                         }
272
273                         /* Compute the expected data length */
274                         if (block != state.total_packets) {
275                                 data_len = state.block_size;
276                         } else {
277                                 data_len = filesize % state.block_size;
278                         }
279                         /* If the packet size is wrong drop the packet and then continue */
280                         if (ntohs(udp->len) !=
281                             (data_len + (data - (unsigned char *) udp))) {
282                                 printf
283                                     ("ALERT: udp packet is not the correct size: %d\n",
284                                      block);
285                                 continue;
286                         }
287                         if (nic.packetlen < data_len + (data - nic.packet)) {
288                                 printf
289                                     ("ALERT: Ethernet packet shorter than data_len: %d\n",
290                                      block);
291                                 continue;
292                         }
293
294                         if (data_len > state.block_size) {
295                                 data_len = state.block_size;
296                         }
297                         if (((state.
298                               bitmap[block >> 3] >> (block & 7)) & 1) ==
299                             0) {
300                                 /* Non duplicate packet */
301                                 state.bitmap[block >> 3] |=
302                                     (1 << (block & 7));
303                                 memcpy(state.image +
304                                        ((block - 1) * state.block_size),
305                                        data, data_len);
306                                 state.received_packets++;
307                         } else {
308
309 /*                              printf("<DUP>\n"); */
310                         }
311                 }
312
313                 else {          /* neither TFTP_OACK, TFTP_DATA nor TFTP_ERROR */
314                         break;
315                 }
316
317                 if (state.received_packets <= state.total_packets) {
318                         unsigned long b;
319                         unsigned long len;
320                         unsigned long max;
321                         int value;
322                         int last;
323
324                         /* Compute the last bit and store an inverted trailer */
325                         max = state.total_packets + 1;
326                         value =
327                             ((state.
328                               bitmap[(max - 1) >> 3] >> ((max -
329                                                           1) & 7)) & 1);
330                         value = !value;
331                         state.bitmap[max >> 3] &= ~(1 << (max & 7));
332                         state.bitmap[max >> 3] |= value << (max & 7);
333
334                         len = 0;
335                         last = 0;       /* Start with the received packets */
336                         for (b = 1; b <= max; b++) {
337                                 value =
338                                     (state.bitmap[b >> 3] >> (b & 7)) & 1;
339
340                                 if (value == 0) {
341                                         tp.u.ack.block = htons(b - 1);  /* Acknowledge the previous block */
342                                         break;
343                                 }
344                         }
345                 }
346                 if (state.ismaster) {
347                         tp.opcode = htons(TFTP_ACK);
348                         oport = ntohs(tr->udp.src);
349                         udp_transmit(info->server.sin_addr.s_addr, iport,
350                                      oport, TFTP_MIN_PACKET, &tp); /* ack */
351                 }
352                 if (state.received_packets == state.total_packets) {
353                         /* If the client is finished and not the master,
354                          * ack the last packet */
355                         if (!state.ismaster) {
356                                 tp.opcode = htons(TFTP_ACK);
357                                 /* Ack Last packet to end xfer */
358                                 tp.u.ack.block = htons(state.total_packets);
359                                 oport = ntohs(tr->udp.src);
360                                 udp_transmit(info->server.sin_addr.s_addr,
361                                              iport, oport,
362                                              TFTP_MIN_PACKET, &tp); /* ack */
363                         }
364                         /* We are done get out */
365                         break;
366                 }
367
368                 if ((unsigned short) (block - prevblock) != 1) {
369                         /* Retransmission or OACK, don't process via callback
370                          * and don't change the value of prevblock.  */
371                         continue;
372                 }
373
374                 prevblock = block;
375                 retry = 0;      /* It's the right place to zero the timer? */
376
377         }
378         /* Leave the multicast group */
379         leave_group(IGMP_SERVER);
380         return 1;
381 }
382
383 static int url_tftm ( char *url __unused, struct sockaddr_in *server,
384                       char *file, struct buffer *buffer ) {
385
386         int ret;
387         struct tftm_info info;
388
389         /* Set the defaults */
390         info.server = *server;
391         info.local.sin_addr.s_addr = arptable[ARP_CLIENT].ipaddr.s_addr;
392         info.local.sin_port = TFTM_PORT; /* Does not matter. */
393         info.multicast = info.local;
394         state.ismaster = 0;
395         info.name = file;
396
397         state.block_size = 0;
398         state.total_bytes = 0;
399         state.total_packets = 0;
400         state.received_packets = 0;
401         state.buffer = buffer;
402         state.image = 0;
403         state.bitmap = 0;
404         state.recvd_oack = 0;
405
406         if (file[0] != '/') {
407                 printf("Bad tftm-URI: [%s]\n", file);
408                 return 0;
409         }
410
411         ret = proto_tftm(&info);
412
413         return ret;
414 }
415
416 /******************************
417 * Parse the multicast options
418 *******************************/
419 static int opt_get_multicast(struct tftp_t *tr, unsigned short *len,
420                              unsigned long *filesize, struct tftm_info *info)
421 {
422         const char *p = tr->u.oack.data, *e = 0;
423         int i = 0;
424         *len = ntohs(tr->udp.len) - sizeof(struct udphdr) - 2;
425         if (*len > TFTM_MIN_PACKET)
426                 return -1;
427         e = p + *len;
428
429         while (*p != '\0' && p < e) {
430                 if (!strcasecmp("tsize", p)) {
431                         p += 6;
432                         if ((*filesize = strtoul(p, &p, 10)) > 0)
433                                 i |= 4;
434                         DBG("\n");
435                         DBG("tsize=%d\n", *filesize);
436                         while (p < e && *p)
437                                 p++;
438                         if (p < e)
439                                 p++;
440                 } else if (!strcasecmp("blksize", p)) {
441                         i |= 2;
442                         p += 8;
443                         state.block_size = strtoul(p, &p, 10);
444                         if (state.block_size != TFTM_MIN_PACKET) {
445                                 printf
446                                     ("TFTM-Server rejected required transfer blocksize %d\n",
447                                      TFTM_MIN_PACKET);
448                                 return 0;
449                         }
450                         DBG("blksize=%d\n", state.block_size);
451                         while (p < e && *p)
452                                 p++;
453                         if (p < e)
454                                 p++;
455                 } else if (!strncmp(p, "multicast", 10)) {
456                         i |= 1;
457                         p += 10;
458                         DBG("multicast options: %s\n", p);
459                         p += 1 + inet_aton(p, &info->multicast.sin_addr);
460                         DBG("multicast ip = %@\n", info->multicast_ip);
461                         info->multicast.sin_port = strtoul(p, &p, 10);
462                         ++p;
463                         DBG("multicast port = %d\n",
464                             info->multicast.sin_port);
465                         state.ismaster = (*p == '1' ? 1 : 0);
466                         DBG("multicast ismaster = %d\n",
467                                state.ismaster);
468                         while (p < e && *p)
469                                 p++;
470                         if (p < e)
471                                 p++;
472                 }
473         }
474         if (p > e)
475                 return 0;
476         return i;
477 }
478
479 static struct protocol tftm_protocol __protocol = {
480         .name = "x-tftm",
481         .default_port = TFTM_PORT,
482         .load = url_tftm,
483 };