Adapt translation to new i/o invariants, bug discovered

This commit is contained in:
Denis Merigoux 2022-02-15 11:38:56 +01:00
parent cab4e5c17e
commit 48f064ccea
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
3 changed files with 87 additions and 81 deletions

View File

@ -81,18 +81,13 @@ let bind_scope_lets (acc : scope_lets Bindlib.box) (scope_let : D.scope_let) :
scope_lets Bindlib.box =
let pos = snd scope_let.D.scope_let_var in
Cli.debug_print
@@ Format.asprintf "binding let %a. Variable occurs = %b" Print.format_var
(fst scope_let.D.scope_let_var)
(Bindlib.occur (fst scope_let.D.scope_let_var) acc);
(* Cli.debug_print @@ Format.asprintf "binding let %a. Variable occurs = %b" Print.format_var (fst
scope_let.D.scope_let_var) (Bindlib.occur (fst scope_let.D.scope_let_var) acc); *)
let binder = Bindlib.bind_var (fst scope_let.D.scope_let_var) acc in
Bindlib.box_apply2
(fun expr binder ->
Cli.debug_print
@@ Format.asprintf "free variables in expression: %a"
(Format.pp_print_list Print.format_var)
(D.free_vars expr);
(* Cli.debug_print @@ Format.asprintf "free variables in expression: %a" (Format.pp_print_list
Print.format_var) (D.free_vars expr); *)
ScopeLet
{
scope_let_kind = scope_let.D.scope_let_kind;
@ -111,15 +106,15 @@ let bind_scope_body (body : D.scope_body) : scope_body Bindlib.box =
~f:(Fun.flip bind_scope_lets)
in
Cli.debug_print @@ Format.asprintf "binding arg %a" Print.format_var body.D.scope_body_arg;
(* Cli.debug_print @@ Format.asprintf "binding arg %a" Print.format_var body.D.scope_body_arg; *)
let scope_body_result = Bindlib.bind_var body.D.scope_body_arg body_result in
Cli.debug_print
@@ Format.asprintf "isfinal term is closed: %b" (Bindlib.is_closed scope_body_result);
(* Cli.debug_print @@ Format.asprintf "isfinal term is closed: %b" (Bindlib.is_closed
scope_body_result); *)
Bindlib.box_apply
(fun scope_body_result ->
Cli.debug_print
@@ Format.asprintf "rank of the final term: %i" (Bindlib.binder_rank scope_body_result);
(* Cli.debug_print @@ Format.asprintf "rank of the final term: %i" (Bindlib.binder_rank
scope_body_result); *)
{
scope_body_output_struct = body.D.scope_body_output_struct;
scope_body_input_struct = body.D.scope_body_input_struct;
@ -137,10 +132,6 @@ let bind_scope
let bind_scopes (scopes : (D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list) :
scopes Bindlib.box =
let result = ListLabels.fold_right scopes ~init:(Bindlib.box Nil) ~f:bind_scope in
Cli.debug_print
@@ Format.asprintf "free variable in the program : [%a]"
(Format.pp_print_list Print.format_var)
(free_vars_scopes (Bindlib.unbox result));
(* Cli.debug_print @@ Format.asprintf "free variable in the program : [%a]" (Format.pp_print_list
Print.format_var) (free_vars_scopes (Bindlib.unbox result)); *)
result

View File

@ -49,10 +49,12 @@ type info = { expr : A.expr Pos.marked Bindlib.box; var : A.expr Bindlib.var; is
let pp_info (fmt : Format.formatter) (info : info) =
Format.fprintf fmt "{var: %a; is_pure: %b}" Print.format_var info.var info.is_pure
type ctx = info D.VarMap.t
(** information context about variables in the current scope *)
type ctx = {
decl_ctx : D.decl_ctx;
vars : info D.VarMap.t; (** information context about variables in the current scope *)
}
let pp_ctx (fmt : Format.formatter) (ctx : ctx) =
let _pp_ctx (fmt : Format.formatter) (ctx : ctx) =
let pp_binding (fmt : Format.formatter) ((v, info) : D.Var.t * info) =
Format.fprintf fmt "%a: %a" Dcalc.Print.format_var v pp_info info
in
@ -61,16 +63,13 @@ let pp_ctx (fmt : Format.formatter) (ctx : ctx) =
Format.pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt "; ") pp_binding
in
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx)
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx.vars)
(** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a slightly better way. *)
let find ?(info : string = "none") (n : D.Var.t) (ctx : ctx) : info =
let _ =
Format.asprintf "Searching for variable %a inside context %a" Dcalc.Print.format_var n pp_ctx
ctx
|> Cli.debug_print
in
try D.VarMap.find n ctx
(* let _ = Format.asprintf "Searching for variable %a inside context %a" Dcalc.Print.format_var n
pp_ctx ctx |> Cli.debug_print in *)
try D.VarMap.find n ctx.vars
with Not_found ->
Errors.raise_spanned_error
(Format.asprintf
@ -86,10 +85,9 @@ let add_var (pos : Pos.t) (var : D.Var.t) (is_pure : bool) (ctx : ctx) : ctx =
let new_var = A.Var.make (Bindlib.name_of var, pos) in
let expr = A.make_var (new_var, pos) in
Cli.debug_print
@@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var var Print.format_var new_var;
D.VarMap.update var (fun _ -> Some { expr; var = new_var; is_pure }) ctx
(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var var Print.format_var
new_var; *)
{ ctx with vars = D.VarMap.update var (fun _ -> Some { expr; var = new_var; is_pure }) ctx.vars }
(** [tau' = translate_typ tau] translate the a dcalc type into a lcalc type.
@ -150,22 +148,18 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
thunked, hence matched in the next case. This assumption can change in the future, and this
case is here for this reason. *)
let v, pos_v = v in
if not (find ~info:"search for a variable" v ctx).is_pure then begin
if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = A.Var.make (Bindlib.name_of v, pos_v) in
Cli.debug_print
@@ Format.asprintf "Found an unpure variable %a, created a variable %a to replace it"
Dcalc.Print.format_var v Print.format_var v';
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to
replace it" Dcalc.Print.format_var v Print.format_var v'; *)
(A.make_var (v', pos), A.VarMap.singleton v' e)
end
else ((find ~info:"should never happend" v ctx).expr, A.VarMap.empty)
| D.EApp ((D.EVar (v, pos_v), p), [ (D.ELit D.LUnit, _) ]) ->
if not (find ~info:"search for a variable" v ctx).is_pure then begin
if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = A.Var.make (Bindlib.name_of v, pos_v) in
Cli.debug_print
@@ Format.asprintf "Found an unpure variable %a, created a variable %a to replace it"
Dcalc.Print.format_var v Print.format_var v';
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to
replace it" Dcalc.Print.format_var v Print.format_var v'; *)
(A.make_var (v', pos), A.VarMap.singleton v' (D.EVar (v, pos_v), p))
end
else
Errors.raise_spanned_error
"Internal error: an pure variable was found in an unpure environment." pos
@ -282,16 +276,12 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) :
let _pos = Pos.get_position e in
(* build the hoists *)
Cli.debug_print
@@ Format.asprintf "hoist for the expression: [%a]"
(Format.pp_print_list Print.format_var)
(List.map fst hoists);
(* Cli.debug_print @@ Format.asprintf "hoist for the expression: [%a]" (Format.pp_print_list
Print.format_var) (List.map fst hoists); *)
ListLabels.fold_left hoists
~init:(if append_esome then A.make_some e' else e')
~f:(fun acc (v, (hoist, pos_hoist)) ->
Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var v;
(* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var v; *)
let c' : A.expr Pos.marked Bindlib.box =
match hoist with
(* Here we have to handle only the cases appearing in hoists, as defined the
@ -335,7 +325,7 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) :
in
(* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end ] *)
Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.format_var v;
(* Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.format_var v; *)
A.make_matchopt pos_hoist v (D.TAny, pos_hoist) c' (A.make_none pos_hoist) acc)
let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
@ -349,21 +339,42 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
scope_let_next = next;
scope_let_pos = pos;
} ->
(* special case : the subscope variable is always thunked. We remove this thunking. *)
(* special case : the subscope variable is thunked (context i/o). We remove this thunking. *)
let _, expr = Bindlib.unmbind binder in
let var_is_pure = true in
let var, next = Bindlib.unbind next in
Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var;
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)
let ctx' = add_var pos var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in
A.make_let_in new_var (translate_typ typ)
(translate_expr ctx ~append_esome:true expr)
(translate_expr ctx ~append_esome:false expr)
(translate_scope_let ctx' next)
| ScopeLet { scope_let_kind = SubScopeVarDefinition; scope_let_pos = pos; _ } ->
| ScopeLet
{
scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ;
scope_let_expr = (D.ErrorOnEmpty _, _) as expr;
scope_let_next = next;
scope_let_pos = pos;
} ->
(* special case: regular input to the subscope *)
let var_is_pure = true in
let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)
let ctx' = add_var pos var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in
A.make_let_in new_var (translate_typ typ)
(translate_expr ctx ~append_esome:false expr)
(translate_scope_let ctx' next)
| ScopeLet
{ scope_let_kind = SubScopeVarDefinition; scope_let_pos = pos; scope_let_expr = expr; _ } ->
Errors.raise_spanned_error
"Internal Error: found an SubScopeVarDefinition that does not satisfy the thunked \
invariant when translating Dcalc to Lcalc without exceptions."
(Format.asprintf
"Internal Error: found an SubScopeVarDefinition that does not satisfy the invariants \
when translating Dcalc to Lcalc without exceptions: @[<hov 2>%a@]"
(Dcalc.Print.format_expr ctx.decl_ctx)
expr)
pos
| ScopeLet
{
@ -375,21 +386,25 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
} ->
let var_is_pure =
match kind with
| DestructuringInputStruct -> false
| DestructuringInputStruct -> (
(* Here, we have to distinguish between context and input variables. We can do so by
looking at the typ of the destructuring: if it's thunked, then the variable is
context. If it's not thunked, it's a regular input. *)
match Pos.unmark typ with D.TArrow ((D.TLit D.TUnit, _), _) -> false | _ -> true)
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
| DestructuringSubScopeResults | Assertion ->
true
in
let var, next = Bindlib.unbind next in
Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var;
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)
let ctx' = add_var pos var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in
A.make_let_in new_var (translate_typ typ)
(translate_expr ctx ~append_esome:false expr)
(translate_scope_let ctx' next)
let translate_scope_body (scope_pos : Pos.t) (_decl_ctx : D.decl_ctx) (ctx : ctx)
(body : scope_body) : A.expr Pos.marked Bindlib.box =
let translate_scope_body (scope_pos : Pos.t) (ctx : ctx) (body : scope_body) :
A.expr Pos.marked Bindlib.box =
match body with
| {
scope_body_result = result;
@ -404,8 +419,7 @@ let translate_scope_body (scope_pos : Pos.t) (_decl_ctx : D.decl_ctx) (ctx : ctx
[ (D.TTuple ([], Some input_struct), Pos.no_pos) ]
Pos.no_pos
let rec translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) :
Ast.scope_body list Bindlib.box =
let rec translate_scopes (ctx : ctx) (scopes : scopes) : Ast.scope_body list Bindlib.box =
match scopes with
| Nil -> Bindlib.box []
| ScopeDef { scope_name; scope_body; scope_next } ->
@ -415,8 +429,8 @@ let rec translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) :
let scope_pos = Pos.get_position (D.ScopeName.get_info scope_name) in
let new_body = translate_scope_body scope_pos decl_ctx ctx scope_body in
let tail = translate_scopes decl_ctx new_ctx next in
let new_body = translate_scope_body scope_pos ctx scope_body in
let tail = translate_scopes new_ctx next in
Bindlib.box_apply2
(fun body tail ->
@ -428,8 +442,8 @@ let rec translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) :
:: tail)
new_body tail
let translate_scopes (decl_ctx : D.decl_ctx) (ctx : ctx) (scopes : scopes) : Ast.scope_body list =
Bindlib.unbox (translate_scopes decl_ctx ctx scopes)
let translate_scopes (ctx : ctx) (scopes : scopes) : Ast.scope_body list =
Bindlib.unbox (translate_scopes ctx scopes)
let translate_program (prgm : D.program) : A.program =
(* modify the *)
@ -438,32 +452,33 @@ let translate_program (prgm : D.program) : A.program =
body.D.scope_body_input_struct :: acc)
in
Cli.debug_print
@@ Format.asprintf "List of structs to modify: [%a]"
(Format.pp_print_list D.StructName.format_t)
inputs_structs;
(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]" (Format.pp_print_list
D.StructName.format_t) inputs_structs; *)
let decl_ctx =
{
prgm.decl_ctx with
D.ctx_enums = prgm.decl_ctx.ctx_enums |> D.EnumMap.add A.option_enum A.option_enum_config;
}
in
let decl_ctx =
{
decl_ctx with
D.ctx_structs =
prgm.decl_ctx.ctx_structs
|> D.StructMap.mapi (fun n l ->
if List.mem n inputs_structs then
ListLabels.map l ~f:(fun (n, tau) ->
Cli.debug_print
@@ Format.asprintf "Input type: %a" (Dcalc.Print.format_typ prgm.decl_ctx) tau;
Cli.debug_print
@@ Format.asprintf "Output type: %a"
(Dcalc.Print.format_typ prgm.decl_ctx)
(translate_typ tau);
(* Cli.debug_print @@ Format.asprintf "Input type: %a" (Dcalc.Print.format_typ
decl_ctx) tau; Cli.debug_print @@ Format.asprintf "Output type: %a"
(Dcalc.Print.format_typ decl_ctx) (translate_typ tau); *)
(n, translate_typ tau))
else l);
}
in
let scopes =
prgm.scopes |> bind_scopes |> Bindlib.unbox |> translate_scopes decl_ctx D.VarMap.empty
prgm.scopes |> bind_scopes |> Bindlib.unbox
|> translate_scopes { decl_ctx; vars = D.VarMap.empty }
in
{ scopes; decl_ctx }

View File

@ -111,7 +111,7 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Fo
(fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n))
format_expr e
| EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword "match" format_expr e
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" format_keyword "match" format_expr e
format_keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")