mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Add a helper to fold on expressions
This commit is contained in:
parent
f103fb1ea5
commit
f9f834e30a
@ -208,40 +208,16 @@ type scope = {
|
||||
|
||||
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx }
|
||||
|
||||
let rec locations_used (e : expr) : LocationSet.t =
|
||||
match Marked.unmark e with
|
||||
| ELocation l -> LocationSet.singleton (l, Expr.pos e)
|
||||
| EVar _ | ELit _ | EOp _ -> LocationSet.empty
|
||||
| EAbs (binder, _) ->
|
||||
let rec locations_used e : LocationSet.t =
|
||||
match e with
|
||||
| ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m)
|
||||
| EAbs (binder, _), _ ->
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
locations_used body
|
||||
| EStruct (_, es) ->
|
||||
StructFieldMap.fold
|
||||
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
|
||||
es LocationSet.empty
|
||||
| EStructAccess (e1, _, _) -> locations_used e1
|
||||
| EEnumInj (e1, _, _) -> locations_used e1
|
||||
| EMatchS (e1, _, es) ->
|
||||
EnumConstructorMap.fold
|
||||
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
|
||||
es (locations_used e1)
|
||||
| EApp (e1, args) ->
|
||||
List.fold_left
|
||||
(fun acc arg -> LocationSet.union (locations_used arg) acc)
|
||||
(locations_used e1) args
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
LocationSet.union (locations_used e1)
|
||||
(LocationSet.union (locations_used e2) (locations_used e3))
|
||||
| EDefault (excepts, just, cons) ->
|
||||
List.fold_left
|
||||
(fun acc except -> LocationSet.union (locations_used except) acc)
|
||||
(LocationSet.union (locations_used just) (locations_used cons))
|
||||
excepts
|
||||
| EArray es ->
|
||||
List.fold_left
|
||||
(fun acc e' -> LocationSet.union acc (locations_used e'))
|
||||
LocationSet.empty es
|
||||
| ErrorOnEmpty e' -> locations_used e'
|
||||
| e ->
|
||||
Expr.shallow_fold
|
||||
(fun e -> LocationSet.union (locations_used e))
|
||||
e LocationSet.empty
|
||||
|
||||
let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
|
||||
let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) :
|
||||
|
@ -29,39 +29,15 @@ end)
|
||||
type 'm expr = (scopelang, 'm mark) gexpr
|
||||
|
||||
let rec locations_used (e : 'm expr) : LocationSet.t =
|
||||
match Marked.unmark e with
|
||||
| ELocation l -> LocationSet.singleton (l, Expr.pos e)
|
||||
| EVar _ | ELit _ | EOp _ -> LocationSet.empty
|
||||
| EAbs (binder, _) ->
|
||||
match e with
|
||||
| ELocation l, pos -> LocationSet.singleton (l, Expr.mark_pos pos)
|
||||
| EAbs (binder, _), _ ->
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
locations_used body
|
||||
| EStruct (_, es) ->
|
||||
StructFieldMap.fold
|
||||
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
|
||||
es LocationSet.empty
|
||||
| EStructAccess (e1, _, _) -> locations_used e1
|
||||
| EEnumInj (e1, _, _) -> locations_used e1
|
||||
| EMatchS (e1, _, es) ->
|
||||
EnumConstructorMap.fold
|
||||
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
|
||||
es (locations_used e1)
|
||||
| EApp (e1, args) ->
|
||||
List.fold_left
|
||||
(fun acc arg -> LocationSet.union (locations_used arg) acc)
|
||||
(locations_used e1) args
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
LocationSet.union (locations_used e1)
|
||||
(LocationSet.union (locations_used e2) (locations_used e3))
|
||||
| EDefault (excepts, just, cons) ->
|
||||
List.fold_left
|
||||
(fun acc except -> LocationSet.union (locations_used except) acc)
|
||||
(LocationSet.union (locations_used just) (locations_used cons))
|
||||
excepts
|
||||
| EArray es ->
|
||||
List.fold_left
|
||||
(fun acc e' -> LocationSet.union acc (locations_used e'))
|
||||
LocationSet.empty es
|
||||
| ErrorOnEmpty e' -> locations_used e'
|
||||
| e ->
|
||||
Expr.shallow_fold
|
||||
(fun e -> LocationSet.union (locations_used e))
|
||||
e LocationSet.empty
|
||||
|
||||
type io_input = NoInput | OnlyInput | Reentrant
|
||||
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }
|
||||
|
@ -1,7 +1,9 @@
|
||||
(library
|
||||
(name scopelang)
|
||||
(public_name catala.scopelang)
|
||||
(libraries utils dcalc ocamlgraph))
|
||||
(libraries utils dcalc ocamlgraph)
|
||||
(flags
|
||||
(:standard -short-paths)))
|
||||
|
||||
(documentation
|
||||
(package catala)
|
||||
|
@ -231,6 +231,34 @@ let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e)
|
||||
let map_marks ~f e =
|
||||
map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
|
||||
|
||||
(* Folds the given function on the direct children of the given expression. Does
|
||||
not open binders. *)
|
||||
let shallow_fold
|
||||
(type a)
|
||||
(f : (a, 'm) gexpr -> 'acc -> 'acc)
|
||||
(e : (a, 'm) gexpr)
|
||||
(acc : 'acc) : 'acc =
|
||||
let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in
|
||||
match Marked.unmark e with
|
||||
| ELit _ | EOp _ | EVar _ | ERaise _ | ELocation _ -> acc
|
||||
| EApp (e1, args) -> acc |> f e1 |> lfold args
|
||||
| EArray args -> acc |> lfold args
|
||||
| EAbs _ -> acc
|
||||
| EIfThenElse (e1, e2, e3) -> acc |> f e1 |> f e2 |> f e3
|
||||
| ETuple (args, _) -> acc |> lfold args
|
||||
| ETupleAccess (e1, _, _, _) -> acc |> f e1
|
||||
| EInj (e1, _, _, _) -> acc |> f e1
|
||||
| EMatch (arg, arms, _) -> acc |> f arg |> lfold arms
|
||||
| EAssert e1 -> acc |> f e1
|
||||
| EDefault (excepts, just, cons) -> acc |> lfold excepts |> f just |> f cons
|
||||
| ErrorOnEmpty e1 -> acc |> f e1
|
||||
| ECatch (e1, _, e2) -> acc |> f e1 |> f e2
|
||||
| EStruct (_, fields) -> acc |> StructFieldMap.fold (fun _ -> f) fields
|
||||
| EStructAccess (e1, _, _) -> acc |> f e1
|
||||
| EEnumInj (e1, _, _) -> acc |> f e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
acc |> f e1 |> EnumConstructorMap.fold (fun _ -> f) cases
|
||||
|
||||
(* - *)
|
||||
|
||||
(** See [Bindlib.box_term] documentation for why we are doing that. *)
|
||||
@ -661,45 +689,12 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
|
||||
| ERaise _, _ -> -1 | _, ERaise _ -> 1
|
||||
| ECatch _, _ -> . | _, ECatch _ -> .
|
||||
|
||||
let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t =
|
||||
fun e ->
|
||||
match Marked.unmark e with
|
||||
| EOp _ | ELit _ | ERaise _ -> Var.Set.empty
|
||||
| EVar v -> Var.Set.singleton v
|
||||
| ETuple (es, _) ->
|
||||
es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EArray es ->
|
||||
es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
|
||||
| ETupleAccess (e1, _, _, _) -> free_vars e1
|
||||
| EAssert e1 -> free_vars e1
|
||||
| EInj (e1, _, _, _) -> free_vars e1
|
||||
| ErrorOnEmpty e1 -> free_vars e1
|
||||
| ECatch (etry, _, ewith) -> Var.Set.union (free_vars etry) (free_vars ewith)
|
||||
| EApp (e1, es) ->
|
||||
e1 :: es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EMatch (e1, es, _) ->
|
||||
e1 :: es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EDefault (es, ejust, econs) ->
|
||||
ejust :: econs :: es
|
||||
|> List.map free_vars
|
||||
|> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
[e1; e2; e3]
|
||||
|> List.map free_vars
|
||||
|> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EAbs (binder, _) ->
|
||||
let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t = function
|
||||
| EVar v, _ -> Var.Set.singleton v
|
||||
| EAbs (binder, _), _ ->
|
||||
let vs, body = Bindlib.unmbind binder in
|
||||
Array.fold_right Var.Set.remove vs (free_vars body)
|
||||
| ELocation _ -> Var.Set.empty
|
||||
| EStruct (_, fields) ->
|
||||
StructFieldMap.fold
|
||||
(fun _ e -> Var.Set.union (free_vars e))
|
||||
fields Var.Set.empty
|
||||
| EStructAccess (e1, _, _) -> free_vars e1
|
||||
| EEnumInj (e1, _, _) -> free_vars e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
free_vars e1
|
||||
|> EnumConstructorMap.fold (fun _ e -> Var.Set.union (free_vars e)) cases
|
||||
| e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
|
||||
|
||||
let remove_logging_calls e =
|
||||
let rec f () e =
|
||||
|
@ -221,6 +221,12 @@ val map_top_down :
|
||||
|
||||
val map_marks : f:('t1 -> 't2) -> ('a, 't1) gexpr -> ('a, 't2) boxed_gexpr
|
||||
|
||||
val shallow_fold :
|
||||
(('a, 't) gexpr -> 'acc -> 'acc) -> ('a, 't) gexpr -> 'acc -> 'acc
|
||||
(** Applies a function on all sub-terms of the given expression. Does not
|
||||
recurse, and doesn't open binders. Useful as helper for recursive calls
|
||||
within traversal functions *)
|
||||
|
||||
(** {2 Expression building helpers} *)
|
||||
|
||||
val make_var : ('a, 't) gexpr Var.t -> 't -> ('a, 't) boxed_gexpr
|
||||
|
Loading…
Reference in New Issue
Block a user