/* * Performs type binding and type checking of expressions. * * Copyright © 2025 Samuel Lidén Borell * * SPDX-License-Identifier: EUPL-1.2+ OR LGPL-2.1-or-later */ #include #include #include "compiler.h" #include "out.h" struct SignedNum { uint64_t num; bool neg; }; static int max_subexpr_id(struct Expr *expr) { int max = -1; for (; expr != NULL; expr = expr->rpnnext) { assert(expr->id != max); /* Can't have duplicates! */ if (expr->id > max) { max = expr->id; } } assert(max >= 0); return max; } bool is_const(const struct TypeRefNumeric *range) { return range->min == range->max && range->min_neg == range->max_neg; } bool is_expr_const(const struct Expr *expr) { const struct TypeRef *tr = expr->typeref; assert(tr != NULL); switch (tr->kind) { case TR_BOOL: case TR_INT: return is_const(tr->u.num); case TR_UNKNOWN: case TR_CLASS: return false; } unreachable(); return false; } static bool is_intbool_literal(const struct Expr *expr) { switch ((int)expr->kind) { case E_FALSE: case E_TRUE: case E_INTEGER: return true; } return false; } /* uint64_t min, max; unsigned min_neg : 1; unsigned max_neg : 1; unsigned maybe_zero : 1; */ static const struct TypeRefNumeric num_zero = { 0, 0, 0,0,1 }; static const struct TypeRefNumeric num_one = { 1, 1, 0,0,0 }; const struct TypeRefNumeric range_bool = { 0, 1, 0,0,1 }; static const struct TypeRefNumeric *make_int_range( uint64_t min, uint64_t max, bool min_neg, bool max_neg) { struct TypeRefNumeric *range; if (!min_neg && !max_neg) { if (min == max) { if (min == 0) return &num_zero; if (min == 1) return &num_one; } else if (min == 0 && max == 1) { return &range_bool; } } range = malloc(sizeof(struct TypeRefNumeric)); NO_NULL(range); range->min = min; range->max = max; range->min_neg = min_neg && min != 0 ? 1 : 0; range->max_neg = max_neg && max != 0 ? 1 : 0; range->maybe_zero = min_neg != max_neg || min == 0; return range; } static const struct TypeRefNumeric *make_int_range_res( const struct SignedNum *minval, const struct SignedNum *maxval) { return make_int_range(minval->num, maxval->num, minval->neg, maxval->neg); } static void require_type(const struct Expr *expr, enum TypeRefKind kind, const char *msg) { assert(expr->typeref != NULL); if (expr->typeref->kind == TR_UNKNOWN) { /* TODO TR_UNKNOWN means unimplemented. E_IDENT...E_CALL are unimplemented */ return; } if (expr->typeref->kind != kind) { /* TODO set source filename and line somewhere */ error(msg); } } static void require_bool(const struct Expr *expr) { require_type(expr, TR_BOOL, "Expected a bool operand"); } static void require_int(const struct Expr *expr) { require_type(expr, TR_INT, "Expected an integer operand"); } static bool is_integer(const struct Expr *expr) { assert(expr->typeref != NULL); return expr->typeref->kind == TR_INT; } static uint64_t checked_add(uint64_t a, uint64_t b) { uint64_t r = a + b; if (r < a) { error("Addition might overflow (even disregarding the type)"); } return r; } static struct SignedNum range_add( uint64_t a, uint64_t b, unsigned _a_neg, unsigned _b_neg) { uint64_t num; bool neg; struct SignedNum res; bool a_neg = (bool)_a_neg; bool b_neg = (bool)_b_neg; if (a_neg == b_neg) { neg = a_neg; num = checked_add(a, b); } else if (a > b) { neg = a_neg; num = a - b; } else { neg = b_neg; num = b - a; } res.num = num; res.neg = neg; return res; } static struct SignedNum checked_mul( uint64_t a, uint64_t b, unsigned a_neg, unsigned b_neg) { uint64_t num; struct SignedNum res; num = a * b; if (b != 0 && num/b != a) { error("Multiplication might overflow (even disregarding the type)"); } res.num = num; res.neg = (bool)(a_neg ^ b_neg); return res; } static struct SignedNum num_div( uint64_t a, uint64_t b, unsigned a_neg, unsigned b_neg) { struct SignedNum res; assert(b != 0); res.num = a / b; res.neg = (bool)(a_neg ^ b_neg); return res; } static struct SignedNum num_mod( uint64_t a, uint64_t b, unsigned a_neg, unsigned b_neg) { struct SignedNum res; assert(b != 0); assert(a_neg <= 1); assert(b_neg <= 1); if (a_neg == b_neg) { res.num = a % b; } else { res.num = b - (a % b); if (res.num == b) { res.num = 0; } } res.neg = (bool)b_neg; return res; } /** Computes minimum and maximum bounds of a modulus operations. Takes some common cases into account, but not all. */ static void range_mod( const struct TypeRefNumeric *a, const struct TypeRefNumeric *b, struct SignedNum *rmin, struct SignedNum *rmax) { assert(rmin != rmax); /* Minimum bound can be `0` or `-(b-1)` */ if (b->min_neg) { rmin->neg = true; rmin->num = b->min - 1; } else { rmin->neg = false; rmin->num = 0; } /* Maximum bound can be `0` or `b-1` */ if (b->max_neg) { rmax->neg = false; rmax->num = 0; } else { rmax->neg = false; rmax->num = b->max - 1; } /* If `abs(a) < abs(b)`, then we can narrow down the range further */ if (a->min_neg == a->max_neg && is_const(b) && a->min < b->max && a->max < b->max) { if (a->min_neg == b->min_neg) { rmin->neg = rmax->neg = b->min_neg; rmin->num = a->min; rmax->num = a->max; } else if (a->min != 0 && a->max != 0) { uint64_t denom = b->min; rmin->neg = rmax->neg = b->min_neg; rmin->num = denom - a->max; rmax->num = denom - a->min; } } } static int compare_num(const struct SignedNum *a, const struct SignedNum *b) { bool invert; if (a->neg) { if (!b->neg) return -1; invert = true; } else { if (b->neg) return 1; invert = false; } if (a->num == b->num) return 0; else if (a->num < b->num) return invert ? 1 : -1; else return invert ? -1 : 1; } static void compare_nums( const struct SignedNum *n, struct SignedNum *rmin, struct SignedNum *rmax) { if (compare_num(n, rmin) < 0) { *rmin = *n; } if (compare_num(n, rmax) > 0) { *rmax = *n; } } static void sort_nums( const struct SignedNum *a, const struct SignedNum *b, const struct SignedNum *c, const struct SignedNum *d, struct SignedNum *rmin, struct SignedNum *rmax) { rmin->num = UINT64_MAX; rmin->neg = false; rmax->num = UINT64_MAX; rmax->neg = true; compare_nums(a, rmin, rmax); compare_nums(b, rmin, rmax); compare_nums(c, rmin, rmax); compare_nums(d, rmin, rmax); } static void typeref_free(struct TypeRef *tr) { struct TypeRefNumeric *num; if (!tr) return; if (tr->kind == TR_INT || tr->kind == TR_BOOL) { num = (struct TypeRefNumeric *)tr->u.num; if (num != &num_zero && num != &num_one && num != &range_bool) { free(num); } } free(tr); } static void expr_free(struct Expr *expr) { if (!expr) return; typeref_free(expr->typeref); free(expr); } /** Checks if the subexpression is compile-time constant, and if so, replaces it with a single Expr node with a literal value. */ static void eliminate_if_const( struct Expr **expr_ptr, struct TypeRef *tr, struct Expr **id_to_expr, struct Expr *opnd1, struct Expr *opnd2) { struct Expr *e; assert(opnd1 != NULL); assert(opnd1 != opnd2); assert(tr != NULL); assert(expr_ptr != NULL); if (!is_intbool_literal(opnd1)) return; if (opnd2 && !is_intbool_literal(opnd2)) return; e = *expr_ptr; assert(e != opnd1 && e != opnd2); if (e->kind == E_BOOL_AND || e->kind == E_BOOL_OR) { /* TODO Handle short-circuiting operators also. That's trickier, because it also requires handling of the sequence points. There are three cases to take into account then: - only opnd1 constant - only opnd2 constant - both operands are constants */ return; } /* If we go here, then constant evaluation has been done, and just needs to be finished up. */ assert(tr->kind == TR_INT || tr->kind == TR_BOOL); assert(is_const(tr->u.num)); /* Put evaluated value in opnd1, and discard everything else */ if (!opnd2) { assert(opnd1->rpnnext == e); } else { assert(opnd1->rpnnext == opnd2); assert(opnd2->rpnnext == e); id_to_expr[opnd2->id] = NULL; } id_to_expr[opnd1->id] = NULL; id_to_expr[e->id] = opnd1; typeref_free(opnd1->typeref); opnd1->kind = ( tr->kind == TR_INT ? E_INTEGER : tr->u.num->min == 0 ? E_FALSE : E_TRUE); opnd1->id = e->id; opnd1->typeref = NULL; /* replaced with tr last in typecheck_expr loop */ opnd1->rpnnext = e->rpnnext; expr_free(e); expr_free(opnd2); *expr_ptr = opnd1; } void typecheck_expr(const struct TypeRef *typeref, struct Expr *expr) { size_t id_max; struct Expr *e, **id_to_expr; const struct TypeRef *last_typeref; int last_id; /* TODO the id_to_expr stuff causes the expr to be scanned 3 times instead of 1, and could be optimized. */ assert(expr != NULL); id_max = (size_t)max_subexpr_id(expr); id_to_expr = calloc(id_max+1, sizeof(struct Expr *)); NO_NULL(id_to_expr); for (e = expr; e != NULL; e = e->rpnnext) { assert(e->id >= 0); assert(e->typeref == NULL); id_to_expr[e->id] = e; } last_id = 0; last_typeref = NULL; /* TODO varstates and integer ranges is tricky: - can set "allowed_min" and "allowed_max" values in expressions with only one variable, e.g. in `byte b = x + 200` the allowed_max of x is 55. - but what about: `byte b = x + y` or `byte b = x + y + z + w` - and what about: `byte b = 3 * (2 + 1/x)` - solution 1: - allow only one fully free variable. - set allowed_min=min and allowed_max=max on all other variables? - might be really hard for convoluted exprs. - solution 2: - infer allowed_min/allowed_max from `assert` and some basic operations, e.g. `byte b = x + 1` - still tricky. - solution 3 (best?): - have a special "persistent assert" that adds a constraint for the variable for the following code. `assert always x > 0` - this can set the allowed_min,allowed_max values on the variable. - for immutable variables, this isn't needed at all. (still needs varstate tracking for branches) - differening (or missing) `assert always` in branches should be forbidden. But an `assert always` can always be widened in the more restrictive branch(es). there's the same issue with optional values. - how to specify that the reference itself is mutable? - keyword choice / bikeshedding problem... */ for (e = expr; e != NULL; e = e->rpnnext) { struct TypeRef tr; tr.kind = TR_UNKNOWN; switch (e->kind) { case E_GROUP_TEMP: case E_SEQPOINT: e->typeref = id_to_expr[last_id]->typeref; break; case E_NONE: tr.kind = TR_CLASS; tr.quals = 0; /* TODO set qualifiers */ tr.u.class_ = NULL; break; case E_FALSE: tr.kind = TR_BOOL; tr.quals = 0; tr.u.num = &num_zero; break; case E_TRUE: tr.kind = TR_BOOL; tr.quals = 0; tr.u.num = &num_one; break; case E_INTEGER: tr.kind = TR_INT; tr.quals = 0; tr.u.num = make_int_range(e->u.intval.num, e->u.intval.num, 0, 0); break; case E_STRING: tr.kind = TR_CLASS; tr.quals = 0; tr.quals = 0; /* TODO set qualifiers */ tr.u.class_ = NULL; /* TODO */ break; case E_IDENT: { struct Var *var; if (e->u.ident.namelen != 0) { /* TODO typeidents */ ast_error("typeidents are unimplemented"); } else { var = e->u.ident.u.var; /* TODO varstate tracking */ } tr = *var->typeref; goto force_set_typeref; } case E_MEMBER: case E_ARRAY: case E_CALL: /* TODO implement these. the code below just sets a dummy type */ tr.kind = TR_UNKNOWN; tr.quals = 0; tr.u.num = &num_one; goto force_set_typeref; case E_NEGATE: { struct Expr *opnd = id_to_expr[last_id]; const struct TypeRefNumeric *opndnum; require_int(opnd); tr.kind = TR_INT; tr.quals = 0; opndnum = opnd->typeref->u.num; tr.u.num = make_int_range( opndnum->max, opndnum->min, !opndnum->max_neg, !opndnum->min_neg); eliminate_if_const(&e, &tr, id_to_expr, opnd, NULL); break; } case E_BOOL_NOT: { struct Expr *opnd = id_to_expr[last_id]; const struct TypeRefNumeric *opndnum; require_bool(opnd); tr.kind = TR_BOOL; tr.quals = 0; opndnum = opnd->typeref->u.num; if (is_const(opndnum)) { tr.u.num = opndnum->min ? &num_zero : &num_one; eliminate_if_const(&e, &tr, id_to_expr, opnd, NULL); } else { tr.u.num = &range_bool; } break; } case E_ADD: case E_SUB: case E_MUL: case E_DIV: case E_MOD: { struct Expr *opnd1 = id_to_expr[e->u.binary.left_id]; struct Expr *opnd2 = id_to_expr[last_id]; const struct TypeRefNumeric *a, *b; /* Result min/max values (i.e. range) */ struct SignedNum rmin; struct SignedNum rmax; require_int(opnd1); require_int(opnd2); a = opnd1->typeref->u.num; b = opnd2->typeref->u.num; switch ((int)e->kind) { case E_ADD: case E_SUB: { unsigned b_min_neg = b->min_neg, b_max_neg = b->max_neg; if (e->kind == E_SUB) { b_min_neg = !b_min_neg; b_max_neg = !b_max_neg; } rmin = range_add(a->min, b->min, a->min_neg, b_min_neg); rmax = range_add(a->max, b->max, a->max_neg, b_max_neg); break; } case E_MUL: { struct SignedNum ll, lh, hl, hh; ll = checked_mul(a->min, b->min, a->min_neg, b->min_neg); lh = checked_mul(a->min, b->max, a->min_neg, b->max_neg); hl = checked_mul(a->max, b->min, a->max_neg, b->min_neg); hh = checked_mul(a->max, b->max, a->max_neg, b->max_neg); sort_nums(&ll, &lh, &hl, &hh, &rmin, &rmax); break; } case E_DIV: { struct SignedNum ll, lh, hl, hh; if (b->maybe_zero) { error("Possibility of division by zero"); } ll = num_div(a->min, b->min, a->min_neg, b->min_neg); lh = num_div(a->min, b->max, a->min_neg, b->max_neg); hl = num_div(a->max, b->min, a->max_neg, b->min_neg); hh = num_div(a->max, b->max, a->max_neg, b->max_neg); sort_nums(&ll, &lh, &hl, &hh, &rmin, &rmax); break; } case E_MOD: if (b->maybe_zero) { error("Possibility of division by zero"); } assert(b->max != 0); if (is_const(a) && is_const(b)) { rmin = rmax = num_mod(a->min, b->min, a->min_neg, b->min_neg); } else if (is_const(a) && a->min == 0) { rmin.num = 0; rmin.neg = false; rmax = rmin; } else { range_mod(a, b, &rmin, &rmax); } break; default: unreachable(); } tr.kind = TR_INT; tr.quals = 0; tr.u.num = make_int_range_res(&rmin, &rmax); eliminate_if_const(&e, &tr, id_to_expr, opnd1, opnd2); break; } case E_EQUAL: case E_NOT_EQUAL: if (!is_integer(id_to_expr[e->u.binary.left_id])) { struct Expr *opnd1 = id_to_expr[e->u.binary.left_id]; struct Expr *opnd2 = id_to_expr[last_id]; tr.kind = TR_BOOL; tr.quals = 0; if (is_expr_const(opnd1) && is_expr_const(opnd2) && opnd1->typeref->kind == TR_BOOL && opnd2->typeref->kind == TR_BOOL) { bool result = (opnd1->typeref->u.num->min == opnd2->typeref->u.num->min); if (e->kind == E_NOT_EQUAL) result = !result; tr.u.num = result ? &num_one : &num_zero; } else { tr.u.num = &range_bool; } check_type_compat(opnd1->typeref, opnd2->typeref, TC_COMPARE); eliminate_if_const(&e, &tr, id_to_expr, opnd1, opnd2); break; } /* Fall through */ case E_LESS: case E_GREATER: case E_LESS_EQUAL: case E_GREATER_EQUAL: { struct Expr *opnd1 = id_to_expr[e->u.binary.left_id]; struct Expr *opnd2 = id_to_expr[last_id]; const struct TypeRefNumeric *a, *b; require_int(opnd1); require_int(opnd2); a = opnd1->typeref->u.num; b = opnd2->typeref->u.num; if (is_const(a) && is_const(b)) { struct SignedNum an, bn; int cmp; bool res; an.num = a->min; an.neg = a->min_neg; bn.num = b->min; bn.neg = b->min_neg; cmp = compare_num(&an, &bn); switch ((int)e->kind) { case E_EQUAL: res = (cmp == 0); break; case E_NOT_EQUAL: res = (cmp != 0); break; case E_LESS: res = (cmp < 0); break; case E_GREATER: res = (cmp > 0); break; case E_LESS_EQUAL: res = (cmp <= 0); break; case E_GREATER_EQUAL: res = (cmp >= 0); break; default: unreachable(); } tr.kind = TR_BOOL; tr.quals = 0; tr.u.num = res ? &num_one : &num_zero; eliminate_if_const(&e, &tr, id_to_expr, opnd1, opnd2); } else { tr.kind = TR_BOOL; tr.quals = 0; tr.u.num = &range_bool; /* Detect always-out-of-range cases */ check_type_compat(opnd1->typeref, opnd2->typeref, TC_COMPARE); } break; } case E_BOOL_AND: case E_BOOL_OR: { struct Expr *opnd1 = id_to_expr[e->u.binary.left_id]; struct Expr *opnd2 = id_to_expr[last_id]; const struct TypeRefNumeric *opnd1num, *opnd2num; require_bool(opnd1); require_bool(opnd2); opnd1num = opnd1->typeref->u.num; opnd2num = opnd2->typeref->u.num; if (is_const(opnd1num) && is_const(opnd2num)) { if (e->kind == E_BOOL_AND) { /* If first is true, take second value (might be false), otherwise, take first value (which is false). */ e->typeref = opnd1num->min ? opnd2->typeref : opnd1->typeref; } else { e->typeref = opnd1num->min ? opnd1->typeref : opnd2->typeref; } /* TODO add const evaluation of and/or */ } else { tr.kind = TR_BOOL; tr.quals = 0; tr.u.num = &range_bool; } break; } case E_ASSIGN: case E_ASSIGN_FINAL: { struct Expr *opnd_target = id_to_expr[e->u.binary.left_id]; struct Expr *opnd_source = id_to_expr[last_id]; assert(opnd_source->typeref == last_typeref); e->typeref = opnd_source->typeref; check_type_compat(opnd_target->typeref, opnd_source->typeref, TC_ASSIGN); break; } default: unreachable(); } if (tr.kind != TR_UNKNOWN) { force_set_typeref: assert(e->typeref == NULL); e->typeref = malloc(sizeof(struct TypeRef)); NO_NULL(e->typeref); *e->typeref = tr; } else { assert(e->typeref != NULL); }; last_typeref = e->typeref; last_id = e->id; } free(id_to_expr); if (typeref) { check_type_compat(typeref, last_typeref, TC_ASSIGN); } }