diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 2a3074a9..60aec150 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -922,13 +922,23 @@ let translate_rules (Expr.Box.lift return_exp)), new_ctx ) +let input_var_typ typ io_in = + match io_in.Desugared.Ast.io_input with + | Runtime.OnlyInput, pos -> typ, pos + | Runtime.Reentrant, pos -> ( + match typ with + | TArrow _ -> typ, pos + | _ -> + ( TArrow ([TLit TUnit, pos], (typ, pos)), + pos )) + | Runtime.NoInput, _ -> invalid_arg "input_var_typ" + (* From a scope declaration and definitions, create the corresponding scope body wrapped in the appropriate call convention. *) let translate_scope_decl (ctx : 'm ctx) (scope_name : ScopeName.t) - (sigma : 'm Scopelang.Ast.scope_decl) : - 'm Ast.expr scope_body Bindlib.box * struct_ctx = + (sigma : 'm Scopelang.Ast.scope_decl) = let sigma_info = ScopeName.get_info sigma.scope_decl_name in let scope_sig = ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters.scope_sigs @@ -1007,17 +1017,6 @@ let translate_scope_decl | _ -> true) scope_variables in - let input_var_typ (var_ctx : scope_var_ctx) = - match Mark.remove var_ctx.scope_var_io.io_input with - | OnlyInput -> var_ctx.scope_var_typ, pos_sigma - | Reentrant -> ( - match var_ctx.scope_var_typ with - | TArrow _ -> var_ctx.scope_var_typ, pos_sigma - | _ -> - ( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)), - pos_sigma )) - | NoInput -> failwith "should not happen" - in let input_destructurings next = List.fold_right (fun (var_ctx, v) next -> @@ -1033,7 +1032,7 @@ let translate_scope_decl scope_let_kind = DestructuringInputStruct; scope_let_next = next; scope_let_pos = pos_sigma; - scope_let_typ = input_var_typ var_ctx; + scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io; scope_let_expr = ( EStructAccess { name = scope_input_struct_name; e = r; field }, @@ -1044,31 +1043,15 @@ let translate_scope_decl (Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma)))) scope_input_variables next in - let scope_body = - Bindlib.box_apply - (fun scope_body_expr -> - { - scope_body_expr; - scope_body_input_struct = scope_input_struct_name; - scope_body_output_struct = scope_return_struct_name; - }) - (Bindlib.bind_var scope_input_var - (input_destructurings rules_with_return_expr)) - in - let field_map = - List.fold_left - (fun acc (var_ctx, _) -> - let var = var_ctx.scope_var_name in - let field = - (ScopeVar.Map.find var scope_sig.scope_sig_in_fields).scope_input_name - in - StructField.Map.add field (input_var_typ var_ctx) acc) - StructField.Map.empty scope_input_variables - in - let new_struct_ctx = - StructName.Map.singleton scope_input_struct_name field_map - in - scope_body, new_struct_ctx + Bindlib.box_apply + (fun scope_body_expr -> + { + scope_body_expr; + scope_body_input_struct = scope_input_struct_name; + scope_body_output_struct = scope_return_struct_name; + }) + (Bindlib.bind_var scope_input_var + (input_destructurings rules_with_return_expr)) let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = let defs_dependencies = Scopelang.Dependency.build_program_dep_graph prgm in @@ -1114,7 +1097,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = { scope_input_name = StructField.fresh (s, Mark.get info); scope_input_io = vis.Desugared.Ast.io_input; - scope_input_typ = Mark.remove typ; + scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis); }) scope.Scopelang.Ast.scope_sig in @@ -1155,35 +1138,40 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = ModuleName.Map.map process_modules prgm.Scopelang.Ast.program_modules; } in + let add_scope_in_structs scope_sigs structs = + ScopeName.Map.fold + (fun _ scope_sig_ctx acc -> + let fields = + ScopeVar.Map.fold + (fun _ sivc acc -> + let pos = + Mark.get (StructField.get_info sivc.scope_input_name) + in + StructField.Map.add sivc.scope_input_name + (sivc.scope_input_typ, pos) + acc) + scope_sig_ctx.scope_sig_in_fields StructField.Map.empty + in + StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) + scope_sigs.scope_sigs structs + in let rec gather_module_in_structs acc sctx = (* Expose all added in_structs from submodules at toplevel *) ModuleName.Map.fold (fun _ scope_sigs acc -> - let acc = gather_module_in_structs acc scope_sigs.scope_sigs_modules in - ScopeName.Map.fold - (fun _ scope_sig_ctx acc -> - let fields = - ScopeVar.Map.fold - (fun _ sivc acc -> - let pos = - Mark.get (StructField.get_info sivc.scope_input_name) - in - StructField.Map.add sivc.scope_input_name - (sivc.scope_input_typ, pos) - acc) - scope_sig_ctx.scope_sig_in_fields StructField.Map.empty - in - StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) - scope_sigs.scope_sigs acc) + add_scope_in_structs scope_sigs + (gather_module_in_structs acc scope_sigs.scope_sigs_modules)) sctx acc in let decl_ctx = { decl_ctx with ctx_structs = - gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules; + add_scope_in_structs sctx + (gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules); } in + Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs; let top_ctx = let toplevel_vars = TopdefName.Map.mapi @@ -1205,21 +1193,20 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = ending with the top-level scope. The decl_ctx is filled in left-to-right order, then the chained scopes aggregated from the right. *) let rec translate_defs ctx = function - | [] -> Bindlib.box Nil, ctx + | [] -> Bindlib.box Nil | def :: next -> - let ctx, dvar, def = + let dvar, def = match def with | Scopelang.Dependency.Topdef gname -> let expr, ty = TopdefName.Map.find gname prgm.program_topdefs in let expr = translate_expr ctx expr in - ( ctx, - fst (TopdefName.Map.find gname ctx.toplevel_vars), + ( fst (TopdefName.Map.find gname ctx.toplevel_vars), Bindlib.box_apply (fun e -> Topdef (gname, ty, e)) (Expr.Box.lift expr) ) | Scopelang.Dependency.Scope scope_name -> let scope = ScopeName.Map.find scope_name prgm.program_scopes in - let scope_body, scope_in_struct = + let scope_body = translate_scope_decl ctx scope_name (Mark.remove scope) in let scope_var = @@ -1230,33 +1217,21 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = | Local_scope_ref v -> v | External_scope_ref _ -> assert false in - ( { - ctx with - decl_ctx = - { - ctx.decl_ctx with - ctx_structs = - StructName.Map.union - (fun _ _ -> assert false) - ctx.decl_ctx.ctx_structs scope_in_struct; - }; - }, - scope_var, + ( scope_var, Bindlib.box_apply (fun body -> ScopeDef (scope_name, body)) scope_body ) in - let scope_next, ctx = translate_defs ctx next in + let scope_next = translate_defs ctx next in let next_bind = Bindlib.bind_var dvar scope_next in - ( Bindlib.box_apply2 - (fun item next_bind -> Cons (item, next_bind)) - def next_bind, - ctx ) + Bindlib.box_apply2 + (fun item next_bind -> Cons (item, next_bind)) + def next_bind in - let items, ctx = translate_defs top_ctx defs_ordering in + let items = translate_defs top_ctx defs_ordering in { code_items = Bindlib.unbox items; - decl_ctx = ctx.decl_ctx; + decl_ctx; module_name = prgm.Scopelang.Ast.program_module_name; lang = prgm.program_lang; }