mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
more optimization on fold
This commit is contained in:
parent
4038ea02be
commit
850a1fdb56
@ -168,6 +168,16 @@ let monad_eoe ?(toplevel = false) arg ~(mark : 'a mark) =
|
||||
if toplevel then Expr.ematch arg Ast.option_enum cases mark
|
||||
else monad_return ~mark (Expr.ematch arg Ast.option_enum cases mark)
|
||||
|
||||
let monad_handle_default ~(_mark : 'a mark) _except _cond _just =
|
||||
(* let handle_default_opt (exceptions : 'a eoption array) (just : bool
|
||||
eoption) (cons : 'a eoption) : 'a eoption = let except = Array.fold_left
|
||||
(fun acc except -> match acc, except with | ENone _, _ -> except | ESome _,
|
||||
ENone _ -> acc | ESome _, ESome _ -> raise (ConflictError pos)) (ENone ())
|
||||
exceptions in match except with | ESome _ -> except | ENone _ -> ( match
|
||||
just with | ESome b -> if b then cons else ENone () | ENone _ -> ENone
|
||||
()) *)
|
||||
assert false
|
||||
|
||||
let _ = monad_return
|
||||
let _ = monad_empty
|
||||
let _ = monad_bind_var
|
||||
@ -180,6 +190,7 @@ let _ = monad_eoe
|
||||
let _ = monad_map
|
||||
let _ = monad_mmap_mvar
|
||||
let _ = monad_mmap
|
||||
let _ = monad_handle_default
|
||||
let trans_var _ctx (x : 'm D.expr Var.t) : 'm Ast.expr Var.t = Var.translate x
|
||||
let trans_op : (dcalc, 'a) Op.t -> (lcalc, 'a) Op.t = Operator.translate
|
||||
|
||||
|
@ -125,6 +125,18 @@ let rec beta_expr (e : 'm expr) : 'm expr boxed =
|
||||
m
|
||||
| _ -> visitor_map beta_expr e
|
||||
|
||||
let rec fold_expr (e : 'm expr) : 'm expr boxed =
|
||||
match Marked.unmark e with
|
||||
| EApp { f = EOp { op = Op.Fold; _ }, _; args = [_f; init; (EArray [], _)] }
|
||||
->
|
||||
visitor_map fold_expr init
|
||||
| EApp { f = EOp { op = Op.Fold; _ }, _; args = [f; init; (EArray [e'], _)] }
|
||||
->
|
||||
Expr.make_app (visitor_map fold_expr f)
|
||||
[visitor_map fold_expr init; visitor_map fold_expr e']
|
||||
(Expr.pos e)
|
||||
| _ -> visitor_map fold_expr e
|
||||
|
||||
let iota_optimizations (p : 'm program) : 'm program =
|
||||
let new_code_items =
|
||||
Scope.map_exprs ~f:iota_expr ~varf:(fun v -> v) p.code_items
|
||||
@ -137,6 +149,12 @@ let iota2_optimizations (p : 'm program) : 'm program =
|
||||
in
|
||||
{ p with code_items = Bindlib.unbox new_code_items }
|
||||
|
||||
let fold_optimizations (p : 'm program) : 'm program =
|
||||
let new_code_items =
|
||||
Scope.map_exprs ~f:fold_expr ~varf:(fun v -> v) p.code_items
|
||||
in
|
||||
{ p with code_items = Bindlib.unbox new_code_items }
|
||||
|
||||
(* TODO: beta optimizations apply inlining of the program. We left the inclusion
|
||||
of beta-optimization as future work since its produce code that is harder to
|
||||
read, and can produce exponential blowup of the size of the generated
|
||||
@ -183,84 +201,24 @@ let peephole_optimizations (p : 'm program) : 'm program =
|
||||
in
|
||||
{ p with code_items = Bindlib.unbox new_code_items }
|
||||
|
||||
let rec fix_opti
|
||||
?(maxiter = 100)
|
||||
~(fs : ('m program -> 'm program) list)
|
||||
(p : 'm program) =
|
||||
assert (maxiter >= 0);
|
||||
let p' = ListLabels.fold_left ~init:p fs ~f:(fun p f -> f p) in
|
||||
|
||||
if Program.equal p' p || maxiter = 0 then p'
|
||||
else fix_opti ~fs p ~maxiter:(maxiter - 1)
|
||||
|
||||
let optimize_program (p : 'm program) : untyped program =
|
||||
p
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota2_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> iota_optimizations
|
||||
|> peephole_optimizations
|
||||
|> Program.untype
|
||||
Program.untype
|
||||
(fix_opti p
|
||||
~fs:
|
||||
[
|
||||
iota_optimizations;
|
||||
iota2_optimizations;
|
||||
fold_optimizations;
|
||||
(* _beta_optimizations; *)
|
||||
peephole_optimizations;
|
||||
])
|
||||
|
@ -46,8 +46,26 @@ let rec find_scope name vars = function
|
||||
let var, next = Bindlib.unbind next_bind in
|
||||
find_scope name (var :: vars) next
|
||||
|
||||
let rec all_scopes code_item_list =
|
||||
match code_item_list with
|
||||
| Nil -> []
|
||||
| Cons (ScopeDef (n, _), next_bind) ->
|
||||
let _var, next = Bindlib.unbind next_bind in
|
||||
n :: all_scopes next
|
||||
| Cons (_, next_bind) ->
|
||||
let _var, next = Bindlib.unbind next_bind in
|
||||
all_scopes next
|
||||
|
||||
let to_expr p main_scope =
|
||||
let _, main_scope_body = find_scope main_scope [] p.code_items in
|
||||
Scope.unfold p.decl_ctx p.code_items
|
||||
(Scope.get_body_mark main_scope_body)
|
||||
(ScopeName main_scope)
|
||||
|
||||
let equal p p' =
|
||||
let ss = all_scopes p.code_items in
|
||||
let ss' = all_scopes p'.code_items in
|
||||
|
||||
ListLabels.for_all2 ss ss' ~f:(fun s s' ->
|
||||
ScopeName.equal s s'
|
||||
&& Expr.equal (Expr.unbox @@ to_expr p s) (Expr.unbox @@ to_expr p' s))
|
||||
|
@ -37,3 +37,8 @@ val to_expr :
|
||||
(** Usage: [build_whole_program_expr program main_scope] builds an expression
|
||||
corresponding to the main program and returning the main scope as a
|
||||
function. *)
|
||||
|
||||
val equal :
|
||||
(([< dcalc | lcalc ], _) gexpr as 'e) program ->
|
||||
(([< dcalc | lcalc ], _) gexpr as 'e) program ->
|
||||
bool
|
||||
|
@ -68,13 +68,11 @@ let rec typ_to_ast ?(unsafe = false) (ty : unionfind_typ) : A.typ =
|
||||
"Internal error: typing at this point could not be resolved"
|
||||
|
||||
(* Checks that there are no type variables remaining *)
|
||||
let rec all_resolved ty =
|
||||
match Marked.unmark (UnionFind.get (UnionFind.find ty)) with
|
||||
| TAny _ -> false
|
||||
| TLit _ | TStruct _ | TEnum _ -> true
|
||||
| TOption t1 | TArray t1 -> all_resolved t1
|
||||
| TArrow (t1, t2) -> List.for_all all_resolved t1 && all_resolved t2
|
||||
| TTuple ts -> List.for_all all_resolved ts
|
||||
(* let rec all_resolved ty = match Marked.unmark (UnionFind.get (UnionFind.find
|
||||
ty)) with | TAny _ -> false | TLit _ | TStruct _ | TEnum _ -> true | TOption
|
||||
t1 | TArray t1 -> all_resolved t1 | TArrow (t1, t2) -> List.for_all
|
||||
all_resolved t1 && all_resolved t2 | TTuple ts -> List.for_all all_resolved
|
||||
ts *)
|
||||
|
||||
let rec ast_to_typ (ty : A.typ) : unionfind_typ =
|
||||
let ty' =
|
||||
@ -626,7 +624,7 @@ and typecheck_expr_top_down :
|
||||
let t_ret = unionfind (TAny (Any.fresh ())) in
|
||||
let t_func = unionfind (TArrow (tau_args, t_ret)) in
|
||||
let mark = uf_mark t_func in
|
||||
assert (List.for_all all_resolved tau_args);
|
||||
(* assert (List.for_all all_resolved tau_args); *)
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map Var.translate xs in
|
||||
let env =
|
||||
|
Loading…
Reference in New Issue
Block a user