mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Fix handing of context variables in modules
The call convention imposes a translation of their types within the scope input structs definitions in dcalc.
This commit is contained in:
parent
f8b6e60e16
commit
af8ff472a5
@ -922,13 +922,23 @@ let translate_rules
|
||||
(Expr.Box.lift return_exp)),
|
||||
new_ctx )
|
||||
|
||||
let input_var_typ typ io_in =
|
||||
match io_in.Desugared.Ast.io_input with
|
||||
| Runtime.OnlyInput, pos -> typ, pos
|
||||
| Runtime.Reentrant, pos -> (
|
||||
match typ with
|
||||
| TArrow _ -> typ, pos
|
||||
| _ ->
|
||||
( TArrow ([TLit TUnit, pos], (typ, pos)),
|
||||
pos ))
|
||||
| Runtime.NoInput, _ -> invalid_arg "input_var_typ"
|
||||
|
||||
(* From a scope declaration and definitions, create the corresponding scope body
|
||||
wrapped in the appropriate call convention. *)
|
||||
let translate_scope_decl
|
||||
(ctx : 'm ctx)
|
||||
(scope_name : ScopeName.t)
|
||||
(sigma : 'm Scopelang.Ast.scope_decl) :
|
||||
'm Ast.expr scope_body Bindlib.box * struct_ctx =
|
||||
(sigma : 'm Scopelang.Ast.scope_decl) =
|
||||
let sigma_info = ScopeName.get_info sigma.scope_decl_name in
|
||||
let scope_sig =
|
||||
ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters.scope_sigs
|
||||
@ -1007,17 +1017,6 @@ let translate_scope_decl
|
||||
| _ -> true)
|
||||
scope_variables
|
||||
in
|
||||
let input_var_typ (var_ctx : scope_var_ctx) =
|
||||
match Mark.remove var_ctx.scope_var_io.io_input with
|
||||
| OnlyInput -> var_ctx.scope_var_typ, pos_sigma
|
||||
| Reentrant -> (
|
||||
match var_ctx.scope_var_typ with
|
||||
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
|
||||
| _ ->
|
||||
( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
|
||||
pos_sigma ))
|
||||
| NoInput -> failwith "should not happen"
|
||||
in
|
||||
let input_destructurings next =
|
||||
List.fold_right
|
||||
(fun (var_ctx, v) next ->
|
||||
@ -1033,7 +1032,7 @@ let translate_scope_decl
|
||||
scope_let_kind = DestructuringInputStruct;
|
||||
scope_let_next = next;
|
||||
scope_let_pos = pos_sigma;
|
||||
scope_let_typ = input_var_typ var_ctx;
|
||||
scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
|
||||
scope_let_expr =
|
||||
( EStructAccess
|
||||
{ name = scope_input_struct_name; e = r; field },
|
||||
@ -1044,31 +1043,15 @@ let translate_scope_decl
|
||||
(Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma))))
|
||||
scope_input_variables next
|
||||
in
|
||||
let scope_body =
|
||||
Bindlib.box_apply
|
||||
(fun scope_body_expr ->
|
||||
{
|
||||
scope_body_expr;
|
||||
scope_body_input_struct = scope_input_struct_name;
|
||||
scope_body_output_struct = scope_return_struct_name;
|
||||
})
|
||||
(Bindlib.bind_var scope_input_var
|
||||
(input_destructurings rules_with_return_expr))
|
||||
in
|
||||
let field_map =
|
||||
List.fold_left
|
||||
(fun acc (var_ctx, _) ->
|
||||
let var = var_ctx.scope_var_name in
|
||||
let field =
|
||||
(ScopeVar.Map.find var scope_sig.scope_sig_in_fields).scope_input_name
|
||||
in
|
||||
StructField.Map.add field (input_var_typ var_ctx) acc)
|
||||
StructField.Map.empty scope_input_variables
|
||||
in
|
||||
let new_struct_ctx =
|
||||
StructName.Map.singleton scope_input_struct_name field_map
|
||||
in
|
||||
scope_body, new_struct_ctx
|
||||
Bindlib.box_apply
|
||||
(fun scope_body_expr ->
|
||||
{
|
||||
scope_body_expr;
|
||||
scope_body_input_struct = scope_input_struct_name;
|
||||
scope_body_output_struct = scope_return_struct_name;
|
||||
})
|
||||
(Bindlib.bind_var scope_input_var
|
||||
(input_destructurings rules_with_return_expr))
|
||||
|
||||
let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
let defs_dependencies = Scopelang.Dependency.build_program_dep_graph prgm in
|
||||
@ -1114,7 +1097,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
{
|
||||
scope_input_name = StructField.fresh (s, Mark.get info);
|
||||
scope_input_io = vis.Desugared.Ast.io_input;
|
||||
scope_input_typ = Mark.remove typ;
|
||||
scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis);
|
||||
})
|
||||
scope.Scopelang.Ast.scope_sig
|
||||
in
|
||||
@ -1155,35 +1138,40 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
ModuleName.Map.map process_modules prgm.Scopelang.Ast.program_modules;
|
||||
}
|
||||
in
|
||||
let add_scope_in_structs scope_sigs structs =
|
||||
ScopeName.Map.fold
|
||||
(fun _ scope_sig_ctx acc ->
|
||||
let fields =
|
||||
ScopeVar.Map.fold
|
||||
(fun _ sivc acc ->
|
||||
let pos =
|
||||
Mark.get (StructField.get_info sivc.scope_input_name)
|
||||
in
|
||||
StructField.Map.add sivc.scope_input_name
|
||||
(sivc.scope_input_typ, pos)
|
||||
acc)
|
||||
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
|
||||
in
|
||||
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
|
||||
scope_sigs.scope_sigs structs
|
||||
in
|
||||
let rec gather_module_in_structs acc sctx =
|
||||
(* Expose all added in_structs from submodules at toplevel *)
|
||||
ModuleName.Map.fold
|
||||
(fun _ scope_sigs acc ->
|
||||
let acc = gather_module_in_structs acc scope_sigs.scope_sigs_modules in
|
||||
ScopeName.Map.fold
|
||||
(fun _ scope_sig_ctx acc ->
|
||||
let fields =
|
||||
ScopeVar.Map.fold
|
||||
(fun _ sivc acc ->
|
||||
let pos =
|
||||
Mark.get (StructField.get_info sivc.scope_input_name)
|
||||
in
|
||||
StructField.Map.add sivc.scope_input_name
|
||||
(sivc.scope_input_typ, pos)
|
||||
acc)
|
||||
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
|
||||
in
|
||||
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
|
||||
scope_sigs.scope_sigs acc)
|
||||
add_scope_in_structs scope_sigs
|
||||
(gather_module_in_structs acc scope_sigs.scope_sigs_modules))
|
||||
sctx acc
|
||||
in
|
||||
let decl_ctx =
|
||||
{
|
||||
decl_ctx with
|
||||
ctx_structs =
|
||||
gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules;
|
||||
add_scope_in_structs sctx
|
||||
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
|
||||
}
|
||||
in
|
||||
Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs;
|
||||
let top_ctx =
|
||||
let toplevel_vars =
|
||||
TopdefName.Map.mapi
|
||||
@ -1205,21 +1193,20 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
ending with the top-level scope. The decl_ctx is filled in left-to-right
|
||||
order, then the chained scopes aggregated from the right. *)
|
||||
let rec translate_defs ctx = function
|
||||
| [] -> Bindlib.box Nil, ctx
|
||||
| [] -> Bindlib.box Nil
|
||||
| def :: next ->
|
||||
let ctx, dvar, def =
|
||||
let dvar, def =
|
||||
match def with
|
||||
| Scopelang.Dependency.Topdef gname ->
|
||||
let expr, ty = TopdefName.Map.find gname prgm.program_topdefs in
|
||||
let expr = translate_expr ctx expr in
|
||||
( ctx,
|
||||
fst (TopdefName.Map.find gname ctx.toplevel_vars),
|
||||
( fst (TopdefName.Map.find gname ctx.toplevel_vars),
|
||||
Bindlib.box_apply
|
||||
(fun e -> Topdef (gname, ty, e))
|
||||
(Expr.Box.lift expr) )
|
||||
| Scopelang.Dependency.Scope scope_name ->
|
||||
let scope = ScopeName.Map.find scope_name prgm.program_scopes in
|
||||
let scope_body, scope_in_struct =
|
||||
let scope_body =
|
||||
translate_scope_decl ctx scope_name (Mark.remove scope)
|
||||
in
|
||||
let scope_var =
|
||||
@ -1230,33 +1217,21 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
| Local_scope_ref v -> v
|
||||
| External_scope_ref _ -> assert false
|
||||
in
|
||||
( {
|
||||
ctx with
|
||||
decl_ctx =
|
||||
{
|
||||
ctx.decl_ctx with
|
||||
ctx_structs =
|
||||
StructName.Map.union
|
||||
(fun _ _ -> assert false)
|
||||
ctx.decl_ctx.ctx_structs scope_in_struct;
|
||||
};
|
||||
},
|
||||
scope_var,
|
||||
( scope_var,
|
||||
Bindlib.box_apply
|
||||
(fun body -> ScopeDef (scope_name, body))
|
||||
scope_body )
|
||||
in
|
||||
let scope_next, ctx = translate_defs ctx next in
|
||||
let scope_next = translate_defs ctx next in
|
||||
let next_bind = Bindlib.bind_var dvar scope_next in
|
||||
( Bindlib.box_apply2
|
||||
(fun item next_bind -> Cons (item, next_bind))
|
||||
def next_bind,
|
||||
ctx )
|
||||
Bindlib.box_apply2
|
||||
(fun item next_bind -> Cons (item, next_bind))
|
||||
def next_bind
|
||||
in
|
||||
let items, ctx = translate_defs top_ctx defs_ordering in
|
||||
let items = translate_defs top_ctx defs_ordering in
|
||||
{
|
||||
code_items = Bindlib.unbox items;
|
||||
decl_ctx = ctx.decl_ctx;
|
||||
decl_ctx;
|
||||
module_name = prgm.Scopelang.Ast.program_module_name;
|
||||
lang = prgm.program_lang;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user