mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
finished refactoring
This commit is contained in:
parent
e519b7f146
commit
839a7ffd83
@ -352,33 +352,40 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
|||||||
field sc_sig.scope_sig_output_struct (Expr.with_ty m typ)
|
field sc_sig.scope_sig_output_struct (Expr.with_ty m typ)
|
||||||
in
|
in
|
||||||
match Marked.unmark typ with
|
match Marked.unmark typ with
|
||||||
| TArrow (t_in, t_out) ->
|
| TArrow (ts_in, t_out) ->
|
||||||
(* Here the output scope struct field is a function so we
|
(* Here the output scope struct field is a function so we
|
||||||
eta-expand it and insert logging instructions. Invariant:
|
eta-expand it and insert logging instructions. Invariant:
|
||||||
works because user-defined functions in scope have only one
|
works because user-defined functions in scope have only one
|
||||||
argument. *)
|
argument. *)
|
||||||
let param_var = Var.make "param" in
|
let params_vars =
|
||||||
|
ListLabels.mapi ts_in ~f:(fun i _ ->
|
||||||
|
Var.make ("param" ^ string_of_int i))
|
||||||
|
in
|
||||||
let f_markings =
|
let f_markings =
|
||||||
[ScopeName.get_info scope; StructField.get_info field]
|
[ScopeName.get_info scope; StructField.get_info field]
|
||||||
in
|
in
|
||||||
Expr.make_abs
|
Expr.make_abs
|
||||||
(Array.of_list [param_var])
|
(Array.of_list params_vars)
|
||||||
(tag_with_log_entry
|
(tag_with_log_entry
|
||||||
(tag_with_log_entry
|
(tag_with_log_entry
|
||||||
(Expr.eapp
|
(Expr.eapp
|
||||||
(tag_with_log_entry original_field_expr BeginCall
|
(tag_with_log_entry original_field_expr BeginCall
|
||||||
f_markings)
|
f_markings)
|
||||||
[
|
(ListLabels.mapi (List.combine params_vars ts_in)
|
||||||
tag_with_log_entry
|
~f:(fun i (param_var, t_in) ->
|
||||||
(Expr.make_var param_var (Expr.with_ty m t_in))
|
tag_with_log_entry
|
||||||
(VarDef (Marked.unmark t_in))
|
(Expr.make_var param_var (Expr.with_ty m t_in))
|
||||||
(f_markings @ [Marked.mark (Expr.pos e) "input"]);
|
(VarDef (Marked.unmark t_in))
|
||||||
]
|
(f_markings
|
||||||
|
@ [
|
||||||
|
Marked.mark (Expr.pos e)
|
||||||
|
("input" ^ string_of_int i);
|
||||||
|
])))
|
||||||
(Expr.with_ty m t_out))
|
(Expr.with_ty m t_out))
|
||||||
(VarDef (Marked.unmark t_out))
|
(VarDef (Marked.unmark t_out))
|
||||||
(f_markings @ [Marked.mark (Expr.pos e) "output"]))
|
(f_markings @ [Marked.mark (Expr.pos e) "output"]))
|
||||||
EndCall f_markings)
|
EndCall f_markings)
|
||||||
[t_in] (Expr.pos e)
|
ts_in (Expr.pos e)
|
||||||
| _ -> original_field_expr)
|
| _ -> original_field_expr)
|
||||||
(StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs))
|
(StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs))
|
||||||
(Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))
|
(Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))
|
||||||
@ -443,7 +450,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
|||||||
| m -> tag_with_log_entry e1_func BeginCall m
|
| m -> tag_with_log_entry e1_func BeginCall m
|
||||||
in
|
in
|
||||||
let new_args = List.map (translate_expr ctx) args in
|
let new_args = List.map (translate_expr ctx) args in
|
||||||
let input_typ, output_typ =
|
let input_typs, output_typ =
|
||||||
(* NOTE: this is a temporary solution, it works because it's assume that
|
(* NOTE: this is a temporary solution, it works because it's assume that
|
||||||
all function calls are from scope variable. However, this will change
|
all function calls are from scope variable. However, this will change
|
||||||
-- for more information see
|
-- for more information see
|
||||||
@ -452,8 +459,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
|||||||
let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in
|
let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in
|
||||||
match typ with
|
match typ with
|
||||||
| TArrow (marked_input_typ, marked_output_typ) ->
|
| TArrow (marked_input_typ, marked_output_typ) ->
|
||||||
Marked.unmark marked_input_typ, Marked.unmark marked_output_typ
|
( List.map Marked.unmark marked_input_typ,
|
||||||
| _ -> TAny, TAny
|
Marked.unmark marked_output_typ )
|
||||||
|
| _ -> [TAny], TAny
|
||||||
in
|
in
|
||||||
match Marked.unmark f with
|
match Marked.unmark f with
|
||||||
| ELocation (ScopelangScopeVar var) ->
|
| ELocation (ScopelangScopeVar var) ->
|
||||||
@ -467,21 +475,22 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
|||||||
TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars
|
TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars
|
||||||
in
|
in
|
||||||
match typ with
|
match typ with
|
||||||
| TArrow ((tin, _), (tout, _)) -> tin, tout
|
| TArrow (tin, (tout, _)) -> List.map Marked.unmark tin, tout
|
||||||
| _ ->
|
| _ ->
|
||||||
Errors.raise_spanned_error (Expr.pos e)
|
Errors.raise_spanned_error (Expr.pos e)
|
||||||
"Application of non-function toplevel variable")
|
"Application of non-function toplevel variable")
|
||||||
| _ -> TAny, TAny
|
| _ -> [TAny], TAny
|
||||||
in
|
in
|
||||||
let new_args =
|
let new_args =
|
||||||
match markings, new_args with
|
ListLabels.mapi (List.combine new_args input_typs)
|
||||||
| (_ :: _ as m), [new_arg] ->
|
~f:(fun i (new_arg, input_typ) ->
|
||||||
[
|
match markings with
|
||||||
tag_with_log_entry new_arg (VarDef input_typ)
|
| _ :: _ as m ->
|
||||||
(m @ [Marked.mark (Expr.pos e) "input"]);
|
tag_with_log_entry new_arg (VarDef input_typ)
|
||||||
]
|
(m @ [Marked.mark (Expr.pos e) ("input" ^ string_of_int i)])
|
||||||
| _ -> new_args
|
| _ -> new_arg)
|
||||||
in
|
in
|
||||||
|
|
||||||
let new_e = Expr.eapp e1_func new_args m in
|
let new_e = Expr.eapp e1_func new_args m in
|
||||||
let new_e =
|
let new_e =
|
||||||
match markings with
|
match markings with
|
||||||
@ -640,7 +649,7 @@ let translate_rule
|
|||||||
| OnlyInput -> tau
|
| OnlyInput -> tau
|
||||||
| Reentrant ->
|
| Reentrant ->
|
||||||
if is_func then tau
|
if is_func then tau
|
||||||
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
|
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
|
||||||
scope_let_expr = thunked_or_nonempty_new_e;
|
scope_let_expr = thunked_or_nonempty_new_e;
|
||||||
scope_let_kind = SubScopeVarDefinition;
|
scope_let_kind = SubScopeVarDefinition;
|
||||||
})
|
})
|
||||||
@ -935,7 +944,7 @@ let translate_scope_decl
|
|||||||
match var_ctx.scope_var_typ with
|
match var_ctx.scope_var_typ with
|
||||||
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
|
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
|
||||||
| _ ->
|
| _ ->
|
||||||
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
|
( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
|
||||||
pos_sigma ))
|
pos_sigma ))
|
||||||
| NoInput -> failwith "should not happen"
|
| NoInput -> failwith "should not happen"
|
||||||
in
|
in
|
||||||
|
@ -522,9 +522,9 @@ let interpret_program :
|
|||||||
match Marked.unmark ty with
|
match Marked.unmark ty with
|
||||||
| TArrow (ty_in, ty_out) ->
|
| TArrow (ty_in, ty_out) ->
|
||||||
Expr.make_abs
|
Expr.make_abs
|
||||||
[| Var.make "_" |]
|
(Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in)
|
||||||
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
|
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
|
||||||
[ty_in] (Expr.mark_pos mark_e)
|
ty_in (Expr.mark_pos mark_e)
|
||||||
| _ ->
|
| _ ->
|
||||||
Errors.raise_spanned_error (Marked.get_mark ty)
|
Errors.raise_spanned_error (Marked.get_mark ty)
|
||||||
"This scope needs input arguments to be executed. But the Catala \
|
"This scope needs input arguments to be executed. But the Catala \
|
||||||
|
@ -129,8 +129,8 @@ let rec translate_typ (tau : typ) : typ =
|
|||||||
| TAny -> TAny
|
| TAny -> TAny
|
||||||
| TArray ts -> TArray (translate_typ ts)
|
| TArray ts -> TArray (translate_typ ts)
|
||||||
(* catala is not polymorphic *)
|
(* catala is not polymorphic *)
|
||||||
| TArrow ((TLit TUnit, _), t2) -> TOption (translate_typ t2)
|
| TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2)
|
||||||
| TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2)
|
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
|
||||||
end
|
end
|
||||||
|
|
||||||
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
|
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
|
||||||
@ -458,7 +458,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
|
|||||||
thunked, then the variable is context. If it's not thunked, it's a
|
thunked, then the variable is context. If it's not thunked, it's a
|
||||||
regular input. *)
|
regular input. *)
|
||||||
match Marked.unmark typ with
|
match Marked.unmark typ with
|
||||||
| TArrow ((TLit TUnit, _), _) -> false
|
| TArrow ([(TLit TUnit, _)], _) -> false
|
||||||
| _ -> true)
|
| _ -> true)
|
||||||
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
|
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
|
||||||
| DestructuringSubScopeResults | Assertion ->
|
| DestructuringSubScopeResults | Assertion ->
|
||||||
|
@ -167,8 +167,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
|||||||
format_enum_name Ast.option_enum
|
format_enum_name Ast.option_enum
|
||||||
| TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e)
|
| TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e)
|
||||||
| TArrow (t1, t2) ->
|
| TArrow (t1, t2) ->
|
||||||
Format.fprintf fmt "@[<hov 2>%a ->@ %a@]" format_typ_with_parens t1
|
Format.fprintf fmt "@[<hov 2>%a@]"
|
||||||
format_typ_with_parens t2
|
(Format.pp_print_list
|
||||||
|
~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ")
|
||||||
|
format_typ_with_parens)
|
||||||
|
(t1 @ [t2])
|
||||||
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
|
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
|
||||||
| TAny -> Format.fprintf fmt "_"
|
| TAny -> Format.fprintf fmt "_"
|
||||||
|
|
||||||
|
@ -186,8 +186,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
|||||||
Format.fprintf fmt "Optional[%a]" format_typ some_typ
|
Format.fprintf fmt "Optional[%a]" format_typ some_typ
|
||||||
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e
|
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e
|
||||||
| TArrow (t1, t2) ->
|
| TArrow (t1, t2) ->
|
||||||
Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1
|
Format.fprintf fmt "Callable[[%a], %a]"
|
||||||
format_typ_with_parens t2
|
(Format.pp_print_list
|
||||||
|
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||||
|
format_typ_with_parens)
|
||||||
|
t1 format_typ_with_parens t2
|
||||||
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
|
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
|
||||||
| TAny -> Format.fprintf fmt "Any"
|
| TAny -> Format.fprintf fmt "Any"
|
||||||
|
|
||||||
|
@ -405,10 +405,12 @@ let find_or_create_funcdecl (ctx : context) (v : typed expr Var.t) (ty : typ) :
|
|||||||
| None -> (
|
| None -> (
|
||||||
match Marked.unmark ty with
|
match Marked.unmark ty with
|
||||||
| TArrow (t1, t2) ->
|
| TArrow (t1, t2) ->
|
||||||
let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in
|
let ctx, z3_t1 =
|
||||||
|
List.fold_left_map translate_typ ctx (List.map Marked.unmark t1)
|
||||||
|
in
|
||||||
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
|
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
|
||||||
let name = unique_name v in
|
let name = unique_name v in
|
||||||
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in
|
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name z3_t1 z3_t2 in
|
||||||
let ctx = add_funcdecl v fd ctx in
|
let ctx = add_funcdecl v fd ctx in
|
||||||
let ctx = add_z3var name v ty ctx in
|
let ctx = add_z3var name v ty ctx in
|
||||||
ctx, fd
|
ctx, fd
|
||||||
|
Loading…
Reference in New Issue
Block a user