Arithmetic parser with string comparisons
[people/lynusvaz/gpxe.git] / src / hci / arith.c
1 /* Recursive descent arithmetic calculator:
2  * Ops: !, ~                            (Highest)
3  *      *, /, %
4  *      +, -
5  *      <<, >>
6  *      <, <=, >, >=
7  *      !=, ==
8  *      &
9  *      |
10  *      ^
11  *      &&
12  *      ||                              (Lowest)
13 */
14
15 #include <ctype.h>
16 #include <stdint.h>
17 #include <string.h>
18 #include <stdlib.h>
19 #include <stdio.h>
20 #include <errno.h>
21
22 #ifndef __ARITH_TEST__
23 #include <lib.h>
24 #endif
25
26 #define NUM_OPS         20
27 #define MAX_PRIO                11
28 #define MIN_TOK         258
29 #define TOK_PLUS                (MIN_TOK + 5)
30 #define TOK_MINUS       (MIN_TOK + 6)
31 #define TOK_NUMBER      256
32 #define TOK_STRING      257
33 #define TOK_TOTAL               20
34
35 #define SEP1                    -1
36 #define SEP2                    -1, -1
37
38 #ifndef __ARITH_TEST__
39 #define EPARSE          EUNIQ_01
40 #define EDIV0                   EUNIQ_02
41 #define ENOOP           EUNIQ_03
42 #define EWRONGOP        EUNIQ_04
43 #else
44 #define EPARSE          1
45 #define EDIV0                   2
46 #define ENOOP           3
47 #define EWRONGOP        4
48 #endif
49
50 char *inp, *prev, *orig;
51 int tok;
52 int err_val;
53 union {
54         long num_value;
55         char *str_value;
56 }tok_value;
57
58 /* Here is a table of the operators */
59 const char op_table[NUM_OPS * 3 + 1] = {        '!', SEP2 ,  '~', SEP2, '*', SEP2, '/', SEP2, '%', SEP2, '+', SEP2, '-', SEP2,
60                                                 '<', SEP2, '<', '=', SEP1, '<', '<', SEP1, '>', SEP2, '>', '=', SEP1, '>', '>', SEP1,  '&', SEP2,
61                                                 '|', SEP2, '^', SEP2, '&', '&', SEP1, '|', '|', SEP1, '!', '=', SEP1, '=', '=', SEP1, '\0'
62 };
63
64 const char *keyword_table = " \t\v()!~*/%+-<=>&|^";                     /* Characters that cannot appear in a string */
65 signed const char op_prio[NUM_OPS]      = { 10, 10, 9, 9, 9, 8, 8, 6, 6, 7, 6, 6, 7, 4, 3, 2, 1, 0, 5, 5 };
66
67 static void ignore_whitespace ( void );
68 static int parse_expr ( char **buffer );
69
70 static void input ( void ) {
71         char t_op[3] = { '\0', '\0', '\0'};
72         char *p1, *p2;
73         size_t len;
74         
75         if ( tok == -1 )
76                 return;
77
78         prev = inp;
79         ignore_whitespace();
80         
81         if ( *inp == '\0' ) {
82                 tok = -1;
83                 return;
84         }
85         
86         /* Check for a number */
87         if ( isdigit ( *inp ) ) {
88                 tok = TOK_NUMBER;
89                 tok_value.num_value = strtoul ( inp, &inp, 0 );
90                 return;
91         }
92         
93         /* Check for a string */
94         len = strcspn ( inp, keyword_table );
95         
96         if ( len > 0 )  {
97                 inp += len;
98                 tok = TOK_STRING;
99                 
100                 if ( ( tok_value.str_value = malloc ( len + 1 ) ) == NULL ) {
101                         err_val = -ENOMEM;
102                         return;
103                 }
104                 strncpy ( tok_value.str_value, inp - len, len );
105                 tok_value.str_value[len] = 0;
106                 return;
107         }
108         
109         /* Check for an operator */
110         t_op[0] = *inp++;
111         p1 = strstr ( op_table, t_op );
112         if ( !p1 ) {                    /* The character is not present in the list of operators */
113                 tok = *t_op;
114                 return;
115         }
116         
117         t_op[1] = *inp;
118         p2 = strstr ( op_table, t_op );
119         if ( !p2 || p1 == p2 ) {
120                 if ( ( p1 - op_table ) % 3 ) {                  /* Without this, it would take '=' as '<=' */
121                         tok = *t_op;
122                         return;
123                 }
124                 
125                 tok = MIN_TOK + ( p1 - op_table ) / 3;
126                 return;
127         }
128         inp++;
129         tok = MIN_TOK + ( p2 - op_table ) / 3;
130 }
131
132 /* Check if a string is a number: "-1" and "+42" is accepted, but not "-1a". If so, place it in num and return 1 else num = 0 and return 0 */
133 static int isnum ( char *string, long *num ) {
134         int flag = 0;
135         
136         *num = 0;
137         if ( *string == '+' ) {
138                 string++;
139         } else if ( *string == '-' ) {
140                 flag = 1;
141                 string++;
142         }
143         *num = strtoul ( string, &string, 0 );
144         if ( *string != 0 ) {
145                 *num = 0;
146                 return 0;
147         }
148         if ( flag )
149                 *num = -*num;
150         return 1;
151 }
152
153 static void ignore_whitespace ( void ) {
154         while ( isspace ( *inp ) ) {
155                 inp++;
156         }
157 }
158
159 static int accept ( int ch ) {
160         if ( tok == ch ) {
161                 input();
162                 return 1;
163         }
164         return 0;
165 }
166
167 static void skip ( int ch ) {
168         if ( !accept ( ch ) ) {
169                 err_val = -EPARSE;
170         }
171 }
172
173 static int parse_num ( char **buffer ) {
174         long num = 0;
175         int flag = 0;
176
177         if ( tok == TOK_MINUS || tok == TOK_PLUS || tok == '(' || tok == TOK_NUMBER ) {
178         
179                 if ( accept ( TOK_MINUS ) )                             //Handle -NUM and +NUM
180                         flag = 1;
181                 else if ( accept ( TOK_PLUS ) ) {
182                 }
183                         
184                 if ( accept ( '(' ) ) {
185                         parse_expr ( buffer );
186                         if ( err_val )  {
187                                 return ( err_val );
188                         }
189                         skip ( ')' );
190                         if ( err_val )  {
191                                 free ( *buffer );
192                                 return (err_val );
193                         }
194                         if ( flag ) {
195                                 int t;
196                                 t = isnum ( *buffer, &num );
197                                 free ( *buffer );
198                                 if ( t == 0 ) {         /* Trying to do a -string, which should not be permitted */
199                                         return ( err_val = -EWRONGOP );
200                                 }
201                                 return ( ( asprintf ( buffer, "%ld", -num )  < 0 ) ? (err_val = -ENOMEM ) : 0 );
202                         }
203                         return 0;
204                 }
205                 
206                 if ( tok == TOK_NUMBER ) {
207                         if ( flag )
208                                 num = -tok_value.num_value;
209                         else
210                                 num = tok_value.num_value;
211                         input();
212                         return ( ( asprintf ( buffer, "%ld", num ) < 0) ? (err_val = -ENOMEM ) : 0 );
213                 }
214                 return ( err_val = -EPARSE );
215         }
216         if ( tok == TOK_STRING ) {
217                 *buffer = tok_value.str_value;
218                 input();
219                 return 0;
220         }
221         return ( err_val = -EPARSE );
222 }
223
224 static int eval(int op, char *op1, char *op2, char **buffer) {
225         long value;
226         int bothints = 1;
227         long lhs, rhs;
228         
229         if ( op1 ) {
230                 if ( ! isnum ( op1, &lhs ) ) 
231                         bothints = 0;
232         }
233         
234         if ( ! isnum (op2, &rhs ) )
235                 bothints = 0;
236         
237         if ( op <= 17 && ! bothints ) {
238                 return ( err_val = -EWRONGOP );
239         }
240         
241         switch(op)
242         {
243                 case 0:
244                         value = !rhs;
245                         break;
246                 case 1: 
247                         value = ~rhs;
248                         break;
249                 case 2: 
250                         value = lhs * rhs;
251                         break;
252                 case 3: 
253                         if(rhs != 0)
254                                 value = lhs / rhs;
255                         else {
256                                 return ( err_val = -EDIV0 );
257                         }
258                         break;
259                 case 4: 
260                         value = lhs % rhs;
261                         break;
262                 case 5:
263                         value = lhs + rhs;
264                         break;
265                 case 6: 
266                         value = lhs - rhs;
267                         break;
268                 case 7: 
269                         value = lhs < rhs;
270                         break;
271                 case 8: 
272                         value = lhs <= rhs;
273                         break;
274                 case 9: 
275                         value = lhs << rhs;
276                         break;
277                 case 10: 
278                         value = lhs > rhs;
279                         break;
280                 case 11: 
281                         value = lhs >= rhs;
282                         break;
283                 case 12: 
284                         value = lhs >> rhs;
285                         break;
286                 case 13:
287                         value = lhs & rhs;
288                         break;
289                 case 14: 
290                         value = lhs | rhs;
291                         break;
292                 case 15:
293                         value = lhs ^ rhs;
294                         break;
295                 case 16: 
296                         value = lhs && rhs;
297                         break;
298                 case 17: 
299                         value = lhs || rhs;
300                         break;
301                 case 18:
302                         value = strcmp ( op1, op2 ) != 0;
303                         break;
304                 case 19:
305                         value = strcmp ( op1, op2 ) == 0;
306                         break;
307                 default:                /* This means the operator is in the op_table, but not defined in this switch statement */
308                         return ( err_val = -ENOOP );
309         }
310         return ( ( asprintf(buffer, "%ld", value)  < 0) ? ( err_val = -ENOMEM ) : 0 );
311 }
312
313 static int parse_prio(int prio, char **buffer) {
314         int op;
315         char *lc, *rc;
316                 
317         if ( tok < MIN_TOK || tok == TOK_MINUS || tok == TOK_PLUS ) {
318                 parse_num ( &lc );
319         } else {
320                 if ( tok < MIN_TOK + 2 ) {
321                         lc = NULL;
322                 } else {
323                         return ( err_val = -EPARSE );
324                 }
325         }
326         
327         if ( err_val ) {
328                 return err_val;
329         }
330         while( tok != -1 && tok != ')' ) {
331                 long lhs;
332                 if ( tok < MIN_TOK ) {
333                         if ( lc )
334                                 free(lc);
335                         return ( err_val = -EPARSE );
336                 }
337                 if ( op_prio[tok - MIN_TOK] <= prio - ( tok - MIN_TOK <= 1 ) ? 1 : 0 ) {
338                         break;
339                 }
340                 
341                 op  = tok;
342                 input();
343                 parse_prio ( op_prio[op - MIN_TOK], &rc );
344                 
345                 if ( err_val )  {
346                         if ( lc )
347                                 free ( lc );
348                         return err_val;
349                 }
350                 
351                 lhs = eval ( op - MIN_TOK, lc, rc, buffer );
352                 free ( rc );
353                 if ( lc )
354                         free ( lc );
355                 if ( err_val ) {
356                         return err_val;
357                 }
358                 lc = *buffer;
359         }  
360         *buffer = lc;
361         return 0;
362 }
363
364 static int parse_expr ( char **buffer ) {
365         return parse_prio ( -1, buffer );
366 }
367
368 int parse_arith ( char *inp_string, char **end, char **buffer ) {
369         err_val = tok = 0;
370         orig = inp = inp_string;
371         input();
372         *buffer = NULL;
373         
374         skip ( '(' );
375         parse_expr ( buffer );
376         
377         if ( !err_val ) {
378                 
379                 if ( tok != ')' ) {
380                         free ( *buffer );
381                         err_val = -EPARSE;
382                 }
383         }
384         *end = inp;
385         if ( err_val )  {                       //Read till we get a ')'
386                 
387                 if ( tok == TOK_STRING )
388                         free ( tok_value.str_value );
389                 switch ( err_val ) {
390                         case -EPARSE:
391                                 printf ( "parse error:\n%s\n", orig );
392                                 break;
393                         case -EDIV0:
394                                 printf ( "division by 0\n" );
395                                 break;
396                         case -ENOOP:
397                                 printf ( "operator undefined\n" );
398                                 break;
399                         case -EWRONGOP:
400                                 printf ( "wrong type of operand\n" );
401                                 break;
402                         case -ENOMEM:
403                                 printf("out of memory\n");
404                                 break;
405                 }
406                 return - ( EINVAL | -err_val );
407         }
408         
409         return 0;
410 }
411
412 #ifdef __ARITH_TEST__
413 int main ( int argc, char *argv[] ) {
414         char *ret_val;
415         int r = 0;
416         char *head, *tail, *string, *t;
417         char line[100];
418         
419         while ( !feof ( stdin ) ) {
420                 fgets ( line, 100, stdin );
421                 if ( line[strlen ( line ) - 1] == '\n' )
422                         line[strlen ( line ) - 1] = 0;
423                 asprintf ( &string, "%s", line );
424                 while ( ( head = strstr ( string, "$(" ) ) != NULL ) {
425                         *head++ = 0;
426                         r = parse_arith ( head, &tail, &ret_val );
427                         t = string;
428                         if ( r == 0 ) {
429                                 asprintf ( &string, "%s%s%s", string, ret_val, tail );
430                                 free ( ret_val );
431                         } else
432                                 break;
433                         free ( t );
434                 }
435                 if ( r == 0 ) {
436                         printf ( "Line: %s\n", string );
437                 }
438                 free ( string );
439         }
440         return 0;
441 }
442 #endif