mirror of
https://github.com/CatalaLang/catala.git
synced 2024-09-19 16:28:12 +03:00
Adapt translation to new i/o invariants, bug discovered
This commit is contained in:
parent
cab4e5c17e
commit
48f064ccea
@ -81,18 +81,13 @@ let bind_scope_lets (acc : scope_lets Bindlib.box) (scope_let : D.scope_let) :
|
||||
scope_lets Bindlib.box =
|
||||
let pos = snd scope_let.D.scope_let_var in
|
||||
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "binding let %a. Variable occurs = %b" Print.format_var
|
||||
(fst scope_let.D.scope_let_var)
|
||||
(Bindlib.occur (fst scope_let.D.scope_let_var) acc);
|
||||
|
||||
(* Cli.debug_print @@ Format.asprintf "binding let %a. Variable occurs = %b" Print.format_var (fst
|
||||
scope_let.D.scope_let_var) (Bindlib.occur (fst scope_let.D.scope_let_var) acc); *)
|
||||
let binder = Bindlib.bind_var (fst scope_let.D.scope_let_var) acc in
|
||||
Bindlib.box_apply2
|
||||
(fun expr binder ->
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "free variables in expression: %a"
|
||||
(Format.pp_print_list Print.format_var)
|
||||
(D.free_vars expr);
|
||||
(* Cli.debug_print @@ Format.asprintf "free variables in expression: %a" (Format.pp_print_list
|
||||
Print.format_var) (D.free_vars expr); *)
|
||||
ScopeLet
|
||||
{
|
||||
scope_let_kind = scope_let.D.scope_let_kind;
|
||||
@ -111,15 +106,15 @@ let bind_scope_body (body : D.scope_body) : scope_body Bindlib.box =
|
||||
~f:(Fun.flip bind_scope_lets)
|
||||
in
|
||||
|
||||
Cli.debug_print @@ Format.asprintf "binding arg %a" Print.format_var body.D.scope_body_arg;
|
||||
(* Cli.debug_print @@ Format.asprintf "binding arg %a" Print.format_var body.D.scope_body_arg; *)
|
||||
let scope_body_result = Bindlib.bind_var body.D.scope_body_arg body_result in
|
||||
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "isfinal term is closed: %b" (Bindlib.is_closed scope_body_result);
|
||||
(* Cli.debug_print @@ Format.asprintf "isfinal term is closed: %b" (Bindlib.is_closed
|
||||
scope_body_result); *)
|
||||
Bindlib.box_apply
|
||||
(fun scope_body_result ->
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "rank of the final term: %i" (Bindlib.binder_rank scope_body_result);
|
||||
(* Cli.debug_print @@ Format.asprintf "rank of the final term: %i" (Bindlib.binder_rank
|
||||
scope_body_result); *)
|
||||
{
|
||||
scope_body_output_struct = body.D.scope_body_output_struct;
|
||||
scope_body_input_struct = body.D.scope_body_input_struct;
|
||||
@ -137,10 +132,6 @@ let bind_scope
|
||||
let bind_scopes (scopes : (D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list) :
|
||||
scopes Bindlib.box =
|
||||
let result = ListLabels.fold_right scopes ~init:(Bindlib.box Nil) ~f:bind_scope in
|
||||
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "free variable in the program : [%a]"
|
||||
(Format.pp_print_list Print.format_var)
|
||||
(free_vars_scopes (Bindlib.unbox result));
|
||||
|
||||
(* Cli.debug_print @@ Format.asprintf "free variable in the program : [%a]" (Format.pp_print_list
|
||||
Print.format_var) (free_vars_scopes (Bindlib.unbox result)); *)
|
||||
result
|
||||
|
@ -49,10 +49,12 @@ type info = { expr : A.expr Pos.marked Bindlib.box; var : A.expr Bindlib.var; is
|
||||
let pp_info (fmt : Format.formatter) (info : info) =
|
||||
Format.fprintf fmt "{var: %a; is_pure: %b}" Print.format_var info.var info.is_pure
|
||||
|
||||
type ctx = info D.VarMap.t
|
||||
(** information context about variables in the current scope *)
|
||||
type ctx = {
|
||||
decl_ctx : D.decl_ctx;
|
||||
vars : info D.VarMap.t; (** information context about variables in the current scope *)
|
||||
}
|
||||
|
||||
let pp_ctx (fmt : Format.formatter) (ctx : ctx) =
|
||||
let _pp_ctx (fmt : Format.formatter) (ctx : ctx) =
|
||||
let pp_binding (fmt : Format.formatter) ((v, info) : D.Var.t * info) =
|
||||
Format.fprintf fmt "%a: %a" Dcalc.Print.format_var v pp_info info
|
||||
in
|
||||
@ -61,16 +63,13 @@ let pp_ctx (fmt : Format.formatter) (ctx : ctx) =
|
||||
Format.pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt "; ") pp_binding
|
||||
in
|
||||
|
||||
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx)
|
||||
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx.vars)
|
||||
|
||||
(** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a slightly better way. *)
|
||||
let find ?(info : string = "none") (n : D.Var.t) (ctx : ctx) : info =
|
||||
let _ =
|
||||
Format.asprintf "Searching for variable %a inside context %a" Dcalc.Print.format_var n pp_ctx
|
||||
ctx
|
||||
|> Cli.debug_print
|
||||
in
|
||||
try D.VarMap.find n ctx
|
||||
(* let _ = Format.asprintf "Searching for variable %a inside context %a" Dcalc.Print.format_var n
|
||||
pp_ctx ctx |> Cli.debug_print in *)
|
||||
try D.VarMap.find n ctx.vars
|
||||
with Not_found ->
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf
|
||||
@ -86,10 +85,9 @@ let add_var (pos : Pos.t) (var : D.Var.t) (is_pure : bool) (ctx : ctx) : ctx =
|
||||
let new_var = A.Var.make (Bindlib.name_of var, pos) in
|
||||
let expr = A.make_var (new_var, pos) in
|
||||
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var var Print.format_var new_var;
|
||||
|
||||
D.VarMap.update var (fun _ -> Some { expr; var = new_var; is_pure }) ctx
|
||||
(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var var Print.format_var
|
||||
new_var; *)
|
||||
{ ctx with vars = D.VarMap.update var (fun _ -> Some { expr; var = new_var; is_pure }) ctx.vars }
|
||||
|
||||
(** [tau' = translate_typ tau] translate the a dcalc type into a lcalc type.
|
||||
|
||||
@ -150,22 +148,18 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
|
||||
thunked, hence matched in the next case. This assumption can change in the future, and this
|
||||
case is here for this reason. *)
|
||||
let v, pos_v = v in
|
||||
if not (find ~info:"search for a variable" v ctx).is_pure then begin
|
||||
if not (find ~info:"search for a variable" v ctx).is_pure then
|
||||
let v' = A.Var.make (Bindlib.name_of v, pos_v) in
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "Found an unpure variable %a, created a variable %a to replace it"
|
||||
Dcalc.Print.format_var v Print.format_var v';
|
||||
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to
|
||||
replace it" Dcalc.Print.format_var v Print.format_var v'; *)
|
||||
(A.make_var (v', pos), A.VarMap.singleton v' e)
|
||||
end
|
||||
else ((find ~info:"should never happend" v ctx).expr, A.VarMap.empty)
|
||||
| D.EApp ((D.EVar (v, pos_v), p), [ (D.ELit D.LUnit, _) ]) ->
|
||||
if not (find ~info:"search for a variable" v ctx).is_pure then begin
|
||||
if not (find ~info:"search for a variable" v ctx).is_pure then
|
||||
let v' = A.Var.make (Bindlib.name_of v, pos_v) in
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "Found an unpure variable %a, created a variable %a to replace it"
|
||||
Dcalc.Print.format_var v Print.format_var v';
|
||||
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to
|
||||
replace it" Dcalc.Print.format_var v Print.format_var v'; *)
|
||||
(A.make_var (v', pos), A.VarMap.singleton v' (D.EVar (v, pos_v), p))
|
||||
end
|
||||
else
|
||||
Errors.raise_spanned_error
|
||||
"Internal error: an pure variable was found in an unpure environment." pos
|
||||
@ -282,16 +276,12 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) :
|
||||
let _pos = Pos.get_position e in
|
||||
|
||||
(* build the hoists *)
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "hoist for the expression: [%a]"
|
||||
(Format.pp_print_list Print.format_var)
|
||||
(List.map fst hoists);
|
||||
|
||||
(* Cli.debug_print @@ Format.asprintf "hoist for the expression: [%a]" (Format.pp_print_list
|
||||
Print.format_var) (List.map fst hoists); *)
|
||||
ListLabels.fold_left hoists
|
||||
~init:(if append_esome then A.make_some e' else e')
|
||||
~f:(fun acc (v, (hoist, pos_hoist)) ->
|
||||
Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var v;
|
||||
|
||||
(* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var v; *)
|
||||
let c' : A.expr Pos.marked Bindlib.box =
|
||||
match hoist with
|
||||
(* Here we have to handle only the cases appearing in hoists, as defined the
|
||||
@ -335,7 +325,7 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) :
|
||||
in
|
||||
|
||||
(* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end ] *)
|
||||
Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.format_var v;
|
||||
(* Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.format_var v; *)
|
||||
A.make_matchopt pos_hoist v (D.TAny, pos_hoist) c' (A.make_none pos_hoist) acc)
|
||||
|
||||
let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
|
||||
@ -349,21 +339,42 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
|
||||
scope_let_next = next;
|
||||
scope_let_pos = pos;
|
||||
} ->
|
||||
(* special case : the subscope variable is always thunked. We remove this thunking. *)
|
||||
(* special case : the subscope variable is thunked (context i/o). We remove this thunking. *)
|
||||
let _, expr = Bindlib.unmbind binder in
|
||||
|
||||
let var_is_pure = true in
|
||||
let var, next = Bindlib.unbind next in
|
||||
Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var;
|
||||
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)
|
||||
let ctx' = add_var pos var var_is_pure ctx in
|
||||
let new_var = (find ~info:"variable that was just created" var ctx').var in
|
||||
A.make_let_in new_var (translate_typ typ)
|
||||
(translate_expr ctx ~append_esome:true expr)
|
||||
(translate_expr ctx ~append_esome:false expr)
|
||||
(translate_scope_let ctx' next)
|
||||
| ScopeLet { scope_let_kind = SubScopeVarDefinition; scope_let_pos = pos; _ } ->
|
||||
| ScopeLet
|
||||
{
|
||||
scope_let_kind = SubScopeVarDefinition;
|
||||
scope_let_typ = typ;
|
||||
scope_let_expr = (D.ErrorOnEmpty _, _) as expr;
|
||||
scope_let_next = next;
|
||||
scope_let_pos = pos;
|
||||
} ->
|
||||
(* special case: regular input to the subscope *)
|
||||
let var_is_pure = true in
|
||||
let var, next = Bindlib.unbind next in
|
||||
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)
|
||||
let ctx' = add_var pos var var_is_pure ctx in
|
||||
let new_var = (find ~info:"variable that was just created" var ctx').var in
|
||||
A.make_let_in new_var (translate_typ typ)
|
||||
(translate_expr ctx ~append_esome:false expr)
|
||||
(translate_scope_let ctx' next)
|
||||
| ScopeLet
|
||||
{ scope_let_kind = SubScopeVarDefinition; scope_let_pos = pos; scope_let_expr = expr; _ } ->
|
||||
Errors.raise_spanned_error
|
||||
"Internal Error: found an SubScopeVarDefinition that does not satisfy the thunked \
|
||||
invariant when translating Dcalc to Lcalc without exceptions."
|
||||
(Format.asprintf
|
||||
"Internal Error: found an SubScopeVarDefinition that does not satisfy the invariants \
|
||||
when translating Dcalc to Lcalc without exceptions: @[<hov 2>%a@]"
|
||||
(Dcalc.Print.format_expr ctx.decl_ctx)
|
||||
expr)
|
||||
pos
|
||||
| ScopeLet
|
||||
{
|
||||
@ -375,21 +386,25 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
|
||||
} ->
|
||||
let var_is_pure =
|
||||
match kind with
|
||||
| DestructuringInputStruct -> false
|
||||
| DestructuringInputStruct -> (
|
||||
(* Here, we have to distinguish between context and input variables. We can do so by
|
||||
looking at the typ of the destructuring: if it's thunked, then the variable is
|
||||
context. If it's not thunked, it's a regular input. *)
|
||||
match Pos.unmark typ with D.TArrow ((D.TLit D.TUnit, _), _) -> false | _ -> true)
|
||||
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
|
||||
| DestructuringSubScopeResults | Assertion ->
|
||||
true
|
||||
in
|
||||
let var, next = Bindlib.unbind next in
|
||||
Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var;
|
||||
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)
|
||||
let ctx' = add_var pos var var_is_pure ctx in
|
||||
let new_var = (find ~info:"variable that was just created" var ctx').var in
|
||||
A.make_let_in new_var (translate_typ typ)
|
||||
(translate_expr ctx ~append_esome:false expr)
|
||||
(translate_scope_let ctx' next)
|
||||
|
||||
let translate_scope_body (scope_pos : Pos.t) (_decl_ctx : D.decl_ctx) (ctx : ctx)
|
||||
(body : scope_body) : A.expr Pos.marked Bindlib.box =
|
||||
let translate_scope_body (scope_pos : Pos.t) (ctx : ctx) (body : scope_body) :
|
||||
A.expr Pos.marked Bindlib.box =
|
||||
match body with
|
||||
| {
|
||||
scope_body_result = result;
|
||||
@ -404,8 +419,7 @@ let translate_scope_body (scope_pos : Pos.t) (_decl_ctx : D.decl_ctx) (ctx : ctx
|
||||
[ (D.TTuple ([], Some input_struct), Pos.no_pos) ]
|
||||
Pos.no_pos
|
||||
|
||||
let rec translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) :
|
||||
Ast.scope_body list Bindlib.box =
|
||||
let rec translate_scopes (ctx : ctx) (scopes : scopes) : Ast.scope_body list Bindlib.box =
|
||||
match scopes with
|
||||
| Nil -> Bindlib.box []
|
||||
| ScopeDef { scope_name; scope_body; scope_next } ->
|
||||
@ -415,8 +429,8 @@ let rec translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) :
|
||||
|
||||
let scope_pos = Pos.get_position (D.ScopeName.get_info scope_name) in
|
||||
|
||||
let new_body = translate_scope_body scope_pos decl_ctx ctx scope_body in
|
||||
let tail = translate_scopes decl_ctx new_ctx next in
|
||||
let new_body = translate_scope_body scope_pos ctx scope_body in
|
||||
let tail = translate_scopes new_ctx next in
|
||||
|
||||
Bindlib.box_apply2
|
||||
(fun body tail ->
|
||||
@ -428,8 +442,8 @@ let rec translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) :
|
||||
:: tail)
|
||||
new_body tail
|
||||
|
||||
let translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) : Ast.scope_body list =
|
||||
Bindlib.unbox (translate_scopes decl_ctx ctx scopes)
|
||||
let translate_scopes (ctx : ctx) (scopes : scopes) : Ast.scope_body list =
|
||||
Bindlib.unbox (translate_scopes ctx scopes)
|
||||
|
||||
let translate_program (prgm : D.program) : A.program =
|
||||
(* modify the *)
|
||||
@ -438,32 +452,33 @@ let translate_program (prgm : D.program) : A.program =
|
||||
body.D.scope_body_input_struct :: acc)
|
||||
in
|
||||
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "List of structs to modify: [%a]"
|
||||
(Format.pp_print_list D.StructName.format_t)
|
||||
inputs_structs;
|
||||
|
||||
(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]" (Format.pp_print_list
|
||||
D.StructName.format_t) inputs_structs; *)
|
||||
let decl_ctx =
|
||||
{
|
||||
prgm.decl_ctx with
|
||||
D.ctx_enums = prgm.decl_ctx.ctx_enums |> D.EnumMap.add A.option_enum A.option_enum_config;
|
||||
}
|
||||
in
|
||||
let decl_ctx =
|
||||
{
|
||||
decl_ctx with
|
||||
D.ctx_structs =
|
||||
prgm.decl_ctx.ctx_structs
|
||||
|> D.StructMap.mapi (fun n l ->
|
||||
if List.mem n inputs_structs then
|
||||
ListLabels.map l ~f:(fun (n, tau) ->
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "Input type: %a" (Dcalc.Print.format_typ prgm.decl_ctx) tau;
|
||||
Cli.debug_print
|
||||
@@ Format.asprintf "Output type: %a"
|
||||
(Dcalc.Print.format_typ prgm.decl_ctx)
|
||||
(translate_typ tau);
|
||||
(* Cli.debug_print @@ Format.asprintf "Input type: %a" (Dcalc.Print.format_typ
|
||||
decl_ctx) tau; Cli.debug_print @@ Format.asprintf "Output type: %a"
|
||||
(Dcalc.Print.format_typ decl_ctx) (translate_typ tau); *)
|
||||
(n, translate_typ tau))
|
||||
else l);
|
||||
}
|
||||
in
|
||||
|
||||
let scopes =
|
||||
prgm.scopes |> bind_scopes |> Bindlib.unbox |> translate_scopes decl_ctx D.VarMap.empty
|
||||
prgm.scopes |> bind_scopes |> Bindlib.unbox
|
||||
|> translate_scopes { decl_ctx; vars = D.VarMap.empty }
|
||||
in
|
||||
|
||||
{ scopes; decl_ctx }
|
||||
|
@ -111,7 +111,7 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Fo
|
||||
(fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n))
|
||||
format_expr e
|
||||
| EMatch (e, es, e_name) ->
|
||||
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword "match" format_expr e
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" format_keyword "match" format_expr e
|
||||
format_keyword "with"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
|
Loading…
Reference in New Issue
Block a user