diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 3849b12c..872ec1b9 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -352,33 +352,40 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : field sc_sig.scope_sig_output_struct (Expr.with_ty m typ) in match Marked.unmark typ with - | TArrow (t_in, t_out) -> + | TArrow (ts_in, t_out) -> (* Here the output scope struct field is a function so we eta-expand it and insert logging instructions. Invariant: works because user-defined functions in scope have only one argument. *) - let param_var = Var.make "param" in + let params_vars = + ListLabels.mapi ts_in ~f:(fun i _ -> + Var.make ("param" ^ string_of_int i)) + in let f_markings = [ScopeName.get_info scope; StructField.get_info field] in Expr.make_abs - (Array.of_list [param_var]) + (Array.of_list params_vars) (tag_with_log_entry (tag_with_log_entry (Expr.eapp (tag_with_log_entry original_field_expr BeginCall f_markings) - [ - tag_with_log_entry - (Expr.make_var param_var (Expr.with_ty m t_in)) - (VarDef (Marked.unmark t_in)) - (f_markings @ [Marked.mark (Expr.pos e) "input"]); - ] + (ListLabels.mapi (List.combine params_vars ts_in) + ~f:(fun i (param_var, t_in) -> + tag_with_log_entry + (Expr.make_var param_var (Expr.with_ty m t_in)) + (VarDef (Marked.unmark t_in)) + (f_markings + @ [ + Marked.mark (Expr.pos e) + ("input" ^ string_of_int i); + ]))) (Expr.with_ty m t_out)) (VarDef (Marked.unmark t_out)) (f_markings @ [Marked.mark (Expr.pos e) "output"])) EndCall f_markings) - [t_in] (Expr.pos e) + ts_in (Expr.pos e) | _ -> original_field_expr) (StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs)) (Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e)) @@ -443,7 +450,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : | m -> tag_with_log_entry e1_func BeginCall m in let new_args = List.map (translate_expr ctx) args in - let input_typ, output_typ = + let input_typs, output_typ = (* NOTE: this is a temporary solution, it works because it's assume that all function calls are from scope variable. However, this will change -- for more information see @@ -452,8 +459,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in match typ with | TArrow (marked_input_typ, marked_output_typ) -> - Marked.unmark marked_input_typ, Marked.unmark marked_output_typ - | _ -> TAny, TAny + ( List.map Marked.unmark marked_input_typ, + Marked.unmark marked_output_typ ) + | _ -> [TAny], TAny in match Marked.unmark f with | ELocation (ScopelangScopeVar var) -> @@ -467,21 +475,22 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars in match typ with - | TArrow ((tin, _), (tout, _)) -> tin, tout + | TArrow (tin, (tout, _)) -> List.map Marked.unmark tin, tout | _ -> Errors.raise_spanned_error (Expr.pos e) "Application of non-function toplevel variable") - | _ -> TAny, TAny + | _ -> [TAny], TAny in let new_args = - match markings, new_args with - | (_ :: _ as m), [new_arg] -> - [ - tag_with_log_entry new_arg (VarDef input_typ) - (m @ [Marked.mark (Expr.pos e) "input"]); - ] - | _ -> new_args + ListLabels.mapi (List.combine new_args input_typs) + ~f:(fun i (new_arg, input_typ) -> + match markings with + | _ :: _ as m -> + tag_with_log_entry new_arg (VarDef input_typ) + (m @ [Marked.mark (Expr.pos e) ("input" ^ string_of_int i)]) + | _ -> new_arg) in + let new_e = Expr.eapp e1_func new_args m in let new_e = match markings with @@ -640,7 +649,7 @@ let translate_rule | OnlyInput -> tau | Reentrant -> if is_func then tau - else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos); + else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos); scope_let_expr = thunked_or_nonempty_new_e; scope_let_kind = SubScopeVarDefinition; }) @@ -935,7 +944,7 @@ let translate_scope_decl match var_ctx.scope_var_typ with | TArrow _ -> var_ctx.scope_var_typ, pos_sigma | _ -> - ( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), + ( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)), pos_sigma )) | NoInput -> failwith "should not happen" in diff --git a/compiler/dcalc/interpreter.ml b/compiler/dcalc/interpreter.ml index e6260c04..a36472b4 100644 --- a/compiler/dcalc/interpreter.ml +++ b/compiler/dcalc/interpreter.ml @@ -522,9 +522,9 @@ let interpret_program : match Marked.unmark ty with | TArrow (ty_in, ty_out) -> Expr.make_abs - [| Var.make "_" |] + (Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in) (Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out) - [ty_in] (Expr.mark_pos mark_e) + ty_in (Expr.mark_pos mark_e) | _ -> Errors.raise_spanned_error (Marked.get_mark ty) "This scope needs input arguments to be executed. But the Catala \ diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index aecc1efa..6ee04c66 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -129,8 +129,8 @@ let rec translate_typ (tau : typ) : typ = | TAny -> TAny | TArray ts -> TArray (translate_typ ts) (* catala is not polymorphic *) - | TArrow ((TLit TUnit, _), t2) -> TOption (translate_typ t2) - | TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2) + | TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2) + | TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2) end (** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps. @@ -458,7 +458,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) : thunked, then the variable is context. If it's not thunked, it's a regular input. *) match Marked.unmark typ with - | TArrow ((TLit TUnit, _), _) -> false + | TArrow ([(TLit TUnit, _)], _) -> false | _ -> true) | ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope | DestructuringSubScopeResults | Assertion -> diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 21887866..1f7397df 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -167,8 +167,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = format_enum_name Ast.option_enum | TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e) | TArrow (t1, t2) -> - Format.fprintf fmt "@[%a ->@ %a@]" format_typ_with_parens t1 - format_typ_with_parens t2 + Format.fprintf fmt "@[%a@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ") + format_typ_with_parens) + (t1 @ [t2]) | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1 | TAny -> Format.fprintf fmt "_" diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 633e4c1b..32ba7cb0 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -186,8 +186,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = Format.fprintf fmt "Optional[%a]" format_typ some_typ | TEnum e -> Format.fprintf fmt "%a" format_enum_name e | TArrow (t1, t2) -> - Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 - format_typ_with_parens t2 + Format.fprintf fmt "Callable[[%a], %a]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + format_typ_with_parens) + t1 format_typ_with_parens t2 | TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1 | TAny -> Format.fprintf fmt "Any" diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index f1afe275..afe47766 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -405,10 +405,12 @@ let find_or_create_funcdecl (ctx : context) (v : typed expr Var.t) (ty : typ) : | None -> ( match Marked.unmark ty with | TArrow (t1, t2) -> - let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in + let ctx, z3_t1 = + List.fold_left_map translate_typ ctx (List.map Marked.unmark t1) + in let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in let name = unique_name v in - let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in + let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name z3_t1 z3_t2 in let ctx = add_funcdecl v fd ctx in let ctx = add_z3var name v ty ctx in ctx, fd