more optimization on fold

This commit is contained in:
adelaett 2023-03-17 11:34:52 +01:00
parent 4038ea02be
commit 850a1fdb56
5 changed files with 78 additions and 88 deletions

View File

@ -168,6 +168,16 @@ let monad_eoe ?(toplevel = false) arg ~(mark : 'a mark) =
if toplevel then Expr.ematch arg Ast.option_enum cases mark
else monad_return ~mark (Expr.ematch arg Ast.option_enum cases mark)
let monad_handle_default ~(_mark : 'a mark) _except _cond _just =
(* let handle_default_opt (exceptions : 'a eoption array) (just : bool
eoption) (cons : 'a eoption) : 'a eoption = let except = Array.fold_left
(fun acc except -> match acc, except with | ENone _, _ -> except | ESome _,
ENone _ -> acc | ESome _, ESome _ -> raise (ConflictError pos)) (ENone ())
exceptions in match except with | ESome _ -> except | ENone _ -> ( match
just with | ESome b -> if b then cons else ENone () | ENone _ -> ENone
()) *)
assert false
let _ = monad_return
let _ = monad_empty
let _ = monad_bind_var
@ -180,6 +190,7 @@ let _ = monad_eoe
let _ = monad_map
let _ = monad_mmap_mvar
let _ = monad_mmap
let _ = monad_handle_default
let trans_var _ctx (x : 'm D.expr Var.t) : 'm Ast.expr Var.t = Var.translate x
let trans_op : (dcalc, 'a) Op.t -> (lcalc, 'a) Op.t = Operator.translate

View File

@ -125,6 +125,18 @@ let rec beta_expr (e : 'm expr) : 'm expr boxed =
m
| _ -> visitor_map beta_expr e
let rec fold_expr (e : 'm expr) : 'm expr boxed =
match Marked.unmark e with
| EApp { f = EOp { op = Op.Fold; _ }, _; args = [_f; init; (EArray [], _)] }
->
visitor_map fold_expr init
| EApp { f = EOp { op = Op.Fold; _ }, _; args = [f; init; (EArray [e'], _)] }
->
Expr.make_app (visitor_map fold_expr f)
[visitor_map fold_expr init; visitor_map fold_expr e']
(Expr.pos e)
| _ -> visitor_map fold_expr e
let iota_optimizations (p : 'm program) : 'm program =
let new_code_items =
Scope.map_exprs ~f:iota_expr ~varf:(fun v -> v) p.code_items
@ -137,6 +149,12 @@ let iota2_optimizations (p : 'm program) : 'm program =
in
{ p with code_items = Bindlib.unbox new_code_items }
let fold_optimizations (p : 'm program) : 'm program =
let new_code_items =
Scope.map_exprs ~f:fold_expr ~varf:(fun v -> v) p.code_items
in
{ p with code_items = Bindlib.unbox new_code_items }
(* TODO: beta optimizations apply inlining of the program. We left the inclusion
of beta-optimization as future work since its produce code that is harder to
read, and can produce exponential blowup of the size of the generated
@ -183,84 +201,24 @@ let peephole_optimizations (p : 'm program) : 'm program =
in
{ p with code_items = Bindlib.unbox new_code_items }
let rec fix_opti
?(maxiter = 100)
~(fs : ('m program -> 'm program) list)
(p : 'm program) =
assert (maxiter >= 0);
let p' = ListLabels.fold_left ~init:p fs ~f:(fun p f -> f p) in
if Program.equal p' p || maxiter = 0 then p'
else fix_opti ~fs p ~maxiter:(maxiter - 1)
let optimize_program (p : 'm program) : untyped program =
p
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota2_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> iota_optimizations
|> peephole_optimizations
|> Program.untype
Program.untype
(fix_opti p
~fs:
[
iota_optimizations;
iota2_optimizations;
fold_optimizations;
(* _beta_optimizations; *)
peephole_optimizations;
])

View File

@ -46,8 +46,26 @@ let rec find_scope name vars = function
let var, next = Bindlib.unbind next_bind in
find_scope name (var :: vars) next
let rec all_scopes code_item_list =
match code_item_list with
| Nil -> []
| Cons (ScopeDef (n, _), next_bind) ->
let _var, next = Bindlib.unbind next_bind in
n :: all_scopes next
| Cons (_, next_bind) ->
let _var, next = Bindlib.unbind next_bind in
all_scopes next
let to_expr p main_scope =
let _, main_scope_body = find_scope main_scope [] p.code_items in
Scope.unfold p.decl_ctx p.code_items
(Scope.get_body_mark main_scope_body)
(ScopeName main_scope)
let equal p p' =
let ss = all_scopes p.code_items in
let ss' = all_scopes p'.code_items in
ListLabels.for_all2 ss ss' ~f:(fun s s' ->
ScopeName.equal s s'
&& Expr.equal (Expr.unbox @@ to_expr p s) (Expr.unbox @@ to_expr p' s))

View File

@ -37,3 +37,8 @@ val to_expr :
(** Usage: [build_whole_program_expr program main_scope] builds an expression
corresponding to the main program and returning the main scope as a
function. *)
val equal :
(([< dcalc | lcalc ], _) gexpr as 'e) program ->
(([< dcalc | lcalc ], _) gexpr as 'e) program ->
bool

View File

@ -68,13 +68,11 @@ let rec typ_to_ast ?(unsafe = false) (ty : unionfind_typ) : A.typ =
"Internal error: typing at this point could not be resolved"
(* Checks that there are no type variables remaining *)
let rec all_resolved ty =
match Marked.unmark (UnionFind.get (UnionFind.find ty)) with
| TAny _ -> false
| TLit _ | TStruct _ | TEnum _ -> true
| TOption t1 | TArray t1 -> all_resolved t1
| TArrow (t1, t2) -> List.for_all all_resolved t1 && all_resolved t2
| TTuple ts -> List.for_all all_resolved ts
(* let rec all_resolved ty = match Marked.unmark (UnionFind.get (UnionFind.find
ty)) with | TAny _ -> false | TLit _ | TStruct _ | TEnum _ -> true | TOption
t1 | TArray t1 -> all_resolved t1 | TArrow (t1, t2) -> List.for_all
all_resolved t1 && all_resolved t2 | TTuple ts -> List.for_all all_resolved
ts *)
let rec ast_to_typ (ty : A.typ) : unionfind_typ =
let ty' =
@ -626,7 +624,7 @@ and typecheck_expr_top_down :
let t_ret = unionfind (TAny (Any.fresh ())) in
let t_func = unionfind (TArrow (tau_args, t_ret)) in
let mark = uf_mark t_func in
assert (List.for_all all_resolved tau_args);
(* assert (List.for_all all_resolved tau_args); *)
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map Var.translate xs in
let env =