Started replacement of Ast by Binded_representation in Dcalc [skip-ci]

This commit is contained in:
Denis Merigoux 2022-04-02 12:29:43 +02:00
parent be26fa2474
commit 8f39b65bb6
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
5 changed files with 159 additions and 306 deletions

View File

@ -133,26 +133,49 @@ type scope_let_kind =
| Assertion
type scope_let = {
scope_let_var : expr Bindlib.var Pos.marked;
scope_let_kind : scope_let_kind;
scope_let_typ : typ Pos.marked;
scope_let_expr : expr Pos.marked Bindlib.box;
scope_let_typ : typ Utils.Pos.marked;
scope_let_expr : expr Utils.Pos.marked;
scope_let_next : (expr, scope_body_expr) Bindlib.binder;
scope_let_pos : Utils.Pos.t;
}
and scope_body_expr = Result of expr Utils.Pos.marked | ScopeLet of scope_let
type scope_body = {
scope_body_lets : scope_let list;
scope_body_result : expr Pos.marked Bindlib.box;
(** {x1 = x1; x2 = x2; x3 = x3; ... } *)
scope_body_arg : expr Bindlib.var; (** x: input_struct *)
scope_body_input_struct : StructName.t;
scope_body_output_struct : StructName.t;
scope_body_expr : (expr, scope_body_expr) Bindlib.binder;
}
type program = {
decl_ctx : decl_ctx;
scopes : (ScopeName.t * expr Bindlib.var * scope_body) list;
type scope_def = {
scope_name : ScopeName.t;
scope_body : scope_body;
scope_next : (expr, scopes) Bindlib.binder;
}
and scopes = Nil | ScopeDef of scope_def
type program = { decl_ctx : decl_ctx; scopes : scopes }
let rec fold_scope_lets
~(f : 'a -> scope_let -> 'a)
~(init : 'a)
(scope_body_expr : scope_body_expr) : 'a =
match scope_body_expr with
| Result _ -> init
| ScopeLet scope_let ->
let _, next = Bindlib.unbind scope_let.scope_let_next in
fold_scope_lets ~f ~init:(f init scope_let) next
let rec fold_scope_defs
~(f : 'a -> scope_def -> 'a) ~(init : 'a) (scopes : scopes) : 'a =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let _, next = Bindlib.unbind scope_def.scope_next in
fold_scope_defs ~f ~init:(f init scope_def) next
module Var = struct
type t = expr Bindlib.var
@ -165,35 +188,53 @@ module Var = struct
end
module VarMap = Map.Make (Var)
module VarSet = Set.Make (Var)
let union : unit VarMap.t -> unit VarMap.t -> unit VarMap.t =
VarMap.union (fun _ _ _ -> Some ())
let rec free_vars_set (e : expr Pos.marked) : unit VarMap.t =
let rec free_vars_expr (e : expr Pos.marked) : VarSet.t =
match Pos.unmark e with
| EVar (v, _) -> VarMap.singleton v ()
| EVar (v, _) -> VarSet.singleton v
| ETuple (es, _) | EArray es ->
es |> List.map free_vars_set |> List.fold_left union VarMap.empty
es |> List.map free_vars_expr |> List.fold_left VarSet.union VarSet.empty
| ETupleAccess (e1, _, _, _)
| EAssert e1
| ErrorOnEmpty e1
| EInj (e1, _, _, _) ->
free_vars_set e1
free_vars_expr e1
| EApp (e1, es) | EMatch (e1, es, _) ->
e1 :: es |> List.map free_vars_set |> List.fold_left union VarMap.empty
e1 :: es |> List.map free_vars_expr
|> List.fold_left VarSet.union VarSet.empty
| EDefault (es, ejust, econs) ->
ejust :: econs :: es |> List.map free_vars_set
|> List.fold_left union VarMap.empty
| EOp _ | ELit _ -> VarMap.empty
ejust :: econs :: es |> List.map free_vars_expr
|> List.fold_left VarSet.union VarSet.empty
| EOp _ | ELit _ -> VarSet.empty
| EIfThenElse (e1, e2, e3) ->
[ e1; e2; e3 ] |> List.map free_vars_set
|> List.fold_left union VarMap.empty
[ e1; e2; e3 ] |> List.map free_vars_expr
|> List.fold_left VarSet.union VarSet.empty
| EAbs ((binder, _), _) ->
let vs, body = Bindlib.unmbind binder in
Array.fold_right VarMap.remove vs (free_vars_set body)
Array.fold_right VarSet.remove vs (free_vars_expr body)
let free_vars_list (e : expr Pos.marked) : Var.t list =
free_vars_set e |> VarMap.bindings |> List.map fst
let rec free_vars_scope_body_expr (scope_lets : scope_body_expr) : VarSet.t =
match scope_lets with
| Result e -> free_vars_expr e
| ScopeLet { scope_let_expr = e; scope_let_next = next; _ } ->
let v, body = Bindlib.unbind next in
VarSet.union (free_vars_expr e)
(VarSet.remove v (free_vars_scope_body_expr body))
let free_vars_scope_body (scope_body : scope_body) : VarSet.t =
let { scope_body_expr = binder; _ } = scope_body in
let v, body = Bindlib.unbind binder in
VarSet.remove v (free_vars_scope_body_expr body)
let rec free_vars_scopes (scopes : scopes) : VarSet.t =
match scopes with
| Nil -> VarSet.empty
| ScopeDef { scope_body = body; scope_next = next; _ } ->
let v, next = Bindlib.unbind next in
VarSet.union
(VarSet.remove v (free_vars_scopes next))
(free_vars_scope_body body)
type vars = expr Bindlib.mvar
@ -312,19 +353,30 @@ and equal_exprs_list (es1 : expr Pos.marked list) (es2 : expr Pos.marked list) :
assume here that both lists have equal length *)
List.for_all (fun (x, y) -> equal_exprs x y) (List.combine es1 es2)
let rec unfold_scope_body_expr (ctx : decl_ctx) (scope_let : scope_body_expr) :
expr Pos.marked Bindlib.box =
match scope_let with
| Result e -> Bindlib.box e
| ScopeLet
{
scope_let_kind = _;
scope_let_typ;
scope_let_expr;
scope_let_next;
scope_let_pos;
} ->
let var, next = Bindlib.unbind scope_let_next in
make_let_in var scope_let_typ
(Bindlib.box scope_let_expr)
(unfold_scope_body_expr ctx next)
scope_let_pos
let build_whole_scope_expr
(ctx : decl_ctx) (body : scope_body) (pos_scope : Pos.t) =
let body_expr =
List.fold_right
(fun scope_let acc ->
make_let_in
(Pos.unmark scope_let.scope_let_var)
scope_let.scope_let_typ scope_let.scope_let_expr acc
(Pos.get_position scope_let.scope_let_var))
body.scope_body_lets body.scope_body_result
in
let var, body_expr = Bindlib.unbind body.scope_body_expr in
let body_expr = unfold_scope_body_expr ctx body_expr in
make_abs
(Array.of_list [ body.scope_body_arg ])
(Array.of_list [ var ])
body_expr pos_scope
[
( TTuple
@ -352,25 +404,36 @@ let build_scope_typ_from_sig
in
(TArrow (input_typ, result_typ), pos)
let build_whole_program_expr (p : program) (main_scope : ScopeName.t) =
let end_result =
make_var
(let _, x, _ =
List.find
(fun (s_name, _, _) -> ScopeName.compare main_scope s_name = 0)
p.scopes
in
(x, Pos.no_pos))
in
List.fold_right
(fun (scope_name, scope_var, scope_body) acc ->
let pos = Pos.get_position (ScopeName.get_info scope_name) in
type scope_name_or_var = ScopeName of ScopeName.t | ScopeVar of Var.t
let rec unfold_scopes
(ctx : decl_ctx) (s : scopes) (main_scope : scope_name_or_var) :
expr Pos.marked Bindlib.box =
match s with
| Nil -> (
match main_scope with
| ScopeVar v ->
Bindlib.box_apply (fun v -> (v, Pos.no_pos)) (Bindlib.box_var v)
| ScopeName _ -> failwith "should not happen")
| ScopeDef { scope_name; scope_body; scope_next } ->
let scope_var, scope_next = Bindlib.unbind scope_next in
let scope_pos = Pos.get_position (ScopeName.get_info scope_name) in
let main_scope =
match main_scope with
| ScopeVar v -> ScopeVar v
| ScopeName n ->
if ScopeName.compare n scope_name = 0 then ScopeVar scope_var
else ScopeName n
in
make_let_in scope_var
(build_scope_typ_from_sig p.decl_ctx scope_body.scope_body_input_struct
scope_body.scope_body_output_struct pos)
(build_whole_scope_expr p.decl_ctx scope_body pos)
acc pos)
p.scopes end_result
(build_scope_typ_from_sig ctx scope_body.scope_body_input_struct
scope_body.scope_body_output_struct scope_pos)
(build_whole_scope_expr ctx scope_body scope_pos)
(unfold_scopes ctx scope_next main_scope)
scope_pos
let build_whole_program_expr (p : program) (main_scope : ScopeName.t) =
unfold_scopes p.decl_ctx p.scopes (ScopeName main_scope)
let rec expr_size (e : expr Pos.marked) : int =
match Pos.unmark e with
@ -396,14 +459,3 @@ let rec expr_size (e : expr Pos.marked) : int =
(fun acc except -> acc + expr_size except)
(1 + expr_size just + expr_size cons)
exceptions
let variable_types (p : program) : typ Pos.marked VarMap.t =
List.fold_left
(fun acc (_, _, scope) ->
List.fold_left
(fun acc scope_let ->
VarMap.add
(Pos.unmark scope_let.scope_let_var)
scope_let.scope_let_typ acc)
acc scope.scope_body_lets)
VarMap.empty p.scopes

View File

@ -144,33 +144,52 @@ type scope_let_kind =
| Assertion (** [let _ = assert e]*)
type scope_let = {
scope_let_var : expr Bindlib.var Pos.marked;
scope_let_kind : scope_let_kind;
scope_let_typ : typ Pos.marked;
scope_let_expr : expr Pos.marked Bindlib.box;
scope_let_typ : typ Utils.Pos.marked;
scope_let_expr : expr Utils.Pos.marked;
scope_let_next : (expr, scope_body_expr) Bindlib.binder;
scope_let_pos : Utils.Pos.t;
}
(** A scope let-binding has all the information necessary to make a proper
let-binding expression, plus an annotation for the kind of the let-binding
that comes from the compilation of a {!module: Scopelang.Ast} statement. *)
and scope_body_expr = Result of expr Utils.Pos.marked | ScopeLet of scope_let
type scope_body = {
scope_body_lets : scope_let list;
scope_body_result : expr Pos.marked Bindlib.box;
scope_body_arg : expr Bindlib.var;
scope_body_input_struct : StructName.t;
scope_body_output_struct : StructName.t;
scope_body_expr : (expr, scope_body_expr) Bindlib.binder;
}
(** Instead of being a single expression, we give a little more ad-hoc structure
to the scope body by decomposing it in an ordered list of let-bindings, and
a result expression that uses the let-binded variables. *)
a result expression that uses the let-binded variables. The first binder is
the argument of type [scope_body_input_struct]. *)
type program = {
decl_ctx : decl_ctx;
scopes : (ScopeName.t * expr Bindlib.var * scope_body) list;
type scope_def = {
scope_name : ScopeName.t;
scope_body : scope_body;
scope_next : (expr, scopes) Bindlib.binder;
}
(** Finally, we do the same transformation for the whole program for the kinded
lets. This permit us to use bindlib variables for scopes names. *)
and scopes = Nil | ScopeDef of scope_def
type program = { decl_ctx : decl_ctx; scopes : scopes }
(** {1 Helpers} *)
(**{2 Program traversal}*)
(** Be careful when using these traversal functions, as the bound variables they
open will be different at each traversal. *)
val fold_scope_lets :
f:('a -> scope_let -> 'a) -> init:'a -> scope_body_expr -> 'a
val fold_scope_defs : f:('a -> scope_def -> 'a) -> init:'a -> scopes -> 'a
(** {2 Variables}*)
module Var : sig
@ -181,9 +200,12 @@ module Var : sig
end
module VarMap : Map.S with type key = Var.t
module VarSet : Set.S with type elt = Var.t
val free_vars_set : expr Pos.marked -> unit VarMap.t
val free_vars_list : expr Pos.marked -> Var.t list
val free_vars_expr : expr Pos.marked -> VarSet.t
val free_vars_scope_body_expr : scope_body_expr -> VarSet.t
val free_vars_scope_body : scope_body -> VarSet.t
val free_vars_scopes : scopes -> VarSet.t
type vars = expr Bindlib.mvar
@ -235,8 +257,3 @@ val build_whole_program_expr :
val expr_size : expr Pos.marked -> int
(** Used by the optimizer to know when to stop *)
val variable_types : program -> typ Pos.marked VarMap.t
(** Traverses all the scopes and retrieves all the types for the variables that
may appear in scope or subscope variable definitions, giving them as a big
map. *)

View File

@ -1,146 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Utils
module D = Ast
type scope_lets =
| Result of D.expr Pos.marked
| ScopeLet of {
scope_let_kind : D.scope_let_kind;
scope_let_typ : D.typ Pos.marked;
scope_let_expr : D.expr Pos.marked;
scope_let_next : (D.expr, scope_lets) Bindlib.binder;
scope_let_pos : Pos.t;
}
type scope_body = {
scope_body_input_struct : D.StructName.t;
scope_body_output_struct : D.StructName.t;
scope_body_result : (D.expr, scope_lets) Bindlib.binder;
}
type scopes =
| Nil
| ScopeDef of {
scope_name : D.ScopeName.t;
scope_body : scope_body;
scope_next : (D.expr, scopes) Bindlib.binder;
}
let union : unit D.VarMap.t -> unit D.VarMap.t -> unit D.VarMap.t =
D.VarMap.union (fun _ _ _ -> Some ())
let rec free_vars_set_scope_lets (scope_lets : scope_lets) : unit D.VarMap.t =
match scope_lets with
| Result e -> D.free_vars_set e
| ScopeLet { scope_let_expr = e; scope_let_next = next; _ } ->
let v, body = Bindlib.unbind next in
union (D.free_vars_set e)
(D.VarMap.remove v (free_vars_set_scope_lets body))
let free_vars_set_scope_body (scope_body : scope_body) : unit D.VarMap.t =
let { scope_body_result = binder; _ } = scope_body in
let v, body = Bindlib.unbind binder in
D.VarMap.remove v (free_vars_set_scope_lets body)
let rec free_vars_set_scopes (scopes : scopes) : unit D.VarMap.t =
match scopes with
| Nil -> D.VarMap.empty
| ScopeDef { scope_body = body; scope_next = next; _ } ->
let v, next = Bindlib.unbind next in
union
(D.VarMap.remove v (free_vars_set_scopes next))
(free_vars_set_scope_body body)
let free_vars_list_scope_lets (scope_lets : scope_lets) : D.Var.t list =
free_vars_set_scope_lets scope_lets |> D.VarMap.bindings |> List.map fst
let free_vars_list_scope_body (scope_body : scope_body) : D.Var.t list =
free_vars_set_scope_body scope_body |> D.VarMap.bindings |> List.map fst
let free_vars_list_scopes (scopes : scopes) : D.Var.t list =
free_vars_set_scopes scopes |> D.VarMap.bindings |> List.map fst
(** Actual transformation for scopes. *)
let bind_scope_lets (acc : scope_lets Bindlib.box) (scope_let : D.scope_let) :
scope_lets Bindlib.box =
let pos = snd scope_let.D.scope_let_var in
(* Cli.debug_print @@ Format.asprintf "binding let %a. Variable occurs = %b"
Print.format_var (fst scope_let.D.scope_let_var) (Bindlib.occur (fst
scope_let.D.scope_let_var) acc); *)
let binder = Bindlib.bind_var (fst scope_let.D.scope_let_var) acc in
Bindlib.box_apply2
(fun expr binder ->
(* Cli.debug_print @@ Format.asprintf "free variables in expression: %a"
(Format.pp_print_list Print.format_var) (D.free_vars_list expr); *)
ScopeLet
{
scope_let_kind = scope_let.D.scope_let_kind;
scope_let_typ = scope_let.D.scope_let_typ;
scope_let_expr = expr;
scope_let_next = binder;
scope_let_pos = pos;
})
scope_let.D.scope_let_expr binder
let bind_scope_body (body : D.scope_body) : scope_body Bindlib.box =
(* it is a fold_right and not a fold_left. *)
let body_result =
ListLabels.fold_right body.D.scope_body_lets
~init:(Bindlib.box_apply (fun e -> Result e) body.D.scope_body_result)
~f:(Fun.flip bind_scope_lets)
in
(* Cli.debug_print @@ Format.asprintf "binding arg %a" Print.format_var
body.D.scope_body_arg; *)
let scope_body_result = Bindlib.bind_var body.D.scope_body_arg body_result in
(* Cli.debug_print @@ Format.asprintf "isfinal term is closed: %b"
(Bindlib.is_closed scope_body_result); *)
Bindlib.box_apply
(fun scope_body_result ->
(* Cli.debug_print @@ Format.asprintf "rank of the final term: %i"
(Bindlib.binder_rank scope_body_result); *)
{
scope_body_output_struct = body.D.scope_body_output_struct;
scope_body_input_struct = body.D.scope_body_input_struct;
scope_body_result;
})
scope_body_result
let bind_scope
((scope_name, scope_var, scope_body) :
D.ScopeName.t * D.expr Bindlib.var * D.scope_body)
(acc : scopes Bindlib.box) : scopes Bindlib.box =
Bindlib.box_apply2
(fun scope_body scope_next ->
ScopeDef { scope_name; scope_body; scope_next })
(bind_scope_body scope_body)
(Bindlib.bind_var scope_var acc)
let bind_scopes
(scopes : (D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list) :
scopes Bindlib.box =
let result =
ListLabels.fold_right scopes ~init:(Bindlib.box Nil) ~f:bind_scope
in
(* Cli.debug_print @@ Format.asprintf "free variable in the program : [%a]"
(Format.pp_print_list Print.format_var) (free_vars_list_scopes
(Bindlib.unbox result)); *)
result

View File

@ -1,68 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
module D = Ast
(** Alternative representation of the Dcalc Ast. It is currently used in the
transformation without exceptions. We make heavy use of bindlib, binding
each scope-let-variable and each scope explicitly. *)
(** In [Ast], [Ast.scope_lets] is defined as a list of kind, var, and boxed
expression. This representation binds using bindlib the tail of the list
with the variable defined in the let. *)
type scope_lets =
| Result of D.expr Utils.Pos.marked
| ScopeLet of {
scope_let_kind : D.scope_let_kind;
scope_let_typ : D.typ Utils.Pos.marked;
scope_let_expr : D.expr Utils.Pos.marked;
scope_let_next : (D.expr, scope_lets) Bindlib.binder;
scope_let_pos : Utils.Pos.t;
}
type scope_body = {
scope_body_input_struct : D.StructName.t;
scope_body_output_struct : D.StructName.t;
scope_body_result : (D.expr, scope_lets) Bindlib.binder;
}
(** As a consequence, the scope_body contains only a result and input/output
signature, as the other elements are stored inside the scope_let. The binder
present is the argument of type [scope_body_input_struct]. *)
(** Finally, we do the same transformation for the whole program for the kinded
lets. This permit us to use bindlib variables for scopes names. *)
type scopes =
| Nil
| ScopeDef of {
scope_name : D.ScopeName.t;
scope_body : scope_body;
scope_next : (D.expr, scopes) Bindlib.binder;
}
val free_vars_list_scope_lets : scope_lets -> D.Var.t list
(** List of variables not binded inside a scope_lets *)
val free_vars_list_scope_body : scope_body -> D.Var.t list
(** List of variables not binded inside a scope_body. *)
val free_vars_list_scopes : scopes -> D.Var.t list
(** List of variables not binded inside scopes*)
val bind_scopes :
(D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list -> scopes Bindlib.box
(** Transform a list of scopes into our representation of scopes. It requires
that scopes are topologically-well-ordered, and ensure there is no free
variables in the returned [scopes] *)

View File

@ -17,7 +17,6 @@
open Utils
module D = Dcalc.Ast
module A = Ast
open Dcalc.Binded_representation
(** The main idea around this pass is to compile Dcalc to Lcalc without using
[raise EmptyError] nor [try _ with EmptyError -> _]. To do so, we use the
@ -392,7 +391,7 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) :
A.make_matchopt pos_hoist v (D.TAny, pos_hoist) c' (A.make_none pos_hoist)
acc)
let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
let rec translate_scope_let (ctx : ctx) (lets : D.scope_body_expr) =
match lets with
| Result e -> translate_expr ~append_esome:false ctx e
| ScopeLet
@ -484,11 +483,11 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
(translate_expr ctx ~append_esome:false expr)
(translate_scope_let ctx' next)
let translate_scope_body (scope_pos : Pos.t) (ctx : ctx) (body : scope_body) :
let translate_scope_body (scope_pos : Pos.t) (ctx : ctx) (body : D.scope_body) :
A.expr Pos.marked Bindlib.box =
match body with
| {
scope_body_result = result;
scope_body_expr = result;
scope_body_input_struct = input_struct;
scope_body_output_struct = _output_struct;
} ->
@ -502,7 +501,7 @@ let translate_scope_body (scope_pos : Pos.t) (ctx : ctx) (body : scope_body) :
[ (D.TTuple ([], Some input_struct), Pos.no_pos) ]
Pos.no_pos
let rec translate_scopes (ctx : ctx) (scopes : scopes) :
let rec translate_scopes (ctx : ctx) (scopes : D.scopes) :
Ast.scope_body list Bindlib.box =
match scopes with
| Nil -> Bindlib.box []
@ -528,13 +527,13 @@ let rec translate_scopes (ctx : ctx) (scopes : scopes) :
:: tail)
new_body tail
let translate_scopes (ctx : ctx) (scopes : scopes) : Ast.scope_body list =
let translate_scopes (ctx : ctx) (scopes : D.scopes) : Ast.scope_body list =
Bindlib.unbox (translate_scopes ctx scopes)
let translate_program (prgm : D.program) : A.program =
let inputs_structs =
ListLabels.fold_left prgm.scopes ~init:[] ~f:(fun acc (_, _, body) ->
body.D.scope_body_input_struct :: acc)
D.fold_scope_defs prgm.scopes ~init:[] ~f:(fun acc scope_def ->
scope_def.D.scope_body.scope_body_input_struct :: acc)
in
(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]"
@ -566,8 +565,7 @@ let translate_program (prgm : D.program) : A.program =
in
let scopes =
prgm.scopes |> bind_scopes |> Bindlib.unbox
|> translate_scopes { decl_ctx; vars = D.VarMap.empty }
prgm.scopes |> translate_scopes { decl_ctx; vars = D.VarMap.empty }
in
{ scopes; decl_ctx }