Split Shared_ast.Expr of scope and program functions

This commit is contained in:
Louis Gesbert 2022-08-16 17:09:26 +02:00
parent 4bb49c14f1
commit 8e7f65d204
13 changed files with 238 additions and 159 deletions

View File

@ -253,4 +253,4 @@ let optimize_program (p : 'm program) : untyped program =
(program_map partial_evaluation (program_map partial_evaluation
{ var_values = Var.Map.empty; decl_ctx = p.decl_ctx } { var_values = Var.Map.empty; decl_ctx = p.decl_ctx }
p) p)
|> Expr.untype_program |> Program.untype

View File

@ -200,7 +200,7 @@ let driver source_file (options : Cli.options) : int =
(Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) (Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
( scope_uid, ( scope_uid,
Option.get Option.get
(Shared_ast.Expr.fold_left_scope_defs ~init:None (Shared_ast.Scope.fold_left ~init:None
~f:(fun acc scope_def _ -> ~f:(fun acc scope_def _ ->
if if
Shared_ast.ScopeName.compare scope_def.scope_name Shared_ast.ScopeName.compare scope_def.scope_name
@ -285,7 +285,7 @@ let driver source_file (options : Cli.options) : int =
Cli.debug_print "Optimizing lambda calculus..."; Cli.debug_print "Optimizing lambda calculus...";
Lcalc.Optimizations.optimize_program prgm Lcalc.Optimizations.optimize_program prgm
end end
else Shared_ast.Expr.untype_program prgm else Shared_ast.Program.untype prgm
in in
let prgm = let prgm =
if options.closure_conversion then ( if options.closure_conversion then (
@ -305,7 +305,7 @@ let driver source_file (options : Cli.options) : int =
(Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) (Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
( scope_uid, ( scope_uid,
Option.get Option.get
(Shared_ast.Expr.fold_left_scope_defs ~init:None (Shared_ast.Scope.fold_left ~init:None
~f:(fun acc scope_def _ -> ~f:(fun acc scope_def _ ->
if if
Shared_ast.ScopeName.compare scope_def.scope_name Shared_ast.ScopeName.compare scope_def.scope_name

View File

@ -275,7 +275,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
let closure_conversion (p : 'm program) : 'm program Bindlib.box = let closure_conversion (p : 'm program) : 'm program Bindlib.box =
let new_scopes, _ = let new_scopes, _ =
Expr.fold_left_scope_defs Scope.fold_left
~f:(fun (acc_new_scopes, global_vars) scope scope_var -> ~f:(fun (acc_new_scopes, global_vars) scope scope_var ->
(* [acc_new_scopes] represents what has been translated in the past, it (* [acc_new_scopes] represents what has been translated in the past, it
needs a continuation to attach the rest of the translated scopes. *) needs a continuation to attach the rest of the translated scopes. *)
@ -290,7 +290,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
} }
in in
let new_scope_lets = let new_scope_lets =
Expr.map_exprs_in_scope_lets Scope.map_exprs_in_lets
~f:(closure_conversion_expr ctx) ~f:(closure_conversion_expr ctx)
~varf:(fun v -> v) ~varf:(fun v -> v)
scope_body_expr scope_body_expr

View File

@ -546,7 +546,7 @@ let rec translate_scopes (ctx : 'm ctx) (scopes : 'm D.expr scopes) :
let translate_program (prgm : 'm D.program) : 'm A.program = let translate_program (prgm : 'm D.program) : 'm A.program =
let inputs_structs = let inputs_structs =
Expr.fold_left_scope_defs prgm.scopes ~init:[] ~f:(fun acc scope_def _ -> Scope.fold_left prgm.scopes ~init:[] ~f:(fun acc scope_def _ ->
scope_def.scope_body.scope_body_input_struct :: acc) scope_def.scope_body.scope_body_input_struct :: acc)
in in

View File

@ -102,7 +102,7 @@ let rec beta_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box =
let iota_optimizations (p : 'm program) : 'm program = let iota_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes =
Expr.map_exprs_in_scopes ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes Scope.map_exprs ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes
in in
{ p with scopes = Bindlib.unbox new_scopes } { p with scopes = Bindlib.unbox new_scopes }
@ -112,7 +112,7 @@ let iota_optimizations (p : 'm program) : 'm program =
program. *) program. *)
let _beta_optimizations (p : 'm program) : 'm program = let _beta_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes =
Expr.map_exprs_in_scopes ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes Scope.map_exprs ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes
in in
{ p with scopes = Bindlib.unbox new_scopes } { p with scopes = Bindlib.unbox new_scopes }
@ -146,9 +146,9 @@ let rec peephole_expr (_ : unit) (e : 'm marked_expr) :
let peephole_optimizations (p : 'm program) : 'm program = let peephole_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes =
Expr.map_exprs_in_scopes ~f:(peephole_expr ()) ~varf:(fun v -> v) p.scopes Scope.map_exprs ~f:(peephole_expr ()) ~varf:(fun v -> v) p.scopes
in in
{ p with scopes = Bindlib.unbox new_scopes } { p with scopes = Bindlib.unbox new_scopes }
let optimize_program (p : 'm program) : untyped program = let optimize_program (p : 'm program) : untyped program =
p |> iota_optimizations |> peephole_optimizations |> Expr.untype_program p |> iota_optimizations |> peephole_optimizations |> Program.untype

View File

@ -335,7 +335,7 @@ let translate_program (p : 'm L.program) : A.program =
decl_ctx = p.decl_ctx; decl_ctx = p.decl_ctx;
scopes = scopes =
(let _, new_scopes = (let _, new_scopes =
Expr.fold_left_scope_defs Scope.fold_left
~f:(fun (func_dict, new_scopes) scope_def scope_var -> ~f:(fun (func_dict, new_scopes) scope_def scope_var ->
let scope_input_var, scope_body_expr = let scope_input_var, scope_body_expr =
Bindlib.unbind scope_def.scope_body.scope_body_expr Bindlib.unbind scope_def.scope_body.scope_body_expr

View File

@ -166,80 +166,6 @@ let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e)
let map_marks ~f e = let map_marks ~f e =
map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
let rec fold_left_scope_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result _ -> init
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
fold_left_scope_lets ~f ~init:(f init scope_let var) next
let rec fold_right_scope_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result result -> init result
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
let next_result = fold_right_scope_lets ~f ~init next in
f scope_let var next_result
let map_exprs_in_scope_lets ~f ~varf scope_body_expr =
fold_right_scope_lets
~f:(fun scope_let var_next acc ->
Bindlib.box_apply2
(fun scope_let_next scope_let_expr ->
ScopeLet { scope_let with scope_let_next; scope_let_expr })
(Bindlib.bind_var (varf var_next) acc)
(f scope_let.scope_let_expr))
~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res))
scope_body_expr
let rec fold_left_scope_defs ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var, next = Bindlib.unbind scope_def.scope_next in
fold_left_scope_defs ~f ~init:(f init scope_def var) next
let rec fold_right_scope_defs ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var_next, next = Bindlib.unbind scope_def.scope_next in
let result_next = fold_right_scope_defs ~f ~init next in
f scope_def var_next result_next
let map_scope_defs ~f scopes =
fold_right_scope_defs
~f:(fun scope_def var_next acc ->
let new_scope_def = f scope_def in
let new_next = Bindlib.bind_var var_next acc in
Bindlib.box_apply2
(fun new_scope_def new_next ->
ScopeDef { new_scope_def with scope_next = new_next })
new_scope_def new_next)
~init:(Bindlib.box Nil) scopes
let map_exprs_in_scopes ~f ~varf scopes =
fold_right_scope_defs
~f:(fun scope_def var_next acc ->
let scope_input_var, scope_lets =
Bindlib.unbind scope_def.scope_body.scope_body_expr
in
let new_scope_body_expr = map_exprs_in_scope_lets ~f ~varf scope_lets in
let new_scope_body_expr =
Bindlib.bind_var (varf scope_input_var) new_scope_body_expr
in
let new_next = Bindlib.bind_var (varf var_next) acc in
Bindlib.box_apply2
(fun scope_body_expr scope_next ->
ScopeDef
{
scope_def with
scope_body = { scope_def.scope_body with scope_body_expr };
scope_next;
})
new_scope_body_expr new_next)
~init:(Bindlib.box Nil) scopes
(* - *) (* - *)
(** See [Bindlib.box_term] documentation for why we are doing that. *) (** See [Bindlib.box_term] documentation for why we are doing that. *)
@ -248,12 +174,3 @@ let box e =
id_t () e id_t () e
let untype e = map_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e let untype e = map_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e
let untype_program (prg : ('a, 'm mark) gexpr program) :
('a, untyped mark) gexpr program =
{
prg with
scopes =
Bindlib.unbox
(map_exprs_in_scopes ~f:untype ~varf:Var.translate prg.scopes);
}

View File

@ -15,7 +15,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Functions handling the types of [shared_ast] *) (** Functions handling the expressions of [shared_ast] *)
open Utils open Utils
open Types open Types
@ -134,9 +134,6 @@ val get_scope_body_mark : (_, 'm mark) gexpr scope_body -> 'm mark
val untype : val untype :
('a, 'm mark) marked_gexpr -> ('a, untyped mark) marked_gexpr Bindlib.box ('a, 'm mark) marked_gexpr -> ('a, untyped mark) marked_gexpr Bindlib.box
val untype_program :
(([< any ] as 'a), 'm mark) gexpr program -> ('a, untyped mark) gexpr program
(** {2 Handling of boxing} *) (** {2 Handling of boxing} *)
val box : ('a, 't) marked_gexpr -> ('a, 't) marked_gexpr Bindlib.box val box : ('a, 't) marked_gexpr -> ('a, 't) marked_gexpr Bindlib.box
@ -177,62 +174,3 @@ val map_top_down :
val map_marks : val map_marks :
f:('t1 -> 't2) -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box f:('t1 -> 't2) -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box
val fold_left_scope_lets :
f:('a -> 'e scope_let -> 'e Bindlib.var -> 'a) ->
init:'a ->
'e scope_body_expr ->
'a
(** Usage:
[fold_left_scope_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined. *)
val fold_right_scope_lets :
f:('expr1 scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:('expr1 marked -> 'a) ->
'expr1 scope_body_expr ->
'a
(** Usage:
[fold_right_scope_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined (which are before in the program order). *)
val map_exprs_in_scope_lets :
f:('expr1 marked -> 'expr2 marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
'expr1 scope_body_expr ->
'expr2 scope_body_expr Bindlib.box
val fold_left_scope_defs :
f:('a -> 'expr1 scope_def -> 'expr1 Bindlib.var -> 'a) ->
init:'a ->
'expr1 scopes ->
'a
(** Usage:
[fold_left_scope_defs ~f:(fun acc scope_def scope_var -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined. *)
val fold_right_scope_defs :
f:('expr1 scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:'a ->
'expr1 scopes ->
'a
(** Usage:
[fold_right_scope_defs ~f:(fun scope_def scope_var acc -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined (which are before in the program order). *)
val map_scope_defs :
f:('e scope_def -> 'e scope_def Bindlib.box) ->
'e scopes ->
'e scopes Bindlib.box
val map_exprs_in_scopes :
f:('expr1 marked -> 'expr2 marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
'expr1 scopes ->
'expr2 scopes Bindlib.box
(** This is the main map visitor for all the expressions inside all the scopes
of the program. *)

View File

@ -0,0 +1,27 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Types
let untype (prg : ('a, 'm mark) gexpr program) :
('a, untyped mark) gexpr program =
{
prg with
scopes =
Bindlib.unbox
(Scope.map_exprs ~f:Expr.untype ~varf:Var.translate prg.scopes);
}

View File

@ -0,0 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Types
val untype :
(([< any ] as 'a), 'm mark) gexpr program -> ('a, untyped mark) gexpr program

View File

@ -0,0 +1,92 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Types
let rec fold_left_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result _ -> init
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
fold_left_lets ~f ~init:(f init scope_let var) next
let rec fold_right_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result result -> init result
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
let next_result = fold_right_lets ~f ~init next in
f scope_let var next_result
let map_exprs_in_lets ~f ~varf scope_body_expr =
fold_right_lets
~f:(fun scope_let var_next acc ->
Bindlib.box_apply2
(fun scope_let_next scope_let_expr ->
ScopeLet { scope_let with scope_let_next; scope_let_expr })
(Bindlib.bind_var (varf var_next) acc)
(f scope_let.scope_let_expr))
~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res))
scope_body_expr
let rec fold_left ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var, next = Bindlib.unbind scope_def.scope_next in
fold_left ~f ~init:(f init scope_def var) next
let rec fold_right ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var_next, next = Bindlib.unbind scope_def.scope_next in
let result_next = fold_right ~f ~init next in
f scope_def var_next result_next
let map ~f scopes =
fold_right
~f:(fun scope_def var_next acc ->
let new_def = f scope_def in
let new_next = Bindlib.bind_var var_next acc in
Bindlib.box_apply2
(fun new_def new_next ->
ScopeDef { new_def with scope_next = new_next })
new_def new_next)
~init:(Bindlib.box Nil) scopes
let map_exprs ~f ~varf scopes =
fold_right
~f:(fun scope_def var_next acc ->
let scope_input_var, scope_lets =
Bindlib.unbind scope_def.scope_body.scope_body_expr
in
let new_body_expr = map_exprs_in_lets ~f ~varf scope_lets in
let new_body_expr =
Bindlib.bind_var (varf scope_input_var) new_body_expr
in
let new_next = Bindlib.bind_var (varf var_next) acc in
Bindlib.box_apply2
(fun scope_body_expr scope_next ->
ScopeDef
{
scope_def with
scope_body = { scope_def.scope_body with scope_body_expr };
scope_next;
})
new_body_expr new_next)
~init:(Bindlib.box Nil) scopes

View File

@ -0,0 +1,82 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Functions handling the scope structures of [shared_ast] *)
open Types
(** {2 Traversal functions} *)
val fold_left_lets :
f:('a -> 'e scope_let -> 'e Bindlib.var -> 'a) ->
init:'a ->
'e scope_body_expr ->
'a
(** Usage:
[fold_left_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined. *)
val fold_right_lets :
f:('expr1 scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:('expr1 marked -> 'a) ->
'expr1 scope_body_expr ->
'a
(** Usage:
[fold_right_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined (which are before in the program order). *)
val map_exprs_in_lets :
f:('expr1 marked -> 'expr2 marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
'expr1 scope_body_expr ->
'expr2 scope_body_expr Bindlib.box
val fold_left :
f:('a -> 'expr1 scope_def -> 'expr1 Bindlib.var -> 'a) ->
init:'a ->
'expr1 scopes ->
'a
(** Usage: [fold_left ~f:(fun acc scope_def scope_var -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined. *)
val fold_right :
f:('expr1 scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:'a ->
'expr1 scopes ->
'a
(** Usage:
[fold_right_scope ~f:(fun scope_def scope_var acc -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined (which are before in the program order). *)
val map :
f:('e scope_def -> 'e scope_def Bindlib.box) ->
'e scopes ->
'e scopes Bindlib.box
val map_exprs :
f:('expr1 marked -> 'expr2 marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
'expr1 scopes ->
'expr2 scopes Bindlib.box
(** This is the main map visitor for all the expressions inside all the scopes
of the program. *)
(** {2 Other helpers} *)

View File

@ -15,5 +15,7 @@
the License. *) the License. *)
include Types include Types
module Expr = Expr
module Var = Var module Var = Var
module Expr = Expr
module Scope = Scope
module Program = Program