From ea94e3b7059fa5e307411a3e78150b7f0a991d49 Mon Sep 17 00:00:00 2001 From: Nicolas Chataing Date: Fri, 31 Jul 2020 10:18:31 +0200 Subject: [PATCH] Changed functions expressions to have list of arguments --- src/catala/ir/lambda.ml | 13 +++++---- src/catala/translation/firstpass.ml | 41 +++++++++++++++++------------ 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/catala/ir/lambda.ml b/src/catala/ir/lambda.ml index 16855082..2083794a 100644 --- a/src/catala/ir/lambda.ml +++ b/src/catala/ir/lambda.ml @@ -34,8 +34,8 @@ type term = untyped_term Pos.marked * typ option and untyped_term = | EVar of uid - | EFun of binding * term - | EApp of term * term + | EFun of binding list * term + | EApp of term * term list | EIfThenElse of term * term * term | ELiteral of literal | EOp of op @@ -73,10 +73,13 @@ let print_op (op : op) : string = let rec print_term (((t, _), _) : term) : string = match t with | EVar uid -> Printf.sprintf "var(%n)" uid - | EFun ((binder, _), body) -> + | EFun (binders, body) -> let sbody = print_term body in - Printf.sprintf "fun %n -> %s" binder sbody - | EApp (t1, t2) -> Printf.sprintf "(%s) (%s)" (print_term t1) (print_term t2) + Printf.sprintf "fun %s -> %s" + (binders |> List.map (fun (x, _) -> Printf.sprintf "%d" x) |> String.concat " ") + sbody + | EApp (f, args) -> + Printf.sprintf "(%s) [%s]" (print_term f) (args |> List.map print_term |> String.concat ";") | EIfThenElse (tif, tthen, telse) -> Printf.sprintf "IF %s THEN %s ELSE %s" (print_term tif) (print_term tthen) (print_term telse) | ELiteral l -> print_literal l diff --git a/src/catala/translation/firstpass.ml b/src/catala/translation/firstpass.ml index 313e80de..5b852455 100644 --- a/src/catala/translation/firstpass.ml +++ b/src/catala/translation/firstpass.ml @@ -27,11 +27,10 @@ let rec expr_to_lambda ?(subdef : uid option) (scope : Context.uid) (ctxt : Cont ((EIfThenElse (rec_helper e_if, rec_helper e_then, rec_helper e_else), pos), None) | Binop (op, e1, e2) -> let op_term = (Pos.same_pos_as (EOp (Binop (Pos.unmark op))) op, None) in - let op_1 = ((EApp (op_term, rec_helper e1), pos), None) in - ((EApp (op_1, rec_helper e2), pos), None) + ((EApp (op_term, [ rec_helper e1; rec_helper e2 ]), pos), None) | Unop (op, e) -> let op_term = (Pos.same_pos_as (EOp (Unop (Pos.unmark op))) op, None) in - ((EApp (op_term, rec_helper e), pos), None) + ((EApp (op_term, [ rec_helper e ]), pos), None) | Literal l -> ((ELiteral l, pos), None) | Ident x -> let uid = Context.get_var_uid scope ctxt (x, pos) in @@ -66,20 +65,29 @@ let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.t let typ = Context.get_uid_typ ctxt uid in let term = ((EVar uid, pos), Some typ) in (term, typ) - | EFun (binding, body) -> + | EFun (bindings, body) -> (* Note that given the context formation process, the binder will already be present in the context (since we are working with uids), however it's added there for the sake of clarity *) - let uid, arg_typ = binding in - let uid_data : Context.uid_data = { uid_typ = arg_typ; uid_sort = Context.IdBinder } in - let body, ret_typ = typing { ctxt with data = Uid.UidMap.add uid uid_data ctxt.data } body in - let fun_typ = TArrow (arg_typ, ret_typ) in - (((EFun (binding, body), pos), Some fun_typ), fun_typ) - | EApp (t1, t2) -> ( - let t1, typ1 = typing ctxt t1 in - let t2, typ2 = typing ctxt t2 in - match typ1 with - | TArrow (arg_typ, ret_typ) -> - if arg_typ <> typ2 then assert false else (((EApp (t1, t2), pos), Some ret_typ), ret_typ) + let ctxt_data = + List.fold_left + (fun data (uid, arg_typ) -> + let uid_data : Context.uid_data = { uid_typ = arg_typ; uid_sort = Context.IdBinder } in + Uid.UidMap.add uid uid_data data) + ctxt.data bindings + in + + let body, ret_typ = typing { ctxt with data = ctxt_data } body in + let rec build_typ = function + | [] -> ret_typ + | (_, arg_t) :: args -> TArrow (arg_t, build_typ args) + in + let fun_typ = build_typ bindings in + (((EFun (bindings, body), pos), Some fun_typ), fun_typ) + | EApp (f, args) -> ( + let _f, f_typ = typing ctxt f in + let _args, _args_typ = args |> List.map (typing ctxt) |> List.split in + match f_typ with + | TArrow (_arg_typ, _ret_typ) -> assert false | TBool | TInt | TDummy -> assert false ) | EIfThenElse (t_if, t_then, t_else) -> let t_if, typ_if = typing ctxt t_if in @@ -110,8 +118,7 @@ let merge_conditions (precond : Lambda.term option) (cond : Lambda.term option) match (precond, cond) with | Some precond, Some cond -> let op_term = ((EOp (Binop And), Pos.no_pos), None) in - let term = ((EApp (op_term, precond), Pos.no_pos), None) in - ((EApp (term, cond), Pos.no_pos), None) + ((EApp (op_term, [ precond; cond ]), Pos.no_pos), None) | Some cond, None | None, Some cond -> cond | None, None -> ((ELiteral (Ast.Bool true), Pos.no_pos), Some TBool)