finished refactoring

This commit is contained in:
adelaett 2023-02-20 17:58:29 +01:00
parent e519b7f146
commit 839a7ffd83
6 changed files with 52 additions and 35 deletions

View File

@ -352,33 +352,40 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
field sc_sig.scope_sig_output_struct (Expr.with_ty m typ)
in
match Marked.unmark typ with
| TArrow (t_in, t_out) ->
| TArrow (ts_in, t_out) ->
(* Here the output scope struct field is a function so we
eta-expand it and insert logging instructions. Invariant:
works because user-defined functions in scope have only one
argument. *)
let param_var = Var.make "param" in
let params_vars =
ListLabels.mapi ts_in ~f:(fun i _ ->
Var.make ("param" ^ string_of_int i))
in
let f_markings =
[ScopeName.get_info scope; StructField.get_info field]
in
Expr.make_abs
(Array.of_list [param_var])
(Array.of_list params_vars)
(tag_with_log_entry
(tag_with_log_entry
(Expr.eapp
(tag_with_log_entry original_field_expr BeginCall
f_markings)
[
tag_with_log_entry
(Expr.make_var param_var (Expr.with_ty m t_in))
(VarDef (Marked.unmark t_in))
(f_markings @ [Marked.mark (Expr.pos e) "input"]);
]
(ListLabels.mapi (List.combine params_vars ts_in)
~f:(fun i (param_var, t_in) ->
tag_with_log_entry
(Expr.make_var param_var (Expr.with_ty m t_in))
(VarDef (Marked.unmark t_in))
(f_markings
@ [
Marked.mark (Expr.pos e)
("input" ^ string_of_int i);
])))
(Expr.with_ty m t_out))
(VarDef (Marked.unmark t_out))
(f_markings @ [Marked.mark (Expr.pos e) "output"]))
EndCall f_markings)
[t_in] (Expr.pos e)
ts_in (Expr.pos e)
| _ -> original_field_expr)
(StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs))
(Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))
@ -443,7 +450,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
| m -> tag_with_log_entry e1_func BeginCall m
in
let new_args = List.map (translate_expr ctx) args in
let input_typ, output_typ =
let input_typs, output_typ =
(* NOTE: this is a temporary solution, it works because it's assume that
all function calls are from scope variable. However, this will change
-- for more information see
@ -452,8 +459,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in
match typ with
| TArrow (marked_input_typ, marked_output_typ) ->
Marked.unmark marked_input_typ, Marked.unmark marked_output_typ
| _ -> TAny, TAny
( List.map Marked.unmark marked_input_typ,
Marked.unmark marked_output_typ )
| _ -> [TAny], TAny
in
match Marked.unmark f with
| ELocation (ScopelangScopeVar var) ->
@ -467,21 +475,22 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars
in
match typ with
| TArrow ((tin, _), (tout, _)) -> tin, tout
| TArrow (tin, (tout, _)) -> List.map Marked.unmark tin, tout
| _ ->
Errors.raise_spanned_error (Expr.pos e)
"Application of non-function toplevel variable")
| _ -> TAny, TAny
| _ -> [TAny], TAny
in
let new_args =
match markings, new_args with
| (_ :: _ as m), [new_arg] ->
[
tag_with_log_entry new_arg (VarDef input_typ)
(m @ [Marked.mark (Expr.pos e) "input"]);
]
| _ -> new_args
ListLabels.mapi (List.combine new_args input_typs)
~f:(fun i (new_arg, input_typ) ->
match markings with
| _ :: _ as m ->
tag_with_log_entry new_arg (VarDef input_typ)
(m @ [Marked.mark (Expr.pos e) ("input" ^ string_of_int i)])
| _ -> new_arg)
in
let new_e = Expr.eapp e1_func new_args m in
let new_e =
match markings with
@ -640,7 +649,7 @@ let translate_rule
| OnlyInput -> tau
| Reentrant ->
if is_func then tau
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition;
})
@ -935,7 +944,7 @@ let translate_scope_decl
match var_ctx.scope_var_typ with
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
| _ ->
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma ))
| NoInput -> failwith "should not happen"
in

View File

@ -522,9 +522,9 @@ let interpret_program :
match Marked.unmark ty with
| TArrow (ty_in, ty_out) ->
Expr.make_abs
[| Var.make "_" |]
(Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in)
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
[ty_in] (Expr.mark_pos mark_e)
ty_in (Expr.mark_pos mark_e)
| _ ->
Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \

View File

@ -129,8 +129,8 @@ let rec translate_typ (tau : typ) : typ =
| TAny -> TAny
| TArray ts -> TArray (translate_typ ts)
(* catala is not polymorphic *)
| TArrow ((TLit TUnit, _), t2) -> TOption (translate_typ t2)
| TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2)
| TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2)
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
end
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
@ -458,7 +458,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
thunked, then the variable is context. If it's not thunked, it's a
regular input. *)
match Marked.unmark typ with
| TArrow ((TLit TUnit, _), _) -> false
| TArrow ([(TLit TUnit, _)], _) -> false
| _ -> true)
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
| DestructuringSubScopeResults | Assertion ->

View File

@ -167,8 +167,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
format_enum_name Ast.option_enum
| TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e)
| TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a ->@ %a@]" format_typ_with_parens t1
format_typ_with_parens t2
Format.fprintf fmt "@[<hov 2>%a@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ")
format_typ_with_parens)
(t1 @ [t2])
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "_"

View File

@ -186,8 +186,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e
| TArrow (t1, t2) ->
Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1
format_typ_with_parens t2
Format.fprintf fmt "Callable[[%a], %a]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
format_typ_with_parens)
t1 format_typ_with_parens t2
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "Any"

View File

@ -405,10 +405,12 @@ let find_or_create_funcdecl (ctx : context) (v : typed expr Var.t) (ty : typ) :
| None -> (
match Marked.unmark ty with
| TArrow (t1, t2) ->
let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in
let ctx, z3_t1 =
List.fold_left_map translate_typ ctx (List.map Marked.unmark t1)
in
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
let name = unique_name v in
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name z3_t1 z3_t2 in
let ctx = add_funcdecl v fd ctx in
let ctx = add_z3var name v ty ctx in
ctx, fd