From e1dfb520952a27888f518dae47ce50ea87fc5655 Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Tue, 24 Nov 2020 11:27:23 +0100 Subject: [PATCH] Added ints and operators to default calculus --- src/catala/default_calculus/ast.ml | 10 ++++- src/catala/default_calculus/interpreter.ml | 47 +++++++++++++++++++--- src/catala/default_calculus/typing.ml | 30 ++++++++++++-- src/catala/desugared/ast.ml | 2 +- 4 files changed, 79 insertions(+), 10 deletions(-) diff --git a/src/catala/default_calculus/ast.ml b/src/catala/default_calculus/ast.ml index eb57a3de..2c142fa4 100644 --- a/src/catala/default_calculus/ast.ml +++ b/src/catala/default_calculus/ast.ml @@ -17,10 +17,17 @@ module Pos = Utils.Pos type typ = | TBool | TUnit + | TInt | TTuple of typ Pos.marked list | TArrow of typ Pos.marked * typ Pos.marked -type lit = LTrue | LFalse | LEmptyError +type lit = LBool of bool | LEmptyError | LInt of Int64.t + +type binop = And | Or | Add | Sub | Mult | Div | Lt | Lte | Gt | Gte | Eq | Neq + +type unop = Not | Minus + +type operator = Binop of binop | Unop of unop type expr = | EVar of expr Pos.marked Bindlib.var @@ -29,6 +36,7 @@ type expr = | ELit of lit | EAbs of Pos.t * (expr Pos.marked, expr Pos.marked) Bindlib.mbinder * typ list | EApp of expr Pos.marked * expr Pos.marked list + | EOp of operator | EDefault of expr Pos.marked * expr Pos.marked * expr Pos.marked list | EIfThenElse of expr Pos.marked * expr Pos.marked * expr Pos.marked diff --git a/src/catala/default_calculus/interpreter.ml b/src/catala/default_calculus/interpreter.ml index 20b11780..6d2dbaef 100644 --- a/src/catala/default_calculus/interpreter.ml +++ b/src/catala/default_calculus/interpreter.ml @@ -19,6 +19,42 @@ module A = Ast let is_empty_error (e : A.expr Pos.marked) : bool = match Pos.unmark e with ELit LEmptyError -> true | _ -> false +let evaluate_operator (op : A.operator Pos.marked) (args : A.expr Pos.marked list) : + A.expr Pos.marked = + Pos.same_pos_as + ( match (Pos.unmark op, List.map Pos.unmark args) with + | A.Binop A.And, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 && b2)) + | A.Binop A.Or, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 || b2)) + | A.Binop A.Add, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt (Int64.add i1 i2)) + | A.Binop A.Sub, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt (Int64.sub i1 i2)) + | A.Binop A.Mult, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt (Int64.mul i1 i2)) + | A.Binop A.Div, [ ELit (LInt i1); ELit (LInt i2) ] -> + if i2 <> Int64.zero then A.ELit (LInt (Int64.div i1 i2)) + else + Errors.raise_multispanned_error "division by zero at runtime" + [ + (Some "The division operator:", Pos.get_position op); + (Some "The null denominator:", Pos.get_position (List.nth args 2)); + ] + | A.Binop A.Lt, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 < i2)) + | A.Binop A.Lte, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 <= i2)) + | A.Binop A.Gt, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 > i2)) + | A.Binop A.Gte, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 >= i2)) + | A.Binop A.Eq, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 = i2)) + | A.Binop A.Eq, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 = b2)) + | A.Binop A.Eq, [ _; _ ] -> A.ELit (LBool false) (* comparing functions return false *) + | A.Binop A.Neq, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 <> i2)) + | A.Binop A.Neq, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 <> b2)) + | A.Binop A.Neq, [ _; _ ] -> A.ELit (LBool true) + | A.Unop A.Not, [ ELit (LBool b) ] -> A.ELit (LBool (not b)) + | A.Unop A.Minus, [ ELit (LInt i) ] -> A.ELit (LInt (Int64.sub Int64.zero i)) + | _ -> + Errors.raise_multispanned_error + "operator applied to the wrong arguments (should not happen if the term was well-typed)" + [ (Some "Operator:", Pos.get_position op) ] + @@ List.mapi (fun i arg -> Some ("Argument n°" ^ string_of_int i, Pos.get_position arg)) ) + op + let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked = match Pos.unmark e with | EVar _ -> @@ -38,13 +74,14 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked = (Format.asprintf "wrong function call, expected %d arguments, got %d" (Bindlib.mbinder_arity binder) (List.length args)) (Pos.get_position e) + | EOp op -> evaluate_operator (Pos.same_pos_as op e1) args | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e | _ -> Errors.raise_spanned_error "function has not been reduced to a lambda at evaluation (should not happen if the \ term was well-typed" (Pos.get_position e) ) - | EAbs _ | ELit _ -> e (* thse are values *) + | EAbs _ | ELit _ | EOp _ -> e (* thse are values *) | ETuple es -> Pos.same_pos_as (A.ETuple (List.map evaluate_expr es)) e | ETupleAccess (e1, n) -> ( let e1 = evaluate_expr e1 in @@ -70,8 +107,8 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked = let just = evaluate_expr just in match Pos.unmark just with | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e - | ELit LTrue -> evaluate_expr cons - | ELit LFalse -> ( + | ELit (LBool true) -> evaluate_expr cons + | ELit (LBool false) -> ( let subs = List.map evaluate_expr subs in let empty_count = List.length (List.filter is_empty_error subs) in match List.length subs - empty_count with @@ -96,8 +133,8 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked = (Pos.get_position e) ) | EIfThenElse (cond, et, ef) -> ( match Pos.unmark (evaluate_expr cond) with - | ELit LTrue -> evaluate_expr et - | ELit LFalse -> evaluate_expr ef + | ELit (LBool true) -> evaluate_expr et + | ELit (LBool false) -> evaluate_expr ef | _ -> Errors.raise_spanned_error "expected a boolean literal for the result of this condition (should not happen if the \ diff --git a/src/catala/default_calculus/typing.ml b/src/catala/default_calculus/typing.ml index d92091c9..6163e1f2 100644 --- a/src/catala/default_calculus/typing.ml +++ b/src/catala/default_calculus/typing.ml @@ -21,6 +21,7 @@ module A = Ast type typ = | TUnit + | TInt | TBool | TArrow of typ Pos.marked UnionFind.elem * typ Pos.marked UnionFind.elem | TTuple of typ Pos.marked UnionFind.elem list @@ -31,6 +32,7 @@ let rec format_typ (fmt : Format.formatter) (ty : typ Pos.marked UnionFind.elem) match Pos.unmark ty_repr with | TUnit -> Format.fprintf fmt "unit" | TBool -> Format.fprintf fmt "bool" + | TInt -> Format.fprintf fmt "int" | TAny -> Format.fprintf fmt "any" | TTuple ts -> Format.fprintf fmt "(%a)" @@ -42,7 +44,7 @@ let rec unify (t1 : typ Pos.marked UnionFind.elem) (t2 : typ Pos.marked UnionFin let t1_repr = UnionFind.get (UnionFind.find t1) in let t2_repr = UnionFind.get (UnionFind.find t2) in match (t1_repr, t2_repr) with - | (TUnit, _), (TUnit, _) | (TBool, _), (TBool, _) -> () + | (TUnit, _), (TUnit, _) | (TBool, _), (TBool, _) | (TInt, _), (TInt, _) -> () | (TArrow (t11, t12), _), (TArrow (t21, t22), _) -> unify t11 t21; unify t12 t22 @@ -60,10 +62,25 @@ let rec unify (t1 : typ Pos.marked UnionFind.elem) (t2 : typ Pos.marked UnionFin (Some (Format.asprintf "Type %a coming from expression:" format_typ t2), t2_pos); ] +let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem = + let pos = Pos.get_position op in + let bt = UnionFind.make (TBool, pos) in + let it = UnionFind.make (TInt, pos) in + let any = UnionFind.make (TAny, pos) in + let arr x y = UnionFind.make (TArrow (x, y), pos) in + match Pos.unmark op with + | A.Binop (A.And | A.Or) -> arr bt (arr bt bt) + | A.Binop (A.Add | A.Sub | A.Mult | A.Div) -> arr it (arr it it) + | A.Binop (A.Lt | A.Lte | A.Gt | A.Gte) -> arr it (arr it bt) + | A.Binop (A.Eq | A.Neq) -> arr any (arr any bt) + | A.Unop A.Minus -> arr it it + | A.Unop A.Not -> arr bt bt + let rec ast_to_typ (ty : A.typ) : typ = match ty with | A.TUnit -> TUnit | A.TBool -> TBool + | A.TInt -> TInt | A.TArrow (t1, t2) -> TArrow ( UnionFind.make (Pos.map_under_mark ast_to_typ t1), @@ -76,6 +93,7 @@ let rec typ_to_ast (ty : typ Pos.marked UnionFind.elem) : A.typ Pos.marked = match ty with | TUnit -> A.TUnit | TBool -> A.TBool + | TInt -> A.TInt | TTuple ts -> A.TTuple (List.map typ_to_ast ts) | TArrow (t1, t2) -> A.TArrow (typ_to_ast t1, typ_to_ast t2) | TAny -> A.TUnit) @@ -92,7 +110,8 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m | None -> Errors.raise_spanned_error "Variable not found in the current context" (Pos.get_position e) ) - | ELit (LTrue | LFalse) -> UnionFind.make (Pos.same_pos_as TBool e) + | ELit (LBool _) -> UnionFind.make (Pos.same_pos_as TBool e) + | ELit (LInt _) -> UnionFind.make (Pos.same_pos_as TInt e) | ELit LEmptyError -> UnionFind.make (Pos.same_pos_as TAny e) | ETuple es -> let ts = List.map (typecheck_expr_bottom_up env) es in @@ -138,6 +157,7 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m in typecheck_expr_top_down env e1 t_app; t_app + | EOp op -> op_type (Pos.same_pos_as op e) | EDefault (just, cons, subs) -> typecheck_expr_top_down env just (UnionFind.make (Pos.same_pos_as TBool just)); let tcons = typecheck_expr_bottom_up env cons in @@ -158,7 +178,8 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked) | None -> Errors.raise_spanned_error "Variable not found in the current context" (Pos.get_position e) ) - | ELit (LTrue | LFalse) -> unify tau (UnionFind.make (Pos.same_pos_as TBool e)) + | ELit (LBool _) -> unify tau (UnionFind.make (Pos.same_pos_as TBool e)) + | ELit (LInt _) -> unify tau (UnionFind.make (Pos.same_pos_as TInt e)) | ELit LEmptyError -> unify tau (UnionFind.make (Pos.same_pos_as TAny e)) | ETuple es -> ( let tau' = UnionFind.get (UnionFind.find tau) in @@ -216,6 +237,9 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked) t_args tau in unify te1 t_func + | EOp op -> + let op_typ = op_type (Pos.same_pos_as op e) in + unify op_typ tau | EDefault (just, cons, subs) -> typecheck_expr_top_down env just (UnionFind.make (Pos.same_pos_as TBool just)); typecheck_expr_top_down env cons tau; diff --git a/src/catala/desugared/ast.ml b/src/catala/desugared/ast.ml index 983f6095..6609a1a7 100644 --- a/src/catala/desugared/ast.ml +++ b/src/catala/desugared/ast.ml @@ -57,7 +57,7 @@ type definition = Scopelang.Ast.expr Pos.marked let empty_def (pos : Pos.t) : definition = ( Scopelang.Ast.EDefault - ( (Scopelang.Ast.ELit Dcalc.Ast.LFalse, pos), + ( (Scopelang.Ast.ELit (Dcalc.Ast.LBool false), pos), (Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, pos), [] ), pos )