x10 performance on Catala compilation & interpretation

Cleaner rewriting of main let-binding chaining procedure from Scopelang to Dcalc
Removed costly unboxing in DCalc.Ast.make_let_in seemed to do the trick
This commit is contained in:
Denis Merigoux 2021-10-28 15:24:39 +02:00
parent a271d96b3a
commit 2c0e8a7864
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
5 changed files with 108 additions and 129 deletions

View File

@ -151,27 +151,13 @@ let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box
Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u)
let make_let_in (x : Var.t) (tau : typ Pos.marked) (e1 : expr Pos.marked Bindlib.box)
(e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
Bindlib.box_apply2
(fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2)))
(make_abs
(Array.of_list [ x ])
e2
(Pos.get_position (Bindlib.unbox e2))
[ tau ]
(Pos.get_position (Bindlib.unbox e2)))
(Bindlib.box_list [ e1 ])
(e2 : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box =
make_app (make_abs (Array.of_list [ x ]) e2 pos [ tau ] pos) [ e1 ] pos
let make_multiple_let_in (xs : Var.t array) (taus : typ Pos.marked list)
(e1 : expr Pos.marked list Bindlib.box) (e2 : expr Pos.marked Bindlib.box) :
(e1 : expr Pos.marked Bindlib.box list) (e2 : expr Pos.marked Bindlib.box) (pos : Pos.t) :
expr Pos.marked Bindlib.box =
Bindlib.box_apply2
(fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2)))
(make_abs xs e2
(Pos.get_position (Bindlib.unbox e2))
taus
(Pos.get_position (Bindlib.unbox e2)))
e1
make_app (make_abs xs e2 pos taus pos) e1 pos
type binder = (expr, expr Pos.marked) Bindlib.binder

View File

@ -166,13 +166,15 @@ val make_let_in :
typ Pos.marked ->
expr Pos.marked Bindlib.box ->
expr Pos.marked Bindlib.box ->
Pos.t ->
expr Pos.marked Bindlib.box
val make_multiple_let_in :
Var.t array ->
typ Pos.marked list ->
expr Pos.marked list Bindlib.box ->
expr Pos.marked Bindlib.box list ->
expr Pos.marked Bindlib.box ->
Pos.t ->
expr Pos.marked Bindlib.box
type binder = (expr, expr Pos.marked) Bindlib.binder

View File

@ -161,7 +161,7 @@ let needs_parens (e : expr Pos.marked) : bool =
match Pos.unmark e with EAbs _ | ETuple (_, Some _) -> true | _ -> false
let format_var (fmt : Format.formatter) (v : Var.t) : unit =
Format.fprintf fmt "%s" (Bindlib.name_of v)
Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v)
let rec format_expr (ctx : Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) : unit =
let format_expr = format_expr ctx in

View File

@ -281,23 +281,21 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
(fun es -> Dcalc.Ast.EArray es)
(Bindlib.box_list (List.map (translate_expr ctx) es)))
let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info)
(sigma_return_struct_name : Ast.StructName.t) : Dcalc.Ast.expr Pos.marked Bindlib.box * ctx =
(** The result of a rule translation is a list of assignment, with variables and expressions. We
also return the new translation context available after the assignment to use in later rule
translations. The list is actually a list of list because we want to group in assignments that
are independent of each other to speed up the translation by minimizing Bindlib.bind_mvar *)
let translate_rule (ctx : ctx) (rule : Ast.rule)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) :
Dcalc.Ast.Var.t list Pos.marked list
* Dcalc.Ast.typ Pos.marked list list
* Dcalc.Ast.expr Pos.marked Bindlib.box list list
* ctx =
match rule with
| Definition ((ScopeVar a, var_def_pos), tau, e) ->
let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in
let a_var = Dcalc.Ast.Var.make a_name in
let tau = translate_typ ctx tau in
let new_ctx =
{
ctx with
scope_vars = Ast.ScopeVarMap.add (Pos.unmark a) (a_var, Pos.unmark tau) ctx.scope_vars;
}
in
let next_e, new_ctx =
translate_rules new_ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
in
let new_e = translate_expr ctx e in
let a_expr = Dcalc.Ast.make_var (a_var, var_def_pos) in
let merged_expr =
@ -310,46 +308,21 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
(Dcalc.Ast.VarDef (Pos.unmark tau))
[ (sigma_name, pos_sigma); a_name ]
in
let next_e = Dcalc.Ast.make_let_in a_var tau merged_expr next_e in
(next_e, new_ctx)
( [ ([ a_var ], Pos.get_position a) ],
[ [ tau ] ],
[ [ merged_expr ] ],
{
ctx with
scope_vars = Ast.ScopeVarMap.add (Pos.unmark a) (a_var, Pos.unmark tau) ctx.scope_vars;
} )
| Definition ((SubScopeVar (_subs_name, subs_index, subs_var), var_def_pos), tau, e) ->
let a_name =
Pos.map_under_mark
(fun str -> str ^ "." ^ Pos.unmark (Ast.ScopeVar.get_info (Pos.unmark subs_var)))
(Ast.SubScopeName.get_info (Pos.unmark subs_index))
in
let a_var = (Dcalc.Ast.Var.make a_name, var_def_pos) in
let a_var = Dcalc.Ast.Var.make a_name in
let tau = translate_typ ctx tau in
let new_ctx =
{
ctx with
subscope_vars =
Ast.SubScopeMap.update (Pos.unmark subs_index)
(fun map ->
match map with
| Some map ->
Some
(Ast.ScopeVarMap.add (Pos.unmark subs_var)
(Pos.unmark a_var, Pos.unmark tau)
map)
| None ->
Some
(Ast.ScopeVarMap.singleton (Pos.unmark subs_var)
(Pos.unmark a_var, Pos.unmark tau)))
ctx.subscope_vars;
}
in
let next_e, new_ctx =
translate_rules new_ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
in
let intermediate_e =
Dcalc.Ast.make_abs
(Array.of_list [ Pos.unmark a_var ])
next_e var_def_pos
[ (Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos) ]
(Pos.get_position e)
in
let new_e =
tag_with_log_entry (translate_expr ctx e)
(Dcalc.Ast.VarDef (Pos.unmark tau))
@ -363,8 +336,21 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
[ (Dcalc.Ast.TLit TUnit, var_def_pos) ]
var_def_pos
in
let out_e = Dcalc.Ast.make_app intermediate_e [ thunked_new_e ] (Pos.get_position e) in
(out_e, new_ctx)
( [ ([ a_var ], Pos.get_position a_name) ],
[ [ (Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos) ] ],
[ [ thunked_new_e ] ],
{
ctx with
subscope_vars =
Ast.SubScopeMap.update (Pos.unmark subs_index)
(fun map ->
match map with
| Some map ->
Some (Ast.ScopeVarMap.add (Pos.unmark subs_var) (a_var, Pos.unmark tau) map)
| None ->
Some (Ast.ScopeVarMap.singleton (Pos.unmark subs_var) (a_var, Pos.unmark tau)))
ctx.subscope_vars;
} )
| Call (subname, subindex) ->
let ( all_subscope_vars,
scope_dcalc_var,
@ -409,17 +395,6 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
(subvar, tau, sub_dcalc_var))
all_subscope_vars
in
let new_ctx =
{
ctx with
subscope_vars =
Ast.SubScopeMap.add subindex
(List.fold_left
(fun acc (var, tau, dvar) -> Ast.ScopeVarMap.add var (dvar, tau) acc)
Ast.ScopeVarMap.empty all_subscope_vars_dcalc)
ctx.subscope_vars;
}
in
let subscope_func =
tag_with_log_entry
(Dcalc.Ast.make_var
@ -443,28 +418,24 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
Ast.ScopeName.get_info subname;
]
in
let result_tuple_var = Dcalc.Ast.Var.make ("result", Pos.no_pos) in
let next_e, new_ctx =
translate_rules new_ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
let result_tuple_var = Dcalc.Ast.Var.make ("result", pos_sigma) in
let results_bindings_xs = List.map (fun (_, _, v) -> v) all_subscope_vars_dcalc in
let results_bindings_taus =
List.map (fun (_, tau, _) -> (tau, pos_sigma)) all_subscope_vars_dcalc
in
let results_bindings =
let xs = Array.of_list (List.map (fun (_, _, v) -> v) all_subscope_vars_dcalc) in
let taus = List.map (fun (_, tau, _) -> (tau, pos_sigma)) all_subscope_vars_dcalc in
let e1s =
List.mapi
(fun i _ ->
Bindlib.box_apply
(fun r ->
( Dcalc.Ast.ETupleAccess
( r,
i,
Some called_scope_return_struct,
List.map (fun (_, t, _) -> (t, pos_sigma)) all_subscope_vars_dcalc ),
pos_sigma ))
(Dcalc.Ast.make_var (result_tuple_var, pos_sigma)))
all_subscope_vars_dcalc
in
Dcalc.Ast.make_multiple_let_in xs taus (Bindlib.box_list e1s) next_e
let results_bindings_e1s =
List.mapi
(fun i _ ->
Bindlib.box_apply
(fun r ->
( Dcalc.Ast.ETupleAccess
( r,
i,
Some called_scope_return_struct,
List.map (fun (_, t, _) -> (t, pos_sigma)) all_subscope_vars_dcalc ),
pos_sigma ))
(Dcalc.Ast.make_var (result_tuple_var, pos_sigma)))
all_subscope_vars_dcalc
in
let result_tuple_typ =
( Dcalc.Ast.TTuple
@ -472,35 +443,54 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
Some called_scope_return_struct ),
pos_sigma )
in
(Dcalc.Ast.make_let_in result_tuple_var result_tuple_typ call_expr results_bindings, new_ctx)
( [ ([ result_tuple_var ], pos_sigma); (results_bindings_xs, pos_sigma) ],
[ [ result_tuple_typ ]; results_bindings_taus ],
[ [ call_expr ]; results_bindings_e1s ],
{
ctx with
subscope_vars =
Ast.SubScopeMap.add subindex
(List.fold_left
(fun acc (var, tau, dvar) -> Ast.ScopeVarMap.add var (dvar, tau) acc)
Ast.ScopeVarMap.empty all_subscope_vars_dcalc)
ctx.subscope_vars;
} )
| Assertion e ->
let next_e, new_ctx =
translate_rules ctx rest (sigma_name, pos_sigma) sigma_return_struct_name
in
let new_e = translate_expr ctx e in
( Dcalc.Ast.make_let_in
(Dcalc.Ast.Var.make ("_", Pos.no_pos))
(Dcalc.Ast.TLit TUnit, Pos.no_pos)
(Bindlib.box_apply (fun new_e -> Pos.same_pos_as (Dcalc.Ast.EAssert new_e) e) new_e)
next_e,
new_ctx )
( [ ([ Dcalc.Ast.Var.make ("_", Pos.get_position e) ], Pos.get_position e) ],
[ [ (Dcalc.Ast.TLit TUnit, Pos.get_position e) ] ],
[ [ Bindlib.box_apply (fun new_e -> Pos.same_pos_as (Dcalc.Ast.EAssert new_e) e) new_e ] ],
ctx )
and translate_rules (ctx : ctx) (rules : Ast.rule list)
let translate_rules (ctx : ctx) (rules : Ast.rule list)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info)
(sigma_return_struct_name : Ast.StructName.t) : Dcalc.Ast.expr Pos.marked Bindlib.box * ctx =
match rules with
| [] ->
let scope_variables = Ast.ScopeVarMap.bindings ctx.scope_vars in
let return_exp =
Bindlib.box_apply
(fun args -> (Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma))
(Bindlib.box_list
(List.map
(fun (_, (dcalc_var, _)) -> Dcalc.Ast.make_var (dcalc_var, pos_sigma))
scope_variables))
in
(return_exp, ctx)
| hd :: tl -> translate_rule ctx hd tl (sigma_name, pos_sigma) sigma_return_struct_name
let vars, taus, exprs, new_ctx =
List.fold_left
(fun (vars, taus, exprs, ctx) rule ->
let new_vars, new_taus, new_exprs, new_ctx =
translate_rule ctx rule (sigma_name, pos_sigma)
in
(vars @ new_vars, taus @ new_taus, exprs @ new_exprs, new_ctx))
([], [], [], ctx) rules
in
let scope_variables = Ast.ScopeVarMap.bindings new_ctx.scope_vars in
let return_exp =
Bindlib.box_apply
(fun args -> (Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma))
(Bindlib.box_list
(List.map
(fun (_, (dcalc_var, _)) -> Dcalc.Ast.make_var (dcalc_var, pos_sigma))
scope_variables))
in
let let_bindings_chain =
List.fold_right
(fun ((vars, pos), taus, exprs) acc ->
Dcalc.Ast.make_multiple_let_in (Array.of_list vars) taus exprs acc pos)
(List.map2 (fun x (y, z) -> (x, y, z)) vars (List.map2 (fun x y -> (x, y)) taus exprs))
return_exp
in
(let_bindings_chain, new_ctx)
let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
(sctx : scope_sigs_ctx) (scope_name : Ast.ScopeName.t) (sigma : Ast.scope_decl) :
@ -546,7 +536,7 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
(Dcalc.Ast.make_var (scope_input_var, pos_sigma)))
scope_variables
in
Dcalc.Ast.make_multiple_let_in xs taus (Bindlib.box_list e1s) rules
Dcalc.Ast.make_multiple_let_in xs taus e1s rules pos_sigma
in
let scope_return_struct_fields =
List.map
@ -683,7 +673,7 @@ let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName
decl_ctx.Dcalc.Ast.ctx_structs scope_out_struct;
}
in
( Dcalc.Ast.make_let_in dvar scope_typ scope_expr acc,
( Dcalc.Ast.make_let_in dvar scope_typ scope_expr acc pos_scope,
(scope_name, dvar, Bindlib.unbox scope_expr) :: scopes,
decl_ctx ))
scope_ordering (acc, [], decl_ctx)

View File

@ -126,16 +126,17 @@ let info =
let time : float ref = ref (Unix.gettimeofday ())
let print_with_style (styles : ANSITerminal.style list) (str : ('a, unit, string) format) =
if !style_flag then ANSITerminal.sprintf styles str else Printf.sprintf str
let time_marker () =
let new_time = Unix.gettimeofday () in
let old_time = !time in
time := new_time;
let delta = (new_time -. old_time) *. 1000. in
if delta > 50. then
ANSITerminal.printf [ ANSITerminal.Bold; ANSITerminal.black ] "[TIME] %.0f ms\n" delta
let print_with_style (styles : ANSITerminal.style list) (str : ('a, unit, string) format) =
if !style_flag then ANSITerminal.sprintf styles str else Printf.sprintf str
Printf.printf "%s"
(print_with_style [ ANSITerminal.Bold; ANSITerminal.black ] "[TIME] %.0f ms\n" delta)
(** Prints [\[DEBUG\]] in purple on the terminal standard output *)
let debug_marker () =