Closure conversion: recursive hoisting

This commit is contained in:
Louis Gesbert 2024-06-24 11:51:03 +02:00
parent e78ea378bd
commit c0ad0e8820
2 changed files with 13 additions and 8 deletions

View File

@ -30,6 +30,11 @@ type 'm ctx = {
let new_var ?(pfx = "") name_context =
name_context.counter <- name_context.counter + 1;
Var.make (pfx ^ name_context.prefix ^ string_of_int name_context.counter)
(* TODO: Closures end up as a toplevel names. However for now we assume
toplevel names are unique, this is a temporary workaround to avoid
name wrangling in the backends. We need to have a better system for
name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but
also OCaml, Python, etc. *)
let new_context prefix = { prefix; counter = 0 }
@ -562,17 +567,17 @@ let rec hoist_closures_expr :
args ([], [])
in
collected_closures, Expr.eappop ~op ~args:new_args ~tys (Mark.get e)
| EAbs { tys; _ } ->
| EAbs { binder; tys } ->
(* this is the closure we want to hoist *)
let closure_var = new_var ~pfx:"closure_" name_context in
(* TODO: This will end up as a toplevel name. However for now we assume
toplevel names are unique, but this breaks this assertions and can lead
to name wrangling in the backends. We need to have a better system for
name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but
also OCaml, Python, etc. *)
let pos = Expr.mark_pos m in
let ty = Expr.maybe_ty ~typ:(TArrow (tys, (TAny, pos))) m in
( [{ name = closure_var; ty; closure = Expr.rebox e }],
let vars, body = Bindlib.unmbind binder in
let collected_closures, new_body =
(hoist_closures_expr name_context) body
in
let closure = Expr.make_abs vars new_body tys pos in
( { name = closure_var; ty; closure } :: collected_closures,
Expr.make_var closure_var m )
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EFatalError _ | EAppOp _ | EIfThenElse _

View File

@ -53,7 +53,7 @@ let rec format_expr
(StructField.Map.bindings es)
Print.punctuation "}"
| ETuple es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "()"
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "("
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))