mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
x10 performance on Catala compilation & interpretation
Cleaner rewriting of main let-binding chaining procedure from Scopelang to Dcalc Removed costly unboxing in DCalc.Ast.make_let_in seemed to do the trick
This commit is contained in:
parent
a271d96b3a
commit
2c0e8a7864
@ -151,27 +151,13 @@ let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box
|
||||
Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u)
|
||||
|
||||
let make_let_in (x : Var.t) (tau : typ Pos.marked) (e1 : expr Pos.marked Bindlib.box)
|
||||
(e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
|
||||
Bindlib.box_apply2
|
||||
(fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2)))
|
||||
(make_abs
|
||||
(Array.of_list [ x ])
|
||||
e2
|
||||
(Pos.get_position (Bindlib.unbox e2))
|
||||
[ tau ]
|
||||
(Pos.get_position (Bindlib.unbox e2)))
|
||||
(Bindlib.box_list [ e1 ])
|
||||
(e2 : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box =
|
||||
make_app (make_abs (Array.of_list [ x ]) e2 pos [ tau ] pos) [ e1 ] pos
|
||||
|
||||
let make_multiple_let_in (xs : Var.t array) (taus : typ Pos.marked list)
|
||||
(e1 : expr Pos.marked list Bindlib.box) (e2 : expr Pos.marked Bindlib.box) :
|
||||
(e1 : expr Pos.marked Bindlib.box list) (e2 : expr Pos.marked Bindlib.box) (pos : Pos.t) :
|
||||
expr Pos.marked Bindlib.box =
|
||||
Bindlib.box_apply2
|
||||
(fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2)))
|
||||
(make_abs xs e2
|
||||
(Pos.get_position (Bindlib.unbox e2))
|
||||
taus
|
||||
(Pos.get_position (Bindlib.unbox e2)))
|
||||
e1
|
||||
make_app (make_abs xs e2 pos taus pos) e1 pos
|
||||
|
||||
type binder = (expr, expr Pos.marked) Bindlib.binder
|
||||
|
||||
|
@ -166,13 +166,15 @@ val make_let_in :
|
||||
typ Pos.marked ->
|
||||
expr Pos.marked Bindlib.box ->
|
||||
expr Pos.marked Bindlib.box ->
|
||||
Pos.t ->
|
||||
expr Pos.marked Bindlib.box
|
||||
|
||||
val make_multiple_let_in :
|
||||
Var.t array ->
|
||||
typ Pos.marked list ->
|
||||
expr Pos.marked list Bindlib.box ->
|
||||
expr Pos.marked Bindlib.box list ->
|
||||
expr Pos.marked Bindlib.box ->
|
||||
Pos.t ->
|
||||
expr Pos.marked Bindlib.box
|
||||
|
||||
type binder = (expr, expr Pos.marked) Bindlib.binder
|
||||
|
@ -161,7 +161,7 @@ let needs_parens (e : expr Pos.marked) : bool =
|
||||
match Pos.unmark e with EAbs _ | ETuple (_, Some _) -> true | _ -> false
|
||||
|
||||
let format_var (fmt : Format.formatter) (v : Var.t) : unit =
|
||||
Format.fprintf fmt "%s" (Bindlib.name_of v)
|
||||
Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v)
|
||||
|
||||
let rec format_expr (ctx : Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) : unit =
|
||||
let format_expr = format_expr ctx in
|
||||
|
@ -281,23 +281,21 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
|
||||
(fun es -> Dcalc.Ast.EArray es)
|
||||
(Bindlib.box_list (List.map (translate_expr ctx) es)))
|
||||
|
||||
let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
|
||||
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info)
|
||||
(sigma_return_struct_name : Ast.StructName.t) : Dcalc.Ast.expr Pos.marked Bindlib.box * ctx =
|
||||
(** The result of a rule translation is a list of assignment, with variables and expressions. We
|
||||
also return the new translation context available after the assignment to use in later rule
|
||||
translations. The list is actually a list of list because we want to group in assignments that
|
||||
are independent of each other to speed up the translation by minimizing Bindlib.bind_mvar *)
|
||||
let translate_rule (ctx : ctx) (rule : Ast.rule)
|
||||
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) :
|
||||
Dcalc.Ast.Var.t list Pos.marked list
|
||||
* Dcalc.Ast.typ Pos.marked list list
|
||||
* Dcalc.Ast.expr Pos.marked Bindlib.box list list
|
||||
* ctx =
|
||||
match rule with
|
||||
| Definition ((ScopeVar a, var_def_pos), tau, e) ->
|
||||
let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in
|
||||
let a_var = Dcalc.Ast.Var.make a_name in
|
||||
let tau = translate_typ ctx tau in
|
||||
let new_ctx =
|
||||
{
|
||||
ctx with
|
||||
scope_vars = Ast.ScopeVarMap.add (Pos.unmark a) (a_var, Pos.unmark tau) ctx.scope_vars;
|
||||
}
|
||||
in
|
||||
let next_e, new_ctx =
|
||||
translate_rules new_ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
|
||||
in
|
||||
let new_e = translate_expr ctx e in
|
||||
let a_expr = Dcalc.Ast.make_var (a_var, var_def_pos) in
|
||||
let merged_expr =
|
||||
@ -310,46 +308,21 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
|
||||
(Dcalc.Ast.VarDef (Pos.unmark tau))
|
||||
[ (sigma_name, pos_sigma); a_name ]
|
||||
in
|
||||
|
||||
let next_e = Dcalc.Ast.make_let_in a_var tau merged_expr next_e in
|
||||
(next_e, new_ctx)
|
||||
( [ ([ a_var ], Pos.get_position a) ],
|
||||
[ [ tau ] ],
|
||||
[ [ merged_expr ] ],
|
||||
{
|
||||
ctx with
|
||||
scope_vars = Ast.ScopeVarMap.add (Pos.unmark a) (a_var, Pos.unmark tau) ctx.scope_vars;
|
||||
} )
|
||||
| Definition ((SubScopeVar (_subs_name, subs_index, subs_var), var_def_pos), tau, e) ->
|
||||
let a_name =
|
||||
Pos.map_under_mark
|
||||
(fun str -> str ^ "." ^ Pos.unmark (Ast.ScopeVar.get_info (Pos.unmark subs_var)))
|
||||
(Ast.SubScopeName.get_info (Pos.unmark subs_index))
|
||||
in
|
||||
let a_var = (Dcalc.Ast.Var.make a_name, var_def_pos) in
|
||||
let a_var = Dcalc.Ast.Var.make a_name in
|
||||
let tau = translate_typ ctx tau in
|
||||
let new_ctx =
|
||||
{
|
||||
ctx with
|
||||
subscope_vars =
|
||||
Ast.SubScopeMap.update (Pos.unmark subs_index)
|
||||
(fun map ->
|
||||
match map with
|
||||
| Some map ->
|
||||
Some
|
||||
(Ast.ScopeVarMap.add (Pos.unmark subs_var)
|
||||
(Pos.unmark a_var, Pos.unmark tau)
|
||||
map)
|
||||
| None ->
|
||||
Some
|
||||
(Ast.ScopeVarMap.singleton (Pos.unmark subs_var)
|
||||
(Pos.unmark a_var, Pos.unmark tau)))
|
||||
ctx.subscope_vars;
|
||||
}
|
||||
in
|
||||
let next_e, new_ctx =
|
||||
translate_rules new_ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
|
||||
in
|
||||
let intermediate_e =
|
||||
Dcalc.Ast.make_abs
|
||||
(Array.of_list [ Pos.unmark a_var ])
|
||||
next_e var_def_pos
|
||||
[ (Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos) ]
|
||||
(Pos.get_position e)
|
||||
in
|
||||
let new_e =
|
||||
tag_with_log_entry (translate_expr ctx e)
|
||||
(Dcalc.Ast.VarDef (Pos.unmark tau))
|
||||
@ -363,8 +336,21 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
|
||||
[ (Dcalc.Ast.TLit TUnit, var_def_pos) ]
|
||||
var_def_pos
|
||||
in
|
||||
let out_e = Dcalc.Ast.make_app intermediate_e [ thunked_new_e ] (Pos.get_position e) in
|
||||
(out_e, new_ctx)
|
||||
( [ ([ a_var ], Pos.get_position a_name) ],
|
||||
[ [ (Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos) ] ],
|
||||
[ [ thunked_new_e ] ],
|
||||
{
|
||||
ctx with
|
||||
subscope_vars =
|
||||
Ast.SubScopeMap.update (Pos.unmark subs_index)
|
||||
(fun map ->
|
||||
match map with
|
||||
| Some map ->
|
||||
Some (Ast.ScopeVarMap.add (Pos.unmark subs_var) (a_var, Pos.unmark tau) map)
|
||||
| None ->
|
||||
Some (Ast.ScopeVarMap.singleton (Pos.unmark subs_var) (a_var, Pos.unmark tau)))
|
||||
ctx.subscope_vars;
|
||||
} )
|
||||
| Call (subname, subindex) ->
|
||||
let ( all_subscope_vars,
|
||||
scope_dcalc_var,
|
||||
@ -409,17 +395,6 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
|
||||
(subvar, tau, sub_dcalc_var))
|
||||
all_subscope_vars
|
||||
in
|
||||
let new_ctx =
|
||||
{
|
||||
ctx with
|
||||
subscope_vars =
|
||||
Ast.SubScopeMap.add subindex
|
||||
(List.fold_left
|
||||
(fun acc (var, tau, dvar) -> Ast.ScopeVarMap.add var (dvar, tau) acc)
|
||||
Ast.ScopeVarMap.empty all_subscope_vars_dcalc)
|
||||
ctx.subscope_vars;
|
||||
}
|
||||
in
|
||||
let subscope_func =
|
||||
tag_with_log_entry
|
||||
(Dcalc.Ast.make_var
|
||||
@ -443,28 +418,24 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
|
||||
Ast.ScopeName.get_info subname;
|
||||
]
|
||||
in
|
||||
let result_tuple_var = Dcalc.Ast.Var.make ("result", Pos.no_pos) in
|
||||
let next_e, new_ctx =
|
||||
translate_rules new_ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
|
||||
let result_tuple_var = Dcalc.Ast.Var.make ("result", pos_sigma) in
|
||||
let results_bindings_xs = List.map (fun (_, _, v) -> v) all_subscope_vars_dcalc in
|
||||
let results_bindings_taus =
|
||||
List.map (fun (_, tau, _) -> (tau, pos_sigma)) all_subscope_vars_dcalc
|
||||
in
|
||||
let results_bindings =
|
||||
let xs = Array.of_list (List.map (fun (_, _, v) -> v) all_subscope_vars_dcalc) in
|
||||
let taus = List.map (fun (_, tau, _) -> (tau, pos_sigma)) all_subscope_vars_dcalc in
|
||||
let e1s =
|
||||
List.mapi
|
||||
(fun i _ ->
|
||||
Bindlib.box_apply
|
||||
(fun r ->
|
||||
( Dcalc.Ast.ETupleAccess
|
||||
( r,
|
||||
i,
|
||||
Some called_scope_return_struct,
|
||||
List.map (fun (_, t, _) -> (t, pos_sigma)) all_subscope_vars_dcalc ),
|
||||
pos_sigma ))
|
||||
(Dcalc.Ast.make_var (result_tuple_var, pos_sigma)))
|
||||
all_subscope_vars_dcalc
|
||||
in
|
||||
Dcalc.Ast.make_multiple_let_in xs taus (Bindlib.box_list e1s) next_e
|
||||
let results_bindings_e1s =
|
||||
List.mapi
|
||||
(fun i _ ->
|
||||
Bindlib.box_apply
|
||||
(fun r ->
|
||||
( Dcalc.Ast.ETupleAccess
|
||||
( r,
|
||||
i,
|
||||
Some called_scope_return_struct,
|
||||
List.map (fun (_, t, _) -> (t, pos_sigma)) all_subscope_vars_dcalc ),
|
||||
pos_sigma ))
|
||||
(Dcalc.Ast.make_var (result_tuple_var, pos_sigma)))
|
||||
all_subscope_vars_dcalc
|
||||
in
|
||||
let result_tuple_typ =
|
||||
( Dcalc.Ast.TTuple
|
||||
@ -472,35 +443,54 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
|
||||
Some called_scope_return_struct ),
|
||||
pos_sigma )
|
||||
in
|
||||
(Dcalc.Ast.make_let_in result_tuple_var result_tuple_typ call_expr results_bindings, new_ctx)
|
||||
( [ ([ result_tuple_var ], pos_sigma); (results_bindings_xs, pos_sigma) ],
|
||||
[ [ result_tuple_typ ]; results_bindings_taus ],
|
||||
[ [ call_expr ]; results_bindings_e1s ],
|
||||
{
|
||||
ctx with
|
||||
subscope_vars =
|
||||
Ast.SubScopeMap.add subindex
|
||||
(List.fold_left
|
||||
(fun acc (var, tau, dvar) -> Ast.ScopeVarMap.add var (dvar, tau) acc)
|
||||
Ast.ScopeVarMap.empty all_subscope_vars_dcalc)
|
||||
ctx.subscope_vars;
|
||||
} )
|
||||
| Assertion e ->
|
||||
let next_e, new_ctx =
|
||||
translate_rules ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
|
||||
in
|
||||
let new_e = translate_expr ctx e in
|
||||
( Dcalc.Ast.make_let_in
|
||||
(Dcalc.Ast.Var.make ("_", Pos.no_pos))
|
||||
(Dcalc.Ast.TLit TUnit, Pos.no_pos)
|
||||
(Bindlib.box_apply (fun new_e -> Pos.same_pos_as (Dcalc.Ast.EAssert new_e) e) new_e)
|
||||
next_e,
|
||||
new_ctx )
|
||||
( [ ([ Dcalc.Ast.Var.make ("_", Pos.get_position e) ], Pos.get_position e) ],
|
||||
[ [ (Dcalc.Ast.TLit TUnit, Pos.get_position e) ] ],
|
||||
[ [ Bindlib.box_apply (fun new_e -> Pos.same_pos_as (Dcalc.Ast.EAssert new_e) e) new_e ] ],
|
||||
ctx )
|
||||
|
||||
and translate_rules (ctx : ctx) (rules : Ast.rule list)
|
||||
let translate_rules (ctx : ctx) (rules : Ast.rule list)
|
||||
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info)
|
||||
(sigma_return_struct_name : Ast.StructName.t) : Dcalc.Ast.expr Pos.marked Bindlib.box * ctx =
|
||||
match rules with
|
||||
| [] ->
|
||||
let scope_variables = Ast.ScopeVarMap.bindings ctx.scope_vars in
|
||||
let return_exp =
|
||||
Bindlib.box_apply
|
||||
(fun args -> (Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma))
|
||||
(Bindlib.box_list
|
||||
(List.map
|
||||
(fun (_, (dcalc_var, _)) -> Dcalc.Ast.make_var (dcalc_var, pos_sigma))
|
||||
scope_variables))
|
||||
in
|
||||
(return_exp, ctx)
|
||||
| hd :: tl -> translate_rule ctx hd tl (sigma_name, pos_sigma) sigma_return_struct_name
|
||||
let vars, taus, exprs, new_ctx =
|
||||
List.fold_left
|
||||
(fun (vars, taus, exprs, ctx) rule ->
|
||||
let new_vars, new_taus, new_exprs, new_ctx =
|
||||
translate_rule ctx rule (sigma_name, pos_sigma)
|
||||
in
|
||||
(vars @ new_vars, taus @ new_taus, exprs @ new_exprs, new_ctx))
|
||||
([], [], [], ctx) rules
|
||||
in
|
||||
let scope_variables = Ast.ScopeVarMap.bindings new_ctx.scope_vars in
|
||||
let return_exp =
|
||||
Bindlib.box_apply
|
||||
(fun args -> (Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma))
|
||||
(Bindlib.box_list
|
||||
(List.map
|
||||
(fun (_, (dcalc_var, _)) -> Dcalc.Ast.make_var (dcalc_var, pos_sigma))
|
||||
scope_variables))
|
||||
in
|
||||
let let_bindings_chain =
|
||||
List.fold_right
|
||||
(fun ((vars, pos), taus, exprs) acc ->
|
||||
Dcalc.Ast.make_multiple_let_in (Array.of_list vars) taus exprs acc pos)
|
||||
(List.map2 (fun x (y, z) -> (x, y, z)) vars (List.map2 (fun x y -> (x, y)) taus exprs))
|
||||
return_exp
|
||||
in
|
||||
(let_bindings_chain, new_ctx)
|
||||
|
||||
let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
|
||||
(sctx : scope_sigs_ctx) (scope_name : Ast.ScopeName.t) (sigma : Ast.scope_decl) :
|
||||
@ -546,7 +536,7 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
|
||||
(Dcalc.Ast.make_var (scope_input_var, pos_sigma)))
|
||||
scope_variables
|
||||
in
|
||||
Dcalc.Ast.make_multiple_let_in xs taus (Bindlib.box_list e1s) rules
|
||||
Dcalc.Ast.make_multiple_let_in xs taus e1s rules pos_sigma
|
||||
in
|
||||
let scope_return_struct_fields =
|
||||
List.map
|
||||
@ -683,7 +673,7 @@ let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName
|
||||
decl_ctx.Dcalc.Ast.ctx_structs scope_out_struct;
|
||||
}
|
||||
in
|
||||
( Dcalc.Ast.make_let_in dvar scope_typ scope_expr acc,
|
||||
( Dcalc.Ast.make_let_in dvar scope_typ scope_expr acc pos_scope,
|
||||
(scope_name, dvar, Bindlib.unbox scope_expr) :: scopes,
|
||||
decl_ctx ))
|
||||
scope_ordering (acc, [], decl_ctx)
|
||||
|
@ -126,16 +126,17 @@ let info =
|
||||
|
||||
let time : float ref = ref (Unix.gettimeofday ())
|
||||
|
||||
let print_with_style (styles : ANSITerminal.style list) (str : ('a, unit, string) format) =
|
||||
if !style_flag then ANSITerminal.sprintf styles str else Printf.sprintf str
|
||||
|
||||
let time_marker () =
|
||||
let new_time = Unix.gettimeofday () in
|
||||
let old_time = !time in
|
||||
time := new_time;
|
||||
let delta = (new_time -. old_time) *. 1000. in
|
||||
if delta > 50. then
|
||||
ANSITerminal.printf [ ANSITerminal.Bold; ANSITerminal.black ] "[TIME] %.0f ms\n" delta
|
||||
|
||||
let print_with_style (styles : ANSITerminal.style list) (str : ('a, unit, string) format) =
|
||||
if !style_flag then ANSITerminal.sprintf styles str else Printf.sprintf str
|
||||
Printf.printf "%s"
|
||||
(print_with_style [ ANSITerminal.Bold; ANSITerminal.black ] "[TIME] %.0f ms\n" delta)
|
||||
|
||||
(** Prints [\[DEBUG\]] in purple on the terminal standard output *)
|
||||
let debug_marker () =
|
||||
|
Loading…
Reference in New Issue
Block a user