Add a helper to fold on expressions

This commit is contained in:
Louis Gesbert 2022-10-10 15:15:36 +02:00
parent f103fb1ea5
commit f9f834e30a
5 changed files with 56 additions and 101 deletions

View File

@ -208,40 +208,16 @@ type scope = {
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx }
let rec locations_used (e : expr) : LocationSet.t =
match Marked.unmark e with
| ELocation l -> LocationSet.singleton (l, Expr.pos e)
| EVar _ | ELit _ | EOp _ -> LocationSet.empty
| EAbs (binder, _) ->
let rec locations_used e : LocationSet.t =
match e with
| ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m)
| EAbs (binder, _), _ ->
let _, body = Bindlib.unmbind binder in
locations_used body
| EStruct (_, es) ->
StructFieldMap.fold
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
es LocationSet.empty
| EStructAccess (e1, _, _) -> locations_used e1
| EEnumInj (e1, _, _) -> locations_used e1
| EMatchS (e1, _, es) ->
EnumConstructorMap.fold
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
es (locations_used e1)
| EApp (e1, args) ->
List.fold_left
(fun acc arg -> LocationSet.union (locations_used arg) acc)
(locations_used e1) args
| EIfThenElse (e1, e2, e3) ->
LocationSet.union (locations_used e1)
(LocationSet.union (locations_used e2) (locations_used e3))
| EDefault (excepts, just, cons) ->
List.fold_left
(fun acc except -> LocationSet.union (locations_used except) acc)
(LocationSet.union (locations_used just) (locations_used cons))
excepts
| EArray es ->
List.fold_left
(fun acc e' -> LocationSet.union acc (locations_used e'))
LocationSet.empty es
| ErrorOnEmpty e' -> locations_used e'
| e ->
Expr.shallow_fold
(fun e -> LocationSet.union (locations_used e))
e LocationSet.empty
let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) :

View File

@ -29,39 +29,15 @@ end)
type 'm expr = (scopelang, 'm mark) gexpr
let rec locations_used (e : 'm expr) : LocationSet.t =
match Marked.unmark e with
| ELocation l -> LocationSet.singleton (l, Expr.pos e)
| EVar _ | ELit _ | EOp _ -> LocationSet.empty
| EAbs (binder, _) ->
match e with
| ELocation l, pos -> LocationSet.singleton (l, Expr.mark_pos pos)
| EAbs (binder, _), _ ->
let _, body = Bindlib.unmbind binder in
locations_used body
| EStruct (_, es) ->
StructFieldMap.fold
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
es LocationSet.empty
| EStructAccess (e1, _, _) -> locations_used e1
| EEnumInj (e1, _, _) -> locations_used e1
| EMatchS (e1, _, es) ->
EnumConstructorMap.fold
(fun _ e' acc -> LocationSet.union acc (locations_used e'))
es (locations_used e1)
| EApp (e1, args) ->
List.fold_left
(fun acc arg -> LocationSet.union (locations_used arg) acc)
(locations_used e1) args
| EIfThenElse (e1, e2, e3) ->
LocationSet.union (locations_used e1)
(LocationSet.union (locations_used e2) (locations_used e3))
| EDefault (excepts, just, cons) ->
List.fold_left
(fun acc except -> LocationSet.union (locations_used except) acc)
(LocationSet.union (locations_used just) (locations_used cons))
excepts
| EArray es ->
List.fold_left
(fun acc e' -> LocationSet.union acc (locations_used e'))
LocationSet.empty es
| ErrorOnEmpty e' -> locations_used e'
| e ->
Expr.shallow_fold
(fun e -> LocationSet.union (locations_used e))
e LocationSet.empty
type io_input = NoInput | OnlyInput | Reentrant
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }

View File

@ -1,7 +1,9 @@
(library
(name scopelang)
(public_name catala.scopelang)
(libraries utils dcalc ocamlgraph))
(libraries utils dcalc ocamlgraph)
(flags
(:standard -short-paths)))
(documentation
(package catala)

View File

@ -231,6 +231,34 @@ let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e)
let map_marks ~f e =
map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
(* Folds the given function on the direct children of the given expression. Does
not open binders. *)
let shallow_fold
(type a)
(f : (a, 'm) gexpr -> 'acc -> 'acc)
(e : (a, 'm) gexpr)
(acc : 'acc) : 'acc =
let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in
match Marked.unmark e with
| ELit _ | EOp _ | EVar _ | ERaise _ | ELocation _ -> acc
| EApp (e1, args) -> acc |> f e1 |> lfold args
| EArray args -> acc |> lfold args
| EAbs _ -> acc
| EIfThenElse (e1, e2, e3) -> acc |> f e1 |> f e2 |> f e3
| ETuple (args, _) -> acc |> lfold args
| ETupleAccess (e1, _, _, _) -> acc |> f e1
| EInj (e1, _, _, _) -> acc |> f e1
| EMatch (arg, arms, _) -> acc |> f arg |> lfold arms
| EAssert e1 -> acc |> f e1
| EDefault (excepts, just, cons) -> acc |> lfold excepts |> f just |> f cons
| ErrorOnEmpty e1 -> acc |> f e1
| ECatch (e1, _, e2) -> acc |> f e1 |> f e2
| EStruct (_, fields) -> acc |> StructFieldMap.fold (fun _ -> f) fields
| EStructAccess (e1, _, _) -> acc |> f e1
| EEnumInj (e1, _, _) -> acc |> f e1
| EMatchS (e1, _, cases) ->
acc |> f e1 |> EnumConstructorMap.fold (fun _ -> f) cases
(* - *)
(** See [Bindlib.box_term] documentation for why we are doing that. *)
@ -661,45 +689,12 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
| ERaise _, _ -> -1 | _, ERaise _ -> 1
| ECatch _, _ -> . | _, ECatch _ -> .
let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t =
fun e ->
match Marked.unmark e with
| EOp _ | ELit _ | ERaise _ -> Var.Set.empty
| EVar v -> Var.Set.singleton v
| ETuple (es, _) ->
es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
| EArray es ->
es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
| ETupleAccess (e1, _, _, _) -> free_vars e1
| EAssert e1 -> free_vars e1
| EInj (e1, _, _, _) -> free_vars e1
| ErrorOnEmpty e1 -> free_vars e1
| ECatch (etry, _, ewith) -> Var.Set.union (free_vars etry) (free_vars ewith)
| EApp (e1, es) ->
e1 :: es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
| EMatch (e1, es, _) ->
e1 :: es |> List.map free_vars |> List.fold_left Var.Set.union Var.Set.empty
| EDefault (es, ejust, econs) ->
ejust :: econs :: es
|> List.map free_vars
|> List.fold_left Var.Set.union Var.Set.empty
| EIfThenElse (e1, e2, e3) ->
[e1; e2; e3]
|> List.map free_vars
|> List.fold_left Var.Set.union Var.Set.empty
| EAbs (binder, _) ->
let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t = function
| EVar v, _ -> Var.Set.singleton v
| EAbs (binder, _), _ ->
let vs, body = Bindlib.unmbind binder in
Array.fold_right Var.Set.remove vs (free_vars body)
| ELocation _ -> Var.Set.empty
| EStruct (_, fields) ->
StructFieldMap.fold
(fun _ e -> Var.Set.union (free_vars e))
fields Var.Set.empty
| EStructAccess (e1, _, _) -> free_vars e1
| EEnumInj (e1, _, _) -> free_vars e1
| EMatchS (e1, _, cases) ->
free_vars e1
|> EnumConstructorMap.fold (fun _ e -> Var.Set.union (free_vars e)) cases
| e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
let remove_logging_calls e =
let rec f () e =

View File

@ -221,6 +221,12 @@ val map_top_down :
val map_marks : f:('t1 -> 't2) -> ('a, 't1) gexpr -> ('a, 't2) boxed_gexpr
val shallow_fold :
(('a, 't) gexpr -> 'acc -> 'acc) -> ('a, 't) gexpr -> 'acc -> 'acc
(** Applies a function on all sub-terms of the given expression. Does not
recurse, and doesn't open binders. Useful as helper for recursive calls
within traversal functions *)
(** {2 Expression building helpers} *)
val make_var : ('a, 't) gexpr Var.t -> 't -> ('a, 't) boxed_gexpr