Fix handing of context variables in modules

The call convention imposes a translation of their types within the scope input
structs definitions in dcalc.
This commit is contained in:
Louis Gesbert 2023-10-09 12:48:43 +02:00
parent f8b6e60e16
commit af8ff472a5

View File

@ -922,13 +922,23 @@ let translate_rules
(Expr.Box.lift return_exp)),
new_ctx )
let input_var_typ typ io_in =
match io_in.Desugared.Ast.io_input with
| Runtime.OnlyInput, pos -> typ, pos
| Runtime.Reentrant, pos -> (
match typ with
| TArrow _ -> typ, pos
| _ ->
( TArrow ([TLit TUnit, pos], (typ, pos)),
pos ))
| Runtime.NoInput, _ -> invalid_arg "input_var_typ"
(* From a scope declaration and definitions, create the corresponding scope body
wrapped in the appropriate call convention. *)
let translate_scope_decl
(ctx : 'm ctx)
(scope_name : ScopeName.t)
(sigma : 'm Scopelang.Ast.scope_decl) :
'm Ast.expr scope_body Bindlib.box * struct_ctx =
(sigma : 'm Scopelang.Ast.scope_decl) =
let sigma_info = ScopeName.get_info sigma.scope_decl_name in
let scope_sig =
ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters.scope_sigs
@ -1007,17 +1017,6 @@ let translate_scope_decl
| _ -> true)
scope_variables
in
let input_var_typ (var_ctx : scope_var_ctx) =
match Mark.remove var_ctx.scope_var_io.io_input with
| OnlyInput -> var_ctx.scope_var_typ, pos_sigma
| Reentrant -> (
match var_ctx.scope_var_typ with
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
| _ ->
( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma ))
| NoInput -> failwith "should not happen"
in
let input_destructurings next =
List.fold_right
(fun (var_ctx, v) next ->
@ -1033,7 +1032,7 @@ let translate_scope_decl
scope_let_kind = DestructuringInputStruct;
scope_let_next = next;
scope_let_pos = pos_sigma;
scope_let_typ = input_var_typ var_ctx;
scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
scope_let_expr =
( EStructAccess
{ name = scope_input_struct_name; e = r; field },
@ -1044,7 +1043,6 @@ let translate_scope_decl
(Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma))))
scope_input_variables next
in
let scope_body =
Bindlib.box_apply
(fun scope_body_expr ->
{
@ -1054,21 +1052,6 @@ let translate_scope_decl
})
(Bindlib.bind_var scope_input_var
(input_destructurings rules_with_return_expr))
in
let field_map =
List.fold_left
(fun acc (var_ctx, _) ->
let var = var_ctx.scope_var_name in
let field =
(ScopeVar.Map.find var scope_sig.scope_sig_in_fields).scope_input_name
in
StructField.Map.add field (input_var_typ var_ctx) acc)
StructField.Map.empty scope_input_variables
in
let new_struct_ctx =
StructName.Map.singleton scope_input_struct_name field_map
in
scope_body, new_struct_ctx
let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
let defs_dependencies = Scopelang.Dependency.build_program_dep_graph prgm in
@ -1114,7 +1097,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
{
scope_input_name = StructField.fresh (s, Mark.get info);
scope_input_io = vis.Desugared.Ast.io_input;
scope_input_typ = Mark.remove typ;
scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis);
})
scope.Scopelang.Ast.scope_sig
in
@ -1155,11 +1138,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
ModuleName.Map.map process_modules prgm.Scopelang.Ast.program_modules;
}
in
let rec gather_module_in_structs acc sctx =
(* Expose all added in_structs from submodules at toplevel *)
ModuleName.Map.fold
(fun _ scope_sigs acc ->
let acc = gather_module_in_structs acc scope_sigs.scope_sigs_modules in
let add_scope_in_structs scope_sigs structs =
ScopeName.Map.fold
(fun _ scope_sig_ctx acc ->
let fields =
@ -1174,16 +1153,25 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
in
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
scope_sigs.scope_sigs acc)
scope_sigs.scope_sigs structs
in
let rec gather_module_in_structs acc sctx =
(* Expose all added in_structs from submodules at toplevel *)
ModuleName.Map.fold
(fun _ scope_sigs acc ->
add_scope_in_structs scope_sigs
(gather_module_in_structs acc scope_sigs.scope_sigs_modules))
sctx acc
in
let decl_ctx =
{
decl_ctx with
ctx_structs =
gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules;
add_scope_in_structs sctx
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
}
in
Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs;
let top_ctx =
let toplevel_vars =
TopdefName.Map.mapi
@ -1205,21 +1193,20 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
ending with the top-level scope. The decl_ctx is filled in left-to-right
order, then the chained scopes aggregated from the right. *)
let rec translate_defs ctx = function
| [] -> Bindlib.box Nil, ctx
| [] -> Bindlib.box Nil
| def :: next ->
let ctx, dvar, def =
let dvar, def =
match def with
| Scopelang.Dependency.Topdef gname ->
let expr, ty = TopdefName.Map.find gname prgm.program_topdefs in
let expr = translate_expr ctx expr in
( ctx,
fst (TopdefName.Map.find gname ctx.toplevel_vars),
( fst (TopdefName.Map.find gname ctx.toplevel_vars),
Bindlib.box_apply
(fun e -> Topdef (gname, ty, e))
(Expr.Box.lift expr) )
| Scopelang.Dependency.Scope scope_name ->
let scope = ScopeName.Map.find scope_name prgm.program_scopes in
let scope_body, scope_in_struct =
let scope_body =
translate_scope_decl ctx scope_name (Mark.remove scope)
in
let scope_var =
@ -1230,33 +1217,21 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
| Local_scope_ref v -> v
| External_scope_ref _ -> assert false
in
( {
ctx with
decl_ctx =
{
ctx.decl_ctx with
ctx_structs =
StructName.Map.union
(fun _ _ -> assert false)
ctx.decl_ctx.ctx_structs scope_in_struct;
};
},
scope_var,
( scope_var,
Bindlib.box_apply
(fun body -> ScopeDef (scope_name, body))
scope_body )
in
let scope_next, ctx = translate_defs ctx next in
let scope_next = translate_defs ctx next in
let next_bind = Bindlib.bind_var dvar scope_next in
( Bindlib.box_apply2
Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
def next_bind,
ctx )
def next_bind
in
let items, ctx = translate_defs top_ctx defs_ordering in
let items = translate_defs top_ctx defs_ordering in
{
code_items = Bindlib.unbox items;
decl_ctx = ctx.decl_ctx;
decl_ctx;
module_name = prgm.Scopelang.Ast.program_module_name;
lang = prgm.program_lang;
}