Refactoring finished

This commit is contained in:
Denis Merigoux 2022-04-26 16:06:36 +02:00
parent 6ea73a4291
commit 764edb6ef0
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
3 changed files with 137 additions and 78 deletions

View File

@ -217,8 +217,8 @@ let driver source_file (options : Cli.options) : int =
(Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
( scope_uid,
Option.get
(Dcalc.Ast.fold_scope_defs ~init:None
~f:(fun acc scope_def ->
(Dcalc.Ast.fold_left_scope_defs ~init:None
~f:(fun acc scope_def _ ->
if
Dcalc.Ast.ScopeName.compare scope_def.scope_name
scope_uid
@ -298,17 +298,8 @@ let driver source_file (options : Cli.options) : int =
let prgm =
if options.closure_conversion then (
Cli.debug_print "Performing closure conversion...";
let prgm, closures =
Lcalc.Closure_conversion.closure_conversion prgm
in
let prgm = Lcalc.Closure_conversion.closure_conversion prgm in
let prgm = Bindlib.unbox prgm in
List.iter
(fun closure ->
Cli.debug_format "Closure found:\n%a"
(Lcalc.Print.format_expr ~debug:options.debug
prgm.decl_ctx)
(Bindlib.unbox closure.Lcalc.Closure_conversion.expr))
closures;
prgm)
else prgm
in
@ -323,19 +314,27 @@ let driver source_file (options : Cli.options) : int =
if Option.is_some options.ex_scope then
Format.fprintf fmt "%a\n"
(Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
(let body =
List.find
(fun body -> body.Lcalc.Ast.scope_body_name = scope_uid)
prgm.scopes
in
body)
( scope_uid,
Option.get
(Dcalc.Ast.fold_left_scope_defs ~init:None
~f:(fun acc scope_def _ ->
if
Dcalc.Ast.ScopeName.compare scope_def.scope_name
scope_uid
= 0
then Some scope_def.scope_body
else acc)
prgm.scopes) )
else
Format.fprintf fmt "%a\n"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(fun fmt scope ->
(Lcalc.Print.format_scope prgm.decl_ctx) fmt scope))
prgm.scopes;
ignore
(Dcalc.Ast.fold_left_scope_defs ~init:0
~f:(fun i scope_def _ ->
Format.fprintf fmt "%s%a"
(if i = 0 then "" else "\n")
(Lcalc.Print.format_scope prgm.decl_ctx)
(scope_uid, scope_def.scope_body);
i + 1)
prgm.scopes);
at_end ();
exit 0
end;

View File

@ -512,6 +512,40 @@ let format_ctx
Format.fprintf fmt "%a@\n@\n" format_enum_decl (e, find_enum e ctx))
(type_ordering @ scope_structs)
let rec format_scope_body_expr
(ctx : Dcalc.Ast.decl_ctx)
(fmt : Format.formatter)
(scope_lets : Ast.expr Dcalc.Ast.scope_body_expr) : unit =
match scope_lets with
| Dcalc.Ast.Result e -> format_expr ctx fmt e
| Dcalc.Ast.ScopeLet scope_let ->
let scope_let_var, scope_let_next =
Bindlib.unbind scope_let.scope_let_next
in
Format.fprintf fmt "@[<hov 2>let %a: %a = %a in@]@\n%a" format_var
scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx)
scope_let.scope_let_expr
(format_scope_body_expr ctx)
scope_let_next
let rec format_scopes
(ctx : Dcalc.Ast.decl_ctx)
(fmt : Format.formatter)
(scopes : Ast.expr Dcalc.Ast.scopes) : unit =
match scopes with
| Dcalc.Ast.Nil -> ()
| Dcalc.Ast.ScopeDef scope_def ->
let scope_input_var, scope_body_expr =
Bindlib.unbind scope_def.scope_body.scope_body_expr
in
let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in
Format.fprintf fmt "@\n@\n@[<hov 2>let %a (%a: %a) : %a =@\n%a@]%a"
format_var scope_var format_var scope_input_var format_struct_name
scope_def.scope_body.scope_body_input_struct format_struct_name
scope_def.scope_body.scope_body_output_struct
(format_scope_body_expr ctx)
scope_body_expr (format_scopes ctx) scope_next
let format_program
(fmt : Format.formatter)
(p : Ast.program)
@ -524,13 +558,5 @@ let format_program
@\n\
[@@@@@@ocaml.warning \"-4-26-27-32-41-42\"]@\n\
@\n\
%a@\n\
@\n\
%a@?"
(format_ctx type_ordering) p.decl_ctx
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
(fun fmt body ->
Format.fprintf fmt "@[<hov 2>let@ %a@ =@ %a@]" format_var
body.scope_body_var (format_expr p.decl_ctx) body.scope_body_expr))
p.scopes
%a%a@?"
(format_ctx type_ordering) p.decl_ctx (format_scopes p.decl_ctx) p.scopes

View File

@ -281,30 +281,16 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) :
Pos.get_position block_expr );
])
let translate_scope
let rec translate_scope_body_expr
(scope_name : D.ScopeName.t)
(decl_ctx : D.decl_ctx)
(var_dict : A.LocalName.t L.VarMap.t)
(func_dict : A.TopLevelName.t L.VarMap.t)
(scope_expr : L.expr Pos.marked) :
(A.LocalName.t Pos.marked * D.typ Pos.marked) list * A.block =
match Pos.unmark scope_expr with
| L.EAbs ((binder, binder_pos), typs) ->
let vars, body = Bindlib.unmbind binder in
let var_dict =
Array.fold_left
(fun var_dict var ->
L.VarMap.add var
(A.LocalName.fresh (Bindlib.name_of var, binder_pos))
var_dict)
L.VarMap.empty vars
in
let param_list =
List.map2
(fun var typ -> ((L.VarMap.find var var_dict, binder_pos), typ))
(Array.to_list vars) typs
in
let new_body =
translate_statements
(scope_expr : L.expr D.scope_body_expr) : A.block =
match scope_expr with
| Result e ->
let block, new_e =
translate_expr
{
decl_ctx;
func_dict;
@ -312,48 +298,96 @@ let translate_scope
inside_definition_of = None;
context_name = Pos.unmark (D.ScopeName.get_info scope_name);
}
body
e
in
(param_list, new_body)
| _ -> assert false
(* should not happen *)
block @ [ (A.SReturn (Pos.unmark new_e), Pos.get_position new_e) ]
| ScopeLet scope_let ->
let let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in
let let_var_id =
A.LocalName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos)
in
let let_expr_stmts, new_let_expr =
translate_expr
{
decl_ctx;
func_dict;
var_dict;
inside_definition_of = Some let_var_id;
context_name = Pos.unmark (D.ScopeName.get_info scope_name);
}
scope_let.scope_let_expr
in
let new_var_dict = L.VarMap.add let_var let_var_id var_dict in
let_expr_stmts
@ [
( A.SLocalDecl
((let_var_id, scope_let.scope_let_pos), scope_let.scope_let_typ),
scope_let.scope_let_pos );
( A.SLocalDef ((let_var_id, scope_let.scope_let_pos), new_let_expr),
scope_let.scope_let_pos );
]
@ translate_scope_body_expr scope_name decl_ctx new_var_dict func_dict
scope_let_next
let translate_program (p : L.program) : A.program =
{
decl_ctx = p.L.decl_ctx;
scopes =
(let _, new_scopes =
List.fold_left
(fun (func_dict, new_scopes) body ->
let new_scope_params, new_scope_body =
translate_scope body.L.scope_body_name p.decl_ctx func_dict
body.L.scope_body_expr
D.fold_left_scope_defs
~f:(fun (func_dict, new_scopes) scope_def scope_var ->
let scope_input_var, scope_body_expr =
Bindlib.unbind scope_def.scope_body.scope_body_expr
in
let input_pos =
Pos.get_position (D.ScopeName.get_info scope_def.scope_name)
in
let scope_input_var_id =
A.LocalName.fresh (Bindlib.name_of scope_input_var, input_pos)
in
let var_dict =
L.VarMap.singleton scope_input_var scope_input_var_id
in
let new_scope_body =
translate_scope_body_expr scope_def.D.scope_name p.decl_ctx
var_dict func_dict scope_body_expr
in
let func_id =
A.TopLevelName.fresh
(Bindlib.name_of body.Lcalc.Ast.scope_body_var, Pos.no_pos)
in
let func_dict =
L.VarMap.add body.Lcalc.Ast.scope_body_var func_id func_dict
A.TopLevelName.fresh (Bindlib.name_of scope_var, Pos.no_pos)
in
let func_dict = L.VarMap.add scope_var func_id func_dict in
( func_dict,
{
Ast.scope_body_name = body.Lcalc.Ast.scope_body_name;
Ast.scope_body_name = scope_def.D.scope_name;
Ast.scope_body_var = func_id;
scope_body_func =
{
A.func_params = new_scope_params;
A.func_params =
[
( (scope_input_var_id, input_pos),
( D.TTuple
( List.map snd
(D.StructMap.find
scope_def.D.scope_body
.D.scope_body_input_struct
p.L.decl_ctx.ctx_structs),
Some
scope_def.D.scope_body
.D.scope_body_input_struct ),
input_pos ) );
];
A.func_body = new_scope_body;
};
}
:: new_scopes ))
( (if !Cli.avoid_exceptions_flag then
L.VarMap.singleton L.handle_default_opt
(A.TopLevelName.fresh ("handle_default_opt", Pos.no_pos))
else
L.VarMap.singleton L.handle_default
(A.TopLevelName.fresh ("handle_default", Pos.no_pos))),
[] )
~init:
( (if !Cli.avoid_exceptions_flag then
L.VarMap.singleton L.handle_default_opt
(A.TopLevelName.fresh ("handle_default_opt", Pos.no_pos))
else
L.VarMap.singleton L.handle_default
(A.TopLevelName.fresh ("handle_default", Pos.no_pos))),
[] )
p.L.scopes
in
List.rev new_scopes);