finished refactoring

This commit is contained in:
adelaett 2023-02-20 17:58:29 +01:00
parent e519b7f146
commit 839a7ffd83
6 changed files with 52 additions and 35 deletions

View File

@ -352,33 +352,40 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
field sc_sig.scope_sig_output_struct (Expr.with_ty m typ) field sc_sig.scope_sig_output_struct (Expr.with_ty m typ)
in in
match Marked.unmark typ with match Marked.unmark typ with
| TArrow (t_in, t_out) -> | TArrow (ts_in, t_out) ->
(* Here the output scope struct field is a function so we (* Here the output scope struct field is a function so we
eta-expand it and insert logging instructions. Invariant: eta-expand it and insert logging instructions. Invariant:
works because user-defined functions in scope have only one works because user-defined functions in scope have only one
argument. *) argument. *)
let param_var = Var.make "param" in let params_vars =
ListLabels.mapi ts_in ~f:(fun i _ ->
Var.make ("param" ^ string_of_int i))
in
let f_markings = let f_markings =
[ScopeName.get_info scope; StructField.get_info field] [ScopeName.get_info scope; StructField.get_info field]
in in
Expr.make_abs Expr.make_abs
(Array.of_list [param_var]) (Array.of_list params_vars)
(tag_with_log_entry (tag_with_log_entry
(tag_with_log_entry (tag_with_log_entry
(Expr.eapp (Expr.eapp
(tag_with_log_entry original_field_expr BeginCall (tag_with_log_entry original_field_expr BeginCall
f_markings) f_markings)
[ (ListLabels.mapi (List.combine params_vars ts_in)
tag_with_log_entry ~f:(fun i (param_var, t_in) ->
(Expr.make_var param_var (Expr.with_ty m t_in)) tag_with_log_entry
(VarDef (Marked.unmark t_in)) (Expr.make_var param_var (Expr.with_ty m t_in))
(f_markings @ [Marked.mark (Expr.pos e) "input"]); (VarDef (Marked.unmark t_in))
] (f_markings
@ [
Marked.mark (Expr.pos e)
("input" ^ string_of_int i);
])))
(Expr.with_ty m t_out)) (Expr.with_ty m t_out))
(VarDef (Marked.unmark t_out)) (VarDef (Marked.unmark t_out))
(f_markings @ [Marked.mark (Expr.pos e) "output"])) (f_markings @ [Marked.mark (Expr.pos e) "output"]))
EndCall f_markings) EndCall f_markings)
[t_in] (Expr.pos e) ts_in (Expr.pos e)
| _ -> original_field_expr) | _ -> original_field_expr)
(StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs)) (StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs))
(Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e)) (Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))
@ -443,7 +450,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
| m -> tag_with_log_entry e1_func BeginCall m | m -> tag_with_log_entry e1_func BeginCall m
in in
let new_args = List.map (translate_expr ctx) args in let new_args = List.map (translate_expr ctx) args in
let input_typ, output_typ = let input_typs, output_typ =
(* NOTE: this is a temporary solution, it works because it's assume that (* NOTE: this is a temporary solution, it works because it's assume that
all function calls are from scope variable. However, this will change all function calls are from scope variable. However, this will change
-- for more information see -- for more information see
@ -452,8 +459,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in
match typ with match typ with
| TArrow (marked_input_typ, marked_output_typ) -> | TArrow (marked_input_typ, marked_output_typ) ->
Marked.unmark marked_input_typ, Marked.unmark marked_output_typ ( List.map Marked.unmark marked_input_typ,
| _ -> TAny, TAny Marked.unmark marked_output_typ )
| _ -> [TAny], TAny
in in
match Marked.unmark f with match Marked.unmark f with
| ELocation (ScopelangScopeVar var) -> | ELocation (ScopelangScopeVar var) ->
@ -467,21 +475,22 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars
in in
match typ with match typ with
| TArrow ((tin, _), (tout, _)) -> tin, tout | TArrow (tin, (tout, _)) -> List.map Marked.unmark tin, tout
| _ -> | _ ->
Errors.raise_spanned_error (Expr.pos e) Errors.raise_spanned_error (Expr.pos e)
"Application of non-function toplevel variable") "Application of non-function toplevel variable")
| _ -> TAny, TAny | _ -> [TAny], TAny
in in
let new_args = let new_args =
match markings, new_args with ListLabels.mapi (List.combine new_args input_typs)
| (_ :: _ as m), [new_arg] -> ~f:(fun i (new_arg, input_typ) ->
[ match markings with
tag_with_log_entry new_arg (VarDef input_typ) | _ :: _ as m ->
(m @ [Marked.mark (Expr.pos e) "input"]); tag_with_log_entry new_arg (VarDef input_typ)
] (m @ [Marked.mark (Expr.pos e) ("input" ^ string_of_int i)])
| _ -> new_args | _ -> new_arg)
in in
let new_e = Expr.eapp e1_func new_args m in let new_e = Expr.eapp e1_func new_args m in
let new_e = let new_e =
match markings with match markings with
@ -640,7 +649,7 @@ let translate_rule
| OnlyInput -> tau | OnlyInput -> tau
| Reentrant -> | Reentrant ->
if is_func then tau if is_func then tau
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos); else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
scope_let_expr = thunked_or_nonempty_new_e; scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
}) })
@ -935,7 +944,7 @@ let translate_scope_decl
match var_ctx.scope_var_typ with match var_ctx.scope_var_typ with
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma | TArrow _ -> var_ctx.scope_var_typ, pos_sigma
| _ -> | _ ->
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), ( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma )) pos_sigma ))
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
in in

View File

@ -522,9 +522,9 @@ let interpret_program :
match Marked.unmark ty with match Marked.unmark ty with
| TArrow (ty_in, ty_out) -> | TArrow (ty_in, ty_out) ->
Expr.make_abs Expr.make_abs
[| Var.make "_" |] (Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in)
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out) (Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
[ty_in] (Expr.mark_pos mark_e) ty_in (Expr.mark_pos mark_e)
| _ -> | _ ->
Errors.raise_spanned_error (Marked.get_mark ty) Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \ "This scope needs input arguments to be executed. But the Catala \

View File

@ -129,8 +129,8 @@ let rec translate_typ (tau : typ) : typ =
| TAny -> TAny | TAny -> TAny
| TArray ts -> TArray (translate_typ ts) | TArray ts -> TArray (translate_typ ts)
(* catala is not polymorphic *) (* catala is not polymorphic *)
| TArrow ((TLit TUnit, _), t2) -> TOption (translate_typ t2) | TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2)
| TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2) | TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
end end
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps. (** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
@ -458,7 +458,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
thunked, then the variable is context. If it's not thunked, it's a thunked, then the variable is context. If it's not thunked, it's a
regular input. *) regular input. *)
match Marked.unmark typ with match Marked.unmark typ with
| TArrow ((TLit TUnit, _), _) -> false | TArrow ([(TLit TUnit, _)], _) -> false
| _ -> true) | _ -> true)
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope | ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
| DestructuringSubScopeResults | Assertion -> | DestructuringSubScopeResults | Assertion ->

View File

@ -167,8 +167,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
format_enum_name Ast.option_enum format_enum_name Ast.option_enum
| TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e) | TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e)
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a ->@ %a@]" format_typ_with_parens t1 Format.fprintf fmt "@[<hov 2>%a@]"
format_typ_with_parens t2 (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ")
format_typ_with_parens)
(t1 @ [t2])
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1 | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "_" | TAny -> Format.fprintf fmt "_"

View File

@ -186,8 +186,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
Format.fprintf fmt "Optional[%a]" format_typ some_typ Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e | TEnum e -> Format.fprintf fmt "%a" format_enum_name e
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 Format.fprintf fmt "Callable[[%a], %a]"
format_typ_with_parens t2 (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
format_typ_with_parens)
t1 format_typ_with_parens t2
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1 | TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "Any" | TAny -> Format.fprintf fmt "Any"

View File

@ -405,10 +405,12 @@ let find_or_create_funcdecl (ctx : context) (v : typed expr Var.t) (ty : typ) :
| None -> ( | None -> (
match Marked.unmark ty with match Marked.unmark ty with
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in let ctx, z3_t1 =
List.fold_left_map translate_typ ctx (List.map Marked.unmark t1)
in
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
let name = unique_name v in let name = unique_name v in
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name z3_t1 z3_t2 in
let ctx = add_funcdecl v fd ctx in let ctx = add_funcdecl v fd ctx in
let ctx = add_z3var name v ty ctx in let ctx = add_z3var name v ty ctx in
ctx, fd ctx, fd