mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
finished refactoring
This commit is contained in:
parent
e519b7f146
commit
839a7ffd83
@ -352,33 +352,40 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
field sc_sig.scope_sig_output_struct (Expr.with_ty m typ)
|
||||
in
|
||||
match Marked.unmark typ with
|
||||
| TArrow (t_in, t_out) ->
|
||||
| TArrow (ts_in, t_out) ->
|
||||
(* Here the output scope struct field is a function so we
|
||||
eta-expand it and insert logging instructions. Invariant:
|
||||
works because user-defined functions in scope have only one
|
||||
argument. *)
|
||||
let param_var = Var.make "param" in
|
||||
let params_vars =
|
||||
ListLabels.mapi ts_in ~f:(fun i _ ->
|
||||
Var.make ("param" ^ string_of_int i))
|
||||
in
|
||||
let f_markings =
|
||||
[ScopeName.get_info scope; StructField.get_info field]
|
||||
in
|
||||
Expr.make_abs
|
||||
(Array.of_list [param_var])
|
||||
(Array.of_list params_vars)
|
||||
(tag_with_log_entry
|
||||
(tag_with_log_entry
|
||||
(Expr.eapp
|
||||
(tag_with_log_entry original_field_expr BeginCall
|
||||
f_markings)
|
||||
[
|
||||
(ListLabels.mapi (List.combine params_vars ts_in)
|
||||
~f:(fun i (param_var, t_in) ->
|
||||
tag_with_log_entry
|
||||
(Expr.make_var param_var (Expr.with_ty m t_in))
|
||||
(VarDef (Marked.unmark t_in))
|
||||
(f_markings @ [Marked.mark (Expr.pos e) "input"]);
|
||||
]
|
||||
(f_markings
|
||||
@ [
|
||||
Marked.mark (Expr.pos e)
|
||||
("input" ^ string_of_int i);
|
||||
])))
|
||||
(Expr.with_ty m t_out))
|
||||
(VarDef (Marked.unmark t_out))
|
||||
(f_markings @ [Marked.mark (Expr.pos e) "output"]))
|
||||
EndCall f_markings)
|
||||
[t_in] (Expr.pos e)
|
||||
ts_in (Expr.pos e)
|
||||
| _ -> original_field_expr)
|
||||
(StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs))
|
||||
(Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))
|
||||
@ -443,7 +450,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
| m -> tag_with_log_entry e1_func BeginCall m
|
||||
in
|
||||
let new_args = List.map (translate_expr ctx) args in
|
||||
let input_typ, output_typ =
|
||||
let input_typs, output_typ =
|
||||
(* NOTE: this is a temporary solution, it works because it's assume that
|
||||
all function calls are from scope variable. However, this will change
|
||||
-- for more information see
|
||||
@ -452,8 +459,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in
|
||||
match typ with
|
||||
| TArrow (marked_input_typ, marked_output_typ) ->
|
||||
Marked.unmark marked_input_typ, Marked.unmark marked_output_typ
|
||||
| _ -> TAny, TAny
|
||||
( List.map Marked.unmark marked_input_typ,
|
||||
Marked.unmark marked_output_typ )
|
||||
| _ -> [TAny], TAny
|
||||
in
|
||||
match Marked.unmark f with
|
||||
| ELocation (ScopelangScopeVar var) ->
|
||||
@ -467,21 +475,22 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars
|
||||
in
|
||||
match typ with
|
||||
| TArrow ((tin, _), (tout, _)) -> tin, tout
|
||||
| TArrow (tin, (tout, _)) -> List.map Marked.unmark tin, tout
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"Application of non-function toplevel variable")
|
||||
| _ -> TAny, TAny
|
||||
| _ -> [TAny], TAny
|
||||
in
|
||||
let new_args =
|
||||
match markings, new_args with
|
||||
| (_ :: _ as m), [new_arg] ->
|
||||
[
|
||||
ListLabels.mapi (List.combine new_args input_typs)
|
||||
~f:(fun i (new_arg, input_typ) ->
|
||||
match markings with
|
||||
| _ :: _ as m ->
|
||||
tag_with_log_entry new_arg (VarDef input_typ)
|
||||
(m @ [Marked.mark (Expr.pos e) "input"]);
|
||||
]
|
||||
| _ -> new_args
|
||||
(m @ [Marked.mark (Expr.pos e) ("input" ^ string_of_int i)])
|
||||
| _ -> new_arg)
|
||||
in
|
||||
|
||||
let new_e = Expr.eapp e1_func new_args m in
|
||||
let new_e =
|
||||
match markings with
|
||||
@ -640,7 +649,7 @@ let translate_rule
|
||||
| OnlyInput -> tau
|
||||
| Reentrant ->
|
||||
if is_func then tau
|
||||
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
|
||||
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
|
||||
scope_let_expr = thunked_or_nonempty_new_e;
|
||||
scope_let_kind = SubScopeVarDefinition;
|
||||
})
|
||||
@ -935,7 +944,7 @@ let translate_scope_decl
|
||||
match var_ctx.scope_var_typ with
|
||||
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
|
||||
| _ ->
|
||||
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
|
||||
( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
|
||||
pos_sigma ))
|
||||
| NoInput -> failwith "should not happen"
|
||||
in
|
||||
|
@ -522,9 +522,9 @@ let interpret_program :
|
||||
match Marked.unmark ty with
|
||||
| TArrow (ty_in, ty_out) ->
|
||||
Expr.make_abs
|
||||
[| Var.make "_" |]
|
||||
(Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in)
|
||||
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
|
||||
[ty_in] (Expr.mark_pos mark_e)
|
||||
ty_in (Expr.mark_pos mark_e)
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Marked.get_mark ty)
|
||||
"This scope needs input arguments to be executed. But the Catala \
|
||||
|
@ -129,8 +129,8 @@ let rec translate_typ (tau : typ) : typ =
|
||||
| TAny -> TAny
|
||||
| TArray ts -> TArray (translate_typ ts)
|
||||
(* catala is not polymorphic *)
|
||||
| TArrow ((TLit TUnit, _), t2) -> TOption (translate_typ t2)
|
||||
| TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2)
|
||||
| TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2)
|
||||
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
|
||||
end
|
||||
|
||||
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
|
||||
@ -458,7 +458,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
|
||||
thunked, then the variable is context. If it's not thunked, it's a
|
||||
regular input. *)
|
||||
match Marked.unmark typ with
|
||||
| TArrow ((TLit TUnit, _), _) -> false
|
||||
| TArrow ([(TLit TUnit, _)], _) -> false
|
||||
| _ -> true)
|
||||
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
|
||||
| DestructuringSubScopeResults | Assertion ->
|
||||
|
@ -167,8 +167,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
||||
format_enum_name Ast.option_enum
|
||||
| TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e)
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a ->@ %a@]" format_typ_with_parens t1
|
||||
format_typ_with_parens t2
|
||||
Format.fprintf fmt "@[<hov 2>%a@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ")
|
||||
format_typ_with_parens)
|
||||
(t1 @ [t2])
|
||||
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
|
||||
| TAny -> Format.fprintf fmt "_"
|
||||
|
||||
|
@ -186,8 +186,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
||||
Format.fprintf fmt "Optional[%a]" format_typ some_typ
|
||||
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1
|
||||
format_typ_with_parens t2
|
||||
Format.fprintf fmt "Callable[[%a], %a]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
format_typ_with_parens)
|
||||
t1 format_typ_with_parens t2
|
||||
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
|
||||
| TAny -> Format.fprintf fmt "Any"
|
||||
|
||||
|
@ -405,10 +405,12 @@ let find_or_create_funcdecl (ctx : context) (v : typed expr Var.t) (ty : typ) :
|
||||
| None -> (
|
||||
match Marked.unmark ty with
|
||||
| TArrow (t1, t2) ->
|
||||
let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in
|
||||
let ctx, z3_t1 =
|
||||
List.fold_left_map translate_typ ctx (List.map Marked.unmark t1)
|
||||
in
|
||||
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
|
||||
let name = unique_name v in
|
||||
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in
|
||||
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name z3_t1 z3_t2 in
|
||||
let ctx = add_funcdecl v fd ctx in
|
||||
let ctx = add_z3var name v ty ctx in
|
||||
ctx, fd
|
||||
|
Loading…
Reference in New Issue
Block a user