mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
dcalc ast: make the map function more polymorphic
This commit is contained in:
parent
49efb5ddd7
commit
5f882e35a2
@ -293,33 +293,51 @@ let eifthenelse e1 e2 e3 mark =
|
||||
let eerroronempty e1 mark =
|
||||
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1
|
||||
|
||||
let translate_var v =
|
||||
Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
|
||||
let map_expr ctx ~f e =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EVar v -> evar v (Marked.get_mark e)
|
||||
| EVar v -> evar (translate_var v) m
|
||||
| EApp (e1, args) ->
|
||||
eapp (f ctx e1) (List.map (f ctx) args) (Marked.get_mark e)
|
||||
eapp (f ctx e1) (List.map (f ctx) args) m
|
||||
| EAbs (binder, typs) ->
|
||||
eabs (Bindlib.box_mbinder (f ctx) binder) typs (Marked.get_mark e)
|
||||
| ETuple (args, s) -> etuple (List.map (f ctx) args) s (Marked.get_mark e)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
eabs
|
||||
(Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body))
|
||||
typs
|
||||
m
|
||||
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
|
||||
| ETupleAccess (e1, n, s_name, typs) ->
|
||||
etupleaccess ((f ctx) e1) n s_name typs (Marked.get_mark e)
|
||||
etupleaccess ((f ctx) e1) n s_name typs m
|
||||
| EInj (e1, i, e_name, typs) ->
|
||||
einj ((f ctx) e1) i e_name typs (Marked.get_mark e)
|
||||
einj ((f ctx) e1) i e_name typs m
|
||||
| EMatch (arg, arms, e_name) ->
|
||||
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name (Marked.get_mark e)
|
||||
| EArray args -> earray (List.map (f ctx) args) (Marked.get_mark e)
|
||||
| ELit l -> elit l (Marked.get_mark e)
|
||||
| EAssert e1 -> eassert ((f ctx) e1) (Marked.get_mark e)
|
||||
| EOp op -> Bindlib.box (EOp op, Marked.get_mark e)
|
||||
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
|
||||
| EArray args -> earray (List.map (f ctx) args) m
|
||||
| ELit l -> elit l m
|
||||
| EAssert e1 -> eassert ((f ctx) e1) m
|
||||
| EOp op -> Bindlib.box (EOp op, m)
|
||||
| EDefault (excepts, just, cons) ->
|
||||
edefault
|
||||
(List.map (f ctx) excepts)
|
||||
((f ctx) just)
|
||||
((f ctx) cons)
|
||||
(Marked.get_mark e)
|
||||
m
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Marked.get_mark e)
|
||||
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) (Marked.get_mark e)
|
||||
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
|
||||
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
|
||||
|
||||
let rec map_expr_top_down ~f e =
|
||||
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
|
||||
|
||||
let map_expr_marks ~f e =
|
||||
Bindlib.unbox @@
|
||||
map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
|
||||
|
||||
let untype_expr e =
|
||||
map_expr_marks ~f:(fun m -> Untyped {pos=mark_pos m}) e
|
||||
|
||||
type ('expr, 'm) box_expr_sig =
|
||||
('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box
|
||||
|
@ -237,6 +237,7 @@ val map_mark: (Pos.t -> Pos.t) -> (Infer.unionfind_typ -> Infer.unionfind_typ) -
|
||||
val map_mark2: (Pos.t -> Pos.t -> Pos.t) -> (typed -> typed -> Infer.unionfind_typ) -> 'm mark -> 'm mark -> 'm mark
|
||||
val fold_marks: (Pos.t list -> Pos.t) -> (typed list -> Infer.unionfind_typ) -> 'm mark list -> 'm mark
|
||||
val get_scope_body_mark: ('expr, 'm) scope_body -> 'm mark
|
||||
val untype_expr: 'm marked_expr -> untyped marked_expr
|
||||
|
||||
(** {2 Boxed constructors} *)
|
||||
|
||||
@ -318,9 +319,9 @@ val box_expr : ('m expr, 'm) box_expr_sig
|
||||
|
||||
val map_expr :
|
||||
'a ->
|
||||
f:('a -> 'm marked_expr -> 'm marked_expr Bindlib.box) ->
|
||||
'm marked_expr ->
|
||||
'm marked_expr Bindlib.box
|
||||
f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) ->
|
||||
('m1 expr, 'm2 mark) Marked.t ->
|
||||
'm2 marked_expr Bindlib.box
|
||||
(** If you want to apply a map transform to an expression, you can save up
|
||||
writing a painful match over all the cases of the AST. For instance, if you
|
||||
want to remove all errors on empty, you can write
|
||||
@ -338,6 +339,11 @@ val map_expr :
|
||||
The first argument of map_expr is an optional context that you can carry
|
||||
around during your map traversal. *)
|
||||
|
||||
val map_expr_top_down: f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box
|
||||
(** Recursively applies [f] to the nodes of the expression tree. The type returned by [f] is hybrid since the mark at top-level has been rewritten, but not yet the marks in the subtrees. *)
|
||||
|
||||
val map_expr_marks: f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr
|
||||
|
||||
val fold_left_scope_lets :
|
||||
f:('a -> ('expr, 'm) scope_let -> 'expr Bindlib.var -> 'a) ->
|
||||
init:'a ->
|
||||
@ -403,6 +409,9 @@ type 'm var = 'm expr Bindlib.var
|
||||
|
||||
val new_var: string -> 'm var
|
||||
|
||||
(** used to convert between e.g. [untyped expr var] into a [typed expr var] *)
|
||||
val translate_var: 'm1 var -> 'm2 var
|
||||
|
||||
module Var : sig
|
||||
type t
|
||||
|
||||
|
@ -192,10 +192,6 @@ type env = typ Marked.pos UnionFind.elem A.VarMap.t
|
||||
let add_pos e ty = Marked.mark (A.pos e) ty
|
||||
let ty (_, A.Typed { ty; _ }) = ty
|
||||
|
||||
(** used to convert an [untyped expr var] into a [typed expr var] *)
|
||||
let translate_var: 'm1 A.var -> 'm2 A.var =
|
||||
fun v -> Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
|
||||
|
||||
let (let+) x f = Bindlib.box_apply f x
|
||||
let (and+) x1 x2 = Bindlib.box_pair x1 x2
|
||||
|
||||
@ -234,7 +230,7 @@ let rec typecheck_expr_bottom_up
|
||||
| A.EVar v -> begin
|
||||
match A.VarMap.find_opt (A.Var.t v) env with
|
||||
| Some t ->
|
||||
let+ v' = Bindlib.box_var (translate_var v) in
|
||||
let+ v' = Bindlib.box_var (A.translate_var v) in
|
||||
mark v' t
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
@ -311,7 +307,7 @@ let rec typecheck_expr_bottom_up
|
||||
(Bindlib.mbinder_arity binder) (List.length taus)
|
||||
else
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map translate_var xs in
|
||||
let xs' = Array.map A.translate_var xs in
|
||||
let xstaus =
|
||||
List.mapi (fun i tau ->
|
||||
xs'.(i), ast_to_typ tau)
|
||||
@ -413,7 +409,7 @@ and typecheck_expr_top_down
|
||||
match Marked.unmark e with
|
||||
| A.EVar v -> begin
|
||||
match A.VarMap.find_opt (A.Var.t v) env with
|
||||
| Some tau' -> let+ v' = Bindlib.box_var (translate_var v) in
|
||||
| Some tau' -> let+ v' = Bindlib.box_var (A.translate_var v) in
|
||||
unify_and_mark v' tau'
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
@ -491,7 +487,7 @@ and typecheck_expr_top_down
|
||||
(Bindlib.mbinder_arity binder) (List.length t_args)
|
||||
else
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map translate_var xs in
|
||||
let xs' = Array.map A.translate_var xs in
|
||||
let xstaus =
|
||||
List.map2
|
||||
(fun x t_arg ->
|
||||
@ -624,7 +620,7 @@ let infer_types_program prg =
|
||||
let var, next = Bindlib.unbind scope_let_next in
|
||||
let env = A.VarMap.add (A.Var.t var) ty env in
|
||||
let next = process_scope_body_expr env next in
|
||||
let scope_let_next = Bindlib.bind_var (translate_var var) next in
|
||||
let scope_let_next = Bindlib.bind_var (A.translate_var var) next in
|
||||
Bindlib.box_apply2 (fun scope_let_expr scope_let_next ->
|
||||
A.ScopeLet {
|
||||
scope_let_kind;
|
||||
@ -638,14 +634,14 @@ let infer_types_program prg =
|
||||
let scope_body_expr =
|
||||
let var, e = Bindlib.unbind body in
|
||||
let env = A.VarMap.add (A.Var.t var) ty_in env in
|
||||
Bindlib.bind_var (translate_var var)
|
||||
Bindlib.bind_var (A.translate_var var)
|
||||
(process_scope_body_expr env e)
|
||||
in
|
||||
let scope_next =
|
||||
let scope_var, next = Bindlib.unbind scope_next in
|
||||
let env = A.VarMap.add (A.Var.t scope_var) ty_scope env in
|
||||
let next = process_scopes env next in
|
||||
Bindlib.bind_var (translate_var scope_var) next;
|
||||
Bindlib.bind_var (A.translate_var scope_var) next;
|
||||
in
|
||||
Bindlib.box_apply2 (fun scope_body_expr scope_next ->
|
||||
A.ScopeDef {
|
||||
|
Loading…
Reference in New Issue
Block a user