diff --git a/compiler/dcalc/ast.ml b/compiler/dcalc/ast.ml index 3c19da3f..11350503 100644 --- a/compiler/dcalc/ast.ml +++ b/compiler/dcalc/ast.ml @@ -408,6 +408,16 @@ let map_exprs_in_scopes ~f ~varf scopes = new_scope_body_expr new_next) ~init:(Bindlib.box Nil) scopes +let untype_program prg = + { + prg with + scopes = + Bindlib.unbox + (map_exprs_in_scopes + ~f:(fun e -> Bindlib.box (untype_expr e)) + ~varf:translate_var prg.scopes); + } + type 'm var = 'm expr Bindlib.var type 'm vars = 'm expr Bindlib.mvar diff --git a/compiler/dcalc/ast.mli b/compiler/dcalc/ast.mli index 072460ee..7de3bffd 100644 --- a/compiler/dcalc/ast.mli +++ b/compiler/dcalc/ast.mli @@ -251,6 +251,7 @@ val fold_marks : val get_scope_body_mark : ('expr, 'm) scope_body -> 'm mark val untype_expr : 'm marked_expr -> untyped marked_expr +val untype_program : 'm program -> untyped program (** {2 Boxed constructors} *) diff --git a/compiler/dcalc/optimizations.ml b/compiler/dcalc/optimizations.ml index 4439968b..34d4dc7c 100644 --- a/compiler/dcalc/optimizations.ml +++ b/compiler/dcalc/optimizations.ml @@ -247,8 +247,9 @@ let program_map (fun new_scopes -> { p with scopes = new_scopes }) (scopes_map t ctx p.scopes) -let optimize_program (p : 'm program) : 'm program = +let optimize_program (p : 'm program) : untyped program = Bindlib.unbox (program_map partial_evaluation { var_values = VarMap.empty; decl_ctx = p.decl_ctx } p) + |> untype_program diff --git a/compiler/dcalc/optimizations.mli b/compiler/dcalc/optimizations.mli index 27090edc..53c7a600 100644 --- a/compiler/dcalc/optimizations.mli +++ b/compiler/dcalc/optimizations.mli @@ -20,4 +20,4 @@ open Ast val optimize_expr : decl_ctx -> 'm marked_expr -> 'm marked_expr Bindlib.box -val optimize_program : 'm program -> 'm program +val optimize_program : 'm program -> untyped program diff --git a/compiler/driver.ml b/compiler/driver.ml index 6b752eb1..8b4fbe2d 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -284,7 +284,7 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Optimizing lambda calculus..."; Lcalc.Optimizations.optimize_program prgm end - else prgm + else Lcalc.Ast.untype_program prgm in let prgm = if options.closure_conversion then ( diff --git a/compiler/lcalc/ast.ml b/compiler/lcalc/ast.ml index 63cbcd31..3fa07c24 100644 --- a/compiler/lcalc/ast.ml +++ b/compiler/lcalc/ast.ml @@ -118,13 +118,17 @@ let eraise e1 pos = Bindlib.box (ERaise e1, pos) let ecatch e1 exn e2 pos = Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2 +let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v) + let map_expr ctx ~f e = + let m = Marked.get_mark e in match Marked.unmark e with - | EVar v -> evar v (Marked.get_mark e) + | EVar v -> evar (translate_var v) (Marked.get_mark e) | EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) (Marked.get_mark e) | EAbs (binder, typs) -> - eabs (Bindlib.box_mbinder (f ctx) binder) typs (Marked.get_mark e) + let vars, body = Bindlib.unmbind binder in + eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m | ETuple (args, s) -> etuple (List.map (f ctx) args) s (Marked.get_mark e) | ETupleAccess (e1, n, s_name, typs) -> etupleaccess ((f ctx) e1) n s_name typs (Marked.get_mark e) @@ -141,6 +145,26 @@ let map_expr ctx ~f e = eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Marked.get_mark e) | ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) (Marked.get_mark e) +let rec map_expr_top_down ~f e = + map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e) + +let map_expr_marks ~f e = + Bindlib.unbox + @@ map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e + +let untype_expr e = + map_expr_marks ~f:(fun m -> Untyped { pos = D.mark_pos m }) e + +let untype_program prg = + { + prg with + scopes = + Bindlib.unbox + (D.map_exprs_in_scopes + ~f:(fun e -> Bindlib.box (untype_expr e)) + ~varf:translate_var prg.scopes); + } + (** See [Bindlib.box_term] documentation for why we are doing that. *) let box_expr (e : 'm marked_expr) : 'm marked_expr Bindlib.box = let rec id_t () e = map_expr () ~f:id_t e in diff --git a/compiler/lcalc/ast.mli b/compiler/lcalc/ast.mli index bae1a9f0..10a595e5 100644 --- a/compiler/lcalc/ast.mli +++ b/compiler/lcalc/ast.mli @@ -91,6 +91,28 @@ val new_var : string -> 'm var type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder +(** {2 Program traversal} *) + +val map_expr : + 'a -> + f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) -> + ('m1 expr, 'm2 mark) Marked.t -> + 'm2 marked_expr Bindlib.box +(** See [Dcalc.Ast.map_expr] *) + +val map_expr_top_down : + f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) -> + 'm1 marked_expr -> + 'm2 marked_expr Bindlib.box +(** See [Dcalc.Ast.map_expr_top_down] *) + +val map_expr_marks : + f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr +(** See [Dcalc.Ast.map_expr_marks] *) + +val untype_expr : 'm marked_expr -> Dcalc.Ast.untyped marked_expr +val untype_program : 'm program -> Dcalc.Ast.untyped program + (** {1 Boxed constructors} *) val evar : 'm expr Bindlib.var -> 'm mark -> 'm marked_expr Bindlib.box diff --git a/compiler/lcalc/optimizations.ml b/compiler/lcalc/optimizations.ml index ef521799..fc477470 100644 --- a/compiler/lcalc/optimizations.ml +++ b/compiler/lcalc/optimizations.ml @@ -150,5 +150,5 @@ let peephole_optimizations (p : 'm program) : 'm program = in { p with scopes = Bindlib.unbox new_scopes } -let optimize_program (p : 'm program) : 'm program = - p |> iota_optimizations |> peephole_optimizations +let optimize_program (p : 'm program) : Dcalc.Ast.untyped program = + p |> iota_optimizations |> peephole_optimizations |> untype_program diff --git a/compiler/lcalc/optimizations.mli b/compiler/lcalc/optimizations.mli index 59776c96..da3af2c5 100644 --- a/compiler/lcalc/optimizations.mli +++ b/compiler/lcalc/optimizations.mli @@ -16,6 +16,6 @@ open Ast -val optimize_program : 'm program -> 'm program +val optimize_program : 'm program -> Dcalc.Ast.untyped program (** Warning/todo: no effort was yet made to ensure correct propagation of type annotations in the typed case *) diff --git a/compiler/plugin.ml b/compiler/plugin.ml index 09713b0a..08fce637 100644 --- a/compiler/plugin.ml +++ b/compiler/plugin.ml @@ -21,7 +21,7 @@ type 'ast gen = { } type t = - | Lcalc of Dcalc.Ast.typed Lcalc.Ast.program gen + | Lcalc of Dcalc.Ast.untyped Lcalc.Ast.program gen | Scalc of Scalc.Ast.program gen let name = function Lcalc { name; _ } | Scalc { name; _ } -> name diff --git a/compiler/plugin.mli b/compiler/plugin.mli index 3f7d0c14..27acbef7 100644 --- a/compiler/plugin.mli +++ b/compiler/plugin.mli @@ -23,7 +23,7 @@ type 'ast gen = { } type t = - | Lcalc of Dcalc.Ast.typed Lcalc.Ast.program gen + | Lcalc of Dcalc.Ast.untyped Lcalc.Ast.program gen | Scalc of Scalc.Ast.program gen val find : string -> t @@ -42,7 +42,7 @@ module PluginAPI : sig name:string -> extension:string -> (string option -> - Dcalc.Ast.typed Lcalc.Ast.program -> + Dcalc.Ast.untyped Lcalc.Ast.program -> Scopelang.Dependency.TVertex.t list -> unit) -> unit