From 5f882e35a2daf1ef89b4fb8fe7fe7093255f4c9d Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 6 Jul 2022 16:56:16 +0200 Subject: [PATCH] dcalc ast: make the map function more polymorphic --- compiler/dcalc/ast.ml | 46 ++++++++++++++++++++++++++++------------ compiler/dcalc/ast.mli | 15 ++++++++++--- compiler/dcalc/typing.ml | 18 ++++++---------- 3 files changed, 51 insertions(+), 28 deletions(-) diff --git a/compiler/dcalc/ast.ml b/compiler/dcalc/ast.ml index 506bd0b2..8da4af76 100644 --- a/compiler/dcalc/ast.ml +++ b/compiler/dcalc/ast.ml @@ -293,33 +293,51 @@ let eifthenelse e1 e2 e3 mark = let eerroronempty e1 mark = Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1 +let translate_var v = + Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v) + let map_expr ctx ~f e = + let m = Marked.get_mark e in match Marked.unmark e with - | EVar v -> evar v (Marked.get_mark e) + | EVar v -> evar (translate_var v) m | EApp (e1, args) -> - eapp (f ctx e1) (List.map (f ctx) args) (Marked.get_mark e) + eapp (f ctx e1) (List.map (f ctx) args) m | EAbs (binder, typs) -> - eabs (Bindlib.box_mbinder (f ctx) binder) typs (Marked.get_mark e) - | ETuple (args, s) -> etuple (List.map (f ctx) args) s (Marked.get_mark e) + let vars, body = Bindlib.unmbind binder in + eabs + (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) + typs + m + | ETuple (args, s) -> etuple (List.map (f ctx) args) s m | ETupleAccess (e1, n, s_name, typs) -> - etupleaccess ((f ctx) e1) n s_name typs (Marked.get_mark e) + etupleaccess ((f ctx) e1) n s_name typs m | EInj (e1, i, e_name, typs) -> - einj ((f ctx) e1) i e_name typs (Marked.get_mark e) + einj ((f ctx) e1) i e_name typs m | EMatch (arg, arms, e_name) -> - ematch ((f ctx) arg) (List.map (f ctx) arms) e_name (Marked.get_mark e) - | EArray args -> earray (List.map (f ctx) args) (Marked.get_mark e) - | ELit l -> elit l (Marked.get_mark e) - | EAssert e1 -> eassert ((f ctx) e1) (Marked.get_mark e) - | EOp op -> Bindlib.box (EOp op, Marked.get_mark e) + ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m + | EArray args -> earray (List.map (f ctx) args) m + | ELit l -> elit l m + | EAssert e1 -> eassert ((f ctx) e1) m + | EOp op -> Bindlib.box (EOp op, m) | EDefault (excepts, just, cons) -> edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) - (Marked.get_mark e) + m | EIfThenElse (e1, e2, e3) -> - eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Marked.get_mark e) - | ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) (Marked.get_mark e) + eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m + | ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m + +let rec map_expr_top_down ~f e = + map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e) + +let map_expr_marks ~f e = + Bindlib.unbox @@ + map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e + +let untype_expr e = + map_expr_marks ~f:(fun m -> Untyped {pos=mark_pos m}) e type ('expr, 'm) box_expr_sig = ('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box diff --git a/compiler/dcalc/ast.mli b/compiler/dcalc/ast.mli index e16ff00d..28896343 100644 --- a/compiler/dcalc/ast.mli +++ b/compiler/dcalc/ast.mli @@ -237,6 +237,7 @@ val map_mark: (Pos.t -> Pos.t) -> (Infer.unionfind_typ -> Infer.unionfind_typ) - val map_mark2: (Pos.t -> Pos.t -> Pos.t) -> (typed -> typed -> Infer.unionfind_typ) -> 'm mark -> 'm mark -> 'm mark val fold_marks: (Pos.t list -> Pos.t) -> (typed list -> Infer.unionfind_typ) -> 'm mark list -> 'm mark val get_scope_body_mark: ('expr, 'm) scope_body -> 'm mark +val untype_expr: 'm marked_expr -> untyped marked_expr (** {2 Boxed constructors} *) @@ -318,9 +319,9 @@ val box_expr : ('m expr, 'm) box_expr_sig val map_expr : 'a -> - f:('a -> 'm marked_expr -> 'm marked_expr Bindlib.box) -> - 'm marked_expr -> - 'm marked_expr Bindlib.box + f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) -> + ('m1 expr, 'm2 mark) Marked.t -> + 'm2 marked_expr Bindlib.box (** If you want to apply a map transform to an expression, you can save up writing a painful match over all the cases of the AST. For instance, if you want to remove all errors on empty, you can write @@ -338,6 +339,11 @@ val map_expr : The first argument of map_expr is an optional context that you can carry around during your map traversal. *) +val map_expr_top_down: f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box +(** Recursively applies [f] to the nodes of the expression tree. The type returned by [f] is hybrid since the mark at top-level has been rewritten, but not yet the marks in the subtrees. *) + +val map_expr_marks: f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr + val fold_left_scope_lets : f:('a -> ('expr, 'm) scope_let -> 'expr Bindlib.var -> 'a) -> init:'a -> @@ -403,6 +409,9 @@ type 'm var = 'm expr Bindlib.var val new_var: string -> 'm var +(** used to convert between e.g. [untyped expr var] into a [typed expr var] *) +val translate_var: 'm1 var -> 'm2 var + module Var : sig type t diff --git a/compiler/dcalc/typing.ml b/compiler/dcalc/typing.ml index 8863bdce..9e358d05 100644 --- a/compiler/dcalc/typing.ml +++ b/compiler/dcalc/typing.ml @@ -192,10 +192,6 @@ type env = typ Marked.pos UnionFind.elem A.VarMap.t let add_pos e ty = Marked.mark (A.pos e) ty let ty (_, A.Typed { ty; _ }) = ty -(** used to convert an [untyped expr var] into a [typed expr var] *) -let translate_var: 'm1 A.var -> 'm2 A.var = - fun v -> Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v) - let (let+) x f = Bindlib.box_apply f x let (and+) x1 x2 = Bindlib.box_pair x1 x2 @@ -234,7 +230,7 @@ let rec typecheck_expr_bottom_up | A.EVar v -> begin match A.VarMap.find_opt (A.Var.t v) env with | Some t -> - let+ v' = Bindlib.box_var (translate_var v) in + let+ v' = Bindlib.box_var (A.translate_var v) in mark v' t | None -> Errors.raise_spanned_error (A.pos e) @@ -311,7 +307,7 @@ let rec typecheck_expr_bottom_up (Bindlib.mbinder_arity binder) (List.length taus) else let xs, body = Bindlib.unmbind binder in - let xs' = Array.map translate_var xs in + let xs' = Array.map A.translate_var xs in let xstaus = List.mapi (fun i tau -> xs'.(i), ast_to_typ tau) @@ -413,7 +409,7 @@ and typecheck_expr_top_down match Marked.unmark e with | A.EVar v -> begin match A.VarMap.find_opt (A.Var.t v) env with - | Some tau' -> let+ v' = Bindlib.box_var (translate_var v) in + | Some tau' -> let+ v' = Bindlib.box_var (A.translate_var v) in unify_and_mark v' tau' | None -> Errors.raise_spanned_error (A.pos e) @@ -491,7 +487,7 @@ and typecheck_expr_top_down (Bindlib.mbinder_arity binder) (List.length t_args) else let xs, body = Bindlib.unmbind binder in - let xs' = Array.map translate_var xs in + let xs' = Array.map A.translate_var xs in let xstaus = List.map2 (fun x t_arg -> @@ -624,7 +620,7 @@ let infer_types_program prg = let var, next = Bindlib.unbind scope_let_next in let env = A.VarMap.add (A.Var.t var) ty env in let next = process_scope_body_expr env next in - let scope_let_next = Bindlib.bind_var (translate_var var) next in + let scope_let_next = Bindlib.bind_var (A.translate_var var) next in Bindlib.box_apply2 (fun scope_let_expr scope_let_next -> A.ScopeLet { scope_let_kind; @@ -638,14 +634,14 @@ let infer_types_program prg = let scope_body_expr = let var, e = Bindlib.unbind body in let env = A.VarMap.add (A.Var.t var) ty_in env in - Bindlib.bind_var (translate_var var) + Bindlib.bind_var (A.translate_var var) (process_scope_body_expr env e) in let scope_next = let scope_var, next = Bindlib.unbind scope_next in let env = A.VarMap.add (A.Var.t scope_var) ty_scope env in let next = process_scopes env next in - Bindlib.bind_var (translate_var scope_var) next; + Bindlib.bind_var (A.translate_var scope_var) next; in Bindlib.box_apply2 (fun scope_body_expr scope_next -> A.ScopeDef {