fix tentative for functions applications

This commit is contained in:
adelaett 2023-01-06 12:05:38 +01:00
parent 5b33b39636
commit 44ce5a636b
2 changed files with 150 additions and 127 deletions

View File

@ -14,34 +14,39 @@
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
module D = Dcalc.Ast
module A = Ast
(** The main idea around this pass is to compile Dcalc to Lcalc without using
[raise EmptyError] nor [try _ with EmptyError -> _]. To do so, we use the
same technique as in rust or erlang to handle this kind of exceptions. Each
[raise EmptyError] will be translated as [None] and each
[try e1 with EmtpyError -> e2] as
[match e1 with | None -> e2 | Some x -> x].
When doing this naively, this requires to add matches and Some constructor
everywhere. We apply here an other technique where we generate what we call
`hoists`. Hoists are expression whom could minimally [raise EmptyError]. For
instance in [let x = <e1, e2, ..., en| e_just :- e_cons> * 3 in x + 1], the
sub-expression [<e1, e2, ..., en| e_just :- e_cons>] can produce an empty
error. So we make a hoist with a new variable [y] linked to the Dcalc
expression [<e1, e2, ..., en| e_just :- e_cons>], and we return as the
translated expression [let x = y * 3 in x + 1].
The compilation of expressions is found in the functions
[translate_and_hoist ctx e] and [translate_expr ctx e]. Every
option-generating expression when calling [translate_and_hoist] will be
hoisted and later handled by the [translate_expr] function. Every other
cases is found in the translate_and_hoist function. *)
open Catala_utils
module D = Dcalc.Ast
module A = Ast
(** The main idea around this pass is to compile Dcalc to Lcalc without using
[raise EmptyError] nor [try _ with EmptyError -> _]. To do so, we use the
same technique as in rust or erlang to handle this kind of exceptions. Each
[raise EmptyError] will be translated as [None] and each
[try e1 with EmtpyError -> e2] as
[match e1 with | None -> e2 | Some x -> x].
When doing this naively, this requires to add matches and Some constructor
everywhere. We apply here an other technique where we generate what we call
`hoists`. Hoists are expression whom could minimally [raise EmptyError]. For
instance in [let x = <e1, e2, ..., en| e_just :- e_cons> * 3 in x + 1], the
sub-expression [<e1, e2, ..., en| e_just :- e_cons>] can produce an empty
error. So we make a hoist with a new variable [y] linked to the Dcalc
expression [<e1, e2, ..., en| e_just :- e_cons>], and we return as the
translated expression [let x = y * 3 in x + 1].
The compilation of expressions is found in the functions
[translate_and_hoist ctx e] and [translate_expr ctx e]. Every
option-generating expression when calling [translate_and_hoist] will be
hoisted and later handled by the [translate_expr] function. Every other
cases is found in the translate_and_hoist function.
Problem arise when there is a function application.
*)
open Shared_ast
let admit: 'a = assert false
type 'm hoists = ('m A.expr, 'm D.expr) Var.Map.t
(** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *)
@ -81,12 +86,12 @@ let _pp_ctx (fmt : Format.formatter) (ctx : 'm ctx) =
let find ?(info : string = "none") (n : 'm D.expr Var.t) (ctx : 'm ctx) :
'm info =
(* let _ = Format.asprintf "Searching for variable %a inside context %a"
Print.var n pp_ctx ctx |> Cli.debug_print in *)
Print.var n pp_ctx ctx |> Cli.debug_print in *)
try Var.Map.find n ctx.vars
with Not_found ->
Errors.raise_spanned_error Pos.no_pos
"Internal Error: Variable %a was not found in the current environment. \
Additional informations : %s."
Additional informations : %s."
Print.var n info
(** [add_var pos var is_pure ctx] add to the context [ctx] the Dcalc variable
@ -102,7 +107,7 @@ let add_var
let expr = Expr.make_var new_var mark in
(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Print.var var Print.var
new_var; *)
new_var; *)
{
ctx with
vars =
@ -125,7 +130,7 @@ let rec translate_typ (tau : typ) : typ =
| TTuple ts -> TTuple (List.map translate_typ ts)
| TStruct s -> TStruct s
| TEnum en -> TEnum en
| TOption t -> TOption t
| TOption _ -> assert false
| TAny -> TAny
| TArray ts -> TArray (translate_typ ts)
(* catala is not polymorphic *)
@ -141,7 +146,7 @@ let disjoint_union_maps (pos : Pos.t) (cs : ('e, 'a) Var.Map.t list) :
Var.Map.union (fun _ _ _ ->
Errors.raise_spanned_error pos
"Internal Error: Two supposed to be disjoints maps have one shared \
key.")
key.")
in
List.fold_left disjoint_union Var.Map.empty cs
@ -157,24 +162,24 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let pos = Expr.mark_pos mark in
match Marked.unmark e with
(* empty-producing/using terms. We hoist those. (D.EVar in some cases,
EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure
about assert. *)
EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure
about assert. *)
| EVar v ->
(* todo: for now, every unpure (such that [is_pure] is [false] in the
current context) is thunked, hence matched in the next case. This
assumption can change in the future, and this case is here for this
reason. *)
current context) is thunked, hence matched in the next case. This
assumption can change in the future, and this case is here for this
reason. *)
if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = Var.make (Bindlib.name_of v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Print.var v Print.var v'; *)
created a variable %a to replace it" Print.var v Print.var v'; *)
Expr.make_var v' mark, Var.Map.singleton v' e
else (find ~info:"should never happen" v ctx).expr, Var.Map.empty
| EApp { f = EVar v, p; args = [(ELit LUnit, _)] } ->
if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = Var.make (Bindlib.name_of v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Print.var v Print.var v'; *)
created a variable %a to replace it" Print.var v Print.var v'; *)
Expr.make_var v' mark, Var.Map.singleton v' (EVar v, p)
else
Errors.raise_spanned_error (Expr.pos e)
@ -186,7 +191,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let v' = Var.make "empty_litteral" in
Expr.make_var v' mark, Var.Map.singleton v' e
(* This one is a very special case. It transform an unpure expression
environement to a pure expression. *)
environement to a pure expression. *)
| EErrorOnEmpty arg ->
(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)
let silent_var = Var.make "_" in
@ -197,8 +202,8 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
( A.make_matchopt_with_abs_arms arg'
(Expr.make_abs [| silent_var |]
(Expr.eraise NoValueProvided (Expr.with_ty mark rty))
[rty] pos)
(Expr.eraise NoValueProvided (Expr.with_ty mark rty))
[rty] pos)
(Expr.make_abs [| x |] (Expr.make_var x mark) [rty] pos),
Var.Map.empty )
(* pure terms *)
@ -219,40 +224,41 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
e', disjoint_union_maps (Expr.pos e) [h1; h2; h3]
| EAssert e1 ->
(* same behavior as in the ICFP paper: if e1 is empty, then no error is
raised. *)
raised. *)
let e1', h1 = translate_and_hoist ctx e1 in
Expr.eassert e1' mark, h1
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars =
ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) ->
(* we suppose the invariant that when applying a function, its
arguments cannot be of the type "option".
(* We suppose the invariant that when applying a function, its
arguments cannot be of the type "option".
The code should behave correctly in the without this assumption if
we put here an is_pure=false, but the types are more compilcated.
(unimplemented for now) *)
The code should behave correctly in the without this assumption if
we put here an is_pure=false, but the types are more compilcated.
(unimplemented for now) *)
let ctx = add_var mark var true ctx in
let lc_var = (find var ctx).var in
ctx, lc_var :: lc_vars)
in
let lc_vars = Array.of_list lc_vars in
(* here we take the guess that if we cannot build the closure because one of
the variable is empty, then we cannot build the function. *)
let new_body, hoists = translate_and_hoist ctx body in
(* Even if abstractions cannot have unpure arguments, it is possible its returns unpure values. For instance, the term $fun x -> <|x > 0 :- x>$ is valid and appear in the basecode. Hence, we need to translate it using the transalte_expr function. This is linked to a more complex handling of the EApp case. *)
let new_body = translate_expr ctx body in
let new_binder = Expr.bind lc_vars new_body in
Expr.eabs new_binder (List.map translate_typ tys) mark, hoists
| EApp { f = e1; args } ->
let e1', h1 = translate_and_hoist ctx e1 in
let args', h_args =
args |> List.map (translate_and_hoist ctx) |> List.split
in
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in
let e' = Expr.eapp e1' args' mark in
e', hoists
Expr.eabs new_binder (List.map translate_typ tys) mark, Var.Map.empty
| EApp {f; args} ->
begin match Marked.unmark f with
| EOp _ ->
let f', h1 = translate_and_hoist ctx f in
let args', h_args = args |> List.map (translate_and_hoist ctx) |> List.split in
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in
Expr.eapp f' args' mark, hoists
| _ ->
let v' = Var.make "function_application" in
Expr.make_var v' mark, Var.Map.singleton v' e
end
| EStruct { name; fields } ->
let fields', h_fields =
StructField.Map.fold
@ -286,6 +292,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let e1' = Expr.einj e1' cons name mark in
e1', hoists
| EMatch { name; e = e1; cases } ->
(* The current encoding of matches is e with an expression, that will be deconstructed and a series of cases. Each cases is an key constructor and a expression that contains a lambda expression. Hence the following encoding is correct: hoist each branches & the destructed expression. *)
let e1', h1 = translate_and_hoist ctx e1 in
let cases', h_cases =
EnumConstructor.Map.fold
@ -304,6 +311,66 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
Expr.earray es' mark, disjoint_union_maps (Expr.pos e) hoists
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys mark, Var.Map.empty
and translate_hoists ~append_esome ctx hoists kont =
ListLabels.fold_left hoists
~init:(if append_esome then A.make_some kont else kont)
~f:(fun acc (v, (hoist, mark_hoist)) ->
(* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.var v; *)
let pos = Expr.mark_pos mark_hoist in
let c' : 'm A.expr boxed =
match hoist with
(* Here we have to handle only the cases appearing in hoists, as defined
the [translate_and_hoist] function. *)
| EVar v -> (find ~info:"should never happen" v ctx).expr
| EDefault { excepts; just; cons } ->
let excepts' = List.map (translate_expr ctx) excepts in
let just' = translate_expr ctx just in
let cons' = translate_expr ctx cons in
(* calls handle_option. *)
Expr.make_app
(Expr.make_var (Var.translate A.handle_default_opt) mark_hoist)
[Expr.earray excepts' mark_hoist; just'; cons']
pos
| ELit LEmptyError -> A.make_none mark_hoist
| EApp { f; args } ->
let f = translate_expr ctx f in
let args = List.map (translate_expr ctx) args in
(* let*m args' = args' and* f' = f' in f' args' *)
A.make_bind_cont mark_hoist f
begin fun f ->
A.make_bindm_cont mark_hoist args begin fun args ->
Expr.make_app f args pos
(* A.make_bind_cont mark_hosit (Expr.make_app f args pos) *)
end
end
(* assert false *)
| EAssert arg ->
let arg' = translate_expr ctx arg in
(* [ match arg with | None -> raise NoValueProvided | Some v -> assert
{{ v }} ] *)
let silent_var = Var.make "_" in
let x = Var.make "assertion_argument" in
A.make_matchopt_with_abs_arms arg'
(Expr.make_abs [| silent_var |]
(Expr.eraise NoValueProvided mark_hoist)
[TAny, Expr.mark_pos mark_hoist]
pos)
(Expr.make_abs [| x |]
(Expr.eassert (Expr.make_var x mark_hoist) mark_hoist)
[TAny, Expr.mark_pos mark_hoist]
pos)
| _ ->
Errors.raise_spanned_error (Expr.mark_pos mark_hoist)
"Internal Error: An term was found in a position where it should \
not be"
in
A.make_matchopt pos v
(TAny, Expr.mark_pos mark_hoist)
c' (A.make_none mark_hoist) acc)
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
'm A.expr boxed =
let e', hoists = translate_and_hoist ctx e in
@ -313,58 +380,12 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
(* build the hoists *)
(* Cli.debug_print @@ Format.asprintf "hoist for the expression: [%a]"
(Format.pp_print_list Print.var) (List.map fst hoists); *)
ListLabels.fold_left hoists
~init:(if append_esome then A.make_some e' else e')
~f:(fun acc (v, (hoist, mark_hoist)) ->
(* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.var v; *)
let pos = Expr.mark_pos mark_hoist in
let c' : 'm A.expr boxed =
match hoist with
(* Here we have to handle only the cases appearing in hoists, as defined
the [translate_and_hoist] function. *)
| EVar v -> (find ~info:"should never happen" v ctx).expr
| EDefault { excepts; just; cons } ->
let excepts' = List.map (translate_expr ctx) excepts in
let just' = translate_expr ctx just in
let cons' = translate_expr ctx cons in
(* calls handle_option. *)
Expr.make_app
(Expr.make_var (Var.translate A.handle_default_opt) mark_hoist)
[Expr.earray excepts' mark_hoist; just'; cons']
pos
| ELit LEmptyError -> A.make_none mark_hoist
| EAssert arg ->
let arg' = translate_expr ctx arg in
(Format.pp_print_list Print.var) (List.map fst hoists); *)
(* [ match arg with | None -> raise NoValueProvided | Some v -> assert
{{ v }} ] *)
let silent_var = Var.make "_" in
let x = Var.make "assertion_argument" in
A.make_matchopt_with_abs_arms arg'
(Expr.make_abs [| silent_var |]
(Expr.eraise NoValueProvided mark_hoist)
[TAny, Expr.mark_pos mark_hoist]
pos)
(Expr.make_abs [| x |]
(Expr.eassert (Expr.make_var x mark_hoist) mark_hoist)
[TAny, Expr.mark_pos mark_hoist]
pos)
| _ ->
Errors.raise_spanned_error (Expr.mark_pos mark_hoist)
"Internal Error: An term was found in a position where it should \
not be"
in
(* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end
] *)
(* Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.var
v; *)
A.make_matchopt pos v
(TAny, Expr.mark_pos mark_hoist)
c' (A.make_none mark_hoist) acc)
translate_hoists ~append_esome ctx hoists e'
let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
'm A.expr scope_body_expr Bindlib.box =
match lets with
@ -381,7 +402,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
scope_let_pos = pos;
} ->
(* special case : the subscope variable is thunked (context i/o). We remove
this thunking. *)
this thunking. *)
let _, expr = Bindlib.unmbind binder in
let var_is_pure = true in
@ -439,8 +460,8 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
} ->
Errors.raise_spanned_error pos
"Internal Error: found an SubScopeVarDefinition that does not satisfy \
the invariants when translating Dcalc to Lcalc without exceptions: \
@[<hov 2>%a@]"
the invariants when translating Dcalc to Lcalc without exceptions: \
@[<hov 2>%a@]"
(Expr.format ctx.decl_ctx) expr
| ScopeLet
{
@ -454,9 +475,9 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
match kind with
| DestructuringInputStruct -> (
(* Here, we have to distinguish between context and input variables. We
can do so by looking at the typ of the destructuring: if it's
thunked, then the variable is context. If it's not thunked, it's a
regular input. *)
can do so by looking at the typ of the destructuring: if it's
thunked, then the variable is context. If it's not thunked, it's a
regular input. *)
match Marked.unmark typ with
| TArrow ([(TLit TUnit, _)], _) -> false
| _ -> true)
@ -485,12 +506,12 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
let translate_scope_body
(scope_pos : Pos.t)
(ctx : 'm ctx)
(body : 'm D.expr scope_body) : 'm A.expr scope_body Bindlib.box =
(body : typed D.expr scope_body) : 'm A.expr scope_body Bindlib.box =
match body with
| {
scope_body_expr = result;
scope_body_input_struct = input_struct;
scope_body_output_struct = output_struct;
scope_body_expr = result;
scope_body_input_struct = input_struct;
scope_body_output_struct = output_struct;
} ->
let v, lets = Bindlib.unbind result in
let vmark =
@ -532,7 +553,7 @@ let translate_code_items (ctx : 'm ctx) (scopes : 'm D.expr code_item_list) :
in
scopes
let translate_program (prgm : 'm D.program) : 'm A.program =
let translate_program (prgm : typed D.program) : 'm A.program =
let inputs_structs =
Scope.fold_left prgm.code_items ~init:[] ~f:(fun acc def _ ->
match def with
@ -540,7 +561,7 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
| Topdef _ -> acc)
in
(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]"
(Format.pp_print_list D.StructName.format_t) inputs_structs; *)
(Format.pp_print_list D.StructName.format_t) inputs_structs; *)
let decl_ctx =
{
prgm.decl_ctx with
@ -555,13 +576,13 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
ctx_structs =
prgm.decl_ctx.ctx_structs
|> StructName.Map.mapi (fun n str ->
if List.mem n inputs_structs then
StructField.Map.map translate_typ str
(* Cli.debug_print @@ Format.asprintf "Input type: %a"
if List.mem n inputs_structs then
StructField.Map.map translate_typ str
(* Cli.debug_print @@ Format.asprintf "Input type: %a"
(Print.typ decl_ctx) tau; Cli.debug_print @@ Format.asprintf
"Output type: %a" (Print.typ decl_ctx) (translate_typ
tau); *)
else str);
else str);
}
in

View File

@ -19,4 +19,6 @@
transformation is one piece to permit to compile toward legacy languages
that does not contains exceptions. *)
val translate_program : 'm Dcalc.Ast.program -> 'm Ast.program
open Shared_ast
val translate_program :typed Dcalc.Ast.program -> 'm Ast.program