diff --git a/src/catala/default_calculus/ast.ml b/src/catala/default_calculus/ast.ml index 5e30ce72..690d9558 100644 --- a/src/catala/default_calculus/ast.ml +++ b/src/catala/default_calculus/ast.ml @@ -165,6 +165,17 @@ let make_let_in (x : Var.t) (tau : typ Pos.marked) (e1 : expr Pos.marked Bindlib (Pos.get_position (Bindlib.unbox e2))) (Bindlib.box_list [ e1 ]) +let make_multiple_let_in (xs : Var.t array) (taus : typ Pos.marked list) + (e1 : expr Pos.marked list Bindlib.box) (e2 : expr Pos.marked Bindlib.box) : + expr Pos.marked Bindlib.box = + Bindlib.box_apply2 + (fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2))) + (make_abs xs e2 + (Pos.get_position (Bindlib.unbox e2)) + taus + (Pos.get_position (Bindlib.unbox e2))) + e1 + type binder = (expr, expr Pos.marked) Bindlib.binder type program = { decl_ctx : decl_ctx; scopes : (Var.t * expr Pos.marked) list } diff --git a/src/catala/default_calculus/print.ml b/src/catala/default_calculus/print.ml index 119dfa98..5ceabfb9 100644 --- a/src/catala/default_calculus/print.ml +++ b/src/catala/default_calculus/print.ml @@ -253,7 +253,7 @@ let rec format_expr (ctx : Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos. let xs_tau_arg = List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args in Format.fprintf fmt "@[%a%a@]" (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + ~pp_sep:(fun fmt () -> Format.fprintf fmt "") (fun fmt (x, tau, arg) -> Format.fprintf fmt "@[@[let@ %a@ :@ %a@ =@ %a@]@ in@\n@]" format_var x (format_typ ctx) tau format_expr arg)) diff --git a/src/catala/lambda_calculus/to_ocaml.ml b/src/catala/lambda_calculus/to_ocaml.ml index d1da6780..98986c22 100644 --- a/src/catala/lambda_calculus/to_ocaml.ml +++ b/src/catala/lambda_calculus/to_ocaml.ml @@ -252,7 +252,7 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp let xs_tau_arg = List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args in Format.fprintf fmt "%a%a" (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + ~pp_sep:(fun fmt () -> Format.fprintf fmt "") (fun fmt (x, tau, arg) -> Format.fprintf fmt "@[let@ %a@ :@ %a@ =@ %a@]@ in@\n" format_var x format_typ tau format_with_parens arg)) diff --git a/src/catala/scope_language/scope_to_dcalc.ml b/src/catala/scope_language/scope_to_dcalc.ml index a7cc13c1..fb49fec9 100644 --- a/src/catala/scope_language/scope_to_dcalc.ml +++ b/src/catala/scope_language/scope_to_dcalc.ml @@ -457,10 +457,12 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list) let next_e, new_ctx = translate_rules new_ctx rest (sigma_name, pos_sigma) sigma_return_struct_name in - let results_bindings, _ = - List.fold_right - (fun (_, tau, dvar) (acc, i) -> - let result_access = + let results_bindings = + let xs = Array.of_list (List.map (fun (_, _, v) -> v) all_subscope_vars_dcalc) in + let taus = List.map (fun (_, tau, _) -> (tau, pos_sigma)) all_subscope_vars_dcalc in + let e1s = + List.mapi + (fun i _ -> Bindlib.box_apply (fun r -> ( Dcalc.Ast.ETupleAccess @@ -469,11 +471,10 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list) Some called_scope_return_struct, List.map (fun (_, t, _) -> (t, pos_sigma)) all_subscope_vars_dcalc ), pos_sigma )) - (Dcalc.Ast.make_var (result_tuple_var, pos_sigma)) - in - (Dcalc.Ast.make_let_in dvar (tau, pos_sigma) result_access acc, i - 1)) - all_subscope_vars_dcalc - (next_e, List.length all_subscope_vars_dcalc - 1) + (Dcalc.Ast.make_var (result_tuple_var, pos_sigma))) + all_subscope_vars_dcalc + in + Dcalc.Ast.make_multiple_let_in xs taus (Bindlib.box_list e1s) next_e in let result_tuple_typ = ( Dcalc.Ast.TTuple @@ -529,10 +530,17 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) scope_variables in (* first we create variables from the fields of the input struct *) - let rules, _ = - List.fold_right - (fun (_, tau, dvar) (acc, i) -> - let result_access = + let rules = + let xs = Array.of_list (List.map (fun (_, _, v) -> v) scope_variables) in + let taus = + List.map + (fun (_, tau, _) -> + (Dcalc.Ast.TArrow ((Dcalc.Ast.TLit TUnit, pos_sigma), (tau, pos_sigma)), pos_sigma)) + scope_variables + in + let e1s = + List.mapi + (fun i _ -> Bindlib.box_apply (fun r -> ( Dcalc.Ast.ETupleAccess @@ -545,14 +553,10 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) pos_sigma )) scope_variables ), pos_sigma )) - (Dcalc.Ast.make_var (scope_input_var, pos_sigma)) - in - ( Dcalc.Ast.make_let_in dvar - (Dcalc.Ast.TArrow ((Dcalc.Ast.TLit TUnit, pos_sigma), (tau, pos_sigma)), pos_sigma) - result_access acc, - i - 1 )) - scope_variables - (rules, List.length scope_variables - 1) + (Dcalc.Ast.make_var (scope_input_var, pos_sigma))) + scope_variables + in + Dcalc.Ast.make_multiple_let_in xs taus (Bindlib.box_list e1s) rules in let scope_return_struct_fields = List.map diff --git a/tests/test_enum/bad/output/quick_pattern_2.catala.A.out b/tests/test_enum/bad/output/quick_pattern_2.catala.A.out index e920e8e3..893e9eec 100644 --- a/tests/test_enum/bad/output/quick_pattern_2.catala.A.out +++ b/tests/test_enum/bad/output/quick_pattern_2.catala.A.out @@ -1,5 +1,5 @@ [ERROR] Error during typechecking, incompatible types: -[ERROR] --> F [Case3: any[77]] +[ERROR] --> F [Case3: any[73]] [ERROR] --> E [Case1: integer | Case2: unit] [ERROR] [ERROR] Error coming from typechecking the following expression: @@ -9,7 +9,7 @@ [ERROR] | ^ [ERROR] + Article [ERROR] -[ERROR] Type F [Case3: any[77]] coming from expression: +[ERROR] Type F [Case3: any[73]] coming from expression: [ERROR] --> test_enum/bad/quick_pattern_2.catala [ERROR] | [ERROR] 28 | def y := x with Case3 diff --git a/tests/test_enum/bad/output/quick_pattern_3.catala.A.out b/tests/test_enum/bad/output/quick_pattern_3.catala.A.out index 59745cd3..57e53b55 100644 --- a/tests/test_enum/bad/output/quick_pattern_3.catala.A.out +++ b/tests/test_enum/bad/output/quick_pattern_3.catala.A.out @@ -1,5 +1,5 @@ [ERROR] Error during typechecking, incompatible types: -[ERROR] --> F [Case3: any[20] | Case4: any[21]] +[ERROR] --> F [Case3: any[19] | Case4: any[20]] [ERROR] --> E [Case1: unit | Case2: unit] [ERROR] [ERROR] Error coming from typechecking the following expression: @@ -9,7 +9,7 @@ [ERROR] | ^ [ERROR] + Article [ERROR] -[ERROR] Type F [Case3: any[20] | Case4: any[21]] coming from expression: +[ERROR] Type F [Case3: any[19] | Case4: any[20]] coming from expression: [ERROR] --> test_enum/bad/quick_pattern_3.catala [ERROR] | [ERROR] 18 | def y := x with Case3 diff --git a/tests/test_enum/bad/output/quick_pattern_4.catala.A.out b/tests/test_enum/bad/output/quick_pattern_4.catala.A.out index 92b033f2..5ede1a7c 100644 --- a/tests/test_enum/bad/output/quick_pattern_4.catala.A.out +++ b/tests/test_enum/bad/output/quick_pattern_4.catala.A.out @@ -1,5 +1,5 @@ [ERROR] Error during typechecking, incompatible types: -[ERROR] --> F [Case3: any[20]] +[ERROR] --> F [Case3: any[19]] [ERROR] --> E [Case1: unit | Case2: unit] [ERROR] [ERROR] Error coming from typechecking the following expression: @@ -9,7 +9,7 @@ [ERROR] | ^ [ERROR] + Test [ERROR] -[ERROR] Type F [Case3: any[20]] coming from expression: +[ERROR] Type F [Case3: any[19]] coming from expression: [ERROR] --> test_enum/bad/quick_pattern_4.catala [ERROR] | [ERROR] 17 | def y := x with Case3