Moved protocols to proto/
[people/xl0/gpxe.git] / src / proto / tftm.c
1 /**************************************************************************
2 *
3 *    proto_tftm.c -- Etherboot Multicast TFTP 
4 *    Written 2003-2003 by Timothy Legge <tlegge@rogers.com>
5 *
6 *    This program is free software; you can redistribute it and/or modify
7 *    it under the terms of the GNU General Public License as published by
8 *    the Free Software Foundation; either version 2 of the License, or
9 *    (at your option) any later version.
10 *
11 *    This program is distributed in the hope that it will be useful,
12 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *    GNU General Public License for more details.
15 *
16 *    You should have received a copy of the GNU General Public License
17 *    along with this program; if not, write to the Free Software
18 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19 *
20 *    This code is based on the DOWNLOAD_PROTO_TFTM section of 
21 *    Etherboot 5.3 core/nic.c and:
22 *    
23 *    Anselm Martin Hoffmeister's previous proto_tftm.c multicast work
24 *    Eric Biederman's proto_slam.c
25 *
26 *    $Revision$
27 *    $Author$
28 *    $Date$
29 *
30 *    REVISION HISTORY:
31 *    ================
32 *    09-07-2003 timlegge        Release Version, Capable of Multicast Booting
33 *    08-30-2003 timlegge        Initial version, Assumes consecutive blocks
34 *
35 *    Indent Options: indent -kr -i8
36 ***************************************************************************/
37
38 #ifdef DOWNLOAD_PROTO_TFTM
39 #include "etherboot.h"
40 #include "nic.h"
41
42 //#define TFTM_DEBUG
43 #ifdef TFTM_DEBUG
44 #define debug(x) printf x
45 #else
46 #define debug(x)
47 #endif
48 struct tftm_info {
49         in_addr server_ip;
50         in_addr multicast_ip;
51         in_addr local_ip;
52         uint16_t server_port;
53         uint16_t multicast_port;
54         uint16_t local_port;
55         int (*fnc) (unsigned char *, unsigned int, unsigned int, int);
56         int sent_nack;
57         const char *name;       /* Filename */
58 };
59
60 struct tftm_state {
61         unsigned long block_size;
62         unsigned long total_bytes;
63         unsigned long total_packets;
64         char ismaster;
65         unsigned long received_packets;
66         unsigned char *image;
67         unsigned char *bitmap;
68         char recvd_oack;
69 } state;
70
71 #define TFTM_PORT 1758
72 #define TFTM_MIN_PACKET 1024
73
74
75 int opt_get_multicast(struct tftp_t *tr, unsigned short *len,
76                       unsigned long *filesize, struct tftm_info *info);
77
78 static int await_tftm(int ival, void *ptr, unsigned short ptype __unused,
79                       struct iphdr *ip, struct udphdr *udp)
80 {
81         struct tftm_info *info = ptr;
82
83         /* Check for Unicast data being received */
84         if (ip->dest.s_addr == arptable[ARP_CLIENT].ipaddr.s_addr) {
85                 if (!udp) {
86                         return 0;
87                 }
88                 if (arptable[ARP_CLIENT].ipaddr.s_addr != ip->dest.s_addr)
89                         return 0;
90                 if (ntohs(udp->dest) != ival)
91                         return 0;
92
93                 return 1;       /* Unicast Data Received */
94         }
95
96         /* Also check for Multicast data being received */
97         if ((ip->dest.s_addr == info->multicast_ip.s_addr) &&
98             (ntohs(udp->dest) == info->multicast_port) &&
99             (nic.packetlen >= ETH_HLEN + sizeof(struct iphdr) +
100              sizeof(struct udphdr))) {
101                 return 1;       /* Multicast data received */
102         }
103         return 0;
104 }
105
106 int proto_tftm(struct tftm_info *info)
107 {
108         int retry = 0;
109         static unsigned short iport = 2000;
110         unsigned short oport = 0;
111         unsigned short len, block = 0, prevblock = 0;
112         struct tftp_t *tr;
113         struct tftpreq_t tp;
114         unsigned long filesize = 0;
115
116         state.image = 0;
117         state.bitmap = 0;
118
119         rx_qdrain();
120
121         /* Warning: the following assumes the layout of bootp_t.
122            But that's fixed by the IP, UDP and BOOTP specs. */
123
124         /* Send a tftm-request to the server */
125         tp.opcode = htons(TFTP_RRQ);    /* Const for "\0x0" "\0x1" =^= ReadReQuest */
126         len =
127             sizeof(tp.ip) + sizeof(tp.udp) + sizeof(tp.opcode) +
128             sprintf((char *) tp.u.rrq,
129                     "%s%coctet%cmulticast%c%cblksize%c%d%ctsize%c",
130                     info->name, 0, 0, 0, 0, 0, TFTM_MIN_PACKET, 0, 0) + 1;
131
132         if (!udp_transmit(arptable[ARP_SERVER].ipaddr.s_addr, ++iport,
133                           TFTM_PORT, len, &tp))
134                 return (0);
135
136         /* loop to listen for packets and to receive the file */
137         for (;;) {
138                 long timeout;
139 #ifdef  CONGESTED
140                 timeout =
141                     rfc2131_sleep_interval(block ? TFTP_REXMT : TIMEOUT,
142                                            retry);
143 #else
144                 timeout = rfc2131_sleep_interval(TIMEOUT, retry);
145 #endif
146                 /* Calls the await_reply function in nic.c which in turn calls
147                    await_tftm (1st parameter) as above */
148                 if (!await_reply(await_tftm, iport, info, timeout)) {
149                         if (!block && retry++ < MAX_TFTP_RETRIES) {     /* maybe initial request was lost */
150                                 if (!udp_transmit
151                                     (arptable[ARP_SERVER].ipaddr.s_addr,
152                                      ++iport, TFTM_PORT, len, &tp))
153                                         return (0);
154                                 continue;
155                         }
156 #ifdef  CONGESTED
157                         if (block && ((retry += TFTP_REXMT) < TFTP_TIMEOUT)) {  /* we resend our last ack */
158 #ifdef  MDEBUG
159                                 printf("<REXMT>\n");
160 #endif
161                                 debug(("Timed out receiving file"));
162                                 len =
163                                     sizeof(tp.ip) + sizeof(tp.udp) +
164                                     sizeof(tp.opcode) +
165                                     sprintf((char *) tp.u.rrq,
166                                             "%s%coctet%cmulticast%c%cblksize%c%d%ctsize%c",
167                                             info->name, 0, 0, 0, 0, 0,
168                                             TFTM_MIN_PACKET, 0, 0) + 1;
169
170                                 udp_transmit
171                                     (arptable[ARP_SERVER].ipaddr.s_addr,
172                                      ++iport, TFTM_PORT, len, &tp);
173                                         continue;
174                         }
175 #endif
176                         break;  /* timeout */
177                 }
178
179                 tr = (struct tftp_t *) &nic.packet[ETH_HLEN];
180
181                 if (tr->opcode == ntohs(TFTP_ERROR)) {
182                         printf("TFTP error %d (%s)\n",
183                                ntohs(tr->u.err.errcode), tr->u.err.errmsg);
184                         break;
185                 }
186
187                 if (tr->opcode == ntohs(TFTP_OACK)) {
188                         int i =
189                             opt_get_multicast(tr, &len, &filesize, info);
190
191                         if (i == 0 || (i != 7 && !state.recvd_oack)) {  /* Multicast unsupported */
192                                 /* Transmit an error message to the server to end the transmission */
193                                 printf
194                                     ("TFTM-Server doesn't understand options [blksize tsize multicast]\n");
195                                 tp.opcode = htons(TFTP_ERROR);
196                                 tp.u.err.errcode = 8;
197                                 /*
198                                  *      Warning: the following assumes the layout of bootp_t.
199                                  *      But that's fixed by the IP, UDP and BOOTP specs.
200                                  */
201                                 len =
202                                     sizeof(tp.ip) + sizeof(tp.udp) +
203                                     sizeof(tp.opcode) +
204                                     sizeof(tp.u.err.errcode) +
205                                     /*
206                                      *      Normally bad form to omit the format string, but in this case
207                                      *      the string we are copying from is fixed. sprintf is just being
208                                      *      used as a strcpy and strlen.
209                                      */
210                                     sprintf((char *) tp.u.err.errmsg,
211                                             "RFC2090 error") + 1;
212                                 udp_transmit(arptable[ARP_SERVER].ipaddr.
213                                              s_addr, iport,
214                                              ntohs(tr->udp.src), len, &tp);
215                                 block = tp.u.ack.block = 0;     /* this ensures, that */
216                                 /* the packet does not get */
217                                 /* processed as data! */
218                                 return (0);
219                         } else {
220                                 unsigned long bitmap_len;
221                                 /* */
222                                 if (!state.recvd_oack) {
223
224                                         state.total_packets =
225                                             1 + (filesize -
226                                                  (filesize %
227                                                   state.block_size)) /
228                                             state.block_size;
229                                         bitmap_len =
230                                             (state.total_packets + 7) / 8;
231                                         if (!state.image) {
232                                                 state.bitmap =
233                                                     allot(bitmap_len);
234                                                 state.image =
235                                                     allot(filesize);
236
237                                                 if ((unsigned long) state.
238                                                     image < 1024 * 1024) {
239                                                         printf
240                                                             ("ALERT: tftp filesize to large for available memory\n");
241                                                         return 0;
242                                                 }
243                                                 memset(state.bitmap, 0,
244                                                        bitmap_len);
245                                         }
246                                         /* If I'm running over multicast join the multicast group */
247                                         join_group(IGMP_SERVER,
248                                                    info->multicast_ip.
249                                                    s_addr);
250                                 }
251                                 state.recvd_oack = 1;
252                         }
253
254
255
256                 } else if (tr->opcode == htons(TFTP_DATA)) {
257                         unsigned long data_len;
258                         unsigned char *data;
259                         struct udphdr *udp;
260                         udp =
261                             (struct udphdr *) &nic.packet[ETH_HLEN +
262                                                           sizeof(struct
263                                                                  iphdr)];
264                         len =
265                             ntohs(tr->udp.len) - sizeof(struct udphdr) - 4;
266                         data =
267                             nic.packet + ETH_HLEN + sizeof(struct iphdr) +
268                             sizeof(struct udphdr) + 4;
269
270                         if (len > TFTM_MIN_PACKET)      /* shouldn't happen */
271                                 continue;       /* ignore it */
272
273                         block = ntohs(tp.u.ack.block = tr->u.data.block);
274
275                         if (block > state.total_packets) {
276                                 printf("ALERT: Invalid packet number\n");
277                                 continue;
278                         }
279
280                         /* Compute the expected data length */
281                         if (block != state.total_packets) {
282                                 data_len = state.block_size;
283                         } else {
284                                 data_len = filesize % state.block_size;
285                         }
286                         /* If the packet size is wrong drop the packet and then continue */
287                         if (ntohs(udp->len) !=
288                             (data_len + (data - (unsigned char *) udp))) {
289                                 printf
290                                     ("ALERT: udp packet is not the correct size: %d\n",
291                                      block);
292                                 continue;
293                         }
294                         if (nic.packetlen < data_len + (data - nic.packet)) {
295                                 printf
296                                     ("ALERT: Ethernet packet shorter than data_len: %d\n",
297                                      block);
298                                 continue;
299                         }
300
301                         if (data_len > state.block_size) {
302                                 data_len = state.block_size;
303                         }
304                         if (((state.
305                               bitmap[block >> 3] >> (block & 7)) & 1) ==
306                             0) {
307                                 /* Non duplicate packet */
308                                 state.bitmap[block >> 3] |=
309                                     (1 << (block & 7));
310                                 memcpy(state.image +
311                                        ((block - 1) * state.block_size),
312                                        data, data_len);
313                                 state.received_packets++;
314                         } else {
315
316 /*                              printf("<DUP>\n"); */
317                         }
318                 }
319
320                 else {          /* neither TFTP_OACK, TFTP_DATA nor TFTP_ERROR */
321                         break;
322                 }
323
324                 if (state.received_packets <= state.total_packets) {
325                         unsigned long b;
326                         unsigned long len;
327                         unsigned long max;
328                         int value;
329                         int last;
330
331                         /* Compute the last bit and store an inverted trailer */
332                         max = state.total_packets + 1;
333                         value =
334                             ((state.
335                               bitmap[(max - 1) >> 3] >> ((max -
336                                                           1) & 7)) & 1);
337                         value = !value;
338                         state.bitmap[max >> 3] &= ~(1 << (max & 7));
339                         state.bitmap[max >> 3] |= value << (max & 7);
340
341                         len = 0;
342                         last = 0;       /* Start with the received packets */
343                         for (b = 1; b <= max; b++) {
344                                 value =
345                                     (state.bitmap[b >> 3] >> (b & 7)) & 1;
346
347                                 if (value == 0) {
348                                         tp.u.ack.block = htons(b - 1);  /* Acknowledge the previous block */
349                                         break;
350                                 }
351                         }
352                 }
353                 if (state.ismaster) {
354                         tp.opcode = htons(TFTP_ACK);
355                         oport = ntohs(tr->udp.src);
356                         udp_transmit(arptable[ARP_SERVER].ipaddr.s_addr, iport, oport, TFTP_MIN_PACKET, &tp);   /* ack */
357                 }
358                 if (state.received_packets == state.total_packets) {
359                         /* If the client is finished and not the master,
360                          * ack the last packet */
361                         if (!state.ismaster) {
362                                 tp.opcode = htons(TFTP_ACK);
363                                 /* Ack Last packet to end xfer */
364                                 tp.u.ack.block = htons(state.total_packets);
365                                 oport = ntohs(tr->udp.src);
366                                 udp_transmit(arptable[ARP_SERVER].ipaddr.s_addr, iport, oport, TFTP_MIN_PACKET, &tp);   /* ack */
367                         }
368                         /* We are done get out */
369                         forget(state.bitmap);
370                         break;
371                 }
372
373                 if ((unsigned short) (block - prevblock) != 1) {
374                         /* Retransmission or OACK, don't process via callback
375                          * and don't change the value of prevblock.  */
376                         continue;
377                 }
378
379                 prevblock = block;
380                 retry = 0;      /* It's the right place to zero the timer? */
381
382         }
383         /* Leave the multicast group */
384         leave_group(IGMP_SERVER);
385         return info->fnc(state.image, 1, filesize, 1);
386 }
387
388 int url_tftm(const char *name,
389              int (*fnc) (unsigned char *, unsigned int, unsigned int, int))
390 {
391
392         int ret;
393         struct tftm_info info;
394
395         /* Set the defaults */
396         info.server_ip.s_addr = arptable[ARP_SERVER].ipaddr.s_addr;
397         info.server_port = TFTM_PORT;
398         info.local_ip.s_addr = arptable[ARP_CLIENT].ipaddr.s_addr;
399         info.local_port = TFTM_PORT;    /* Does not matter. So take tftm port too. */
400         info.multicast_ip.s_addr = info.local_ip.s_addr;
401         info.multicast_port = TFTM_PORT;
402         info.fnc = fnc;
403         state.ismaster = 0;
404         info.name = name;
405
406         state.block_size = 0;
407         state.total_bytes = 0;
408         state.total_packets = 0;
409         state.received_packets = 0;
410         state.image = 0;
411         state.bitmap = 0;
412         state.recvd_oack = 0;
413
414         if (name[0] != '/') {
415                 /* server ip given, so use it */
416                 name += inet_aton(info.name, &info.server_ip);
417                 /* No way to specify another port for now */
418         }
419         if (name[0] != '/') {
420                 printf("Bad tftm-URI: [%s]\n", info.name);
421                 return 0;
422         }
423
424         ret = proto_tftm(&info);
425
426         return ret;
427 }
428
429 /******************************
430 * Parse the multicast options
431 *******************************/
432 int opt_get_multicast(struct tftp_t *tr, unsigned short *len,
433                       unsigned long *filesize, struct tftm_info *info)
434 {
435         const char *p = tr->u.oack.data, *e = 0;
436         int i = 0;
437         *len = ntohs(tr->udp.len) - sizeof(struct udphdr) - 2;
438         if (*len > TFTM_MIN_PACKET)
439                 return -1;
440         e = p + *len;
441
442         while (*p != '\0' && p < e) {
443                 if (!strcasecmp("tsize", p)) {
444                         p += 6;
445                         if ((*filesize = strtoul(p, &p, 10)) > 0)
446                                 i |= 4;
447                         debug(("\n"));
448                         debug(("tsize=%d\n", *filesize));
449                         while (p < e && *p)
450                                 p++;
451                         if (p < e)
452                                 p++;
453                 } else if (!strcasecmp("blksize", p)) {
454                         i |= 2;
455                         p += 8;
456                         state.block_size = strtoul(p, &p, 10);
457                         if (state.block_size != TFTM_MIN_PACKET) {
458                                 printf
459                                     ("TFTM-Server rejected required transfer blocksize %d\n",
460                                      TFTM_MIN_PACKET);
461                                 return 0;
462                         }
463                         debug(("blksize=%d\n", state.block_size));
464                         while (p < e && *p)
465                                 p++;
466                         if (p < e)
467                                 p++;
468                 } else if (!strncmp(p, "multicast", 10)) {
469                         i |= 1;
470                         p += 10;
471                         debug(("multicast options: %s\n", p));
472                         p += 1 + inet_aton(p, &info->multicast_ip);
473                         debug(("multicast ip = %@\n", info->multicast_ip));
474                         info->multicast_port = strtoul(p, &p, 10);
475                         ++p;
476                         debug(("multicast port = %d\n",
477                                info->multicast_port));
478                         state.ismaster = (*p == '1' ? 1 : 0);
479                         debug(("multicast ismaster = %d\n",
480                                state.ismaster));
481                         while (p < e && *p)
482                                 p++;
483                         if (p < e)
484                                 p++;
485                 }
486         }
487         if (p > e)
488                 return 0;
489         return i;
490 }
491 #endif                          /* DOWNLOAD_PROTO_TFTP */