diff --git a/compiler/dcalc/optimizations.ml b/compiler/dcalc/optimizations.ml index 649ab806..f082c43d 100644 --- a/compiler/dcalc/optimizations.ml +++ b/compiler/dcalc/optimizations.ml @@ -253,4 +253,4 @@ let optimize_program (p : 'm program) : untyped program = (program_map partial_evaluation { var_values = Var.Map.empty; decl_ctx = p.decl_ctx } p) - |> Expr.untype_program + |> Program.untype diff --git a/compiler/driver.ml b/compiler/driver.ml index 685d34cc..6002ca81 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -200,7 +200,7 @@ let driver source_file (options : Cli.options) : int = (Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) ( scope_uid, Option.get - (Shared_ast.Expr.fold_left_scope_defs ~init:None + (Shared_ast.Scope.fold_left ~init:None ~f:(fun acc scope_def _ -> if Shared_ast.ScopeName.compare scope_def.scope_name @@ -285,7 +285,7 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Optimizing lambda calculus..."; Lcalc.Optimizations.optimize_program prgm end - else Shared_ast.Expr.untype_program prgm + else Shared_ast.Program.untype prgm in let prgm = if options.closure_conversion then ( @@ -305,7 +305,7 @@ let driver source_file (options : Cli.options) : int = (Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) ( scope_uid, Option.get - (Shared_ast.Expr.fold_left_scope_defs ~init:None + (Shared_ast.Scope.fold_left ~init:None ~f:(fun acc scope_def _ -> if Shared_ast.ScopeName.compare scope_def.scope_name diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 95ef8c1d..46eeacb5 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -275,7 +275,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : let closure_conversion (p : 'm program) : 'm program Bindlib.box = let new_scopes, _ = - Expr.fold_left_scope_defs + Scope.fold_left ~f:(fun (acc_new_scopes, global_vars) scope scope_var -> (* [acc_new_scopes] represents what has been translated in the past, it needs a continuation to attach the rest of the translated scopes. *) @@ -290,7 +290,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box = } in let new_scope_lets = - Expr.map_exprs_in_scope_lets + Scope.map_exprs_in_lets ~f:(closure_conversion_expr ctx) ~varf:(fun v -> v) scope_body_expr diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 3e848954..5e738bb8 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -546,7 +546,7 @@ let rec translate_scopes (ctx : 'm ctx) (scopes : 'm D.expr scopes) : let translate_program (prgm : 'm D.program) : 'm A.program = let inputs_structs = - Expr.fold_left_scope_defs prgm.scopes ~init:[] ~f:(fun acc scope_def _ -> + Scope.fold_left prgm.scopes ~init:[] ~f:(fun acc scope_def _ -> scope_def.scope_body.scope_body_input_struct :: acc) in diff --git a/compiler/lcalc/optimizations.ml b/compiler/lcalc/optimizations.ml index cbca7967..25526b48 100644 --- a/compiler/lcalc/optimizations.ml +++ b/compiler/lcalc/optimizations.ml @@ -102,7 +102,7 @@ let rec beta_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box = let iota_optimizations (p : 'm program) : 'm program = let new_scopes = - Expr.map_exprs_in_scopes ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes + Scope.map_exprs ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes in { p with scopes = Bindlib.unbox new_scopes } @@ -112,7 +112,7 @@ let iota_optimizations (p : 'm program) : 'm program = program. *) let _beta_optimizations (p : 'm program) : 'm program = let new_scopes = - Expr.map_exprs_in_scopes ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes + Scope.map_exprs ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes in { p with scopes = Bindlib.unbox new_scopes } @@ -146,9 +146,9 @@ let rec peephole_expr (_ : unit) (e : 'm marked_expr) : let peephole_optimizations (p : 'm program) : 'm program = let new_scopes = - Expr.map_exprs_in_scopes ~f:(peephole_expr ()) ~varf:(fun v -> v) p.scopes + Scope.map_exprs ~f:(peephole_expr ()) ~varf:(fun v -> v) p.scopes in { p with scopes = Bindlib.unbox new_scopes } let optimize_program (p : 'm program) : untyped program = - p |> iota_optimizations |> peephole_optimizations |> Expr.untype_program + p |> iota_optimizations |> peephole_optimizations |> Program.untype diff --git a/compiler/scalc/compile_from_lambda.ml b/compiler/scalc/compile_from_lambda.ml index 2342244d..60fefc0d 100644 --- a/compiler/scalc/compile_from_lambda.ml +++ b/compiler/scalc/compile_from_lambda.ml @@ -335,7 +335,7 @@ let translate_program (p : 'm L.program) : A.program = decl_ctx = p.decl_ctx; scopes = (let _, new_scopes = - Expr.fold_left_scope_defs + Scope.fold_left ~f:(fun (func_dict, new_scopes) scope_def scope_var -> let scope_input_var, scope_body_expr = Bindlib.unbind scope_def.scope_body.scope_body_expr diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 50ca7531..ce6ceb3a 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -166,80 +166,6 @@ let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e) let map_marks ~f e = map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e -let rec fold_left_scope_lets ~f ~init scope_body_expr = - match scope_body_expr with - | Result _ -> init - | ScopeLet scope_let -> - let var, next = Bindlib.unbind scope_let.scope_let_next in - fold_left_scope_lets ~f ~init:(f init scope_let var) next - -let rec fold_right_scope_lets ~f ~init scope_body_expr = - match scope_body_expr with - | Result result -> init result - | ScopeLet scope_let -> - let var, next = Bindlib.unbind scope_let.scope_let_next in - let next_result = fold_right_scope_lets ~f ~init next in - f scope_let var next_result - -let map_exprs_in_scope_lets ~f ~varf scope_body_expr = - fold_right_scope_lets - ~f:(fun scope_let var_next acc -> - Bindlib.box_apply2 - (fun scope_let_next scope_let_expr -> - ScopeLet { scope_let with scope_let_next; scope_let_expr }) - (Bindlib.bind_var (varf var_next) acc) - (f scope_let.scope_let_expr)) - ~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res)) - scope_body_expr - -let rec fold_left_scope_defs ~f ~init scopes = - match scopes with - | Nil -> init - | ScopeDef scope_def -> - let var, next = Bindlib.unbind scope_def.scope_next in - fold_left_scope_defs ~f ~init:(f init scope_def var) next - -let rec fold_right_scope_defs ~f ~init scopes = - match scopes with - | Nil -> init - | ScopeDef scope_def -> - let var_next, next = Bindlib.unbind scope_def.scope_next in - let result_next = fold_right_scope_defs ~f ~init next in - f scope_def var_next result_next - -let map_scope_defs ~f scopes = - fold_right_scope_defs - ~f:(fun scope_def var_next acc -> - let new_scope_def = f scope_def in - let new_next = Bindlib.bind_var var_next acc in - Bindlib.box_apply2 - (fun new_scope_def new_next -> - ScopeDef { new_scope_def with scope_next = new_next }) - new_scope_def new_next) - ~init:(Bindlib.box Nil) scopes - -let map_exprs_in_scopes ~f ~varf scopes = - fold_right_scope_defs - ~f:(fun scope_def var_next acc -> - let scope_input_var, scope_lets = - Bindlib.unbind scope_def.scope_body.scope_body_expr - in - let new_scope_body_expr = map_exprs_in_scope_lets ~f ~varf scope_lets in - let new_scope_body_expr = - Bindlib.bind_var (varf scope_input_var) new_scope_body_expr - in - let new_next = Bindlib.bind_var (varf var_next) acc in - Bindlib.box_apply2 - (fun scope_body_expr scope_next -> - ScopeDef - { - scope_def with - scope_body = { scope_def.scope_body with scope_body_expr }; - scope_next; - }) - new_scope_body_expr new_next) - ~init:(Bindlib.box Nil) scopes - (* - *) (** See [Bindlib.box_term] documentation for why we are doing that. *) @@ -248,12 +174,3 @@ let box e = id_t () e let untype e = map_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e - -let untype_program (prg : ('a, 'm mark) gexpr program) : - ('a, untyped mark) gexpr program = - { - prg with - scopes = - Bindlib.unbox - (map_exprs_in_scopes ~f:untype ~varf:Var.translate prg.scopes); - } diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index d4678b2b..c568cc25 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -15,7 +15,7 @@ License for the specific language governing permissions and limitations under the License. *) -(** Functions handling the types of [shared_ast] *) +(** Functions handling the expressions of [shared_ast] *) open Utils open Types @@ -134,9 +134,6 @@ val get_scope_body_mark : (_, 'm mark) gexpr scope_body -> 'm mark val untype : ('a, 'm mark) marked_gexpr -> ('a, untyped mark) marked_gexpr Bindlib.box -val untype_program : - (([< any ] as 'a), 'm mark) gexpr program -> ('a, untyped mark) gexpr program - (** {2 Handling of boxing} *) val box : ('a, 't) marked_gexpr -> ('a, 't) marked_gexpr Bindlib.box @@ -177,62 +174,3 @@ val map_top_down : val map_marks : f:('t1 -> 't2) -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box - -val fold_left_scope_lets : - f:('a -> 'e scope_let -> 'e Bindlib.var -> 'a) -> - init:'a -> - 'e scope_body_expr -> - 'a -(** Usage: - [fold_left_scope_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets], - where [scope_let_var] is the variable bound to the scope let in the next - scope lets to be examined. *) - -val fold_right_scope_lets : - f:('expr1 scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) -> - init:('expr1 marked -> 'a) -> - 'expr1 scope_body_expr -> - 'a -(** Usage: - [fold_right_scope_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets], - where [scope_let_var] is the variable bound to the scope let in the next - scope lets to be examined (which are before in the program order). *) - -val map_exprs_in_scope_lets : - f:('expr1 marked -> 'expr2 marked Bindlib.box) -> - varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) -> - 'expr1 scope_body_expr -> - 'expr2 scope_body_expr Bindlib.box - -val fold_left_scope_defs : - f:('a -> 'expr1 scope_def -> 'expr1 Bindlib.var -> 'a) -> - init:'a -> - 'expr1 scopes -> - 'a -(** Usage: - [fold_left_scope_defs ~f:(fun acc scope_def scope_var -> ...) ~init scope_def], - where [scope_var] is the variable bound to the scope in the next scopes to - be examined. *) - -val fold_right_scope_defs : - f:('expr1 scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) -> - init:'a -> - 'expr1 scopes -> - 'a -(** Usage: - [fold_right_scope_defs ~f:(fun scope_def scope_var acc -> ...) ~init scope_def], - where [scope_var] is the variable bound to the scope in the next scopes to - be examined (which are before in the program order). *) - -val map_scope_defs : - f:('e scope_def -> 'e scope_def Bindlib.box) -> - 'e scopes -> - 'e scopes Bindlib.box - -val map_exprs_in_scopes : - f:('expr1 marked -> 'expr2 marked Bindlib.box) -> - varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) -> - 'expr1 scopes -> - 'expr2 scopes Bindlib.box -(** This is the main map visitor for all the expressions inside all the scopes - of the program. *) diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml new file mode 100644 index 00000000..76ab52f1 --- /dev/null +++ b/compiler/shared_ast/program.ml @@ -0,0 +1,27 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020-2022 Inria, + contributor: Denis Merigoux , Alain Delaët-Tixeuil + , Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Types + +let untype (prg : ('a, 'm mark) gexpr program) : + ('a, untyped mark) gexpr program = + { + prg with + scopes = + Bindlib.unbox + (Scope.map_exprs ~f:Expr.untype ~varf:Var.translate prg.scopes); + } diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli new file mode 100644 index 00000000..2c8863f0 --- /dev/null +++ b/compiler/shared_ast/program.mli @@ -0,0 +1,21 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020-2022 Inria, + contributor: Denis Merigoux , Alain Delaët-Tixeuil + , Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Types + +val untype : + (([< any ] as 'a), 'm mark) gexpr program -> ('a, untyped mark) gexpr program diff --git a/compiler/shared_ast/scope.ml b/compiler/shared_ast/scope.ml new file mode 100644 index 00000000..d83b2d21 --- /dev/null +++ b/compiler/shared_ast/scope.ml @@ -0,0 +1,92 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020-2022 Inria, + contributor: Denis Merigoux , Alain Delaët-Tixeuil + , Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Types + +let rec fold_left_lets ~f ~init scope_body_expr = + match scope_body_expr with + | Result _ -> init + | ScopeLet scope_let -> + let var, next = Bindlib.unbind scope_let.scope_let_next in + fold_left_lets ~f ~init:(f init scope_let var) next + +let rec fold_right_lets ~f ~init scope_body_expr = + match scope_body_expr with + | Result result -> init result + | ScopeLet scope_let -> + let var, next = Bindlib.unbind scope_let.scope_let_next in + let next_result = fold_right_lets ~f ~init next in + f scope_let var next_result + +let map_exprs_in_lets ~f ~varf scope_body_expr = + fold_right_lets + ~f:(fun scope_let var_next acc -> + Bindlib.box_apply2 + (fun scope_let_next scope_let_expr -> + ScopeLet { scope_let with scope_let_next; scope_let_expr }) + (Bindlib.bind_var (varf var_next) acc) + (f scope_let.scope_let_expr)) + ~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res)) + scope_body_expr + +let rec fold_left ~f ~init scopes = + match scopes with + | Nil -> init + | ScopeDef scope_def -> + let var, next = Bindlib.unbind scope_def.scope_next in + fold_left ~f ~init:(f init scope_def var) next + +let rec fold_right ~f ~init scopes = + match scopes with + | Nil -> init + | ScopeDef scope_def -> + let var_next, next = Bindlib.unbind scope_def.scope_next in + let result_next = fold_right ~f ~init next in + f scope_def var_next result_next + +let map ~f scopes = + fold_right + ~f:(fun scope_def var_next acc -> + let new_def = f scope_def in + let new_next = Bindlib.bind_var var_next acc in + Bindlib.box_apply2 + (fun new_def new_next -> + ScopeDef { new_def with scope_next = new_next }) + new_def new_next) + ~init:(Bindlib.box Nil) scopes + +let map_exprs ~f ~varf scopes = + fold_right + ~f:(fun scope_def var_next acc -> + let scope_input_var, scope_lets = + Bindlib.unbind scope_def.scope_body.scope_body_expr + in + let new_body_expr = map_exprs_in_lets ~f ~varf scope_lets in + let new_body_expr = + Bindlib.bind_var (varf scope_input_var) new_body_expr + in + let new_next = Bindlib.bind_var (varf var_next) acc in + Bindlib.box_apply2 + (fun scope_body_expr scope_next -> + ScopeDef + { + scope_def with + scope_body = { scope_def.scope_body with scope_body_expr }; + scope_next; + }) + new_body_expr new_next) + ~init:(Bindlib.box Nil) scopes diff --git a/compiler/shared_ast/scope.mli b/compiler/shared_ast/scope.mli new file mode 100644 index 00000000..a5fbee91 --- /dev/null +++ b/compiler/shared_ast/scope.mli @@ -0,0 +1,82 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020-2022 Inria, + contributor: Denis Merigoux , Alain Delaët-Tixeuil + , Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +(** Functions handling the scope structures of [shared_ast] *) + +open Types + +(** {2 Traversal functions} *) + +val fold_left_lets : + f:('a -> 'e scope_let -> 'e Bindlib.var -> 'a) -> + init:'a -> + 'e scope_body_expr -> + 'a +(** Usage: + [fold_left_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets], + where [scope_let_var] is the variable bound to the scope let in the next + scope lets to be examined. *) + +val fold_right_lets : + f:('expr1 scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) -> + init:('expr1 marked -> 'a) -> + 'expr1 scope_body_expr -> + 'a +(** Usage: + [fold_right_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets], + where [scope_let_var] is the variable bound to the scope let in the next + scope lets to be examined (which are before in the program order). *) + +val map_exprs_in_lets : + f:('expr1 marked -> 'expr2 marked Bindlib.box) -> + varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) -> + 'expr1 scope_body_expr -> + 'expr2 scope_body_expr Bindlib.box + +val fold_left : + f:('a -> 'expr1 scope_def -> 'expr1 Bindlib.var -> 'a) -> + init:'a -> + 'expr1 scopes -> + 'a +(** Usage: [fold_left ~f:(fun acc scope_def scope_var -> ...) ~init scope_def], + where [scope_var] is the variable bound to the scope in the next scopes to + be examined. *) + +val fold_right : + f:('expr1 scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) -> + init:'a -> + 'expr1 scopes -> + 'a +(** Usage: + [fold_right_scope ~f:(fun scope_def scope_var acc -> ...) ~init scope_def], + where [scope_var] is the variable bound to the scope in the next scopes to + be examined (which are before in the program order). *) + +val map : + f:('e scope_def -> 'e scope_def Bindlib.box) -> + 'e scopes -> + 'e scopes Bindlib.box + +val map_exprs : + f:('expr1 marked -> 'expr2 marked Bindlib.box) -> + varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) -> + 'expr1 scopes -> + 'expr2 scopes Bindlib.box +(** This is the main map visitor for all the expressions inside all the scopes + of the program. *) + +(** {2 Other helpers} *) diff --git a/compiler/shared_ast/shared_ast.ml b/compiler/shared_ast/shared_ast.ml index c1ed9a8b..487f649b 100644 --- a/compiler/shared_ast/shared_ast.ml +++ b/compiler/shared_ast/shared_ast.ml @@ -15,5 +15,7 @@ the License. *) include Types -module Expr = Expr module Var = Var +module Expr = Expr +module Scope = Scope +module Program = Program