298e419324f9e517ae24d9928f02dbc731e0260d
[people/adir/gpxe.git] / src / crypto / axtls / aes.c
1 /*
2  *  Copyright(C) 2006 Cameron Rich
3  *
4  *  This library is free software; you can redistribute it and/or modify
5  *  it under the terms of the GNU Lesser General Public License as published by
6  *  the Free Software Foundation; either version 2 of the License, or
7  *  (at your option) any later version.
8  *
9  *  This library is distributed in the hope that it will be useful,
10  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *  GNU Lesser General Public License for more details.
13  *
14  *  You should have received a copy of the GNU Lesser General Public License
15  *  along with this library; if not, write to the Free Software
16  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17  */
18
19 /**
20  * AES implementation - this is a small code version. There are much faster
21  * versions around but they are much larger in size (i.e. they use large 
22  * submix tables).
23  */
24
25 #include <string.h>
26 #include "crypto.h"
27
28 /* all commented out in skeleton mode */
29 #ifndef CONFIG_SSL_SKELETON_MODE
30
31 #define rot1(x) (((x) << 24) | ((x) >> 8))
32 #define rot2(x) (((x) << 16) | ((x) >> 16))
33 #define rot3(x) (((x) <<  8) | ((x) >> 24))
34
35 /* 
36  * This cute trick does 4 'mul by two' at once.  Stolen from
37  * Dr B. R. Gladman <brg@gladman.uk.net> but I'm sure the u-(u>>7) is
38  * a standard graphics trick
39  * The key to this is that we need to xor with 0x1b if the top bit is set.
40  * a 1xxx xxxx   0xxx 0xxx First we mask the 7bit,
41  * b 1000 0000   0000 0000 then we shift right by 7 putting the 7bit in 0bit,
42  * c 0000 0001   0000 0000 we then subtract (c) from (b)
43  * d 0111 1111   0000 0000 and now we and with our mask
44  * e 0001 1011   0000 0000
45  */
46 #define mt  0x80808080
47 #define ml  0x7f7f7f7f
48 #define mh  0xfefefefe
49 #define mm  0x1b1b1b1b
50 #define mul2(x,t)       ((t)=((x)&mt), \
51                         ((((x)+(x))&mh)^(((t)-((t)>>7))&mm)))
52
53 #define inv_mix_col(x,f2,f4,f8,f9) (\
54                         (f2)=mul2(x,f2), \
55                         (f4)=mul2(f2,f4), \
56                         (f8)=mul2(f4,f8), \
57                         (f9)=(x)^(f8), \
58                         (f8)=((f2)^(f4)^(f8)), \
59                         (f2)^=(f9), \
60                         (f4)^=(f9), \
61                         (f8)^=rot3(f2), \
62                         (f8)^=rot2(f4), \
63                         (f8)^rot1(f9))
64
65 /* some macros to do endian independent byte extraction */
66 #define n2l(c,l) l=ntohl(*c); c++
67 #define l2n(l,c) *c++=htonl(l)
68
69 /*
70  * AES S-box
71  */
72 static const uint8_t aes_sbox[256] =
73 {
74         0x63,0x7C,0x77,0x7B,0xF2,0x6B,0x6F,0xC5,
75         0x30,0x01,0x67,0x2B,0xFE,0xD7,0xAB,0x76,
76         0xCA,0x82,0xC9,0x7D,0xFA,0x59,0x47,0xF0,
77         0xAD,0xD4,0xA2,0xAF,0x9C,0xA4,0x72,0xC0,
78         0xB7,0xFD,0x93,0x26,0x36,0x3F,0xF7,0xCC,
79         0x34,0xA5,0xE5,0xF1,0x71,0xD8,0x31,0x15,
80         0x04,0xC7,0x23,0xC3,0x18,0x96,0x05,0x9A,
81         0x07,0x12,0x80,0xE2,0xEB,0x27,0xB2,0x75,
82         0x09,0x83,0x2C,0x1A,0x1B,0x6E,0x5A,0xA0,
83         0x52,0x3B,0xD6,0xB3,0x29,0xE3,0x2F,0x84,
84         0x53,0xD1,0x00,0xED,0x20,0xFC,0xB1,0x5B,
85         0x6A,0xCB,0xBE,0x39,0x4A,0x4C,0x58,0xCF,
86         0xD0,0xEF,0xAA,0xFB,0x43,0x4D,0x33,0x85,
87         0x45,0xF9,0x02,0x7F,0x50,0x3C,0x9F,0xA8,
88         0x51,0xA3,0x40,0x8F,0x92,0x9D,0x38,0xF5,
89         0xBC,0xB6,0xDA,0x21,0x10,0xFF,0xF3,0xD2,
90         0xCD,0x0C,0x13,0xEC,0x5F,0x97,0x44,0x17,
91         0xC4,0xA7,0x7E,0x3D,0x64,0x5D,0x19,0x73,
92         0x60,0x81,0x4F,0xDC,0x22,0x2A,0x90,0x88,
93         0x46,0xEE,0xB8,0x14,0xDE,0x5E,0x0B,0xDB,
94         0xE0,0x32,0x3A,0x0A,0x49,0x06,0x24,0x5C,
95         0xC2,0xD3,0xAC,0x62,0x91,0x95,0xE4,0x79,
96         0xE7,0xC8,0x37,0x6D,0x8D,0xD5,0x4E,0xA9,
97         0x6C,0x56,0xF4,0xEA,0x65,0x7A,0xAE,0x08,
98         0xBA,0x78,0x25,0x2E,0x1C,0xA6,0xB4,0xC6,
99         0xE8,0xDD,0x74,0x1F,0x4B,0xBD,0x8B,0x8A,
100         0x70,0x3E,0xB5,0x66,0x48,0x03,0xF6,0x0E,
101         0x61,0x35,0x57,0xB9,0x86,0xC1,0x1D,0x9E,
102         0xE1,0xF8,0x98,0x11,0x69,0xD9,0x8E,0x94,
103         0x9B,0x1E,0x87,0xE9,0xCE,0x55,0x28,0xDF,
104         0x8C,0xA1,0x89,0x0D,0xBF,0xE6,0x42,0x68,
105         0x41,0x99,0x2D,0x0F,0xB0,0x54,0xBB,0x16,
106 };
107
108 /*
109  * AES is-box
110  */
111 static const uint8_t aes_isbox[256] = 
112 {
113     0x52,0x09,0x6a,0xd5,0x30,0x36,0xa5,0x38,
114     0xbf,0x40,0xa3,0x9e,0x81,0xf3,0xd7,0xfb,
115     0x7c,0xe3,0x39,0x82,0x9b,0x2f,0xff,0x87,
116     0x34,0x8e,0x43,0x44,0xc4,0xde,0xe9,0xcb,
117     0x54,0x7b,0x94,0x32,0xa6,0xc2,0x23,0x3d,
118     0xee,0x4c,0x95,0x0b,0x42,0xfa,0xc3,0x4e,
119     0x08,0x2e,0xa1,0x66,0x28,0xd9,0x24,0xb2,
120     0x76,0x5b,0xa2,0x49,0x6d,0x8b,0xd1,0x25,
121     0x72,0xf8,0xf6,0x64,0x86,0x68,0x98,0x16,
122     0xd4,0xa4,0x5c,0xcc,0x5d,0x65,0xb6,0x92,
123     0x6c,0x70,0x48,0x50,0xfd,0xed,0xb9,0xda,
124     0x5e,0x15,0x46,0x57,0xa7,0x8d,0x9d,0x84,
125     0x90,0xd8,0xab,0x00,0x8c,0xbc,0xd3,0x0a,
126     0xf7,0xe4,0x58,0x05,0xb8,0xb3,0x45,0x06,
127     0xd0,0x2c,0x1e,0x8f,0xca,0x3f,0x0f,0x02,
128     0xc1,0xaf,0xbd,0x03,0x01,0x13,0x8a,0x6b,
129     0x3a,0x91,0x11,0x41,0x4f,0x67,0xdc,0xea,
130     0x97,0xf2,0xcf,0xce,0xf0,0xb4,0xe6,0x73,
131     0x96,0xac,0x74,0x22,0xe7,0xad,0x35,0x85,
132     0xe2,0xf9,0x37,0xe8,0x1c,0x75,0xdf,0x6e,
133     0x47,0xf1,0x1a,0x71,0x1d,0x29,0xc5,0x89,
134     0x6f,0xb7,0x62,0x0e,0xaa,0x18,0xbe,0x1b,
135     0xfc,0x56,0x3e,0x4b,0xc6,0xd2,0x79,0x20,
136     0x9a,0xdb,0xc0,0xfe,0x78,0xcd,0x5a,0xf4,
137     0x1f,0xdd,0xa8,0x33,0x88,0x07,0xc7,0x31,
138     0xb1,0x12,0x10,0x59,0x27,0x80,0xec,0x5f,
139     0x60,0x51,0x7f,0xa9,0x19,0xb5,0x4a,0x0d,
140     0x2d,0xe5,0x7a,0x9f,0x93,0xc9,0x9c,0xef,
141     0xa0,0xe0,0x3b,0x4d,0xae,0x2a,0xf5,0xb0,
142     0xc8,0xeb,0xbb,0x3c,0x83,0x53,0x99,0x61,
143     0x17,0x2b,0x04,0x7e,0xba,0x77,0xd6,0x26,
144     0xe1,0x69,0x14,0x63,0x55,0x21,0x0c,0x7d
145 };
146
147 static const unsigned char Rcon[30]=
148 {
149         0x01,0x02,0x04,0x08,0x10,0x20,0x40,0x80,
150         0x1b,0x36,0x6c,0xd8,0xab,0x4d,0x9a,0x2f,
151         0x5e,0xbc,0x63,0xc6,0x97,0x35,0x6a,0xd4,
152         0xb3,0x7d,0xfa,0xef,0xc5,0x91,
153 };
154
155 /* ----- static functions ----- */
156 static void AES_encrypt(const AES_CTX *ctx, uint32_t *data);
157 static void AES_decrypt(const AES_CTX *ctx, uint32_t *data);
158
159 /* Perform doubling in Galois Field GF(2^8) using the irreducible polynomial
160    x^8+x^4+x^3+x+1 */
161 static unsigned char AES_xtime(uint32_t x)
162 {
163         return x = (x&0x80) ? (x<<1)^0x1b : x<<1;
164 }
165
166 /**
167  * Set up AES with the key/iv and cipher size.
168  */
169 void AES_set_key(AES_CTX *ctx, const uint8_t *key, 
170         const uint8_t *iv, AES_MODE mode)
171 {
172     int i, ii;
173     uint32_t *W, tmp, tmp2;
174     const unsigned char *ip;
175     int words;
176
177     switch (mode)
178     {
179         case AES_MODE_128:
180             i = 10;
181             words = 4;
182             break;
183
184         case AES_MODE_256:
185             i = 14;
186             words = 8;
187             break;
188
189         default:        /* fail silently */
190             return;
191     }
192
193     ctx->rounds = i;
194     ctx->key_size = words;
195     W = ctx->ks;
196     for (i = 0; i < words; i+=2)
197     {
198         W[i+0]= ((uint32_t)key[ 0]<<24)|
199             ((uint32_t)key[ 1]<<16)|
200             ((uint32_t)key[ 2]<< 8)|
201             ((uint32_t)key[ 3]    );
202         W[i+1]= ((uint32_t)key[ 4]<<24)|
203             ((uint32_t)key[ 5]<<16)|
204             ((uint32_t)key[ 6]<< 8)|
205             ((uint32_t)key[ 7]    );
206         key += 8;
207     }
208
209     ip = Rcon;
210     ii = 4 * (ctx->rounds+1);
211     for (i = words; i<ii; i++)
212     {
213         tmp = W[i-1];
214
215         if ((i % words) == 0)
216         {
217             tmp2 =(uint32_t)aes_sbox[(tmp    )&0xff]<< 8;
218             tmp2|=(uint32_t)aes_sbox[(tmp>> 8)&0xff]<<16;
219             tmp2|=(uint32_t)aes_sbox[(tmp>>16)&0xff]<<24;
220             tmp2|=(uint32_t)aes_sbox[(tmp>>24)     ];
221             tmp=tmp2^(((unsigned int)*ip)<<24);
222             ip++;
223         }
224
225         if ((words == 8) && ((i % words) == 4))
226         {
227             tmp2 =(uint32_t)aes_sbox[(tmp    )&0xff]    ;
228             tmp2|=(uint32_t)aes_sbox[(tmp>> 8)&0xff]<< 8;
229             tmp2|=(uint32_t)aes_sbox[(tmp>>16)&0xff]<<16;
230             tmp2|=(uint32_t)aes_sbox[(tmp>>24)     ]<<24;
231             tmp=tmp2;
232         }
233
234         W[i]=W[i-words]^tmp;
235     }
236
237     /* copy the iv across */
238     memcpy(ctx->iv, iv, 16);
239 }
240
241 #if 0
242 /** currently unused function **/
243
244 /**
245  * Change a key for decryption.
246  */
247 void AES_convert_key(AES_CTX *ctx)
248 {
249     int i;
250     uint32_t *k,w,t1,t2,t3,t4;
251
252     k = ctx->ks;
253     k += 4;
254
255     for (i=ctx->rounds*4; i>4; i--)
256     {
257         w= *k;
258         w = inv_mix_col(w,t1,t2,t3,t4);
259         *k++ =w;
260     }
261 }
262 #endif
263
264 /**
265  * Encrypt a byte sequence (with a block size 16) using the AES cipher.
266  */
267 void AES_cbc_encrypt(AES_CTX *ctx, const uint8_t *msg, uint8_t *out, int length)
268 {
269     uint32_t tin0, tin1, tin2, tin3;
270     uint32_t tout0, tout1, tout2, tout3;
271     uint32_t tin[4];
272     uint32_t *iv = (uint32_t *)ctx->iv;
273     uint32_t *msg_32 = (uint32_t *)msg;
274     uint32_t *out_32 = (uint32_t *)out;
275
276     n2l(iv, tout0);
277     n2l(iv, tout1);
278     n2l(iv, tout2);
279     n2l(iv, tout3);
280     iv -= 4;
281
282     for (length -= 16; length >= 0; length -= 16)
283     {
284         n2l(msg_32, tin0);
285         n2l(msg_32, tin1);
286         n2l(msg_32, tin2);
287         n2l(msg_32, tin3);
288         tin[0] = tin0^tout0;
289         tin[1] = tin1^tout1;
290         tin[2] = tin2^tout2;
291         tin[3] = tin3^tout3;
292
293         AES_encrypt(ctx, tin);
294
295         tout0 = tin[0]; 
296         l2n(tout0, out_32);
297         tout1 = tin[1]; 
298         l2n(tout1, out_32);
299         tout2 = tin[2]; 
300         l2n(tout2, out_32);
301         tout3 = tin[3]; 
302         l2n(tout3, out_32);
303     }
304
305     l2n(tout0, iv);
306     l2n(tout1, iv);
307     l2n(tout2, iv);
308     l2n(tout3, iv);
309 }
310
311 /**
312  * Decrypt a byte sequence (with a block size 16) using the AES cipher.
313  */
314 void AES_cbc_decrypt(AES_CTX *ctx, const uint8_t *msg, uint8_t *out, int length)
315 {
316     uint32_t tin0, tin1, tin2, tin3;
317     uint32_t xor0,xor1,xor2,xor3;
318     uint32_t tout0,tout1,tout2,tout3;
319     uint32_t data[4];
320     uint32_t *iv = (uint32_t *)ctx->iv;
321     uint32_t *msg_32 = (uint32_t *)msg;
322     uint32_t *out_32 = (uint32_t *)out;
323
324     n2l(iv ,xor0);
325     n2l(iv, xor1);
326     n2l(iv, xor2);
327     n2l(iv, xor3);
328     iv -= 4;
329
330     for (length-=16; length >= 0; length -= 16)
331     {
332         n2l(msg_32, tin0);
333         n2l(msg_32, tin1);
334         n2l(msg_32, tin2);
335         n2l(msg_32, tin3);
336
337         data[0] = tin0;
338         data[1] = tin1;
339         data[2] = tin2;
340         data[3] = tin3;
341
342         AES_decrypt(ctx, data);
343
344         tout0 = data[0]^xor0;
345         tout1 = data[1]^xor1;
346         tout2 = data[2]^xor2;
347         tout3 = data[3]^xor3;
348
349         xor0 = tin0;
350         xor1 = tin1;
351         xor2 = tin2;
352         xor3 = tin3;
353
354         l2n(tout0, out_32);
355         l2n(tout1, out_32);
356         l2n(tout2, out_32);
357         l2n(tout3, out_32);
358     }
359
360     l2n(xor0, iv);
361     l2n(xor1, iv);
362     l2n(xor2, iv);
363     l2n(xor3, iv);
364 }
365
366 /**
367  * Encrypt a single block (16 bytes) of data
368  */
369 static void AES_encrypt(const AES_CTX *ctx, uint32_t *data)
370 {
371     /* To make this code smaller, generate the sbox entries on the fly.
372      * This will have a really heavy effect upon performance.
373      */
374     uint32_t tmp[4];
375     uint32_t tmp1, old_a0, a0, a1, a2, a3, row;
376     int curr_rnd;
377     int rounds = ctx->rounds; 
378     const uint32_t *k = ctx->ks;
379
380     /* Pre-round key addition */
381     for (row = 0; row < 4; row++)
382     {
383         data[row] ^= *(k++);
384     }
385
386     /* Encrypt one block. */
387     for (curr_rnd = 0; curr_rnd < rounds; curr_rnd++)
388     {
389         /* Perform ByteSub and ShiftRow operations together */
390         for (row = 0; row < 4; row++)
391         {
392             a0 = (uint32_t)aes_sbox[(data[row%4]>>24)&0xFF];
393             a1 = (uint32_t)aes_sbox[(data[(row+1)%4]>>16)&0xFF];
394             a2 = (uint32_t)aes_sbox[(data[(row+2)%4]>>8)&0xFF]; 
395             a3 = (uint32_t)aes_sbox[(data[(row+3)%4])&0xFF];
396
397             /* Perform MixColumn iff not last round */
398             if (curr_rnd < (rounds - 1))
399             {
400                 tmp1 = a0 ^ a1 ^ a2 ^ a3;
401                 old_a0 = a0;
402
403                 a0 ^= tmp1 ^ AES_xtime(a0 ^ a1);
404                 a1 ^= tmp1 ^ AES_xtime(a1 ^ a2);
405                 a2 ^= tmp1 ^ AES_xtime(a2 ^ a3);
406                 a3 ^= tmp1 ^ AES_xtime(a3 ^ old_a0);
407
408             }
409
410             tmp[row] = ((a0 << 24) | (a1 << 16) | (a2 << 8) | a3);
411         }
412
413         /* KeyAddition - note that it is vital that this loop is separate from
414            the MixColumn operation, which must be atomic...*/ 
415         for (row = 0; row < 4; row++)
416         {
417             data[row] = tmp[row] ^ *(k++);
418         }
419     }
420 }
421
422 /**
423  * Decrypt a single block (16 bytes) of data
424  */
425 static void AES_decrypt(const AES_CTX *ctx, uint32_t *data)
426
427     uint32_t tmp[4];
428     uint32_t xt0,xt1,xt2,xt3,xt4,xt5,xt6;
429     uint32_t a0, a1, a2, a3, row;
430     int curr_rnd;
431     int rounds = ctx->rounds;
432     uint32_t *k = (uint32_t*)ctx->ks + ((rounds+1)*4);
433
434     /* pre-round key addition */
435     for (row=4; row > 0;row--)
436     {
437         data[row-1] ^= *(--k);
438     }
439
440     /* Decrypt one block */
441     for (curr_rnd=0; curr_rnd < rounds; curr_rnd++)
442     {
443         /* Perform ByteSub and ShiftRow operations together */
444         for (row = 4; row > 0; row--)
445         {
446             a0 = aes_isbox[(data[(row+3)%4]>>24)&0xFF];
447             a1 = aes_isbox[(data[(row+2)%4]>>16)&0xFF];
448             a2 = aes_isbox[(data[(row+1)%4]>>8)&0xFF];
449             a3 = aes_isbox[(data[row%4])&0xFF];
450
451             /* Perform MixColumn iff not last round */
452             if (curr_rnd<(rounds-1))
453             {
454                 /* The MDS cofefficients (0x09, 0x0B, 0x0D, 0x0E)
455                    are quite large compared to encryption; this 
456                    operation slows decryption down noticeably. */
457                 xt0 = AES_xtime(a0^a1);
458                 xt1 = AES_xtime(a1^a2);
459                 xt2 = AES_xtime(a2^a3);
460                 xt3 = AES_xtime(a3^a0);
461                 xt4 = AES_xtime(xt0^xt1);
462                 xt5 = AES_xtime(xt1^xt2);
463                 xt6 = AES_xtime(xt4^xt5);
464
465                 xt0 ^= a1^a2^a3^xt4^xt6;
466                 xt1 ^= a0^a2^a3^xt5^xt6;
467                 xt2 ^= a0^a1^a3^xt4^xt6;
468                 xt3 ^= a0^a1^a2^xt5^xt6;
469                 tmp[row-1] = ((xt0<<24)|(xt1<<16)|(xt2<<8)|xt3);
470             }
471             else
472                 tmp[row-1] = ((a0<<24)|(a1<<16)|(a2<<8)|a3);
473         }
474
475         for (row = 4; row > 0; row--)
476         {
477             data[row-1] = tmp[row-1] ^ *(--k);
478         }
479     }
480 }
481
482 #endif