mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Refactoring finished
This commit is contained in:
parent
6ea73a4291
commit
764edb6ef0
@ -217,8 +217,8 @@ let driver source_file (options : Cli.options) : int =
|
||||
(Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
|
||||
( scope_uid,
|
||||
Option.get
|
||||
(Dcalc.Ast.fold_scope_defs ~init:None
|
||||
~f:(fun acc scope_def ->
|
||||
(Dcalc.Ast.fold_left_scope_defs ~init:None
|
||||
~f:(fun acc scope_def _ ->
|
||||
if
|
||||
Dcalc.Ast.ScopeName.compare scope_def.scope_name
|
||||
scope_uid
|
||||
@ -298,17 +298,8 @@ let driver source_file (options : Cli.options) : int =
|
||||
let prgm =
|
||||
if options.closure_conversion then (
|
||||
Cli.debug_print "Performing closure conversion...";
|
||||
let prgm, closures =
|
||||
Lcalc.Closure_conversion.closure_conversion prgm
|
||||
in
|
||||
let prgm = Lcalc.Closure_conversion.closure_conversion prgm in
|
||||
let prgm = Bindlib.unbox prgm in
|
||||
List.iter
|
||||
(fun closure ->
|
||||
Cli.debug_format "Closure found:\n%a"
|
||||
(Lcalc.Print.format_expr ~debug:options.debug
|
||||
prgm.decl_ctx)
|
||||
(Bindlib.unbox closure.Lcalc.Closure_conversion.expr))
|
||||
closures;
|
||||
prgm)
|
||||
else prgm
|
||||
in
|
||||
@ -323,19 +314,27 @@ let driver source_file (options : Cli.options) : int =
|
||||
if Option.is_some options.ex_scope then
|
||||
Format.fprintf fmt "%a\n"
|
||||
(Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
|
||||
(let body =
|
||||
List.find
|
||||
(fun body -> body.Lcalc.Ast.scope_body_name = scope_uid)
|
||||
prgm.scopes
|
||||
in
|
||||
body)
|
||||
( scope_uid,
|
||||
Option.get
|
||||
(Dcalc.Ast.fold_left_scope_defs ~init:None
|
||||
~f:(fun acc scope_def _ ->
|
||||
if
|
||||
Dcalc.Ast.ScopeName.compare scope_def.scope_name
|
||||
scope_uid
|
||||
= 0
|
||||
then Some scope_def.scope_body
|
||||
else acc)
|
||||
prgm.scopes) )
|
||||
else
|
||||
Format.fprintf fmt "%a\n"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
|
||||
(fun fmt scope ->
|
||||
(Lcalc.Print.format_scope prgm.decl_ctx) fmt scope))
|
||||
prgm.scopes;
|
||||
ignore
|
||||
(Dcalc.Ast.fold_left_scope_defs ~init:0
|
||||
~f:(fun i scope_def _ ->
|
||||
Format.fprintf fmt "%s%a"
|
||||
(if i = 0 then "" else "\n")
|
||||
(Lcalc.Print.format_scope prgm.decl_ctx)
|
||||
(scope_uid, scope_def.scope_body);
|
||||
i + 1)
|
||||
prgm.scopes);
|
||||
at_end ();
|
||||
exit 0
|
||||
end;
|
||||
|
@ -512,6 +512,40 @@ let format_ctx
|
||||
Format.fprintf fmt "%a@\n@\n" format_enum_decl (e, find_enum e ctx))
|
||||
(type_ordering @ scope_structs)
|
||||
|
||||
let rec format_scope_body_expr
|
||||
(ctx : Dcalc.Ast.decl_ctx)
|
||||
(fmt : Format.formatter)
|
||||
(scope_lets : Ast.expr Dcalc.Ast.scope_body_expr) : unit =
|
||||
match scope_lets with
|
||||
| Dcalc.Ast.Result e -> format_expr ctx fmt e
|
||||
| Dcalc.Ast.ScopeLet scope_let ->
|
||||
let scope_let_var, scope_let_next =
|
||||
Bindlib.unbind scope_let.scope_let_next
|
||||
in
|
||||
Format.fprintf fmt "@[<hov 2>let %a: %a = %a in@]@\n%a" format_var
|
||||
scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx)
|
||||
scope_let.scope_let_expr
|
||||
(format_scope_body_expr ctx)
|
||||
scope_let_next
|
||||
|
||||
let rec format_scopes
|
||||
(ctx : Dcalc.Ast.decl_ctx)
|
||||
(fmt : Format.formatter)
|
||||
(scopes : Ast.expr Dcalc.Ast.scopes) : unit =
|
||||
match scopes with
|
||||
| Dcalc.Ast.Nil -> ()
|
||||
| Dcalc.Ast.ScopeDef scope_def ->
|
||||
let scope_input_var, scope_body_expr =
|
||||
Bindlib.unbind scope_def.scope_body.scope_body_expr
|
||||
in
|
||||
let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in
|
||||
Format.fprintf fmt "@\n@\n@[<hov 2>let %a (%a: %a) : %a =@\n%a@]%a"
|
||||
format_var scope_var format_var scope_input_var format_struct_name
|
||||
scope_def.scope_body.scope_body_input_struct format_struct_name
|
||||
scope_def.scope_body.scope_body_output_struct
|
||||
(format_scope_body_expr ctx)
|
||||
scope_body_expr (format_scopes ctx) scope_next
|
||||
|
||||
let format_program
|
||||
(fmt : Format.formatter)
|
||||
(p : Ast.program)
|
||||
@ -524,13 +558,5 @@ let format_program
|
||||
@\n\
|
||||
[@@@@@@ocaml.warning \"-4-26-27-32-41-42\"]@\n\
|
||||
@\n\
|
||||
%a@\n\
|
||||
@\n\
|
||||
%a@?"
|
||||
(format_ctx type_ordering) p.decl_ctx
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
|
||||
(fun fmt body ->
|
||||
Format.fprintf fmt "@[<hov 2>let@ %a@ =@ %a@]" format_var
|
||||
body.scope_body_var (format_expr p.decl_ctx) body.scope_body_expr))
|
||||
p.scopes
|
||||
%a%a@?"
|
||||
(format_ctx type_ordering) p.decl_ctx (format_scopes p.decl_ctx) p.scopes
|
||||
|
@ -281,30 +281,16 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) :
|
||||
Pos.get_position block_expr );
|
||||
])
|
||||
|
||||
let translate_scope
|
||||
let rec translate_scope_body_expr
|
||||
(scope_name : D.ScopeName.t)
|
||||
(decl_ctx : D.decl_ctx)
|
||||
(var_dict : A.LocalName.t L.VarMap.t)
|
||||
(func_dict : A.TopLevelName.t L.VarMap.t)
|
||||
(scope_expr : L.expr Pos.marked) :
|
||||
(A.LocalName.t Pos.marked * D.typ Pos.marked) list * A.block =
|
||||
match Pos.unmark scope_expr with
|
||||
| L.EAbs ((binder, binder_pos), typs) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let var_dict =
|
||||
Array.fold_left
|
||||
(fun var_dict var ->
|
||||
L.VarMap.add var
|
||||
(A.LocalName.fresh (Bindlib.name_of var, binder_pos))
|
||||
var_dict)
|
||||
L.VarMap.empty vars
|
||||
in
|
||||
let param_list =
|
||||
List.map2
|
||||
(fun var typ -> ((L.VarMap.find var var_dict, binder_pos), typ))
|
||||
(Array.to_list vars) typs
|
||||
in
|
||||
let new_body =
|
||||
translate_statements
|
||||
(scope_expr : L.expr D.scope_body_expr) : A.block =
|
||||
match scope_expr with
|
||||
| Result e ->
|
||||
let block, new_e =
|
||||
translate_expr
|
||||
{
|
||||
decl_ctx;
|
||||
func_dict;
|
||||
@ -312,48 +298,96 @@ let translate_scope
|
||||
inside_definition_of = None;
|
||||
context_name = Pos.unmark (D.ScopeName.get_info scope_name);
|
||||
}
|
||||
body
|
||||
e
|
||||
in
|
||||
(param_list, new_body)
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
block @ [ (A.SReturn (Pos.unmark new_e), Pos.get_position new_e) ]
|
||||
| ScopeLet scope_let ->
|
||||
let let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in
|
||||
let let_var_id =
|
||||
A.LocalName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos)
|
||||
in
|
||||
let let_expr_stmts, new_let_expr =
|
||||
translate_expr
|
||||
{
|
||||
decl_ctx;
|
||||
func_dict;
|
||||
var_dict;
|
||||
inside_definition_of = Some let_var_id;
|
||||
context_name = Pos.unmark (D.ScopeName.get_info scope_name);
|
||||
}
|
||||
scope_let.scope_let_expr
|
||||
in
|
||||
let new_var_dict = L.VarMap.add let_var let_var_id var_dict in
|
||||
let_expr_stmts
|
||||
@ [
|
||||
( A.SLocalDecl
|
||||
((let_var_id, scope_let.scope_let_pos), scope_let.scope_let_typ),
|
||||
scope_let.scope_let_pos );
|
||||
( A.SLocalDef ((let_var_id, scope_let.scope_let_pos), new_let_expr),
|
||||
scope_let.scope_let_pos );
|
||||
]
|
||||
@ translate_scope_body_expr scope_name decl_ctx new_var_dict func_dict
|
||||
scope_let_next
|
||||
|
||||
let translate_program (p : L.program) : A.program =
|
||||
{
|
||||
decl_ctx = p.L.decl_ctx;
|
||||
scopes =
|
||||
(let _, new_scopes =
|
||||
List.fold_left
|
||||
(fun (func_dict, new_scopes) body ->
|
||||
let new_scope_params, new_scope_body =
|
||||
translate_scope body.L.scope_body_name p.decl_ctx func_dict
|
||||
body.L.scope_body_expr
|
||||
D.fold_left_scope_defs
|
||||
~f:(fun (func_dict, new_scopes) scope_def scope_var ->
|
||||
let scope_input_var, scope_body_expr =
|
||||
Bindlib.unbind scope_def.scope_body.scope_body_expr
|
||||
in
|
||||
let input_pos =
|
||||
Pos.get_position (D.ScopeName.get_info scope_def.scope_name)
|
||||
in
|
||||
let scope_input_var_id =
|
||||
A.LocalName.fresh (Bindlib.name_of scope_input_var, input_pos)
|
||||
in
|
||||
let var_dict =
|
||||
L.VarMap.singleton scope_input_var scope_input_var_id
|
||||
in
|
||||
let new_scope_body =
|
||||
translate_scope_body_expr scope_def.D.scope_name p.decl_ctx
|
||||
var_dict func_dict scope_body_expr
|
||||
in
|
||||
let func_id =
|
||||
A.TopLevelName.fresh
|
||||
(Bindlib.name_of body.Lcalc.Ast.scope_body_var, Pos.no_pos)
|
||||
in
|
||||
let func_dict =
|
||||
L.VarMap.add body.Lcalc.Ast.scope_body_var func_id func_dict
|
||||
A.TopLevelName.fresh (Bindlib.name_of scope_var, Pos.no_pos)
|
||||
in
|
||||
let func_dict = L.VarMap.add scope_var func_id func_dict in
|
||||
( func_dict,
|
||||
{
|
||||
Ast.scope_body_name = body.Lcalc.Ast.scope_body_name;
|
||||
Ast.scope_body_name = scope_def.D.scope_name;
|
||||
Ast.scope_body_var = func_id;
|
||||
scope_body_func =
|
||||
{
|
||||
A.func_params = new_scope_params;
|
||||
A.func_params =
|
||||
[
|
||||
( (scope_input_var_id, input_pos),
|
||||
( D.TTuple
|
||||
( List.map snd
|
||||
(D.StructMap.find
|
||||
scope_def.D.scope_body
|
||||
.D.scope_body_input_struct
|
||||
p.L.decl_ctx.ctx_structs),
|
||||
Some
|
||||
scope_def.D.scope_body
|
||||
.D.scope_body_input_struct ),
|
||||
input_pos ) );
|
||||
];
|
||||
A.func_body = new_scope_body;
|
||||
};
|
||||
}
|
||||
:: new_scopes ))
|
||||
( (if !Cli.avoid_exceptions_flag then
|
||||
L.VarMap.singleton L.handle_default_opt
|
||||
(A.TopLevelName.fresh ("handle_default_opt", Pos.no_pos))
|
||||
else
|
||||
L.VarMap.singleton L.handle_default
|
||||
(A.TopLevelName.fresh ("handle_default", Pos.no_pos))),
|
||||
[] )
|
||||
~init:
|
||||
( (if !Cli.avoid_exceptions_flag then
|
||||
L.VarMap.singleton L.handle_default_opt
|
||||
(A.TopLevelName.fresh ("handle_default_opt", Pos.no_pos))
|
||||
else
|
||||
L.VarMap.singleton L.handle_default
|
||||
(A.TopLevelName.fresh ("handle_default", Pos.no_pos))),
|
||||
[] )
|
||||
p.L.scopes
|
||||
in
|
||||
List.rev new_scopes);
|
||||
|
Loading…
Reference in New Issue
Block a user