dcalc ast: make the map function more polymorphic

This commit is contained in:
Louis Gesbert 2022-07-06 16:56:16 +02:00
parent 49efb5ddd7
commit 5f882e35a2
3 changed files with 51 additions and 28 deletions

View File

@ -293,33 +293,51 @@ let eifthenelse e1 e2 e3 mark =
let eerroronempty e1 mark =
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1
let translate_var v =
Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
let map_expr ctx ~f e =
let m = Marked.get_mark e in
match Marked.unmark e with
| EVar v -> evar v (Marked.get_mark e)
| EVar v -> evar (translate_var v) m
| EApp (e1, args) ->
eapp (f ctx e1) (List.map (f ctx) args) (Marked.get_mark e)
eapp (f ctx e1) (List.map (f ctx) args) m
| EAbs (binder, typs) ->
eabs (Bindlib.box_mbinder (f ctx) binder) typs (Marked.get_mark e)
| ETuple (args, s) -> etuple (List.map (f ctx) args) s (Marked.get_mark e)
let vars, body = Bindlib.unmbind binder in
eabs
(Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body))
typs
m
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
| ETupleAccess (e1, n, s_name, typs) ->
etupleaccess ((f ctx) e1) n s_name typs (Marked.get_mark e)
etupleaccess ((f ctx) e1) n s_name typs m
| EInj (e1, i, e_name, typs) ->
einj ((f ctx) e1) i e_name typs (Marked.get_mark e)
einj ((f ctx) e1) i e_name typs m
| EMatch (arg, arms, e_name) ->
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name (Marked.get_mark e)
| EArray args -> earray (List.map (f ctx) args) (Marked.get_mark e)
| ELit l -> elit l (Marked.get_mark e)
| EAssert e1 -> eassert ((f ctx) e1) (Marked.get_mark e)
| EOp op -> Bindlib.box (EOp op, Marked.get_mark e)
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
| EArray args -> earray (List.map (f ctx) args) m
| ELit l -> elit l m
| EAssert e1 -> eassert ((f ctx) e1) m
| EOp op -> Bindlib.box (EOp op, m)
| EDefault (excepts, just, cons) ->
edefault
(List.map (f ctx) excepts)
((f ctx) just)
((f ctx) cons)
(Marked.get_mark e)
m
| EIfThenElse (e1, e2, e3) ->
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Marked.get_mark e)
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) (Marked.get_mark e)
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
let rec map_expr_top_down ~f e =
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
let map_expr_marks ~f e =
Bindlib.unbox @@
map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
let untype_expr e =
map_expr_marks ~f:(fun m -> Untyped {pos=mark_pos m}) e
type ('expr, 'm) box_expr_sig =
('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box

View File

@ -237,6 +237,7 @@ val map_mark: (Pos.t -> Pos.t) -> (Infer.unionfind_typ -> Infer.unionfind_typ) -
val map_mark2: (Pos.t -> Pos.t -> Pos.t) -> (typed -> typed -> Infer.unionfind_typ) -> 'm mark -> 'm mark -> 'm mark
val fold_marks: (Pos.t list -> Pos.t) -> (typed list -> Infer.unionfind_typ) -> 'm mark list -> 'm mark
val get_scope_body_mark: ('expr, 'm) scope_body -> 'm mark
val untype_expr: 'm marked_expr -> untyped marked_expr
(** {2 Boxed constructors} *)
@ -318,9 +319,9 @@ val box_expr : ('m expr, 'm) box_expr_sig
val map_expr :
'a ->
f:('a -> 'm marked_expr -> 'm marked_expr Bindlib.box) ->
'm marked_expr ->
'm marked_expr Bindlib.box
f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) ->
('m1 expr, 'm2 mark) Marked.t ->
'm2 marked_expr Bindlib.box
(** If you want to apply a map transform to an expression, you can save up
writing a painful match over all the cases of the AST. For instance, if you
want to remove all errors on empty, you can write
@ -338,6 +339,11 @@ val map_expr :
The first argument of map_expr is an optional context that you can carry
around during your map traversal. *)
val map_expr_top_down: f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box
(** Recursively applies [f] to the nodes of the expression tree. The type returned by [f] is hybrid since the mark at top-level has been rewritten, but not yet the marks in the subtrees. *)
val map_expr_marks: f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr
val fold_left_scope_lets :
f:('a -> ('expr, 'm) scope_let -> 'expr Bindlib.var -> 'a) ->
init:'a ->
@ -403,6 +409,9 @@ type 'm var = 'm expr Bindlib.var
val new_var: string -> 'm var
(** used to convert between e.g. [untyped expr var] into a [typed expr var] *)
val translate_var: 'm1 var -> 'm2 var
module Var : sig
type t

View File

@ -192,10 +192,6 @@ type env = typ Marked.pos UnionFind.elem A.VarMap.t
let add_pos e ty = Marked.mark (A.pos e) ty
let ty (_, A.Typed { ty; _ }) = ty
(** used to convert an [untyped expr var] into a [typed expr var] *)
let translate_var: 'm1 A.var -> 'm2 A.var =
fun v -> Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
let (let+) x f = Bindlib.box_apply f x
let (and+) x1 x2 = Bindlib.box_pair x1 x2
@ -234,7 +230,7 @@ let rec typecheck_expr_bottom_up
| A.EVar v -> begin
match A.VarMap.find_opt (A.Var.t v) env with
| Some t ->
let+ v' = Bindlib.box_var (translate_var v) in
let+ v' = Bindlib.box_var (A.translate_var v) in
mark v' t
| None ->
Errors.raise_spanned_error (A.pos e)
@ -311,7 +307,7 @@ let rec typecheck_expr_bottom_up
(Bindlib.mbinder_arity binder) (List.length taus)
else
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map translate_var xs in
let xs' = Array.map A.translate_var xs in
let xstaus =
List.mapi (fun i tau ->
xs'.(i), ast_to_typ tau)
@ -413,7 +409,7 @@ and typecheck_expr_top_down
match Marked.unmark e with
| A.EVar v -> begin
match A.VarMap.find_opt (A.Var.t v) env with
| Some tau' -> let+ v' = Bindlib.box_var (translate_var v) in
| Some tau' -> let+ v' = Bindlib.box_var (A.translate_var v) in
unify_and_mark v' tau'
| None ->
Errors.raise_spanned_error (A.pos e)
@ -491,7 +487,7 @@ and typecheck_expr_top_down
(Bindlib.mbinder_arity binder) (List.length t_args)
else
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map translate_var xs in
let xs' = Array.map A.translate_var xs in
let xstaus =
List.map2
(fun x t_arg ->
@ -624,7 +620,7 @@ let infer_types_program prg =
let var, next = Bindlib.unbind scope_let_next in
let env = A.VarMap.add (A.Var.t var) ty env in
let next = process_scope_body_expr env next in
let scope_let_next = Bindlib.bind_var (translate_var var) next in
let scope_let_next = Bindlib.bind_var (A.translate_var var) next in
Bindlib.box_apply2 (fun scope_let_expr scope_let_next ->
A.ScopeLet {
scope_let_kind;
@ -638,14 +634,14 @@ let infer_types_program prg =
let scope_body_expr =
let var, e = Bindlib.unbind body in
let env = A.VarMap.add (A.Var.t var) ty_in env in
Bindlib.bind_var (translate_var var)
Bindlib.bind_var (A.translate_var var)
(process_scope_body_expr env e)
in
let scope_next =
let scope_var, next = Bindlib.unbind scope_next in
let env = A.VarMap.add (A.Var.t scope_var) ty_scope env in
let next = process_scopes env next in
Bindlib.bind_var (translate_var scope_var) next;
Bindlib.bind_var (A.translate_var scope_var) next;
in
Bindlib.box_apply2 (fun scope_body_expr scope_next ->
A.ScopeDef {