Quickly hacked to use a buffer rather than a processor.
[people/lynusvaz/gpxe.git] / src / proto / slam.c
1 #include "etherboot.h"
2 #include "proto.h"
3 #include "nic.h"
4
5 /*
6  * IMPORTANT
7  *
8  * This file should be rewritten to avoid the use of a bitmap.  Our
9  * buffer routines can cope with being handed blocks in an arbitrary
10  * order, duplicate blocks, etc.  This code could be substantially
11  * simplified by taking advantage of these features.
12  *
13  */
14
15 #define SLAM_PORT 10000
16 #define SLAM_MULTICAST_IP ((239<<24)|(255<<16)|(1<<8)|(1<<0))
17 #define SLAM_MULTICAST_PORT 10000
18 #define SLAM_LOCAL_PORT 10000
19
20 /* Set the timeout intervals to at least 1 second so
21  * on a 100Mbit ethernet can receive 10000 packets
22  * in one second.  
23  *
24  * The only case that is likely to trigger all of the nodes
25  * firing a nack packet is a slow server.  The odds of this
26  * happening could be reduced being slightly smarter and utilizing 
27  * the multicast channels for nacks.   But that only improves the odds
28  * it doesn't improve the worst case.  So unless this proves to be
29  * a common case having the control data going unicast should increase
30  * the odds of the data not being dropped.  
31  *
32  * When doing exponential backoff we increase just the timeout
33  * interval and not the base to optimize for throughput.  This is only
34  * expected to happen when the server is down.  So having some nodes
35  * pinging immediately should get the transmission restarted quickly after a
36  * server restart.  The host nic won't be to baddly swamped because of
37  * the random distribution of the nodes.
38  *
39  */
40 #define SLAM_INITIAL_MIN_TIMEOUT      (TICKS_PER_SEC/3)
41 #define SLAM_INITIAL_TIMEOUT_INTERVAL (TICKS_PER_SEC)
42 #define SLAM_BASE_MIN_TIMEOUT         (2*TICKS_PER_SEC)
43 #define SLAM_BASE_TIMEOUT_INTERVAL    (4*TICKS_PER_SEC)
44 #define SLAM_BACKOFF_LIMIT 5
45 #define SLAM_MAX_RETRIES 20
46
47 /*** Packets Formats ***
48  * Data Packet:
49  *   transaction
50  *   total bytes
51  *   block size
52  *   packet #
53  *   data
54  *
55  * Status Request Packet
56  *   transaction
57  *   total bytes
58  *   block size
59  *
60  * Status Packet
61  *   received packets
62  *   requested packets
63  *   received packets
64  *   requested packets
65  *   ...
66  *   received packets
67  *   requested packtes
68  *   0
69  */
70
71 #define MAX_HDR (7 + 7 + 7) /* transaction, total size, block size */
72 #define MIN_HDR (1 + 1 + 1) /* transactino, total size, block size */
73
74 #define MAX_SLAM_REQUEST MAX_HDR
75 #define MIN_SLAM_REQUEST MIN_HDR
76
77 #define MIN_SLAM_DATA (MIN_HDR + 1)
78
79 static struct slam_nack {
80         struct iphdr ip;
81         struct udphdr udp;
82         unsigned char data[ETH_MAX_MTU - 
83                 (sizeof(struct iphdr) + sizeof(struct udphdr))];
84 } nack;
85
86 struct slam_state {
87         unsigned char hdr[MAX_HDR];
88         unsigned long hdr_len;
89         unsigned long block_size;
90         unsigned long total_bytes;
91         unsigned long total_packets;
92
93         unsigned long received_packets;
94
95         struct buffer *buffer;
96         unsigned char *image;
97         unsigned char *bitmap;
98 } state;
99
100
101 static void init_slam_state(void)
102 {
103         state.hdr_len = sizeof(state.hdr);
104         memset(state.hdr, 0, state.hdr_len);
105         state.block_size = 0;
106         state.total_packets = 0;
107
108         state.received_packets = 0;
109
110         state.image = 0;
111         state.bitmap = 0;
112 }
113
114 struct slam_info {
115         struct sockaddr_in server;
116         struct sockaddr_in local;
117         struct sockaddr_in multicast;
118         int sent_nack;
119         struct buffer *buffer;
120 };
121
122 #define SLAM_TIMEOUT 0
123 #define SLAM_REQUEST 1
124 #define SLAM_DATA    2
125 static int await_slam(int ival __unused, void *ptr,
126                       unsigned short ptype __unused, struct iphdr *ip,
127                       struct udphdr *udp, struct tcphdr *tcp __unused)
128 {
129         struct slam_info *info = ptr;
130         if (!udp) {
131                 return 0;
132         }
133         /* I can receive two kinds of packets here, a multicast data packet,
134          * or a unicast request for information 
135          */
136         /* Check for a data request packet */
137         if ((ip->dest.s_addr == arptable[ARP_CLIENT].ipaddr.s_addr) &&
138                 (ntohs(udp->dest) == info->local.sin_port) && 
139                 (nic.packetlen >= 
140                         ETH_HLEN + 
141                         sizeof(struct iphdr) + 
142                         sizeof(struct udphdr) +
143                         MIN_SLAM_REQUEST)) {
144                 return SLAM_REQUEST;
145         }
146         /* Check for a multicast data packet */
147         if ((ip->dest.s_addr == info->multicast.sin_addr.s_addr) &&
148                 (ntohs(udp->dest) == info->multicast.sin_port) &&
149                 (nic.packetlen >= 
150                         ETH_HLEN + 
151                         sizeof(struct iphdr) + 
152                         sizeof(struct udphdr) +
153                         MIN_SLAM_DATA)) {
154                 return SLAM_DATA;
155         }
156 #if 0
157         printf("#");
158         printf("dest: %@ port: %d len: %d\n", 
159                 ip->dest.s_addr, ntohs(udp->dest), nic.packetlen);
160 #endif
161         return 0;
162                 
163 }
164
165 static int slam_encode(
166         unsigned char **ptr, unsigned char *end, unsigned long value)
167 {
168         unsigned char *data = *ptr;
169         int bytes;
170         bytes = sizeof(value);
171         while ((bytes > 0) && ((0xff & (value >> ((bytes -1)<<3))) == 0)) {
172                 bytes--;
173         }
174         if (bytes <= 0) {
175                 bytes = 1;
176         }
177         if (data + bytes >= end) {
178                 return -1;
179         }
180         if ((0xe0 & (value >> ((bytes -1)<<3))) == 0) {
181                 /* packed together */
182                 *data = (bytes << 5) | (value >> ((bytes -1)<<3));
183         } else {
184                 bytes++;
185                 *data = (bytes << 5);
186         }
187         bytes--;
188         data++;
189         while(bytes) {
190                 *(data++) = 0xff & (value >> ((bytes -1)<<3));
191                 bytes--;
192         }
193         *ptr = data;
194         return 0;
195 }
196
197 static int slam_skip(unsigned char **ptr, unsigned char *end) 
198 {
199         int bytes;
200         if (*ptr >= end) {
201                 return -1;
202         }
203         bytes = ((**ptr) >> 5) & 7;
204         if (bytes == 0) {
205                 return -1;
206         }
207         if (*ptr + bytes >= end) {
208                 return -1;
209         }
210         (*ptr) += bytes;
211         return 0;
212         
213 }
214
215 static unsigned long slam_decode(unsigned char **ptr, unsigned char *end,
216                                  int *err)
217 {
218         unsigned long value;
219         unsigned bytes;
220         if (*ptr >= end) {
221                 *err = -1;
222         }
223         bytes = ((**ptr) >> 5) & 7;
224         if ((bytes == 0) || (bytes > sizeof(unsigned long))) {
225                 *err = -1;
226                 return 0;
227         }
228         if ((*ptr) + bytes >= end) {
229                 *err =  -1;
230         }
231         value = (**ptr) & 0x1f;
232         bytes--;
233         (*ptr)++;
234         while(bytes) {
235                 value <<= 8;
236                 value |= **ptr;
237                 (*ptr)++;
238                 bytes--;
239         }
240         return value;
241 }
242
243
244 static long slam_sleep_interval(int exp)
245 {
246         long range;
247         long divisor;
248         long interval;
249         range = SLAM_BASE_TIMEOUT_INTERVAL;
250         if (exp < 0) { 
251                 divisor = RAND_MAX/SLAM_INITIAL_TIMEOUT_INTERVAL;
252         } else {
253                 if (exp > SLAM_BACKOFF_LIMIT) 
254                         exp = SLAM_BACKOFF_LIMIT;
255                 divisor = RAND_MAX/(range << exp);
256         }
257         interval = random()/divisor;
258         if (exp < 0) {
259                 interval += SLAM_INITIAL_MIN_TIMEOUT;
260         } else {
261                 interval += SLAM_BASE_MIN_TIMEOUT;
262         }
263         return interval;
264 }
265
266
267 static unsigned char *reinit_slam_state(
268         unsigned char *header, unsigned char *end)
269 {
270         unsigned long total_bytes;
271         unsigned long block_size;
272
273         unsigned long bitmap_len;
274         unsigned long max_packet_len;
275         unsigned char *data;
276         int err;
277
278 #if 0
279         printf("reinit\n");
280 #endif
281         data = header;
282
283         state.hdr_len = 0;
284         err = slam_skip(&data, end); /* transaction id */
285         total_bytes = slam_decode(&data, end, &err);
286         block_size  = slam_decode(&data, end, &err);
287         if (err) {
288                 printf("ALERT: slam size out of range\n");
289                 return 0;
290         }
291         state.block_size = block_size;
292         state.total_bytes = total_bytes;
293         state.total_packets = (total_bytes + block_size - 1)/block_size;
294         state.hdr_len = data - header;
295         state.received_packets = 0;
296
297         data = state.hdr;
298         slam_encode(&data, &state.hdr[sizeof(state.hdr)], state.total_packets);
299         max_packet_len = data - state.hdr;
300         memcpy(state.hdr, header, state.hdr_len);
301         
302 #if 0
303         printf("block_size:     %ld\n", block_size);
304         printf("total_bytes:    %ld\n", total_bytes);
305         printf("total_packets:  %ld\n", state.total_packets);
306         printf("hdr_len:        %ld\n", state.hdr_len);
307         printf("max_packet_len: %ld\n", max_packet_len);
308 #endif
309
310         if (state.block_size > ETH_MAX_MTU - (
311                 sizeof(struct iphdr) + sizeof(struct udphdr) +
312                 state.hdr_len + max_packet_len)) {
313                 printf("ALERT: slam blocksize to large\n");
314                 return 0;
315         }
316         bitmap_len   = (state.total_packets + 1 + 7)/8;
317         state.image  = phys_to_virt ( state.buffer->start );
318         /* We don't use the buffer routines properly yet; fake it */
319         state.buffer->fill = total_bytes;
320         state.bitmap = state.image + total_bytes;
321         if ((unsigned long)state.image < 1024*1024) {
322                 printf("ALERT: slam filesize to large for available memory\n");
323                 return 0;
324         }
325         memset(state.bitmap, 0, bitmap_len);
326
327         return header + state.hdr_len;
328 }
329
330 static int slam_recv_data(unsigned char *data)
331 {
332         unsigned long packet;
333         unsigned long data_len;
334         int err;
335         struct udphdr *udp;
336         udp = (struct udphdr *)&nic.packet[ETH_HLEN + sizeof(struct iphdr)];
337         err = 0;
338         packet = slam_decode(&data, &nic.packet[nic.packetlen], &err);
339         if (err || (packet > state.total_packets)) {
340                 printf("ALERT: Invalid packet number\n");
341                 return 0;
342         }
343         /* Compute the expected data length */
344         if (packet != state.total_packets -1) {
345                 data_len = state.block_size;
346         } else {
347                 data_len = state.total_bytes % state.block_size;
348         }
349         /* If the packet size is wrong drop the packet and then continue */
350         if (ntohs(udp->len) != (data_len + (data - (unsigned char*)udp))) {
351                 printf("ALERT: udp packet is not the correct size\n");
352                 return 1;
353         }
354         if (nic.packetlen < data_len + (data - nic.packet)) {
355                 printf("ALERT: Ethernet packet shorter than data_len\n");
356                 return 1;
357         }
358         if (data_len > state.block_size) {
359                 data_len = state.block_size;
360         }
361         if (((state.bitmap[packet >> 3] >> (packet & 7)) & 1) == 0) {
362                 /* Non duplicate packet */
363                 state.bitmap[packet >> 3] |= (1 << (packet & 7));
364                 memcpy(state.image + (packet*state.block_size), data, data_len);
365                 state.received_packets++;
366         } else {
367 #ifdef MDEBUG
368                 printf("<DUP>\n");
369 #endif
370         }
371         return 1;
372 }
373
374 static void transmit_nack(unsigned char *ptr, struct slam_info *info)
375 {
376         int nack_len;
377         /* Ensure the packet is null terminated */
378         *ptr++ = 0;
379         nack_len = ptr - (unsigned char *)&nack;
380         build_udp_hdr(info->server.sin_addr.s_addr, info->local.sin_port,
381                       info->server.sin_port, 1, nack_len, &nack);
382         ip_transmit(nack_len, &nack);
383 #if defined(MDEBUG) && 0
384         printf("Sent NACK to %@ bytes: %d have:%ld/%ld\n", 
385                 info->server_ip, nack_len,
386                 state.received_packets, state.total_packets);
387 #endif
388 }
389
390 static void slam_send_nack(struct slam_info *info)
391 {
392         unsigned char *ptr, *end;
393         /* Either I timed out or I was explicitly 
394          * asked for a request packet 
395          */
396         ptr = &nack.data[0];
397         /* Reserve space for the trailling null */
398         end = &nack.data[sizeof(nack.data) -1]; 
399         if (!state.bitmap) {
400                 slam_encode(&ptr, end, 0);
401                 slam_encode(&ptr, end, 1);
402         }
403         else {
404                 /* Walk the bitmap */
405                 unsigned long i;
406                 unsigned long len;
407                 unsigned long max;
408                 int value;
409                 int last;
410                 /* Compute the last bit and store an inverted trailer */
411                 max = state.total_packets;
412                 value = ((state.bitmap[(max -1) >> 3] >> ((max -1) & 7) ) & 1);
413                 value = !value;
414                 state.bitmap[max >> 3] &= ~(1 << (max & 7));
415                 state.bitmap[max >> 3] |= value << (max & 7);
416
417                 len = 0;
418                 last = 1; /* Start with the received packets */
419                 for(i = 0; i <= max; i++) {
420                         value = (state.bitmap[i>>3] >> (i & 7)) & 1;
421                         if (value == last) {
422                                 len++;
423                         } else {
424                                 if (slam_encode(&ptr, end, len))
425                                         break;
426                                 last = value;
427                                 len = 1;
428                         }
429                 }
430         }
431         info->sent_nack = 1;
432         transmit_nack(ptr, info);
433 }
434
435 static void slam_send_disconnect(struct slam_info *info)
436 {
437         if (info->sent_nack) {
438                 /* A disconnect is a packet with just the null terminator */
439                 transmit_nack(&nack.data[0], info);
440         }
441         info->sent_nack = 0;
442 }
443
444
445 static int proto_slam(struct slam_info *info)
446 {
447         int retry;
448         long timeout;
449
450         init_slam_state();
451         state.buffer = info->buffer;
452
453         retry = -1;
454         rx_qdrain();
455         /* Arp for my server */
456         if (arptable[ARP_SERVER].ipaddr.s_addr != info->server.sin_addr.s_addr) {
457                 arptable[ARP_SERVER].ipaddr.s_addr = info->server.sin_addr.s_addr;
458                 memset(arptable[ARP_SERVER].node, 0, ETH_ALEN);
459         }
460         /* If I'm running over multicast join the multicast group */
461         join_group(IGMP_SERVER, info->multicast.sin_addr.s_addr);
462         for(;;) {
463                 unsigned char *header;
464                 unsigned char *data;
465                 int type;
466                 header = data = 0;
467
468                 timeout = slam_sleep_interval(retry);
469                 type = await_reply(await_slam, 0, info, timeout);
470                 /* Compute the timeout for next time */
471                 if (type == SLAM_TIMEOUT) {
472                         /* If I timeouted recompute the next timeout */
473                         if (retry++ > SLAM_MAX_RETRIES) {
474                                 return 0;
475                         }
476                 } else {
477                         retry = 0;
478                 }
479                 if ((type == SLAM_DATA) || (type == SLAM_REQUEST)) {
480                         /* Check the incomming packet and reinit the data 
481                          * structures if necessary.
482                          */
483                         header = &nic.packet[ETH_HLEN + 
484                                 sizeof(struct iphdr) + sizeof(struct udphdr)];
485                         data = header + state.hdr_len;
486                         if (memcmp(state.hdr, header, state.hdr_len) != 0) {
487                                 /* Something is fishy reset the transaction */
488                                 data = reinit_slam_state(header, &nic.packet[nic.packetlen]);
489                                 if (!data) {
490                                         return 0;
491                                 }
492                         }
493                 }
494                 if (type == SLAM_DATA) {
495                         if (!slam_recv_data(data)) {
496                                 return 0;
497                         }
498                         if (state.received_packets == state.total_packets) {
499                                 /* We are done get out */
500                                 break;
501                         }
502                 }
503                 if ((type == SLAM_TIMEOUT) || (type == SLAM_REQUEST)) {
504                         /* Either I timed out or I was explicitly 
505                          * asked by a request packet 
506                          */
507                         slam_send_nack(info);
508                 }
509         }
510         slam_send_disconnect(info);
511
512         /* Leave the multicast group */
513         leave_group(IGMP_SERVER);
514         /* FIXME don't overwrite myself */
515         /* load file to correct location */
516         return 1;
517 }
518
519 static int url_slam ( char *url __unused, struct sockaddr_in *server,
520                       char *file, struct buffer *buffer ) {
521         struct slam_info info;
522         /* Set the defaults */
523         info.server = *server;
524         info.multicast.sin_addr.s_addr = htonl(SLAM_MULTICAST_IP);
525         info.multicast.sin_port      = SLAM_MULTICAST_PORT;
526         info.local.sin_addr.s_addr   = arptable[ARP_CLIENT].ipaddr.s_addr;
527         info.local.sin_port          = SLAM_LOCAL_PORT;
528         info.buffer                  = buffer;
529         info.sent_nack = 0;
530         if (file[0]) {
531                 printf("\nBad url\n");
532                 return 0;
533         }
534         return proto_slam(&info);
535 }
536
537 static struct protocol slam_protocol __protocol = {
538         .name = "x-slam",
539         .default_port = SLAM_PORT,
540         .load = url_slam,
541 };