Protocol names are x-slam and x-tftm
[people/xl0/gpxe.git] / src / proto / slam.c
1 #include "etherboot.h"
2 #include "proto.h"
3 #include "nic.h"
4
5 #define SLAM_PORT 10000
6 #define SLAM_MULTICAST_IP ((239<<24)|(255<<16)|(1<<8)|(1<<0))
7 #define SLAM_MULTICAST_PORT 10000
8 #define SLAM_LOCAL_PORT 10000
9
10 /* Set the timeout intervals to at least 1 second so
11  * on a 100Mbit ethernet can receive 10000 packets
12  * in one second.  
13  *
14  * The only case that is likely to trigger all of the nodes
15  * firing a nack packet is a slow server.  The odds of this
16  * happening could be reduced being slightly smarter and utilizing 
17  * the multicast channels for nacks.   But that only improves the odds
18  * it doesn't improve the worst case.  So unless this proves to be
19  * a common case having the control data going unicast should increase
20  * the odds of the data not being dropped.  
21  *
22  * When doing exponential backoff we increase just the timeout
23  * interval and not the base to optimize for throughput.  This is only
24  * expected to happen when the server is down.  So having some nodes
25  * pinging immediately should get the transmission restarted quickly after a
26  * server restart.  The host nic won't be to baddly swamped because of
27  * the random distribution of the nodes.
28  *
29  */
30 #define SLAM_INITIAL_MIN_TIMEOUT      (TICKS_PER_SEC/3)
31 #define SLAM_INITIAL_TIMEOUT_INTERVAL (TICKS_PER_SEC)
32 #define SLAM_BASE_MIN_TIMEOUT         (2*TICKS_PER_SEC)
33 #define SLAM_BASE_TIMEOUT_INTERVAL    (4*TICKS_PER_SEC)
34 #define SLAM_BACKOFF_LIMIT 5
35 #define SLAM_MAX_RETRIES 20
36
37 /*** Packets Formats ***
38  * Data Packet:
39  *   transaction
40  *   total bytes
41  *   block size
42  *   packet #
43  *   data
44  *
45  * Status Request Packet
46  *   transaction
47  *   total bytes
48  *   block size
49  *
50  * Status Packet
51  *   received packets
52  *   requested packets
53  *   received packets
54  *   requested packets
55  *   ...
56  *   received packets
57  *   requested packtes
58  *   0
59  */
60
61 #define MAX_HDR (7 + 7 + 7) /* transaction, total size, block size */
62 #define MIN_HDR (1 + 1 + 1) /* transactino, total size, block size */
63
64 #define MAX_SLAM_REQUEST MAX_HDR
65 #define MIN_SLAM_REQUEST MIN_HDR
66
67 #define MIN_SLAM_DATA (MIN_HDR + 1)
68
69 static struct slam_nack {
70         struct iphdr ip;
71         struct udphdr udp;
72         unsigned char data[ETH_MAX_MTU - 
73                 (sizeof(struct iphdr) + sizeof(struct udphdr))];
74 } nack;
75
76 struct slam_state {
77         unsigned char hdr[MAX_HDR];
78         unsigned long hdr_len;
79         unsigned long block_size;
80         unsigned long total_bytes;
81         unsigned long total_packets;
82
83         unsigned long received_packets;
84
85         unsigned char *image;
86         unsigned char *bitmap;
87 } state;
88
89
90 static void init_slam_state(void)
91 {
92         state.hdr_len = sizeof(state.hdr);
93         memset(state.hdr, 0, state.hdr_len);
94         state.block_size = 0;
95         state.total_packets = 0;
96
97         state.received_packets = 0;
98
99         state.image = 0;
100         state.bitmap = 0;
101 }
102
103 struct slam_info {
104         struct sockaddr_in server;
105         struct sockaddr_in local;
106         struct sockaddr_in multicast;
107         int ( * process ) ( unsigned char *data,
108                             unsigned int blocknum,
109                             unsigned int len, int eof );
110         int sent_nack;
111 };
112
113 #define SLAM_TIMEOUT 0
114 #define SLAM_REQUEST 1
115 #define SLAM_DATA    2
116 static int await_slam(int ival __unused, void *ptr,
117                       unsigned short ptype __unused, struct iphdr *ip,
118                       struct udphdr *udp, struct tcphdr *tcp __unused)
119 {
120         struct slam_info *info = ptr;
121         if (!udp) {
122                 return 0;
123         }
124         /* I can receive two kinds of packets here, a multicast data packet,
125          * or a unicast request for information 
126          */
127         /* Check for a data request packet */
128         if ((ip->dest.s_addr == arptable[ARP_CLIENT].ipaddr.s_addr) &&
129                 (ntohs(udp->dest) == info->local.sin_port) && 
130                 (nic.packetlen >= 
131                         ETH_HLEN + 
132                         sizeof(struct iphdr) + 
133                         sizeof(struct udphdr) +
134                         MIN_SLAM_REQUEST)) {
135                 return SLAM_REQUEST;
136         }
137         /* Check for a multicast data packet */
138         if ((ip->dest.s_addr == info->multicast.sin_addr.s_addr) &&
139                 (ntohs(udp->dest) == info->multicast.sin_port) &&
140                 (nic.packetlen >= 
141                         ETH_HLEN + 
142                         sizeof(struct iphdr) + 
143                         sizeof(struct udphdr) +
144                         MIN_SLAM_DATA)) {
145                 return SLAM_DATA;
146         }
147 #if 0
148         printf("#");
149         printf("dest: %@ port: %d len: %d\n", 
150                 ip->dest.s_addr, ntohs(udp->dest), nic.packetlen);
151 #endif
152         return 0;
153                 
154 }
155
156 static int slam_encode(
157         unsigned char **ptr, unsigned char *end, unsigned long value)
158 {
159         unsigned char *data = *ptr;
160         int bytes;
161         bytes = sizeof(value);
162         while ((bytes > 0) && ((0xff & (value >> ((bytes -1)<<3))) == 0)) {
163                 bytes--;
164         }
165         if (bytes <= 0) {
166                 bytes = 1;
167         }
168         if (data + bytes >= end) {
169                 return -1;
170         }
171         if ((0xe0 & (value >> ((bytes -1)<<3))) == 0) {
172                 /* packed together */
173                 *data = (bytes << 5) | (value >> ((bytes -1)<<3));
174         } else {
175                 bytes++;
176                 *data = (bytes << 5);
177         }
178         bytes--;
179         data++;
180         while(bytes) {
181                 *(data++) = 0xff & (value >> ((bytes -1)<<3));
182                 bytes--;
183         }
184         *ptr = data;
185         return 0;
186 }
187
188 static int slam_skip(unsigned char **ptr, unsigned char *end) 
189 {
190         int bytes;
191         if (*ptr >= end) {
192                 return -1;
193         }
194         bytes = ((**ptr) >> 5) & 7;
195         if (bytes == 0) {
196                 return -1;
197         }
198         if (*ptr + bytes >= end) {
199                 return -1;
200         }
201         (*ptr) += bytes;
202         return 0;
203         
204 }
205
206 static unsigned long slam_decode(unsigned char **ptr, unsigned char *end,
207                                  int *err)
208 {
209         unsigned long value;
210         unsigned bytes;
211         if (*ptr >= end) {
212                 *err = -1;
213         }
214         bytes = ((**ptr) >> 5) & 7;
215         if ((bytes == 0) || (bytes > sizeof(unsigned long))) {
216                 *err = -1;
217                 return 0;
218         }
219         if ((*ptr) + bytes >= end) {
220                 *err =  -1;
221         }
222         value = (**ptr) & 0x1f;
223         bytes--;
224         (*ptr)++;
225         while(bytes) {
226                 value <<= 8;
227                 value |= **ptr;
228                 (*ptr)++;
229                 bytes--;
230         }
231         return value;
232 }
233
234
235 static long slam_sleep_interval(int exp)
236 {
237         long range;
238         long divisor;
239         long interval;
240         range = SLAM_BASE_TIMEOUT_INTERVAL;
241         if (exp < 0) { 
242                 divisor = RAND_MAX/SLAM_INITIAL_TIMEOUT_INTERVAL;
243         } else {
244                 if (exp > SLAM_BACKOFF_LIMIT) 
245                         exp = SLAM_BACKOFF_LIMIT;
246                 divisor = RAND_MAX/(range << exp);
247         }
248         interval = random()/divisor;
249         if (exp < 0) {
250                 interval += SLAM_INITIAL_MIN_TIMEOUT;
251         } else {
252                 interval += SLAM_BASE_MIN_TIMEOUT;
253         }
254         return interval;
255 }
256
257
258 static unsigned char *reinit_slam_state(
259         unsigned char *header, unsigned char *end)
260 {
261         unsigned long total_bytes;
262         unsigned long block_size;
263
264         unsigned long bitmap_len;
265         unsigned long max_packet_len;
266         unsigned char *data;
267         int err;
268
269 #if 0
270         printf("reinit\n");
271 #endif
272         data = header;
273
274         state.hdr_len = 0;
275         err = slam_skip(&data, end); /* transaction id */
276         total_bytes = slam_decode(&data, end, &err);
277         block_size  = slam_decode(&data, end, &err);
278         if (err) {
279                 printf("ALERT: slam size out of range\n");
280                 return 0;
281         }
282         state.block_size = block_size;
283         state.total_bytes = total_bytes;
284         state.total_packets = (total_bytes + block_size - 1)/block_size;
285         state.hdr_len = data - header;
286         state.received_packets = 0;
287
288         data = state.hdr;
289         slam_encode(&data, &state.hdr[sizeof(state.hdr)], state.total_packets);
290         max_packet_len = data - state.hdr;
291         memcpy(state.hdr, header, state.hdr_len);
292         
293 #if 0
294         printf("block_size:     %ld\n", block_size);
295         printf("total_bytes:    %ld\n", total_bytes);
296         printf("total_packets:  %ld\n", state.total_packets);
297         printf("hdr_len:        %ld\n", state.hdr_len);
298         printf("max_packet_len: %ld\n", max_packet_len);
299 #endif
300
301         if (state.block_size > ETH_MAX_MTU - (
302                 sizeof(struct iphdr) + sizeof(struct udphdr) +
303                 state.hdr_len + max_packet_len)) {
304                 printf("ALERT: slam blocksize to large\n");
305                 return 0;
306         }
307         if (state.bitmap) {
308                 forget(state.bitmap);
309         }
310         bitmap_len   = (state.total_packets + 1 + 7)/8;
311         state.bitmap = allot(bitmap_len);
312         state.image  = allot(total_bytes);
313         if ((unsigned long)state.image < 1024*1024) {
314                 printf("ALERT: slam filesize to large for available memory\n");
315                 return 0;
316         }
317         memset(state.bitmap, 0, bitmap_len);
318
319         return header + state.hdr_len;
320 }
321
322 static int slam_recv_data(unsigned char *data)
323 {
324         unsigned long packet;
325         unsigned long data_len;
326         int err;
327         struct udphdr *udp;
328         udp = (struct udphdr *)&nic.packet[ETH_HLEN + sizeof(struct iphdr)];
329         err = 0;
330         packet = slam_decode(&data, &nic.packet[nic.packetlen], &err);
331         if (err || (packet > state.total_packets)) {
332                 printf("ALERT: Invalid packet number\n");
333                 return 0;
334         }
335         /* Compute the expected data length */
336         if (packet != state.total_packets -1) {
337                 data_len = state.block_size;
338         } else {
339                 data_len = state.total_bytes % state.block_size;
340         }
341         /* If the packet size is wrong drop the packet and then continue */
342         if (ntohs(udp->len) != (data_len + (data - (unsigned char*)udp))) {
343                 printf("ALERT: udp packet is not the correct size\n");
344                 return 1;
345         }
346         if (nic.packetlen < data_len + (data - nic.packet)) {
347                 printf("ALERT: Ethernet packet shorter than data_len\n");
348                 return 1;
349         }
350         if (data_len > state.block_size) {
351                 data_len = state.block_size;
352         }
353         if (((state.bitmap[packet >> 3] >> (packet & 7)) & 1) == 0) {
354                 /* Non duplicate packet */
355                 state.bitmap[packet >> 3] |= (1 << (packet & 7));
356                 memcpy(state.image + (packet*state.block_size), data, data_len);
357                 state.received_packets++;
358         } else {
359 #ifdef MDEBUG
360                 printf("<DUP>\n");
361 #endif
362         }
363         return 1;
364 }
365
366 static void transmit_nack(unsigned char *ptr, struct slam_info *info)
367 {
368         int nack_len;
369         /* Ensure the packet is null terminated */
370         *ptr++ = 0;
371         nack_len = ptr - (unsigned char *)&nack;
372         build_udp_hdr(info->server.sin_addr.s_addr, info->local.sin_port,
373                       info->server.sin_port, 1, nack_len, &nack);
374         ip_transmit(nack_len, &nack);
375 #if defined(MDEBUG) && 0
376         printf("Sent NACK to %@ bytes: %d have:%ld/%ld\n", 
377                 info->server_ip, nack_len,
378                 state.received_packets, state.total_packets);
379 #endif
380 }
381
382 static void slam_send_nack(struct slam_info *info)
383 {
384         unsigned char *ptr, *end;
385         /* Either I timed out or I was explicitly 
386          * asked for a request packet 
387          */
388         ptr = &nack.data[0];
389         /* Reserve space for the trailling null */
390         end = &nack.data[sizeof(nack.data) -1]; 
391         if (!state.bitmap) {
392                 slam_encode(&ptr, end, 0);
393                 slam_encode(&ptr, end, 1);
394         }
395         else {
396                 /* Walk the bitmap */
397                 unsigned long i;
398                 unsigned long len;
399                 unsigned long max;
400                 int value;
401                 int last;
402                 /* Compute the last bit and store an inverted trailer */
403                 max = state.total_packets;
404                 value = ((state.bitmap[(max -1) >> 3] >> ((max -1) & 7) ) & 1);
405                 value = !value;
406                 state.bitmap[max >> 3] &= ~(1 << (max & 7));
407                 state.bitmap[max >> 3] |= value << (max & 7);
408
409                 len = 0;
410                 last = 1; /* Start with the received packets */
411                 for(i = 0; i <= max; i++) {
412                         value = (state.bitmap[i>>3] >> (i & 7)) & 1;
413                         if (value == last) {
414                                 len++;
415                         } else {
416                                 if (slam_encode(&ptr, end, len))
417                                         break;
418                                 last = value;
419                                 len = 1;
420                         }
421                 }
422         }
423         info->sent_nack = 1;
424         transmit_nack(ptr, info);
425 }
426
427 static void slam_send_disconnect(struct slam_info *info)
428 {
429         if (info->sent_nack) {
430                 /* A disconnect is a packet with just the null terminator */
431                 transmit_nack(&nack.data[0], info);
432         }
433         info->sent_nack = 0;
434 }
435
436
437 static int proto_slam(struct slam_info *info)
438 {
439         int retry;
440         long timeout;
441
442         init_slam_state();
443
444         retry = -1;
445         rx_qdrain();
446         /* Arp for my server */
447         if (arptable[ARP_SERVER].ipaddr.s_addr != info->server.sin_addr.s_addr) {
448                 arptable[ARP_SERVER].ipaddr.s_addr = info->server.sin_addr.s_addr;
449                 memset(arptable[ARP_SERVER].node, 0, ETH_ALEN);
450         }
451         /* If I'm running over multicast join the multicast group */
452         join_group(IGMP_SERVER, info->multicast.sin_addr.s_addr);
453         for(;;) {
454                 unsigned char *header;
455                 unsigned char *data;
456                 int type;
457                 header = data = 0;
458
459                 timeout = slam_sleep_interval(retry);
460                 type = await_reply(await_slam, 0, info, timeout);
461                 /* Compute the timeout for next time */
462                 if (type == SLAM_TIMEOUT) {
463                         /* If I timeouted recompute the next timeout */
464                         if (retry++ > SLAM_MAX_RETRIES) {
465                                 return 0;
466                         }
467                 } else {
468                         retry = 0;
469                 }
470                 if ((type == SLAM_DATA) || (type == SLAM_REQUEST)) {
471                         /* Check the incomming packet and reinit the data 
472                          * structures if necessary.
473                          */
474                         header = &nic.packet[ETH_HLEN + 
475                                 sizeof(struct iphdr) + sizeof(struct udphdr)];
476                         data = header + state.hdr_len;
477                         if (memcmp(state.hdr, header, state.hdr_len) != 0) {
478                                 /* Something is fishy reset the transaction */
479                                 data = reinit_slam_state(header, &nic.packet[nic.packetlen]);
480                                 if (!data) {
481                                         return 0;
482                                 }
483                         }
484                 }
485                 if (type == SLAM_DATA) {
486                         if (!slam_recv_data(data)) {
487                                 return 0;
488                         }
489                         if (state.received_packets == state.total_packets) {
490                                 /* We are done get out */
491                                 break;
492                         }
493                 }
494                 if ((type == SLAM_TIMEOUT) || (type == SLAM_REQUEST)) {
495                         /* Either I timed out or I was explicitly 
496                          * asked by a request packet 
497                          */
498                         slam_send_nack(info);
499                 }
500         }
501         slam_send_disconnect(info);
502
503         /* Leave the multicast group */
504         leave_group(IGMP_SERVER);
505         /* FIXME don't overwrite myself */
506         /* load file to correct location */
507         return info->process(state.image, 1, state.total_bytes, 1);
508 }
509
510 static int url_slam ( char *url __unused,
511                       struct sockaddr_in *server,
512                       char *file,
513                       int ( * process ) ( unsigned char *data,
514                                           unsigned int blocknum,
515                                           unsigned int len, int eof ) ) {
516         struct slam_info info;
517         /* Set the defaults */
518         info.server = *server;
519         if ( ! info.server.sin_port )
520                 info.server.sin_port = SLAM_PORT;
521         info.multicast.sin_addr.s_addr = htonl(SLAM_MULTICAST_IP);
522         info.multicast.sin_port      = SLAM_MULTICAST_PORT;
523         info.local.sin_addr.s_addr   = arptable[ARP_CLIENT].ipaddr.s_addr;
524         info.local.sin_port          = SLAM_LOCAL_PORT;
525         info.process                 = process;
526         info.sent_nack = 0;
527         if (file[0]) {
528                 printf("\nBad url\n");
529                 return 0;
530         }
531         return proto_slam(&info);
532 }
533
534 static struct protocol slam_protocol __protocol = {
535         "x-slam", url_slam
536 };