Bugs fixed. Fixed memory leaks (hopefully).
[people/lynusvaz/gpxe.git] / src / hci / arith.c
1 /*
2  * Recursive descent arithmetic calculator:
3  *   + - * / ( )
4  */
5
6 /*
7 Ops: !, ~                               (Highest)
8         *, /, %
9         +, -
10         <<, >>
11         <, <=, >, >=
12         !=, ==
13         &
14         |
15         ^
16         &&
17         ||                              (Lowest)
18 */
19
20 #include <ctype.h>
21 #include <stdint.h>
22 #include <string.h>
23 #include <stdlib.h>
24 #include <stdio.h>
25 #include <errno.h>
26
27 #ifndef __ARITH_TEST__
28 #include <lib.h>
29 #endif
30
31 #define NUM_OPS         20
32 #define MAX_PRIO                11
33 #define MIN_TOK         258
34 #define TOK_PLUS                (MIN_TOK + 5)
35 #define TOK_MINUS       (MIN_TOK + 6)
36 #define TOK_NUMBER      256
37 #define TOK_STRING      257
38 #define TOK_TOTAL               20
39
40 char *inp, *prev;
41 int tok;
42 int err_val;            //skip, parse_num, eval
43 long tok_value;
44 int brackets;
45
46 char *op_table = "!@@" "~@@" "*@@" "/@@" "%@@" "+@@" "-@@" "=@@" "<@@" "<=@" "<<@" ">@@" ">=@" ">>@" "!=@"  "&@@" "|@@" "^@@" "&&@" "||@";
47 char *keyword_table = " \t\v()!~*/%+-<=>&|^";                   //Characters that cannot appear in a string
48 signed char op_prio[NUM_OPS]    = { 10, 10, 9, 9, 9, 8, 8, 6, 6, 7, 6, 6, 7, 5, 5, 4, 3, 2, 1, 0 };
49
50 /*
51         Changes:
52         1. Common function for all operators.
53         2. Common input function.
54         
55         Notes:
56         1. Better way to store operators of > 1 characters? I have tried to keep handling operators consistent.
57 */      
58         
59 static void ignore_whitespace(void);
60
61 static void input(void) {
62         char t_op[3] = { '\0', '\0', '\0'};
63         char *p1, *p2;
64         size_t len;
65         
66         if(tok == -1)
67                 return;
68
69         prev = inp;
70         ignore_whitespace();
71         
72         if(*inp != '\0') {
73                 if(isdigit(*inp)) {
74                         tok_value = 0;
75                         tok = TOK_NUMBER;
76                         tok_value = strtoul(inp, &inp, 0);
77                         return;
78                 }
79                 
80                 len = strcspn(inp, keyword_table);
81                 
82                 if(len > 0)     {
83                         char str_val[len + 1];
84                         strncpy(str_val, inp, len);
85                         str_val[len] = '\0';
86                         if(asprintf((char **)&tok_value, "%s", str_val) < 0) {
87                                 err_val = -ENOMEM;
88                         }
89                         inp += len;
90                         tok = TOK_STRING;
91                         return;
92                 }
93                 
94                 t_op[0] = *inp++;
95                 p1 = strstr(op_table, t_op);
96                 if(!p1) {
97                         tok = *t_op;
98                         return;
99                 }
100                 t_op[1] = *inp;
101                 p2 = strstr(op_table, t_op);
102                 if(!p2 || p1 == p2) {
103                         tok = MIN_TOK + (p1 - op_table)/3;
104                         return;
105                 }
106                 inp++;
107                 tok = MIN_TOK + (p2 - op_table)/3;
108         }
109         else
110                 tok = -1;
111 }
112
113 static int parse_expr(char **buffer);
114
115 static void ignore_whitespace(void) {
116         while (isspace(*inp)) {
117                 inp++;
118         }
119 }
120
121 static int accept(int ch) {
122         if (tok == ch) {
123                 input();
124                 return 1;
125         }
126         return 0;
127 }
128
129 static void skip(int ch) {
130         if (!accept(ch)) {
131                 err_val = -1;
132                 printf("expected '%c', got '%c'\n", (char)ch, (char)tok);
133         }
134 }
135
136 static int parse_num(char **buffer) {
137         long num = 0;
138         int flag = 1;
139         
140         if(tok == TOK_MINUS || tok == TOK_PLUS || tok == '(' || tok == TOK_NUMBER) {
141         
142                 if(accept(TOK_MINUS))                           //Handle -NUM and +NUM
143                         flag = -1;
144                 else if(accept(TOK_PLUS)) {}
145         
146                 if (accept('(')) {
147                         brackets++;
148                         parse_expr(buffer);
149                         if(err_val)     {
150                                 return -1;
151                         }
152                         skip(')');
153                         brackets--;
154                         if(err_val)     {
155                                 free(*buffer);
156                                 return -1;
157                         }
158                         if(flag < 0) {
159                                 if(**buffer == '-') {
160                                         **buffer = '+';
161                                 } else {
162                                         char t[strlen(*buffer) + 2];
163                                         t[0] = '-';
164                                         strcpy(t + 1, *buffer);
165                                         free(*buffer);
166                                         if(asprintf(buffer, "%s", t) < 0) {
167                                                 err_val = -ENOMEM;
168                                                 return -ENOMEM;
169                                         }
170                                 }                               
171                         }
172                         return strlen(*buffer);
173                 }
174                 if(tok == TOK_NUMBER) {
175                         num = flag * tok_value;
176                         input();
177                         if(asprintf(buffer, "%ld", num) < 0) {
178                                 err_val = -ENOMEM;
179                                 return err_val;
180                         }
181                         return strlen(*buffer);
182                 }
183                 err_val = -1;
184                 return -1;
185         }
186         
187         if (tok == TOK_STRING)  {
188                 *buffer = (char *)tok_value;
189                 input();
190                 return strlen(*buffer);
191         }
192         err_val = -1;
193         return -1;
194 }
195
196 //"!" "~" "*" "/" "%" "+" "-" "<" "<=" "<<" ">" ">=" ">>" "!=" "==" "&" "|" "^" "&&" "||";
197
198 static int eval(int op, char *op1, char *op2, char **buffer) {
199         long value;
200         
201         long lhs, rhs;
202         int flag1 = 0, flag2 = 0;
203         char *o1 = op1, *o2 = op2;
204         
205         if(op1 && *op1 == '-') {
206                 flag1 = 1;
207                 o1++;
208         }
209         if(*op2 == '-') {
210                 flag2 = 1;
211                 o2++;
212         }
213         lhs = op1 ? strtoul(o1, NULL, 0) : 0;
214         if(flag1) {
215                 lhs = -lhs;
216         }
217         rhs = strtoul(o2, NULL, 0);
218         if(flag2) {
219                 rhs = -rhs;
220         }
221         
222         switch(op)
223         {
224                 case 0:
225                         value = !rhs;
226                         break;
227                 case 1: 
228                         value = ~rhs;
229                         break;
230                 case 2: 
231                         value = lhs * rhs;
232                         break;
233                 case 3: 
234                         if(rhs != 0)
235                                 value = lhs / rhs;
236                         else {
237                                 err_val = -2;
238                         }
239                         break;
240                 case 4: 
241                         value = lhs % rhs;
242                         break;
243                 case 5:
244                         value = lhs + rhs;
245                         break;
246                 case 6: 
247                         value = lhs - rhs;
248                         break;
249                 case 7:
250                         value = !strcmp(op1, op2);
251                         break;
252                 case 8: 
253                         value = lhs < rhs;
254                         break;
255                 case 9: 
256                         value = lhs <= rhs;
257                         break;
258                 case 10: 
259                         value = lhs << rhs;
260                         break;
261                 case 11: 
262                         value = lhs > rhs;
263                         break;
264                 case 12: 
265                         value = lhs >= rhs;
266                         break;
267                 case 13: 
268                         value = lhs >> rhs;
269                         break;
270                 case 14:
271                         value = strcmp(op1, op2) ? 1 : 0;
272                         break;
273                 case 15:
274                         value = lhs & rhs;
275                         break;
276                 case 16: 
277                         value = lhs | rhs;
278                         break;
279                 case 17:
280                         value = lhs ^ rhs;
281                         break;
282                 case 18: 
283                         value = lhs && rhs;
284                         break;
285                 case 19: 
286                         value = lhs || rhs;
287                         break;
288                 
289                 default:                //This should not happen
290                         *buffer = NULL;
291                         err_val = -3;
292                         return err_val; 
293         }
294         if(asprintf(buffer, "%ld", value) < 0) {
295                 err_val = -ENOMEM;
296                 return err_val;
297         }
298         return strlen(*buffer);
299 }
300
301 static int parse_prio(int prio, char **buffer) {
302         int op;
303         char *lc, *rc;
304                 
305         if(tok < MIN_TOK || tok == TOK_MINUS || tok == TOK_PLUS) {
306                 parse_num(&lc);
307         } else {
308                 if(tok < MIN_TOK + 2) {
309                         lc = NULL;
310                 } else {
311                         err_val = -1;
312                         return -1;
313                 }
314         }
315         
316         if(err_val) {
317                 return -1;
318         }
319         while(tok != -1 && tok != ')') {
320                 long lhs;
321                 if(tok < MIN_TOK) {
322                         err_val = -1;
323                         if(lc) free(lc);
324                         return -1;
325                 }
326                 if(op_prio[tok - MIN_TOK] <= prio - (tok - MIN_TOK <= 1) ? 1 : 0) {
327                         break;
328                 }
329                 
330                 op  = tok;
331                 input();
332                 parse_prio(op_prio[op - MIN_TOK], &rc);
333                 
334                 if(err_val)     {
335                         if(lc) free(lc);
336                         return -1;
337                 }
338                 
339                 lhs = eval(op - MIN_TOK, lc, rc, buffer);
340                 free(rc);
341                 if(lc) free(lc);
342                 if(err_val) {
343                         return -1;
344                 }
345                 lc = *buffer;
346         }
347         *buffer = lc;
348         return strlen(*buffer);
349 }
350
351 static int parse_expr(char **buffer) {
352         return parse_prio(-1, buffer);
353 }
354
355 int parse_arith(char *inp_string, char **end, char **buffer) {
356         err_val = tok = 0;
357         inp = inp_string;
358         brackets = 0;
359         input();
360         *buffer = NULL;
361         
362         skip('(');
363         brackets++;
364         parse_expr(buffer);
365         if(!err_val) {
366                 skip(')');
367         }
368         
369         if(err_val)     {                       //Read till we get a ')'
370                 *end = strchr(inp, ')');
371                 if(!*end) {
372                         *end = inp;
373                 } else {
374                         (*end) += 1;
375                 }
376                 switch (err_val) {
377                         case -1:
378                                 printf("parse error\n");
379                                 break;
380                         case -2:
381                                 printf("division by 0\n");
382                                 break;
383                         case -ENOMEM:
384                                 printf("out of memory\n");
385                                 break;
386                 }
387                 return -1;
388         }
389         
390         *end = prev;
391         return strlen(*buffer);
392 }
393
394 #ifdef __ARITH_TEST__
395 int main(int argc, char *argv[]) {
396         char *ret_val;
397         int r;
398         char *tail;
399         r = parse_arith(argv[1], &tail, &ret_val);
400         if(r < 0)
401                 printf("%d  %s Tail: %s\n", r, ret_val, tail);
402         else
403                 printf("%s Tail:%s\n", ret_val, tail);
404         return 0;
405 }
406 #endif