Protocol names are x-slam and x-tftm
[people/xl0/gpxe.git] / src / proto / tftm.c
1 /**************************************************************************
2 *
3 *    proto_tftm.c -- Etherboot Multicast TFTP 
4 *    Written 2003-2003 by Timothy Legge <tlegge@rogers.com>
5 *
6 *    This program is free software; you can redistribute it and/or modify
7 *    it under the terms of the GNU General Public License as published by
8 *    the Free Software Foundation; either version 2 of the License, or
9 *    (at your option) any later version.
10 *
11 *    This program is distributed in the hope that it will be useful,
12 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *    GNU General Public License for more details.
15 *
16 *    You should have received a copy of the GNU General Public License
17 *    along with this program; if not, write to the Free Software
18 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19 *
20 *    This code is based on the DOWNLOAD_PROTO_TFTM section of 
21 *    Etherboot 5.3 core/nic.c and:
22 *    
23 *    Anselm Martin Hoffmeister's previous proto_tftm.c multicast work
24 *    Eric Biederman's proto_slam.c
25 *
26 *    $Revision$
27 *    $Author$
28 *    $Date$
29 *
30 *    REVISION HISTORY:
31 *    ================
32 *    09-07-2003 timlegge        Release Version, Capable of Multicast Booting
33 *    08-30-2003 timlegge        Initial version, Assumes consecutive blocks
34 *
35 *    Indent Options: indent -kr -i8
36 ***************************************************************************/
37
38 #include "etherboot.h"
39 #include "proto.h"
40 #include "nic.h"
41
42 struct tftm_info {
43         struct sockaddr_in server;
44         struct sockaddr_in local;
45         struct sockaddr_in multicast;
46         int ( * process ) ( unsigned char *data,
47                             unsigned int blocknum,
48                             unsigned int len, int eof );
49         int sent_nack;
50         const char *name;       /* Filename */
51 };
52
53 struct tftm_state {
54         unsigned long block_size;
55         unsigned long total_bytes;
56         unsigned long total_packets;
57         char ismaster;
58         unsigned long received_packets;
59         unsigned char *image;
60         unsigned char *bitmap;
61         char recvd_oack;
62 } state;
63
64 #define TFTM_PORT 1758
65 #define TFTM_MIN_PACKET 1024
66
67
68 static int opt_get_multicast(struct tftp_t *tr, unsigned short *len,
69                              unsigned long *filesize, struct tftm_info *info);
70
71 static int await_tftm(int ival, void *ptr, unsigned short ptype __unused,
72                       struct iphdr *ip, struct udphdr *udp,
73                       struct tcphdr *tcp __unused)
74 {
75         struct tftm_info *info = ptr;
76
77         /* Check for Unicast data being received */
78         if (ip->dest.s_addr == arptable[ARP_CLIENT].ipaddr.s_addr) {
79                 if (!udp) {
80                         return 0;
81                 }
82                 if (arptable[ARP_CLIENT].ipaddr.s_addr != ip->dest.s_addr)
83                         return 0;
84                 if (ntohs(udp->dest) != ival)
85                         return 0;
86
87                 return 1;       /* Unicast Data Received */
88         }
89
90         /* Also check for Multicast data being received */
91         if ((ip->dest.s_addr == info->multicast.sin_addr.s_addr) &&
92             (ntohs(udp->dest) == info->multicast.sin_port) &&
93             (nic.packetlen >= ETH_HLEN + sizeof(struct iphdr) +
94              sizeof(struct udphdr))) {
95                 return 1;       /* Multicast data received */
96         }
97         return 0;
98 }
99
100 static int proto_tftm(struct tftm_info *info)
101 {
102         int retry = 0;
103         static unsigned short iport = 2000;
104         unsigned short oport = 0;
105         unsigned short len, block = 0, prevblock = 0;
106         struct tftp_t *tr;
107         struct tftpreq_t tp;
108         unsigned long filesize = 0;
109
110         state.image = 0;
111         state.bitmap = 0;
112
113         rx_qdrain();
114
115         /* Warning: the following assumes the layout of bootp_t.
116            But that's fixed by the IP, UDP and BOOTP specs. */
117
118         /* Send a tftm-request to the server */
119         tp.opcode = htons(TFTP_RRQ);    /* Const for "\0x0" "\0x1" =^= ReadReQuest */
120         len =
121             sizeof(tp.ip) + sizeof(tp.udp) + sizeof(tp.opcode) +
122             sprintf((char *) tp.u.rrq,
123                     "%s%coctet%cmulticast%c%cblksize%c%d%ctsize%c",
124                     info->name, 0, 0, 0, 0, 0, TFTM_MIN_PACKET, 0, 0) + 1;
125
126         if (!udp_transmit(info->server.sin_addr.s_addr, ++iport,
127                           info->server.sin_port, len, &tp))
128                 return (0);
129
130         /* loop to listen for packets and to receive the file */
131         for (;;) {
132                 long timeout;
133 #ifdef  CONGESTED
134                 timeout =
135                     rfc2131_sleep_interval(block ? TFTP_REXMT : TIMEOUT,
136                                            retry);
137 #else
138                 timeout = rfc2131_sleep_interval(TIMEOUT, retry);
139 #endif
140                 /* Calls the await_reply function in nic.c which in turn calls
141                    await_tftm (1st parameter) as above */
142                 if (!await_reply(await_tftm, iport, info, timeout)) {
143                         if (!block && retry++ < MAX_TFTP_RETRIES) {     /* maybe initial request was lost */
144                                 if (!udp_transmit
145                                     (info->server.sin_addr.s_addr, ++iport,
146                                      info->server.sin_port, len, &tp))
147                                         return (0);
148                                 continue;
149                         }
150 #ifdef  CONGESTED
151                         if (block && ((retry += TFTP_REXMT) < TFTP_TIMEOUT)) {  /* we resend our last ack */
152                                 DBG("Timed out receiving file");
153                                 len =
154                                     sizeof(tp.ip) + sizeof(tp.udp) +
155                                     sizeof(tp.opcode) +
156                                     sprintf((char *) tp.u.rrq,
157                                             "%s%coctet%cmulticast%c%cblksize%c%d%ctsize%c",
158                                             info->name, 0, 0, 0, 0, 0,
159                                             TFTM_MIN_PACKET, 0, 0) + 1;
160
161                                 udp_transmit
162                                         (info->server.sin_addr.s_addr,
163                                          ++iport, info->server.sin_port,
164                                          len, &tp);
165                                         continue;
166                         }
167 #endif
168                         break;  /* timeout */
169                 }
170
171                 tr = (struct tftp_t *) &nic.packet[ETH_HLEN];
172
173                 if (tr->opcode == ntohs(TFTP_ERROR)) {
174                         printf("TFTP error %d (%s)\n",
175                                ntohs(tr->u.err.errcode), tr->u.err.errmsg);
176                         break;
177                 }
178
179                 if (tr->opcode == ntohs(TFTP_OACK)) {
180                         int i =
181                             opt_get_multicast(tr, &len, &filesize, info);
182
183                         if (i == 0 || (i != 7 && !state.recvd_oack)) {  /* Multicast unsupported */
184                                 /* Transmit an error message to the server to end the transmission */
185                                 printf
186                                     ("TFTM-Server doesn't understand options [blksize tsize multicast]\n");
187                                 tp.opcode = htons(TFTP_ERROR);
188                                 tp.u.err.errcode = 8;
189                                 /*
190                                  *      Warning: the following assumes the layout of bootp_t.
191                                  *      But that's fixed by the IP, UDP and BOOTP specs.
192                                  */
193                                 len =
194                                     sizeof(tp.ip) + sizeof(tp.udp) +
195                                     sizeof(tp.opcode) +
196                                     sizeof(tp.u.err.errcode) +
197                                     /*
198                                      *      Normally bad form to omit the format string, but in this case
199                                      *      the string we are copying from is fixed. sprintf is just being
200                                      *      used as a strcpy and strlen.
201                                      */
202                                     sprintf((char *) tp.u.err.errmsg,
203                                             "RFC2090 error") + 1;
204                                 udp_transmit(info->server.sin_addr.s_addr,
205                                              iport, ntohs(tr->udp.src),
206                                              len, &tp);
207                                 block = tp.u.ack.block = 0;     /* this ensures, that */
208                                 /* the packet does not get */
209                                 /* processed as data! */
210                                 return (0);
211                         } else {
212                                 unsigned long bitmap_len;
213                                 /* */
214                                 if (!state.recvd_oack) {
215
216                                         state.total_packets =
217                                             1 + (filesize -
218                                                  (filesize %
219                                                   state.block_size)) /
220                                             state.block_size;
221                                         bitmap_len =
222                                             (state.total_packets + 7) / 8;
223                                         if (!state.image) {
224                                                 state.bitmap =
225                                                     allot(bitmap_len);
226                                                 state.image =
227                                                     allot(filesize);
228
229                                                 if ((unsigned long) state.
230                                                     image < 1024 * 1024) {
231                                                         printf
232                                                             ("ALERT: tftp filesize to large for available memory\n");
233                                                         return 0;
234                                                 }
235                                                 memset(state.bitmap, 0,
236                                                        bitmap_len);
237                                         }
238                                         /* If I'm running over multicast join the multicast group */
239                                         join_group(IGMP_SERVER,
240                                               info->multicast.sin_addr.s_addr);
241                                 }
242                                 state.recvd_oack = 1;
243                         }
244
245
246
247                 } else if (tr->opcode == htons(TFTP_DATA)) {
248                         unsigned long data_len;
249                         unsigned char *data;
250                         struct udphdr *udp;
251                         udp =
252                             (struct udphdr *) &nic.packet[ETH_HLEN +
253                                                           sizeof(struct
254                                                                  iphdr)];
255                         len =
256                             ntohs(tr->udp.len) - sizeof(struct udphdr) - 4;
257                         data =
258                             nic.packet + ETH_HLEN + sizeof(struct iphdr) +
259                             sizeof(struct udphdr) + 4;
260
261                         if (len > TFTM_MIN_PACKET)      /* shouldn't happen */
262                                 continue;       /* ignore it */
263
264                         block = ntohs(tp.u.ack.block = tr->u.data.block);
265
266                         if (block > state.total_packets) {
267                                 printf("ALERT: Invalid packet number\n");
268                                 continue;
269                         }
270
271                         /* Compute the expected data length */
272                         if (block != state.total_packets) {
273                                 data_len = state.block_size;
274                         } else {
275                                 data_len = filesize % state.block_size;
276                         }
277                         /* If the packet size is wrong drop the packet and then continue */
278                         if (ntohs(udp->len) !=
279                             (data_len + (data - (unsigned char *) udp))) {
280                                 printf
281                                     ("ALERT: udp packet is not the correct size: %d\n",
282                                      block);
283                                 continue;
284                         }
285                         if (nic.packetlen < data_len + (data - nic.packet)) {
286                                 printf
287                                     ("ALERT: Ethernet packet shorter than data_len: %d\n",
288                                      block);
289                                 continue;
290                         }
291
292                         if (data_len > state.block_size) {
293                                 data_len = state.block_size;
294                         }
295                         if (((state.
296                               bitmap[block >> 3] >> (block & 7)) & 1) ==
297                             0) {
298                                 /* Non duplicate packet */
299                                 state.bitmap[block >> 3] |=
300                                     (1 << (block & 7));
301                                 memcpy(state.image +
302                                        ((block - 1) * state.block_size),
303                                        data, data_len);
304                                 state.received_packets++;
305                         } else {
306
307 /*                              printf("<DUP>\n"); */
308                         }
309                 }
310
311                 else {          /* neither TFTP_OACK, TFTP_DATA nor TFTP_ERROR */
312                         break;
313                 }
314
315                 if (state.received_packets <= state.total_packets) {
316                         unsigned long b;
317                         unsigned long len;
318                         unsigned long max;
319                         int value;
320                         int last;
321
322                         /* Compute the last bit and store an inverted trailer */
323                         max = state.total_packets + 1;
324                         value =
325                             ((state.
326                               bitmap[(max - 1) >> 3] >> ((max -
327                                                           1) & 7)) & 1);
328                         value = !value;
329                         state.bitmap[max >> 3] &= ~(1 << (max & 7));
330                         state.bitmap[max >> 3] |= value << (max & 7);
331
332                         len = 0;
333                         last = 0;       /* Start with the received packets */
334                         for (b = 1; b <= max; b++) {
335                                 value =
336                                     (state.bitmap[b >> 3] >> (b & 7)) & 1;
337
338                                 if (value == 0) {
339                                         tp.u.ack.block = htons(b - 1);  /* Acknowledge the previous block */
340                                         break;
341                                 }
342                         }
343                 }
344                 if (state.ismaster) {
345                         tp.opcode = htons(TFTP_ACK);
346                         oport = ntohs(tr->udp.src);
347                         udp_transmit(info->server.sin_addr.s_addr, iport,
348                                      oport, TFTP_MIN_PACKET, &tp); /* ack */
349                 }
350                 if (state.received_packets == state.total_packets) {
351                         /* If the client is finished and not the master,
352                          * ack the last packet */
353                         if (!state.ismaster) {
354                                 tp.opcode = htons(TFTP_ACK);
355                                 /* Ack Last packet to end xfer */
356                                 tp.u.ack.block = htons(state.total_packets);
357                                 oport = ntohs(tr->udp.src);
358                                 udp_transmit(info->server.sin_addr.s_addr,
359                                              iport, oport,
360                                              TFTP_MIN_PACKET, &tp); /* ack */
361                         }
362                         /* We are done get out */
363                         forget(state.bitmap);
364                         break;
365                 }
366
367                 if ((unsigned short) (block - prevblock) != 1) {
368                         /* Retransmission or OACK, don't process via callback
369                          * and don't change the value of prevblock.  */
370                         continue;
371                 }
372
373                 prevblock = block;
374                 retry = 0;      /* It's the right place to zero the timer? */
375
376         }
377         /* Leave the multicast group */
378         leave_group(IGMP_SERVER);
379         return info->process(state.image, 1, filesize, 1);
380 }
381
382 static int url_tftm ( char *url __unused,
383                       struct sockaddr_in *server,
384                       char *file,
385                       int ( * process ) ( unsigned char *data,
386                                           unsigned int blocknum,
387                                           unsigned int len, int eof ) ) {
388
389         int ret;
390         struct tftm_info info;
391
392         /* Set the defaults */
393         info.server = *server;
394         if ( ! info.server.sin_port )
395                 info.server.sin_port = TFTM_PORT;
396         info.local.sin_addr.s_addr = arptable[ARP_CLIENT].ipaddr.s_addr;
397         info.local.sin_port = TFTM_PORT; /* Does not matter. */
398         info.multicast = info.local;
399         info.process = process;
400         state.ismaster = 0;
401         info.name = file;
402
403         state.block_size = 0;
404         state.total_bytes = 0;
405         state.total_packets = 0;
406         state.received_packets = 0;
407         state.image = 0;
408         state.bitmap = 0;
409         state.recvd_oack = 0;
410
411         if (file[0] != '/') {
412                 printf("Bad tftm-URI: [%s]\n", file);
413                 return 0;
414         }
415
416         ret = proto_tftm(&info);
417
418         return ret;
419 }
420
421 /******************************
422 * Parse the multicast options
423 *******************************/
424 static int opt_get_multicast(struct tftp_t *tr, unsigned short *len,
425                              unsigned long *filesize, struct tftm_info *info)
426 {
427         const char *p = tr->u.oack.data, *e = 0;
428         int i = 0;
429         *len = ntohs(tr->udp.len) - sizeof(struct udphdr) - 2;
430         if (*len > TFTM_MIN_PACKET)
431                 return -1;
432         e = p + *len;
433
434         while (*p != '\0' && p < e) {
435                 if (!strcasecmp("tsize", p)) {
436                         p += 6;
437                         if ((*filesize = strtoul(p, &p, 10)) > 0)
438                                 i |= 4;
439                         DBG("\n");
440                         DBG("tsize=%d\n", *filesize);
441                         while (p < e && *p)
442                                 p++;
443                         if (p < e)
444                                 p++;
445                 } else if (!strcasecmp("blksize", p)) {
446                         i |= 2;
447                         p += 8;
448                         state.block_size = strtoul(p, &p, 10);
449                         if (state.block_size != TFTM_MIN_PACKET) {
450                                 printf
451                                     ("TFTM-Server rejected required transfer blocksize %d\n",
452                                      TFTM_MIN_PACKET);
453                                 return 0;
454                         }
455                         DBG("blksize=%d\n", state.block_size);
456                         while (p < e && *p)
457                                 p++;
458                         if (p < e)
459                                 p++;
460                 } else if (!strncmp(p, "multicast", 10)) {
461                         i |= 1;
462                         p += 10;
463                         DBG("multicast options: %s\n", p);
464                         p += 1 + inet_aton(p, &info->multicast.sin_addr);
465                         DBG("multicast ip = %@\n", info->multicast_ip);
466                         info->multicast.sin_port = strtoul(p, &p, 10);
467                         ++p;
468                         DBG("multicast port = %d\n",
469                             info->multicast.sin_port);
470                         state.ismaster = (*p == '1' ? 1 : 0);
471                         DBG("multicast ismaster = %d\n",
472                                state.ismaster);
473                         while (p < e && *p)
474                                 p++;
475                         if (p < e)
476                                 p++;
477                 }
478         }
479         if (p > e)
480                 return 0;
481         return i;
482 }
483
484 static struct protocol tftm_protocol __protocol = {
485         "x-tftm", url_tftm
486 };