scopes does not return optional terms

This commit is contained in:
adelaett 2023-03-23 10:45:44 +01:00
parent 78c0842dc6
commit 72ceafd67c
2 changed files with 61 additions and 41 deletions

View File

@ -29,7 +29,7 @@
buildDunePackage {
pname = "catala";
version = "0.7.0"; # TODO parse `catala.opam` with opam2json
version = "0.8.0"; # TODO parse `catala.opam` with opam2json
minimumOCamlVersion = "4.11";
@ -37,7 +37,7 @@ buildDunePackage {
duneVersion = "3";
nativeBuildInputs = [ cppo menhir ];
nativeBuildInputs = [ cppo menhir ocaml-crunch ];
propagatedBuildInputs = [
alcotest

View File

@ -44,7 +44,7 @@ module A = Ast
open Shared_ast
type info_pure = { info_pure : bool }
type info_pure = { info_pure : bool; is_scope : bool }
(** Information about each encontered Dcalc variable is stored inside a context
: what is the corresponding LCalc variable; an expression corresponding to
@ -57,22 +57,8 @@ type info_pure = { info_pure : bool }
Since positions where there is thunked expressions is exactly where we will
put option expressions. Hence, the transformation simply reduce [unit -> 'a]
into ['a option] recursivly. There is no polymorphism inside catala. *)
let rec translate_typ (tau : typ) : typ =
(Fun.flip Marked.same_mark_as)
tau
begin
match Marked.unmark tau with
| TLit l -> TLit l
| TTuple ts -> TTuple (List.map translate_typ ts)
| TStruct s -> TStruct s
| TEnum en -> TEnum en
| TOption _ -> assert false
| TAny -> TAny
| TArray ts -> TArray (translate_typ ts)
(* catala is not polymorphic *)
| TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2)
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
end
let trans_typ (tau : typ) : typ = Marked.same_mark_as TAny tau
let monad_return e ~(mark : 'a mark) =
Expr.einj e Ast.some_constr Ast.option_enum mark
@ -251,10 +237,18 @@ let rec trans ctx (e : 'm D.expr) : (lcalc, 'm mark) boxed_gexpr =
Invariant failed : scope calls are encoded using an other technique. *)
Cli.debug_format "%s %b" (Bindlib.name_of ff)
(Var.Map.find ff ctx).info_pure;
let f_var = Var.make "fff" in
monad_bind_var ~mark
(monad_mbind (Expr.evar f_var mark) (List.map (trans ctx) args) ~mark)
f_var (trans ctx f)
(* todo: is_scope test *)
if (Var.Map.find ff ctx).is_scope then
let f_var = Var.make "fff" in
monad_bind_var ~mark
(monad_mmap (Expr.evar f_var mark) (List.map (trans ctx) args) ~mark)
f_var (trans ctx f)
else
let f_var = Var.make "fff" in
monad_bind_var ~mark
(monad_mbind (Expr.evar f_var mark) (List.map (trans ctx) args) ~mark)
f_var (trans ctx f)
| EApp { f = EAbs { binder; _ }, _; args } ->
(* Invariant: every let have only one argument. (checked by
invariant_let) *)
@ -262,7 +256,7 @@ let rec trans ctx (e : 'm D.expr) : (lcalc, 'm mark) boxed_gexpr =
let[@warning "-8"] [| var |] = var in
let var' = Var.translate var in
let[@warning "-8"] [arg] = args in
let ctx' = Var.Map.add var { info_pure = true } ctx in
let ctx' = Var.Map.add var { info_pure = true; is_scope = false } ctx in
monad_bind_var (trans ctx arg) var' (trans ctx' body) ~mark
| EApp { f = EApp { f = EOp { op = Op.Log _; _ }, _; args = _ }, _; _ } ->
assert false
@ -287,7 +281,7 @@ let rec trans ctx (e : 'm D.expr) : (lcalc, 'm mark) boxed_gexpr =
let vars, body = Bindlib.unmbind binder in
let ctx' =
ArrayLabels.fold_right vars ~init:ctx ~f:(fun var ->
Var.Map.add var { info_pure = true })
Var.Map.add var { info_pure = true; is_scope = false })
in
let binder =
Expr.bind (Array.map Var.translate vars) (trans ctx' body)
@ -379,7 +373,9 @@ let rec trans_scope_let ctx s =
let _, scope_let_expr = Bindlib.unmbind binder in
let next_var, next_body = Bindlib.unbind scope_let_next in
let ctx' = Var.Map.add next_var { info_pure = false } ctx in
let ctx' =
Var.Map.add next_var { info_pure = false; is_scope = false } ctx
in
let next_var = Var.translate next_var in
let next_body = trans_scope_body_expr ctx' next_body in
@ -392,7 +388,7 @@ let rec trans_scope_let ctx s =
(fun scope_let_expr scope_let_next ->
{
scope_let_kind = SubScopeVarDefinition;
scope_let_typ = translate_typ scope_let_typ;
scope_let_typ = trans_typ scope_let_typ;
scope_let_expr;
scope_let_next;
scope_let_pos;
@ -408,7 +404,9 @@ let rec trans_scope_let ctx s =
(* special case: regular input to the subscope *)
let next_var, next_body = Bindlib.unbind scope_let_next in
let ctx' = Var.Map.add next_var { info_pure = false } ctx in
let ctx' =
Var.Map.add next_var { info_pure = false; is_scope = false } ctx
in
let next_var = Var.translate next_var in
let next_body = trans_scope_body_expr ctx' next_body in
@ -423,7 +421,7 @@ let rec trans_scope_let ctx s =
(fun scope_let_expr scope_let_next ->
{
scope_let_kind = SubScopeVarDefinition;
scope_let_typ = translate_typ scope_let_typ;
scope_let_typ = trans_typ scope_let_typ;
scope_let_expr;
scope_let_next;
scope_let_pos;
@ -452,11 +450,12 @@ let rec trans_scope_let ctx s =
thunked, then the variable is context. If it's not thunked, it's a
regular input. *)
match Marked.unmark scope_let_typ with
| TArrow ([(TLit TUnit, _)], _) -> { info_pure = false }
| _ -> { info_pure = false })
| TArrow ([(TLit TUnit, _)], _) ->
{ info_pure = false; is_scope = false }
| _ -> { info_pure = false; is_scope = false })
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
| DestructuringSubScopeResults | Assertion ->
{ info_pure = false })
{ info_pure = false; is_scope = false })
ctx
in
@ -471,7 +470,7 @@ let rec trans_scope_let ctx s =
(fun scope_let_expr scope_let_next ->
{
scope_let_kind;
scope_let_typ = translate_typ scope_let_typ;
scope_let_typ = trans_typ scope_let_typ;
scope_let_expr;
scope_let_next;
scope_let_pos;
@ -488,10 +487,9 @@ and trans_scope_body_expr ctx s :
Bindlib.box_apply
(fun e -> Result e)
(Expr.Box.lift
@@ monad_return ~mark:(Marked.get_mark e)
(Expr.estruct name
(StructField.Map.map (trans ctx) fields)
(Marked.get_mark e)))
@@ Expr.estruct name
(StructField.Map.map (trans ctx) fields)
(Marked.get_mark e))
| _ -> assert false
end
| ScopeLet s ->
@ -502,7 +500,9 @@ let trans_scope_body
{ scope_body_input_struct; scope_body_output_struct; scope_body_expr } =
let var, body = Bindlib.unbind scope_body_expr in
let body =
trans_scope_body_expr (Var.Map.add var { info_pure = true } ctx) body
trans_scope_body_expr
(Var.Map.add var { info_pure = true; is_scope = false } ctx)
body
in
let binder = Bindlib.bind_var (Var.translate var) body in
Bindlib.box_apply
@ -520,23 +520,43 @@ let rec trans_code_items ctx c :
| Topdef (name, typ, e) ->
let next =
Bindlib.bind_var (Var.translate var)
(trans_code_items (Var.Map.add var { info_pure = false } ctx) next)
(trans_code_items
(Var.Map.add var { info_pure = false; is_scope = false } ctx)
next)
in
let e = Expr.Box.lift @@ trans ctx e in
(* TODO: need to add an error_on_empty *)
Bindlib.box_apply2
(fun next e -> Cons (Topdef (name, translate_typ typ, e), next))
(fun next e -> Cons (Topdef (name, trans_typ typ, e), next))
next e
| ScopeDef (name, body) ->
let next =
Bindlib.bind_var (Var.translate var)
(trans_code_items (Var.Map.add var { info_pure = true } ctx) next)
(trans_code_items
(Var.Map.add var { info_pure = true; is_scope = true } ctx)
next)
in
let body = trans_scope_body ctx body in
Bindlib.box_apply2
(fun next body -> Cons (ScopeDef (name, body), next))
next body)
let rec translate_typ (tau : typ) : typ =
(Fun.flip Marked.same_mark_as)
tau
begin
match Marked.unmark tau with
| TLit l -> TLit l
| TTuple ts -> TTuple (List.map translate_typ ts)
| TStruct s -> TStruct s
| TEnum en -> TEnum en
| TOption _ -> assert false
| TAny -> TAny
| TArray ts -> TArray (translate_typ ts) (* catala is not polymorphic *)
| TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2)
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
end
let translate_program (prgm : typed D.program) : untyped A.program =
let inputs_structs =
Scope.fold_left prgm.code_items ~init:[] ~f:(fun acc def _ ->