From 1230f787d661895bf94cc3a5efe1b8caf69c84ea Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 7 Aug 2024 17:43:14 +0200 Subject: [PATCH 1/9] Renaming: use in the scalc translation and in Python --- compiler/driver.ml | 67 ++- compiler/driver.mli | 9 +- compiler/lcalc/to_ocaml.ml | 23 +- compiler/lcalc/to_ocaml.mli | 2 +- compiler/plugins/api_web.ml | 16 +- compiler/plugins/json_schema.ml | 3 +- compiler/plugins/python.ml | 3 +- compiler/scalc/from_lcalc.ml | 252 ++++---- compiler/scalc/from_lcalc.mli | 4 +- compiler/scalc/print.ml | 6 +- compiler/scalc/to_c.ml | 216 ++----- compiler/scalc/to_c.mli | 6 +- compiler/scalc/to_python.ml | 199 ++----- compiler/scalc/to_python.mli | 4 + compiler/shared_ast/expr.mli | 6 + compiler/shared_ast/program.ml | 45 +- compiler/shared_ast/program.mli | 10 +- tests/backends/output/simple.c | 482 +++++---------- tests/backends/python_name_clash.catala_en | 122 ++-- .../good/toplevel_defs.catala_en | 560 +++++++++--------- tests/scope/good/nothing.catala_en | 10 +- 21 files changed, 875 insertions(+), 1170 deletions(-) diff --git a/compiler/driver.ml b/compiler/driver.ml index 33143622..862858ca 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -228,8 +228,11 @@ module Passes = struct ~check_invariants ~(typed : ty mark) ~closure_conversion - ~monomorphize_types : - typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list = + ~monomorphize_types + ~renaming : + typed Lcalc.Ast.program * + Scopelang.Dependency.TVertex.t list * + Expr.Renaming.context option = let prg, type_ordering = dcalc options ~includes ~optimize ~check_invariants ~typed in @@ -275,7 +278,19 @@ module Passes = struct prg, type_ordering) else prg, type_ordering in - prg, type_ordering + match renaming with + | None -> prg, type_ordering, None + | Some renaming -> + let prg, ren_ctx = Program.apply renaming prg in + let type_ordering = + let open Scopelang.Dependency.TVertex in + List.map + (function + | Struct s -> Struct (Expr.Renaming.struct_name ren_ctx s) + | Enum e -> Enum (Expr.Renaming.enum_name ren_ctx e)) + type_ordering + in + prg, type_ordering, Some ren_ctx let scalc options @@ -286,17 +301,30 @@ module Passes = struct ~keep_special_ops ~dead_value_assignment ~no_struct_literals - ~monomorphize_types : - Scalc.Ast.program * Scopelang.Dependency.TVertex.t list = - let prg, type_ordering = + ~monomorphize_types + ~renaming : + Scalc.Ast.program * Scopelang.Dependency.TVertex.t list * Expr.Renaming.context = + let prg, type_ordering, renaming_context = lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed - ~closure_conversion ~monomorphize_types + ~closure_conversion ~monomorphize_types ~renaming + in + let renaming_context = match renaming_context with + | None -> Expr.Renaming.get_ctx { + reserved = []; + sanitize_varname = Fun.id; + reset_context_for_closed_terms = true; + skip_constant_binders = true; + constant_binder_name = None; + } + | Some r -> r in debug_pass_name "scalc"; ( Scalc.From_lcalc.translate_program - ~config:{ keep_special_ops; dead_value_assignment; no_struct_literals } + ~config:{ keep_special_ops; dead_value_assignment; no_struct_literals; + renaming_context } prg, - type_ordering ) + type_ordering, + renaming_context ) end module Commands = struct @@ -711,9 +739,9 @@ module Commands = struct closure_conversion monomorphize_types ex_scope_opt = - let prg, _ = + let prg, _, _ = Passes.lcalc options ~includes ~optimize ~check_invariants - ~closure_conversion ~typed ~monomorphize_types + ~closure_conversion ~typed ~monomorphize_types ~renaming:None in let _output_file, with_output = get_output_format options output in with_output @@ -759,9 +787,9 @@ module Commands = struct optimize check_invariants ex_scope_opt = - let prg, _ = + let prg, _, _ = Passes.lcalc options ~includes ~optimize ~check_invariants - ~closure_conversion ~monomorphize_types ~typed + ~closure_conversion ~monomorphize_types ~typed ~renaming:None in Interpreter.load_runtime_modules ~hashf:(Hash.finalise ~closure_conversion ~monomorphize_types) @@ -809,9 +837,10 @@ module Commands = struct check_invariants closure_conversion ex_scope_opt = - let prg, type_ordering = + let prg, type_ordering, _ = Passes.lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed ~closure_conversion ~monomorphize_types:false + ~renaming:(Some Lcalc.To_ocaml.renaming) in let output_file, with_output = get_output_format options ~ext:".ml" output @@ -851,10 +880,10 @@ module Commands = struct no_struct_literals monomorphize_types ex_scope_opt = - let prg, _ = + let prg, _, _ = Passes.scalc options ~includes ~optimize ~check_invariants ~closure_conversion ~keep_special_ops ~dead_value_assignment - ~no_struct_literals ~monomorphize_types + ~no_struct_literals ~monomorphize_types ~renaming:None in let _output_file, with_output = get_output_format options output in with_output @@ -900,10 +929,11 @@ module Commands = struct optimize check_invariants closure_conversion = - let prg, type_ordering = + let prg, type_ordering, _ren_ctx = Passes.scalc options ~includes ~optimize ~check_invariants ~closure_conversion ~keep_special_ops:false ~dead_value_assignment:true ~no_struct_literals:false ~monomorphize_types:false + ~renaming:(Some Scalc.To_python.renaming) in let output_file, with_output = @@ -929,11 +959,12 @@ module Commands = struct $ Cli.Flags.closure_conversion) let c options includes output optimize check_invariants = - let prg, type_ordering = + let prg, type_ordering, _ren_ctx = Passes.scalc options ~includes ~optimize ~check_invariants ~closure_conversion:true ~keep_special_ops:true ~dead_value_assignment:false ~no_struct_literals:true ~monomorphize_types:true + ~renaming:(Some Scalc.To_c.renaming) in let output_file, with_output = get_output_format options ~ext:".c" output in Message.debug "Compiling program into C..."; diff --git a/compiler/driver.mli b/compiler/driver.mli index 3184a8e7..47f2bfce 100644 --- a/compiler/driver.mli +++ b/compiler/driver.mli @@ -53,7 +53,9 @@ module Passes : sig typed:'m Shared_ast.mark -> closure_conversion:bool -> monomorphize_types:bool -> - Shared_ast.typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list + renaming : Shared_ast.Program.renaming option -> + Shared_ast.typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list * + Shared_ast.Expr.Renaming.context option val scalc : Global.options -> @@ -65,7 +67,10 @@ module Passes : sig dead_value_assignment:bool -> no_struct_literals:bool -> monomorphize_types:bool -> - Scalc.Ast.program * Scopelang.Dependency.TVertex.t list + renaming: Shared_ast.Program.renaming option -> + Scalc.Ast.program * Scopelang.Dependency.TVertex.t list * + Shared_ast.Expr.Renaming.context + end module Commands : sig diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 3a2199fd..d9e5f68e 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -130,6 +130,13 @@ let ocaml_keywords = "Oper"; ] +let renaming = + Program.renaming () + ~reserved:ocaml_keywords + (* TODO: add catala runtime built-ins as reserved as well ? *) + ~reset_context_for_closed_terms:true ~skip_constant_binders:true + ~constant_binder_name:(Some "_") ~namespaced_fields_constrs:true + let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = (match StructName.path v with | [] -> () @@ -414,6 +421,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : Print.runtime_error er format_pos (Expr.pos e) | _ -> . +(* TODO: move [embed_foo] to [Foo.embed] to protect from name clashes *) let format_struct_embedding (fmt : Format.formatter) ((struct_name, struct_fields) : StructName.t * typ StructField.Map.t) = @@ -730,21 +738,6 @@ let format_program ~(hashf : Hash.t -> Hash.full) (p : 'm Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = - let p, ren_ctx = - Program.rename_ids p - ~reserved:ocaml_keywords - (* TODO: add catala runtime built-ins as reserved as well ? *) - ~reset_context_for_closed_terms:true ~skip_constant_binders:true - ~constant_binder_name:(Some "_") ~namespaced_fields_constrs:true - in - let type_ordering = - let open Scopelang.Dependency.TVertex in - List.map - (function - | Struct s -> Struct (Expr.Renaming.struct_name ren_ctx s) - | Enum e -> Enum (Expr.Renaming.enum_name ren_ctx e)) - type_ordering - in Format.pp_open_vbox fmt 0; Format.pp_print_string fmt header; check_and_reexport_used_modules fmt ~hashf diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index ff853298..489343d7 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -17,7 +17,7 @@ open Catala_utils open Shared_ast -val ocaml_keywords : string list +val renaming : Program.renaming (** Formats a lambda calculus program into a valid OCaml program *) diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index e892ba3f..11d3128f 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -475,22 +475,10 @@ let run monomorphize_types _options = let options = Global.enforce_options ~trace:true () in - let prg, type_ordering = + let prg, type_ordering, _ = Driver.Passes.lcalc options ~includes ~optimize ~check_invariants ~closure_conversion ~typed:Expr.typed ~monomorphize_types - in - let prg, ren_ctx = - Program.rename_ids prg ~reserved:To_ocaml.ocaml_keywords - ~reset_context_for_closed_terms:true ~skip_constant_binders:true - ~constant_binder_name:None ~namespaced_fields_constrs:true - in - let type_ordering = - let open Scopelang.Dependency.TVertex in - List.map - (function - | Struct s -> Struct (Expr.Renaming.struct_name ren_ctx s) - | Enum e -> Enum (Expr.Renaming.enum_name ren_ctx e)) - type_ordering + ~renaming:(Some Lcalc.To_ocaml.renaming) in let jsoo_output_file, with_formatter = Driver.Commands.get_output_format options ~ext:"_api_web.ml" output diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index bdfc0e2b..51cda7a7 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -213,9 +213,10 @@ let run monomorphize_types ex_scope options = - let prg, _ = + let prg, _, _ = Driver.Passes.lcalc options ~includes ~optimize ~check_invariants ~closure_conversion ~typed:Expr.typed ~monomorphize_types + ~renaming:(Some Lcalc.To_ocaml.renaming) in let output_file, with_output = Driver.Commands.get_output_format options ~ext:"_schema.json" output diff --git a/compiler/plugins/python.ml b/compiler/plugins/python.ml index 1c167290..eedadd49 100644 --- a/compiler/plugins/python.ml +++ b/compiler/plugins/python.ml @@ -24,10 +24,11 @@ open Catala_utils let run includes output optimize check_invariants closure_conversion options = let open Driver.Commands in - let prg, type_ordering = + let prg, type_ordering, _ = Driver.Passes.scalc options ~includes ~optimize ~check_invariants ~closure_conversion ~keep_special_ops:false ~dead_value_assignment:true ~no_struct_literals:false ~monomorphize_types:false + ~renaming:(Some Scalc.To_python.renaming) in let output_file, with_output = get_output_format options ~ext:".py" output in diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 41ced7ef..c3326954 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -24,6 +24,7 @@ type translation_config = { keep_special_ops : bool; dead_value_assignment : bool; no_struct_literals : bool; + renaming_context : Expr.Renaming.context; } type 'm ctxt = { @@ -33,6 +34,7 @@ type 'm ctxt = { context_name : string; config : translation_config; program_ctx : A.ctx; + ren_ctx : Expr.Renaming.context; } (* Expressions can spill out side effect, hence this function also returns a @@ -65,6 +67,36 @@ end let ( ++ ) = RevBlock.seq +let unbind ctxt bnd = + let v, body, ren_ctx = Expr.Renaming.unbind_in ctxt.ren_ctx bnd in + v, body, { ctxt with ren_ctx } + +let unmbind ctxt bnd = + let vs, body, ren_ctx = Expr.Renaming.unmbind_in ctxt.ren_ctx bnd in + vs, body, { ctxt with ren_ctx } + +let get_name ctxt s = + let name, ren_ctx = Expr.Renaming.new_id ctxt.ren_ctx s in + name, { ctxt with ren_ctx } + +let fresh_var ~pos ctxt name = + let v, ctxt = get_name ctxt name in + A.VarName.fresh (v, pos), ctxt + +let register_fresh_var ~pos ctxt x = + let v = A.VarName.fresh (Bindlib.name_of x, pos) in + let var_dict = Var.Map.add x v ctxt.var_dict in + v, { ctxt with var_dict } + +let register_fresh_func ~pos ctxt x = + let f = A.FuncName.fresh (Bindlib.name_of x, pos) in + let func_dict = Var.Map.add x f ctxt.func_dict in + f, { ctxt with func_dict } + +let register_fresh_arg ~pos ctxt (x, _) = + let _, ctxt = register_fresh_var ~pos ctxt x in + ctxt + let rec translate_expr_list ctxt args = let stmts, args = List.fold_left @@ -138,19 +170,11 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = | EApp { f = EAbs { binder; tys }, binder_mark; args; tys = _ } -> (* This defines multiple local variables at the time *) let binder_pos = Expr.mark_pos binder_mark in - let vars, body = Bindlib.unmbind binder in + let vars, body, ctxt = unmbind ctxt binder in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in let ctxt = - { - ctxt with - var_dict = - List.fold_left - (fun var_dict (x, _) -> - Var.Map.add x - (A.VarName.fresh (Bindlib.name_of x, binder_pos)) - var_dict) - ctxt.var_dict vars_tau; - } + List.fold_left (register_fresh_arg ~pos:binder_pos) + ctxt vars_tau in let local_decls = List.fold_left @@ -215,18 +239,13 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = raise (NotAnExpr { needs_a_local_decl = true }) | _ -> . with NotAnExpr { needs_a_local_decl } -> - let tmp_var = - A.VarName.fresh - ( (*This piece of logic is used to make the code more readable. TODO: - should be removed when - https://github.com/CatalaLang/catala/issues/240 is fixed. *) - (match ctxt.inside_definition_of with - | None -> ctxt.context_name - | Some v -> - let v = Mark.remove (A.VarName.get_info v) in - let tmp_rex = Re.Pcre.regexp "^temp_" in - if Re.Pcre.pmatch ~rex:tmp_rex v then v else "temp_" ^ v), - Expr.pos expr ) + let tmp_var, ctxt = + let name = + match ctxt.inside_definition_of with + | None -> ctxt.context_name + | Some v -> A.VarName.to_string v + in + fresh_var ctxt name ~pos:(Expr.pos expr) in let ctxt = { @@ -314,19 +333,12 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = | EApp { f = EAbs { binder; tys }, binder_mark; args; _ } -> (* This defines multiple local variables at the time *) let binder_pos = Expr.mark_pos binder_mark in - let vars, body = Bindlib.unmbind binder in + let vars, body, ctxt = unmbind ctxt binder in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in let ctxt = - { - ctxt with - var_dict = - List.fold_left - (fun var_dict (x, _) -> - Var.Map.add x - (A.VarName.fresh (Bindlib.name_of x, binder_pos)) - var_dict) - ctxt.var_dict vars_tau; - } + List.fold_left + (register_fresh_arg ~pos:binder_pos) + ctxt vars_tau in let local_decls = List.map @@ -369,26 +381,19 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = let rest_of_block = translate_statements ctxt body in local_decls @ List.flatten def_blocks @ rest_of_block | EAbs { binder; tys } -> - let vars, body = Bindlib.unmbind binder in - let binder_pos = Expr.pos block_expr in - let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in - let closure_name = + let closure_name, ctxt = match ctxt.inside_definition_of with - | None -> A.VarName.fresh (ctxt.context_name, Expr.pos block_expr) - | Some x -> x + | None -> fresh_var ctxt ctxt.context_name ~pos:(Expr.pos block_expr) + | Some x -> x, ctxt in + let vars, body, ctxt = unmbind ctxt binder in + let binder_pos = Expr.pos block_expr in + let vars_tau = List.combine (Array.to_list vars) tys in let ctxt = - { - ctxt with - var_dict = - List.fold_left - (fun var_dict (x, _) -> - Var.Map.add x - (A.VarName.fresh (Bindlib.name_of x, binder_pos)) - var_dict) - ctxt.var_dict vars_tau; - inside_definition_of = None; - } + List.fold_left + (register_fresh_arg ~pos:binder_pos) + { ctxt with inside_definition_of = None } + vars_tau in let new_body = translate_statements ctxt body in [ @@ -419,14 +424,11 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = (fun _ arg new_args -> match Mark.remove arg with | EAbs { binder; tys } -> - let vars, body = Bindlib.unmbind binder in + let vars, body, ctxt = unmbind ctxt binder in assert (Array.length vars = 1); let var = vars.(0) in - let scalc_var = - A.VarName.fresh (Bindlib.name_of var, Expr.pos arg) - in - let ctxt = - { ctxt with var_dict = Var.Map.add var scalc_var ctxt.var_dict } + let scalc_var, ctxt = + register_fresh_var ctxt var ~pos:(Expr.pos arg) in let new_arg = translate_statements ctxt body in { @@ -556,35 +558,22 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = | _ -> . let rec translate_scope_body_expr - ~(config : translation_config) - (scope_name : ScopeName.t) - (program_ctx : A.ctx) - (var_dict : ('m L.expr, A.VarName.t) Var.Map.t) - (func_dict : ('m L.expr, A.FuncName.t) Var.Map.t) + ctx (scope_expr : 'm L.expr scope_body_expr) : A.block = let ctx = - { - func_dict; - var_dict; - inside_definition_of = None; - context_name = Mark.remove (ScopeName.get_info scope_name); - config; - program_ctx; - } + { ctx with inside_definition_of = None } in match scope_expr with | Last e -> let block, new_e = translate_expr ctx e in RevBlock.rebuild block ~tail:[A.SReturn (Mark.remove new_e), Mark.get new_e] | Cons (scope_let, next_bnd) -> ( - let let_var, scope_let_next = Bindlib.unbind next_bnd in - let let_var_id = - A.VarName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos) + let let_var, scope_let_next, ctx1 = unbind ctx next_bnd in + let let_var_id, ctx = + register_fresh_var ctx1 let_var ~pos:scope_let.scope_let_pos in - let new_var_dict = Var.Map.add let_var let_var_id var_dict in let next = - translate_scope_body_expr ~config scope_name program_ctx new_var_dict - func_dict scope_let_next + translate_scope_body_expr ctx scope_let_next in match scope_let.scope_let_kind with | Assertion -> @@ -615,7 +604,7 @@ let rec translate_scope_body_expr scope_let.scope_let_pos ) :: next)) -let translate_program ~(config : translation_config) (p : 'm L.program) : +let translate_program ~(config : translation_config) (p : 'm L.program): A.program = let modules = List.fold_left @@ -630,29 +619,41 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : ModuleName.Map.empty (Program.modules_to_list p.decl_ctx.ctx_modules) in - let ctx = { A.decl_ctx = p.decl_ctx; A.modules } in - let (_, _, rev_items), _vlist = + let program_ctx = { A.decl_ctx = p.decl_ctx; A.modules } in + let ctxt = + { + func_dict = Var.Map.empty; + var_dict = Var.Map.empty; + inside_definition_of = None; + context_name = ""; + config; + program_ctx; + ren_ctx = config.renaming_context; + } + in + let (_, rev_items), _vlist = BoundList.fold_left - ~f:(fun (func_dict, var_dict, rev_items) code_item var -> + ~init:(ctxt, []) + ~f:(fun (ctxt, rev_items) code_item var -> match code_item with | ScopeDef (name, body) -> - let scope_input_var, scope_body_expr = - Bindlib.unbind body.scope_body_expr + let scope_input_var, scope_body_expr, ctxt1 = + unbind ctxt body.scope_body_expr in let input_pos = Mark.get (ScopeName.get_info name) in - let scope_input_var_id = - A.VarName.fresh (Bindlib.name_of scope_input_var, input_pos) - in - let var_dict_local = - Var.Map.add scope_input_var scope_input_var_id var_dict + let scope_input_var_id, ctxt = + register_fresh_var ctxt scope_input_var ~pos:input_pos in let new_scope_body = - translate_scope_body_expr ~config name ctx var_dict_local func_dict + translate_scope_body_expr + { ctxt with + context_name = Mark.remove (ScopeName.get_info name) } scope_body_expr in - let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in - ( Var.Map.add var func_id func_dict, - var_dict, + let func_id, ctxt1 = + register_fresh_func ctxt1 var ~pos:input_pos + in + ( ctxt1, A.SScope { Ast.scope_body_name = name; @@ -670,40 +671,32 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : }; } :: rev_items ) - | Topdef (name, topdef_ty, (EAbs abs, _)) -> + | Topdef (name, topdef_ty, (EAbs abs, m)) -> (* Toplevel function def *) - let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in - let args_a, expr = Bindlib.unmbind abs.binder in - let args = Array.to_list args_a in - let args_id = - List.map2 - (fun v ty -> - let pos = Mark.get ty in - (A.VarName.fresh (Bindlib.name_of v, pos), pos), ty) - args abs.tys - in - let block, expr = + let (block, expr), args_id = + let args_a, expr, ctxt = unmbind ctxt abs.binder in + let args = Array.to_list args_a in + let rargs_id, ctxt = + List.fold_left2 + (fun (rargs_id, ctxt) v ty -> + let pos = Mark.get ty in + let id, ctxt = register_fresh_var ctxt v ~pos in + ((id, pos), ty) :: rargs_id, ctxt) + ([], ctxt) args abs.tys + in let ctxt = - { - func_dict; - var_dict = - List.fold_left2 - (fun map arg ((id, _), _) -> Var.Map.add arg id map) - var_dict args args_id; - inside_definition_of = None; + { ctxt with context_name = Mark.remove (TopdefName.get_info name); - config; - program_ctx = ctx; } in - translate_expr ctxt expr + translate_expr ctxt expr, List.rev rargs_id in let body_block = RevBlock.rebuild block ~tail:[A.SReturn (Mark.remove expr), Mark.get expr] in - ( Var.Map.add var func_id func_dict, - var_dict, + let func_id, ctxt = register_fresh_func ctxt var ~pos:(Expr.mark_pos m) in + ( ctxt, A.SFunc { var = func_id; @@ -721,30 +714,28 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : :: rev_items ) | Topdef (name, topdef_ty, expr) -> (* Toplevel constant def *) - let var_id = A.VarName.fresh (Bindlib.name_of var, Pos.no_pos) in let block, expr = let ctxt = - { - func_dict; - var_dict; - inside_definition_of = None; + { ctxt with context_name = Mark.remove (TopdefName.get_info name); - config; - program_ctx = ctx; } in translate_expr ctxt expr in + let var_id, ctxt = + register_fresh_var ctxt var ~pos:(Mark.get (TopdefName.get_info name)) + in (* If the evaluation of the toplevel expr requires preliminary statements, we lift its computation into an auxiliary function *) - let rev_items = + let rev_items, ctxt = if (block :> (A.stmt * Pos.t) list) = [] then - A.SVar { var = var_id; expr; typ = topdef_ty } :: rev_items + A.SVar { var = var_id; expr; typ = topdef_ty } :: rev_items, ctxt else let pos = Mark.get expr in - let func_id = - A.FuncName.fresh (Bindlib.name_of var ^ "_aux", pos) + let func_name, ctxt = + get_name ctxt (A.VarName.to_string var_id ^ "_init") in + let func_id = A.FuncName.fresh (func_name, pos) in (* The list is being built in reverse order *) (* FIXME: find a better way than a function with no parameters... *) A.SVar @@ -765,14 +756,13 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : A.func_return_typ = topdef_ty; }; } - :: rev_items + :: rev_items, + ctxt in - ( func_dict, + ( ctxt, (* No need to add func_id since the function will only be called right here *) - Var.Map.add var var_id var_dict, rev_items )) - ~init:(Var.Map.empty, Var.Map.empty, []) p.code_items in - { ctx; code_items = List.rev rev_items; module_name = p.module_name } + { ctx = program_ctx; code_items = List.rev rev_items; module_name = p.module_name } diff --git a/compiler/scalc/from_lcalc.mli b/compiler/scalc/from_lcalc.mli index fdd3b255..4cd09de7 100644 --- a/compiler/scalc/from_lcalc.mli +++ b/compiler/scalc/from_lcalc.mli @@ -32,7 +32,9 @@ type translation_config = { (** When [no_struct_literals] is true, the translation inserts a temporary variable to hold the initialization of struct literals. This matches what C89 expects. *) + renaming_context : Expr.Renaming.context; } val translate_program : - config:translation_config -> typed Lcalc.Ast.program -> Ast.program + config:translation_config -> typed Lcalc.Ast.program -> + Ast.program diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 5863678f..a40737c6 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -21,10 +21,12 @@ open Ast let needs_parens (_e : expr) : bool = false let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit = - Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v) + VarName.format fmt v + (* Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v) *) let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = - Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.id v) + FuncName.format fmt v + (* Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.id v) *) let rec format_expr (decl_ctx : decl_ctx) diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index 3205968d..517db4ae 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -21,109 +21,21 @@ module D = Dcalc.Ast module L = Lcalc.Ast open Ast -let avoid_keywords (s : string) : string = - if - match s with - (* list taken from - https://learn.microsoft.com/en-us/cpp/c-language/c-keywords *) - | "auto" | "break" | "case" | "char" | "const" | "continue" | "default" - | "do" | "double" | "else" | "enum" | "extern" | "float" | "for" | "goto" - | "if" | "inline" | "int" | "long" | "register" | "restrict" | "return" - | "short" | "signed" | "sizeof" | "static" | "struct" | "switch" | "typedef" - | "union" | "unsigned" | "void" | "volatile" | "while" -> - true - | _ -> false - then s ^ "_" - else s +let c_keywords = + [ "auto"; "break"; "case"; "char"; "const"; "continue"; "default"; + "do"; "double"; "else"; "enum"; "extern"; "float"; "for"; "goto"; + "if"; "inline"; "int"; "long"; "register"; "restrict"; "return"; + "short"; "signed"; "sizeof"; "static"; "struct"; "switch"; "typedef"; + "union"; "unsigned"; "void"; "volatile"; "while" ] -let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = - Format.fprintf fmt "%s" - (Format.asprintf "%a_struct" StructName.format v - |> String.to_ascii - |> String.to_snake_case - |> avoid_keywords) - -let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit - = - Format.fprintf fmt "%s" - (Format.asprintf "%a_field" StructField.format v - |> String.to_ascii - |> String.to_snake_case - |> avoid_keywords) - -let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = - Format.fprintf fmt "%s_enum" - (Format.asprintf "%a" EnumName.format v - |> String.to_ascii - |> String.to_snake_case - |> avoid_keywords) - -let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : - unit = - Format.fprintf fmt "%s_cons" - (Format.asprintf "%a" EnumConstructor.format v - |> String.to_ascii - |> String.to_snake_case - |> avoid_keywords) - -let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = - s - |> String.to_ascii - |> String.to_snake_case - |> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") - |> String.to_ascii - |> avoid_keywords - |> Format.fprintf fmt "%s" - -let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = - let v_str = Mark.remove (FuncName.get_info v) in - Format.fprintf fmt "%a_func" format_name_cleaned v_str - -module StringMap = String.Map - -module IntMap = Map.Make (struct - include Int - - let format ppf i = Format.pp_print_int ppf i -end) - -(** For each `VarName.t` defined by its string and then by its hash, we keep - track of which local integer id we've given it. This is used to keep - variable naming with low indices rather than one global counter for all - variables. TODO: should be removed when - https://github.com/CatalaLang/catala/issues/240 is fixed. *) -let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty - -let format_var (fmt : Format.formatter) (v : VarName.t) : unit = - let v_str = Mark.remove (VarName.get_info v) in - let id = VarName.id v in - let local_id = - match StringMap.find_opt v_str !string_counter_map with - | Some ids -> ( - match IntMap.find_opt id ids with - | None -> - let max_id = - snd - (List.hd - (List.fast_sort - (fun (_, x) (_, y) -> Int.compare y x) - (IntMap.bindings ids))) - in - string_counter_map := - StringMap.add v_str - (IntMap.add id (max_id + 1) ids) - !string_counter_map; - max_id + 1 - | Some local_id -> local_id) - | None -> - string_counter_map := - StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; - 0 - in - if v_str = "_" then Format.fprintf fmt "dummy_var" - (* special case for the unit pattern TODO escape dummy_var *) - else if local_id = 0 then format_name_cleaned fmt v_str - else Format.fprintf fmt "%a_%d" format_name_cleaned v_str local_id +let renaming = + Program.renaming () + ~reserved:c_keywords + (* TODO: add catala runtime built-ins as reserved as well ? *) + ~reset_context_for_closed_terms:true + ~skip_constant_binders:true + ~constant_binder_name:None + ~namespaced_fields_constrs:false module TypMap = Map.Make (struct type t = naked_typ @@ -156,12 +68,12 @@ let rec format_typ (format_typ decl_ctx (fun fmt -> Format.fprintf fmt "arg_%d" i)) t)) (List.mapi (fun x y -> y, x) ts) - | TStruct s -> Format.fprintf fmt "%a %t" format_struct_name s element_name + | TStruct s -> Format.fprintf fmt "%a %t" StructName.format s element_name | TOption _ -> Message.error ~internal:true "All option types should have been monomorphized before compilation to C." | TDefault t -> format_typ decl_ctx element_name fmt t - | TEnum e -> Format.fprintf fmt "%a %t" format_enum_name e element_name + | TEnum e -> Format.fprintf fmt "%a %t" EnumName.format e element_name | TArrow (t1, t2) -> Format.fprintf fmt "%a(%a)" (format_typ decl_ctx (fun fmt -> Format.fprintf fmt "(*%t)" element_name)) @@ -185,41 +97,41 @@ let format_ctx let format_struct_decl fmt (struct_name, struct_fields) = let fields = StructField.Map.bindings struct_fields in Format.fprintf fmt "@[typedef struct %a {@ %a@]@,} %a;" - format_struct_name struct_name + StructName.format struct_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") (fun fmt (struct_field, struct_field_type) -> Format.fprintf fmt "@[%a;@]" (format_typ ctx (fun fmt -> - format_struct_field_name fmt struct_field)) + StructField.format fmt struct_field)) struct_field_type)) - fields format_struct_name struct_name + fields StructName.format struct_name in let format_enum_decl fmt (enum_name, enum_cons) = if EnumConstructor.Map.is_empty enum_cons then failwith "no constructors in the enum" else Format.fprintf fmt "@[enum %a_code {@,%a@]@,} %a_code;@\n@\n" - format_enum_name enum_name + EnumName.format enum_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (enum_cons, _) -> - Format.fprintf fmt "%a_%a" format_enum_name enum_name - format_enum_cons_name enum_cons)) + Format.fprintf fmt "%a_%a" EnumName.format enum_name + EnumConstructor.format enum_cons)) (EnumConstructor.Map.bindings enum_cons) - format_enum_name enum_name; + EnumName.format enum_name; Format.fprintf fmt "@[typedef struct %a {@ enum %a_code code;@ @[union {@ %a@]@,\ } payload;@]@,\ - } %a;" format_enum_name enum_name format_enum_name enum_name + } %a;" EnumName.format enum_name EnumName.format enum_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") (fun fmt (enum_cons, typ) -> Format.fprintf fmt "%a;" - (format_typ ctx (fun fmt -> format_enum_cons_name fmt enum_cons)) + (format_typ ctx (fun fmt -> EnumConstructor.format fmt enum_cons)) typ)) (EnumConstructor.Map.bindings enum_cons) - format_enum_name enum_name + EnumName.format enum_name in let is_in_type_ordering s = @@ -329,8 +241,8 @@ let _format_string_list (fmt : Format.formatter) (uids : string list) : unit = let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : unit = match Mark.remove e with - | EVar v -> format_var fmt v - | EFunc f -> format_func_name fmt f + | EVar v -> VarName.format fmt v + | EFunc f -> FuncName.format fmt f | EStruct { fields = es; _ } -> (* These should only appear when initializing a variable definition *) Format.fprintf fmt "{ %a }" @@ -340,10 +252,10 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : (StructField.Map.bindings es) | EStructFieldAccess { e1; field; _ } -> Format.fprintf fmt "%a.%a" (format_expression ctx) e1 - format_struct_field_name field + StructField.format field | EInj { e1; cons; name = enum_name; _ } -> - Format.fprintf fmt "{%a_%a,@ {%a: %a}}" format_enum_name enum_name - format_enum_cons_name cons format_enum_cons_name cons + Format.fprintf fmt "{%a_%a,@ {%a: %a}}" EnumName.format enum_name + EnumConstructor.format cons EnumConstructor.format cons (format_expression ctx) e1 | EArray _ -> failwith @@ -402,7 +314,7 @@ let rec format_statement "This inner functions should have been hoisted in Scalc" | SLocalDecl { name = v; typ = ty } -> Format.fprintf fmt "@[%a@];" - (format_typ ctx (fun fmt -> format_var fmt (Mark.remove v))) + (format_typ ctx (fun fmt -> VarName.format fmt (Mark.remove v))) ty (* Below we detect array initializations which have special treatment. *) | SLocalInit { name = v; expr = EStruct { fields; name }, _; typ } @@ -421,20 +333,20 @@ let rec format_statement "@[%a;@]@\n\ @[%a.content_field = catala_malloc(sizeof(%a));@]@\n\ %a" - (format_typ ctx (fun fmt -> format_var fmt (Mark.remove v))) - typ format_var (Mark.remove v) format_struct_name name + (format_typ ctx (fun fmt -> VarName.format fmt (Mark.remove v))) + typ VarName.format (Mark.remove v) StructName.format name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (i, arg) -> Format.fprintf fmt "@[%a.content_field[%d] =@ %a;@]" - format_var (Mark.remove v) i (format_expression ctx) arg)) + VarName.format (Mark.remove v) i (format_expression ctx) arg)) (List.mapi (fun i a -> i, a) array_contents) | SLocalInit { name = v; expr = e; typ } -> Format.fprintf fmt "@[%a = %a;@]" - (format_typ ctx (fun fmt -> format_var fmt (Mark.remove v))) + (format_typ ctx (fun fmt -> VarName.format fmt (Mark.remove v))) typ (format_expression ctx) e | SLocalDef { name = v; expr = e; _ } -> - Format.fprintf fmt "@[%a = %a;@]" format_var (Mark.remove v) + Format.fprintf fmt "@[%a = %a;@]" VarName.format (Mark.remove v) (format_expression ctx) e | SRaiseEmpty | STryWEmpty _ -> assert false | SFatalError err -> @@ -457,18 +369,18 @@ let rec format_statement (EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums)) in let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in - Format.fprintf fmt "@[%a %a = %a;@]@," format_enum_name e_name - format_var tmp_var (format_expression ctx) e1; + Format.fprintf fmt "@[%a %a = %a;@]@," EnumName.format e_name + VarName.format tmp_var (format_expression ctx) e1; Format.pp_open_vbox fmt 2; - Format.fprintf fmt "@[switch (%a.code) {@]@," format_var tmp_var; + Format.fprintf fmt "@[switch (%a.code) {@]@," VarName.format tmp_var; Format.pp_print_list (fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) -> - Format.fprintf fmt "@[case %a_%a:@ " format_enum_name e_name - format_enum_cons_name cons_name; + Format.fprintf fmt "@[case %a_%a:@ " EnumName.format e_name + EnumConstructor.format cons_name; if not (Type.equal payload_var_typ (TLit TUnit, Pos.no_pos)) then Format.fprintf fmt "%a = %a.payload.%a;@ " - (format_typ ctx (fun fmt -> format_var fmt payload_var_name)) - payload_var_typ format_var tmp_var format_enum_cons_name cons_name; + (format_typ ctx (fun fmt -> VarName.format fmt payload_var_name)) + payload_var_typ VarName.format tmp_var EnumConstructor.format cons_name; Format.fprintf fmt "%a@ break;@]" (format_block ctx) case_block) fmt cases; (* Do we want to add 'default' case with a failure ? *) @@ -514,13 +426,13 @@ let rec format_statement in if exceptions <> [] then begin Format.fprintf fmt "@[%a = {%a_%a,@ {%a: NULL}};@]@," - (format_typ ctx (fun fmt -> format_var fmt exception_acc_var)) - return_typ format_enum_name e_name format_enum_cons_name none_cons - format_enum_cons_name none_cons; + (format_typ ctx (fun fmt -> VarName.format fmt exception_acc_var)) + return_typ EnumName.format e_name EnumConstructor.format none_cons + EnumConstructor.format none_cons; Format.fprintf fmt "%a;@," - (format_typ ctx (fun fmt -> format_var fmt exception_current)) + (format_typ ctx (fun fmt -> VarName.format fmt exception_current)) return_typ; - Format.fprintf fmt "char %a = 0;@," format_var exception_conflict; + Format.fprintf fmt "char %a = 0;@," VarName.format exception_conflict; List.iter (fun except -> Format.fprintf fmt @@ -532,11 +444,11 @@ let rec format_statement %a = %a;@]@,\ }@]@,\ }@," - format_var exception_current (format_expression ctx) except - format_var exception_current format_enum_name e_name - format_enum_cons_name some_cons format_var exception_acc_var - format_enum_name e_name format_enum_cons_name some_cons format_var - exception_conflict format_var exception_acc_var format_var + VarName.format exception_current (format_expression ctx) except + VarName.format exception_current EnumName.format e_name + EnumConstructor.format some_cons VarName.format exception_acc_var + EnumName.format e_name EnumConstructor.format some_cons VarName.format + exception_conflict VarName.format exception_acc_var VarName.format exception_current) exceptions; Format.fprintf fmt @@ -544,14 +456,14 @@ let rec format_statement @[catala_raise_fatal_error(catala_conflict,@ \"%s\",@ %d, %d, \ %d, %d);@]@;\ <1 -2>}@]@," - format_var exception_conflict (Pos.get_file pos) + VarName.format exception_conflict (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos); Format.fprintf fmt "@[if (%a.code == %a_%a) {@,%a = %a;@]@,@[} else {@," - format_var exception_acc_var format_enum_name e_name - format_enum_cons_name some_cons format_var variable_defined_in_cons - format_var exception_acc_var + VarName.format exception_acc_var EnumName.format e_name + EnumConstructor.format some_cons VarName.format variable_defined_in_cons + VarName.format exception_acc_var end; Format.fprintf fmt "@[if (%a) {@,\ @@ -560,9 +472,9 @@ let rec format_statement %a.code = %a_%a;@,\ %a.payload.%a = NULL;@]@,\ }" - (format_expression ctx) just (format_block ctx) cons format_var - variable_defined_in_cons format_enum_name e_name format_enum_cons_name - none_cons format_var variable_defined_in_cons format_enum_cons_name + (format_expression ctx) just (format_block ctx) cons VarName.format + variable_defined_in_cons EnumName.format e_name EnumConstructor.format + none_cons VarName.format variable_defined_in_cons EnumConstructor.format none_cons; if exceptions <> [] then Format.fprintf fmt "@]@,}" @@ -591,7 +503,7 @@ let format_program match code_item with | SVar { var; expr; typ } -> Format.fprintf fmt "@[%a = %a;@]" - (format_typ p.ctx.decl_ctx (fun fmt -> format_var fmt var)) + (format_typ p.ctx.decl_ctx (fun fmt -> VarName.format fmt var)) typ (format_expression p.ctx.decl_ctx) expr @@ -599,13 +511,13 @@ let format_program | SScope { scope_body_var = var; scope_body_func = func; _ } -> let { func_params; func_body; func_return_typ } = func in Format.fprintf fmt "@[%a(%a) {@,%a@]@,}" - (format_typ p.ctx.decl_ctx (fun fmt -> format_func_name fmt var)) + (format_typ p.ctx.decl_ctx (fun fmt -> FuncName.format fmt var)) func_return_typ (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (var, typ) -> (format_typ p.ctx.decl_ctx (fun fmt -> - format_var fmt (Mark.remove var))) + VarName.format fmt (Mark.remove var))) fmt typ)) func_params (format_block p.ctx.decl_ctx) diff --git a/compiler/scalc/to_c.mli b/compiler/scalc/to_c.mli index 1a9a15e5..2b7c6853 100644 --- a/compiler/scalc/to_c.mli +++ b/compiler/scalc/to_c.mli @@ -14,7 +14,11 @@ License for the specific language governing permissions and limitations under the License. *) -(** Formats a lambda calculus program into a valid C89 program *) +(** Formats a statement calculus program into a valid C89 program *) + +open Shared_ast + +val renaming : Program.renaming val format_program : Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 65be82b6..26f33765 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -110,98 +110,23 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit = (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) uids -let avoid_keywords (s : string) : string = - if - match s with - (* list taken from - https://www.programiz.com/python-programming/keyword-list *) - | "False" | "None" | "True" | "and" | "as" | "assert" | "async" | "await" - | "break" | "class" | "continue" | "def" | "del" | "elif" | "else" - | "except" | "finally" | "for" | "from" | "global" | "if" | "import" | "in" - | "is" | "lambda" | "nonlocal" | "not" | "or" | "pass" | "raise" | "return" - | "try" | "while" | "with" | "yield" -> - true - | _ -> false - then s ^ "_" - else s +let python_keywords = + (* list taken from + https://www.programiz.com/python-programming/keyword-list *) + [ "False"; "None"; "True"; "and"; "as"; "assert"; "async"; "await"; + "break"; "class"; "continue"; "def"; "del"; "elif"; "else"; + "except"; "finally"; "for"; "from"; "global"; "if"; "import"; "in"; + "is"; "lambda"; "nonlocal"; "not"; "or"; "pass"; "raise"; "return"; + "try"; "while"; "with"; "yield" ] +(* todo: reserved names should also include built-in types and everything exposed by the runtime. *) -module StringMap = String.Map - -module IntMap = Map.Make (struct - include Int - - let format ppf i = Format.pp_print_int ppf i -end) - -let clean_name (s : string) : string = - s - |> String.to_snake_case - |> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") - |> avoid_keywords - -let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = - Format.pp_print_string fmt (clean_name s) - -(** For each `VarName.t` defined by its string and then by its hash, we keep - track of which local integer id we've given it. This is used to keep - variable naming with low indices rather than one global counter for all - variables. TODO: should be removed when - https://github.com/CatalaLang/catala/issues/240 is fixed. *) -let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty - -let format_var (fmt : Format.formatter) (v : VarName.t) : unit = - let v_str = clean_name (Mark.remove (VarName.get_info v)) in - let id = VarName.id v in - let local_id = - match StringMap.find_opt v_str !string_counter_map with - | Some ids -> ( - match IntMap.find_opt id ids with - | None -> - let local_id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in - string_counter_map := - StringMap.add v_str (IntMap.add id local_id ids) !string_counter_map; - local_id - | Some local_id -> local_id) - | None -> - string_counter_map := - StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; - 0 - in - if v_str = "_" then Format.fprintf fmt "_" - (* special case for the unit pattern *) - else if local_id = 0 then Format.pp_print_string fmt v_str - else Format.fprintf fmt "%s_%d" v_str local_id - -let format_path ctx fmt p = - match List.rev p with - | [] -> () - | m :: _ -> - format_var fmt (ModuleName.Map.find m ctx.modules); - Format.pp_print_char fmt '.' - -let format_struct_name ctx (fmt : Format.formatter) (v : StructName.t) : unit = - format_path ctx fmt (StructName.path v); - Format.pp_print_string fmt - (avoid_keywords - (String.to_camel_case - (String.to_ascii (Mark.remove (StructName.get_info v))))) - -let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit - = - Format.pp_print_string fmt - (avoid_keywords (String.to_ascii (StructField.to_string v))) - -let format_enum_name ctx (fmt : Format.formatter) (v : EnumName.t) : unit = - format_path ctx fmt (EnumName.path v); - Format.pp_print_string fmt - (avoid_keywords - (String.to_camel_case - (String.to_ascii (Mark.remove (EnumName.get_info v))))) - -let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : - unit = - Format.pp_print_string fmt - (avoid_keywords (String.to_ascii (EnumConstructor.to_string v))) +let renaming = + Program.renaming () + ~reserved:python_keywords + (* TODO: add catala runtime built-ins as reserved as well ? *) + ~reset_context_for_closed_terms:false ~skip_constant_binders:false + ~constant_binder_name:None ~namespaced_fields_constrs:true + ~f_struct:String.to_camel_case let typ_needs_parens (e : typ) : bool = match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false @@ -226,12 +151,12 @@ let rec format_typ ctx (fmt : Format.formatter) (typ : typ) : unit = ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) ts - | TStruct s -> Format.fprintf fmt "%a" (format_struct_name ctx) s + | TStruct s -> StructName.format fmt s | TOption some_typ -> (* We translate the option type with an overloading by Python's [None] *) Format.fprintf fmt "Optional[%a]" format_typ some_typ | TDefault t -> format_typ fmt t - | TEnum e -> Format.fprintf fmt "%a" (format_enum_name ctx) e + | TEnum e -> EnumName.format fmt e | TArrow (t1, t2) -> Format.fprintf fmt "Callable[[%a], %a]" (Format.pp_print_list @@ -243,8 +168,7 @@ let rec format_typ ctx (fmt : Format.formatter) (typ : typ) : unit = | TClosureEnv -> failwith "unimplemented!" let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = - let v_str = Mark.remove (FuncName.get_info v) in - format_name_cleaned fmt v_str + FuncName.format fmt v let format_position ppf pos = Format.fprintf ppf @@ -263,19 +187,19 @@ let format_error (ppf : Format.formatter) (err : Runtime.error Mark.pos) : unit let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = match Mark.remove e with - | EVar v -> format_var fmt v - | EFunc f -> format_func_name fmt f + | EVar v -> VarName.format fmt v + | EFunc f -> FuncName.format fmt f | EStruct { fields = es; name = s } -> - Format.fprintf fmt "%a(%a)" (format_struct_name ctx) s + Format.fprintf fmt "%a(%a)" StructName.format s (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (struct_field, e) -> - Format.fprintf fmt "%a = %a" format_struct_field_name struct_field + Format.fprintf fmt "%a = %a" StructField.format struct_field (format_expression ctx) e)) (StructField.Map.bindings es) | EStructFieldAccess { e1; field; _ } -> Format.fprintf fmt "%a.%a" (format_expression ctx) e1 - format_struct_field_name field + StructField.format field | EInj { cons; name = e_name; _ } when EnumName.equal e_name Expr.option_enum && EnumConstructor.equal cons Expr.none_constr -> @@ -287,8 +211,8 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = (* We translate the option type with an overloading by Python's [None] *) format_expression ctx fmt e | EInj { e1 = e; cons; name = enum_name; _ } -> - Format.fprintf fmt "%a(%a_Code.%a,@ %a)" (format_enum_name ctx) enum_name - (format_enum_name ctx) enum_name format_enum_cons_name cons + Format.fprintf fmt "%a(%a_Code.%a,@ %a)" EnumName.format enum_name + EnumName.format enum_name EnumConstructor.format cons (format_expression ctx) e | EArray es -> Format.fprintf fmt "[%a]" @@ -380,26 +304,26 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = | ETupleAccess { e1; index } -> Format.fprintf fmt "%a[%d]" (format_expression ctx) e1 index | EExternal { modname; name } -> - Format.fprintf fmt "%a.%a" format_var (Mark.remove modname) - format_name_cleaned (Mark.remove name) + Format.fprintf fmt "%a.%s" VarName.format (Mark.remove modname) + (Mark.remove name) let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit = match Mark.remove s with | SInnerFuncDef { name; func = { func_params; func_body; _ } } -> - Format.fprintf fmt "@[def %a(@[%a@]):@ %a@]" format_var + Format.fprintf fmt "@[def %a(@[%a@]):@ %a@]" VarName.format (Mark.remove name) (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (var, typ) -> - Format.fprintf fmt "%a:%a" format_var (Mark.remove var) + Format.fprintf fmt "%a:%a" VarName.format (Mark.remove var) (format_typ ctx) typ)) func_params (format_block ctx) func_body | SLocalDecl _ -> assert false (* We don't need to declare variables in Python *) | SLocalDef { name = v; expr = e; _ } | SLocalInit { name = v; expr = e; _ } -> - Format.fprintf fmt "@[%a = %a@]" format_var (Mark.remove v) + Format.fprintf fmt "@[%a = %a@]" VarName.format (Mark.remove v) (format_expression ctx) e | STryWEmpty { try_block = try_b; with_block = catch_b } -> Format.fprintf fmt "@[try:@ %a@]@,@[except Empty:@ %a@]" @@ -424,12 +348,12 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit when EnumName.equal e_name Expr.option_enum -> (* We translate the option type with an overloading by Python's [None] *) let tmp_var = VarName.fresh ("perhaps_none_arg", Pos.no_pos) in - Format.fprintf fmt "@[%a = %a@]@," format_var tmp_var + Format.fprintf fmt "@[%a = %a@]@," VarName.format tmp_var (format_expression ctx) e1; - Format.fprintf fmt "@[if %a is None:@ %a@]@," format_var tmp_var + Format.fprintf fmt "@[if %a is None:@ %a@]@," VarName.format tmp_var (format_block ctx) case_none; - Format.fprintf fmt "@[else:@ %a = %a@,%a@]" format_var case_some_var - format_var tmp_var (format_block ctx) case_some + Format.fprintf fmt "@[else:@ %a = %a@,%a@]" VarName.format case_some_var + VarName.format tmp_var (format_block ctx) case_some | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in let cases = @@ -439,15 +363,15 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit (EnumConstructor.Map.bindings cons_map) in let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in - Format.fprintf fmt "%a = %a@\n@[if %a@]" format_var tmp_var + Format.fprintf fmt "%a = %a@\n@[if %a@]" VarName.format tmp_var (format_expression ctx) e1 (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt ({ case_block; payload_var_name; _ }, cons_name) -> Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" - format_var tmp_var (format_enum_name ctx) e_name - format_enum_cons_name cons_name format_var payload_var_name - format_var tmp_var (format_block ctx) case_block)) + VarName.format tmp_var (EnumName.format) e_name + EnumConstructor.format cons_name VarName.format payload_var_name + VarName.format tmp_var (format_block ctx) case_block)) cases | SReturn e1 -> Format.fprintf fmt "@[return %a@]" (format_expression ctx) @@ -497,38 +421,38 @@ let format_ctx \ return not (self == other)@,\ @,\ \ def __str__(self) -> str:@,\ - \ @[return \"%a(%a)\".format(%a)@]" (format_struct_name ctx) + \ @[return \"%a(%a)\".format(%a)@]" StructName.format struct_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt (struct_field, struct_field_type) -> - Format.fprintf fmt "%a: %a" format_struct_field_name struct_field + Format.fprintf fmt "%a: %a" StructField.format struct_field (format_typ ctx) struct_field_type)) fields (if StructField.Map.is_empty struct_fields then fun fmt _ -> Format.fprintf fmt " pass" else Format.pp_print_list (fun fmt (struct_field, _) -> - Format.fprintf fmt " self.%a = %a" format_struct_field_name - struct_field format_struct_field_name struct_field)) - fields (format_struct_name ctx) struct_name + Format.fprintf fmt " self.%a = %a" StructField.format + struct_field StructField.format struct_field)) + fields StructName.format struct_name (if not (StructField.Map.is_empty struct_fields) then Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ") (fun fmt (struct_field, _) -> - Format.fprintf fmt "self.%a == other.%a" format_struct_field_name - struct_field format_struct_field_name struct_field) + Format.fprintf fmt "self.%a == other.%a" StructField.format + struct_field StructField.format struct_field) else fun fmt _ -> Format.fprintf fmt "True") - fields (format_struct_name ctx) struct_name + fields StructName.format struct_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",") (fun fmt (struct_field, _) -> - Format.fprintf fmt "%a={}" format_struct_field_name struct_field)) + Format.fprintf fmt "%a={}" StructField.format struct_field)) fields (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (struct_field, _) -> - Format.fprintf fmt "self.%a" format_struct_field_name struct_field)) + Format.fprintf fmt "self.%a" StructField.format struct_field)) fields in let format_enum_decl fmt (enum_name, enum_cons) = @@ -558,14 +482,14 @@ let format_ctx @,\ \ def __str__(self) -> str:@,\ \ @[return \"{}({})\".format(self.code, self.value)@]" - (format_enum_name ctx) enum_name + (EnumName.format) enum_name (Format.pp_print_list (fun fmt (i, enum_cons, _enum_cons_type) -> - Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i)) + Format.fprintf fmt "%a = %d" EnumConstructor.format enum_cons i)) (List.mapi (fun i (x, y) -> i, x, y) (EnumConstructor.Map.bindings enum_cons)) - (format_enum_name ctx) enum_name (format_enum_name ctx) enum_name - (format_enum_name ctx) enum_name + (EnumName.format) enum_name EnumName.format enum_name + EnumName.format enum_name in let is_in_type_ordering s = @@ -597,19 +521,9 @@ let format_ctx (e, EnumName.Map.find e ctx.decl_ctx.ctx_enums)) (type_ordering @ scope_structs) -(* FIXME: this is an ugly (and partial) workaround, Python basically has one - namespace and we reserve the name to avoid clashes between func ids and - variable ids. *) -let reserve_func_name = function - | SVar _ -> () - | SFunc { var = v; _ } | SScope { scope_body_var = v; _ } -> - let v_str = clean_name (Mark.remove (FuncName.get_info v)) in - string_counter_map := - StringMap.add v_str (IntMap.singleton (-1) 0) !string_counter_map - let format_code_item ctx fmt = function | SVar { var; expr; typ = _ } -> - Format.fprintf fmt "@[%a = (@,%a@;<0 -4>)@]@," format_var var + Format.fprintf fmt "@[%a = (@,%a@;<0 -4>)@]@," VarName.format var (format_expression ctx) expr | SFunc { var; func } | SScope { scope_body_var = var; scope_body_func = func; _ } -> @@ -619,7 +533,7 @@ let format_code_item ctx fmt = function (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (var, typ) -> - Format.fprintf fmt "%a:%a" format_var (Mark.remove var) + Format.fprintf fmt "%a:%a" VarName.format (Mark.remove var) (format_typ ctx) typ)) func_params (format_block ctx) func_body @@ -627,7 +541,6 @@ let format_program (fmt : Format.formatter) (p : Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = - List.iter reserve_func_name p.code_items; Format.pp_open_vbox fmt 0; let header = [ @@ -643,7 +556,7 @@ let format_program ModuleName.Map.iter (fun m v -> Format.fprintf fmt "from . import %a as %a@," ModuleName.format m - format_var v) + VarName.format v) p.ctx.modules; Format.pp_print_cut fmt (); format_ctx type_ordering fmt p.ctx; diff --git a/compiler/scalc/to_python.mli b/compiler/scalc/to_python.mli index 1cc4e3fa..d055d0ab 100644 --- a/compiler/scalc/to_python.mli +++ b/compiler/scalc/to_python.mli @@ -16,6 +16,10 @@ (** Formats a lambda calculus program into a valid Python program *) +open Shared_ast + +val renaming : Program.renaming + val format_program : Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit (** Usage [format_program fmt p type_dependencies_ordering] *) diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 71c15325..28207ae6 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -417,6 +417,12 @@ module Renaming : sig like [String.to_snake_case]). The result is advisory and a numerical suffix may be appended or modified *) + val unmbind_in : + context -> + ?fname:(string -> string) -> + ('e, 'b) Bindlib.mbinder -> + ('e, _) Mark.ed Var.t Array.t * 'b * context + val new_id : context -> string -> string * context val set_rewriters : diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index c5ec07aa..33062c64 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -114,7 +114,8 @@ let rename_ids ?(f_field = uncap) ?(f_enum = cap) ?(f_constr = cap) - p = + p + = let cfg = { Expr.Renaming.reserved; @@ -283,3 +284,45 @@ let rename_ids let decl_ctx = map_decl_ctx ~f:(Expr.Renaming.typ ctx) decl_ctx in let code_items = Scope.rename_ids ctx p.code_items in { p with decl_ctx; code_items }, ctx + +(* This first-class module wrapping is here to allow a polymorphic renaming function to be passed around *) + +module type Renaming = sig + val apply: + 'e program -> + 'e program * Expr.Renaming.context +end + +type renaming = (module Renaming) + +let apply (module R: Renaming) = R.apply + +let renaming + ~reserved + ~reset_context_for_closed_terms + ~skip_constant_binders + ~constant_binder_name + ~namespaced_fields_constrs + ?f_var + ?f_struct + ?f_field + ?f_enum + ?f_constr + () + = + let module M = struct + let apply p = + rename_ids + ~reserved + ~reset_context_for_closed_terms + ~skip_constant_binders + ~constant_binder_name + ~namespaced_fields_constrs + ?f_var + ?f_struct + ?f_field + ?f_enum + ?f_constr + p + end in + (module M: Renaming) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 6d15cd46..7fd58a3b 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -57,7 +57,11 @@ val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list (** Returns a list of used modules, in topological order ; the boolean indicates if the module is external *) -val rename_ids : +type renaming + +val apply: renaming -> 'e program -> 'e program * Expr.Renaming.context + +val renaming : reserved:string list -> reset_context_for_closed_terms:bool -> skip_constant_binders:bool -> @@ -68,8 +72,8 @@ val rename_ids : ?f_field:(string -> string) -> ?f_enum:(string -> string) -> ?f_constr:(string -> string) -> - ('a, 't) gexpr program -> - ('a, 't) gexpr program * Expr.Renaming.context + unit -> + renaming (** Renames all idents (variables, types, struct and enum names, fields and constructors) to dispel ambiguities in the target language. Names in [reserved], typically keywords and built-ins, will be avoided ; the meaning diff --git a/tests/backends/output/simple.c b/tests/backends/output/simple.c index a90cc655..c037f057 100644 --- a/tests/backends/output/simple.c +++ b/tests/backends/output/simple.c @@ -4,413 +4,201 @@ #include #include -enum option_2_enum_code { - option_2_enum_none_2_cons, - option_2_enum_some_2_cons -} option_2_enum_code; +enum Option_2_code { + Option_2_None_2, + Option_2_Some_2 +} Option_2_code; -typedef struct option_2_enum { - enum option_2_enum_code code; +typedef struct Option_2 { + enum Option_2_code code; union { - void* /* unit */ none_2_cons; - double some_2_cons; + void* /* unit */ None_2; + double Some_2; } payload; -} option_2_enum; +} Option_2; -typedef struct foo_struct { - char /* bool */ x_field; - double y_field; -} foo_struct; +typedef struct Foo { + char /* bool */ x; + double y; +} Foo; -typedef struct array_3_struct { - double * content_field; - int length_field; -} array_3_struct; +typedef struct Array_3 { + double * content2; + int length2; +} Array_3; -typedef struct array_2_struct { - option_2_enum * content_field; - int length_field; -} array_2_struct; +typedef struct Array_2 { + Option_2 * content1; + int length1; +} Array_2; -enum bar_enum_code { - bar_enum_no_cons, - bar_enum_yes_cons -} bar_enum_code; +enum Bar_code { + Bar_No, + Bar_Yes +} Bar_code; -typedef struct bar_enum { - enum bar_enum_code code; +typedef struct Bar { + enum Bar_code code; union { - void* /* unit */ no_cons; - foo_struct yes_cons; + void* /* unit */ No; + Foo Yes; } payload; -} bar_enum; +} Bar; -typedef struct baz_struct { - double b_field; - array_3_struct c_field; -} baz_struct; +typedef struct Baz { + double b; + Array_3 c; +} Baz; -enum option_3_enum_code { - option_3_enum_none_3_cons, - option_3_enum_some_3_cons -} option_3_enum_code; +enum Option_3_code { + Option_3_None_3, + Option_3_Some_3 +} Option_3_code; -typedef struct option_3_enum { - enum option_3_enum_code code; +typedef struct Option_3 { + enum Option_3_code code; union { - void* /* unit */ none_3_cons; - array_3_struct some_3_cons; + void* /* unit */ None_3; + Array_3 Some_3; } payload; -} option_3_enum; +} Option_3; -enum option_1_enum_code { - option_1_enum_none_1_cons, - option_1_enum_some_1_cons -} option_1_enum_code; +enum Option_1_code { + Option_1_None_1, + Option_1_Some_1 +} Option_1_code; -typedef struct option_1_enum { - enum option_1_enum_code code; +typedef struct Option_1 { + enum Option_1_code code; union { - void* /* unit */ none_1_cons; - bar_enum some_1_cons; + void* /* unit */ None_1; + Bar Some_1; } payload; -} option_1_enum; +} Option_1; -typedef struct array_4_struct { - option_3_enum * content_field; - int length_field; -} array_4_struct; +typedef struct Array_4 { + Option_3 * content3; + int length3; +} Array_4; -typedef struct array_1_struct { - option_1_enum * content_field; - int length_field; -} array_1_struct; +typedef struct Array_1 { + Option_1 * content; + int length; +} Array_1; -typedef struct tuple_1_struct { - option_1_enum (*elt_0_field)(void * /* closure_env */ arg_0_typ, void* /* unit */ arg_1_typ); - void * /* closure_env */ elt_1_field; -} tuple_1_struct; +typedef struct Tuple_1 { + Option_1 (*elt_0)(void * /* closure_env */ arg_0_typ, void* /* unit */ arg_1_typ); + void * /* closure_env */ elt_1; +} Tuple_1; -typedef struct baz_in_struct { - tuple_1_struct a_in_field; -} baz_in_struct; +typedef struct Baz_in { + Tuple_1 a_in; +} Baz_in; -baz_struct baz_func(baz_in_struct baz_in) { - tuple_1_struct a; - a = baz_in.a_in_field; - bar_enum temp_a; - option_1_enum temp_a_1; - tuple_1_struct code_and_env; +Baz baz(Baz_in baz_in) { + Tuple_1 a; + a = baz_in.a_in; + Bar a2; + option_1 a3; + Tuple_1 code_and_env; code_and_env = a; - option_1_enum (*code)(void * /* closure_env */ arg_0_typ, void* /* unit */ arg_1_typ); + Option_1 (*code)(void * /* closure_env */ arg_0_typ, void* /* unit */ arg_1_typ); void * /* closure_env */ env; - code = code_and_env.elt_0_field; - env = code_and_env.elt_1_field; - array_1_struct temp_a_2; - temp_a_2.content_field = catala_malloc(sizeof(array_1_struct)); - temp_a_2.content_field[0] = code(env, NULL); - option_1_enum match_arg = catala_handle_exceptions(temp_a_2); + code = code_and_env.elt_0; + env = code_and_env.elt_1; + Array_1 a4; + a4.content_field = catala_malloc(sizeof(Array_1)); + a4.content_field[0] = code(env, NULL); + Option_1 match_arg = catala_handle_exceptions(a4); switch (match_arg.code) { - case option_1_enum_none_1_cons: + case Option_1_None_1: if (1 /* TRUE */) { - bar_enum temp_a_3; - option_1_enum temp_a_4; - option_1_enum temp_a_5; - array_1_struct temp_a_6; - temp_a_6.content_field = catala_malloc(sizeof(array_1_struct)); + Bar a3; + option_1 a4; + option_1 a6; + Array_1 a7; + a7.content_field = catala_malloc(sizeof(Array_1)); - option_1_enum match_arg_1 = catala_handle_exceptions(temp_a_6); - switch (match_arg_1.code) { - case option_1_enum_none_1_cons: + Option_1 match_arg = catala_handle_exceptions(a7); + switch (match_arg.code) { + case Option_1_None_1: if (1 /* TRUE */) { - bar_enum temp_a_7 = {bar_enum_no_cons, {no_cons: NULL}}; - option_1_enum temp_a_5 = {option_1_enum_some_1_cons, - {some_1_cons: temp_a_7}}; + Bar a6 = {Bar_No, {No: NULL}}; + option_1 a6 = {Option_1_Some_1, {Some_1: a6}}; } else { - option_1_enum temp_a_5 = {option_1_enum_none_1_cons, - {none_1_cons: NULL}}; + option_1 a6 = {Option_1_None_1, {None_1: NULL}}; } break; - case option_1_enum_some_1_cons: - bar_enum x = match_arg_1.payload.some_1_cons; - option_1_enum temp_a_5 = {option_1_enum_some_1_cons, - {some_1_cons: x}}; + case Option_1_Some_1: + Bar x1 = match_arg.payload.Some_1; + option_1 a6 = {Option_1_Some_1, {Some_1: x1}}; break; } - array_1_struct temp_a_8; - temp_a_8.content_field = catala_malloc(sizeof(array_1_struct)); - temp_a_8.content_field[0] = temp_a_5; - option_1_enum match_arg_2 = catala_handle_exceptions(temp_a_8); - switch (match_arg_2.code) { - case option_1_enum_none_1_cons: + Array_1 a5; + a5.content_field = catala_malloc(sizeof(Array_1)); + a5.content_field[0] = a6; + Option_1 match_arg = catala_handle_exceptions(a5); + switch (match_arg.code) { + case Option_1_None_1: if (0 /* FALSE */) { - option_1_enum temp_a_4 = {option_1_enum_none_1_cons, - {none_1_cons: NULL}}; + option_1 a4 = {Option_1_None_1, {None_1: NULL}}; } else { - option_1_enum temp_a_4 = {option_1_enum_none_1_cons, - {none_1_cons: NULL}}; + option_1 a4 = {Option_1_None_1, {None_1: NULL}}; } break; - case option_1_enum_some_1_cons: - bar_enum x_1 = match_arg_2.payload.some_1_cons; - option_1_enum temp_a_4 = {option_1_enum_some_1_cons, - {some_1_cons: x_1}}; + case Option_1_Some_1: + Bar x1 = match_arg.payload.Some_1; + option_1 a4 = {Option_1_Some_1, {Some_1: x1}}; break; } - option_1_enum match_arg_3 = temp_a_4; - switch (match_arg_3.code) { - case option_1_enum_none_1_cons: + Option_1 match_arg = a4; + switch (match_arg.code) { + case Option_1_None_1: catala_raise_fatal_error (catala_no_value, "tests/backends/simple.catala_en", 11, 11, 11, 12); break; - case option_1_enum_some_1_cons: - bar_enum arg = match_arg_3.payload.some_1_cons; - temp_a_3 = arg; + case Option_1_Some_1: + Bar arg = match_arg.payload.Some_1; + a3 = arg; break; } - option_1_enum temp_a_1 = {option_1_enum_some_1_cons, - {some_1_cons: temp_a_3}}; + option_1 a3 = {Option_1_Some_1, {Some_1: a3}}; } else { - option_1_enum temp_a_1 = {option_1_enum_none_1_cons, - {none_1_cons: NULL}}; + option_1 a3 = {Option_1_None_1, {None_1: NULL}}; } break; - case option_1_enum_some_1_cons: - bar_enum x_2 = match_arg.payload.some_1_cons; - option_1_enum temp_a_1 = {option_1_enum_some_1_cons, - {some_1_cons: x_2}}; + case Option_1_Some_1: + Bar x1 = match_arg.payload.Some_1; + option_1 a3 = {Option_1_Some_1, {Some_1: x1}}; break; } - option_1_enum match_arg_4 = temp_a_1; - switch (match_arg_4.code) { - case option_1_enum_none_1_cons: + Option_1 match_arg = a3; + switch (match_arg.code) { + case Option_1_None_1: catala_raise_fatal_error (catala_no_value, "tests/backends/simple.catala_en", 11, 11, 11, 12); break; - case option_1_enum_some_1_cons: - bar_enum arg_1 = match_arg_4.payload.some_1_cons; - temp_a = arg_1; + case Option_1_Some_1: + Bar arg = match_arg.payload.Some_1; + a2 = arg; break; } - bar_enum a_1; - a_1 = temp_a; - double temp_b; - option_2_enum temp_b_1; - option_2_enum temp_b_2; - option_2_enum temp_b_3; - array_2_struct temp_b_4; - temp_b_4.content_field = catala_malloc(sizeof(array_2_struct)); - - option_2_enum match_arg_5 = catala_handle_exceptions(temp_b_4); - switch (match_arg_5.code) { - case option_2_enum_none_2_cons: - char /* bool */ temp_b_5; - bar_enum match_arg_6 = a_1; - switch (match_arg_6.code) { - case bar_enum_no_cons: temp_b_5 = 1 /* TRUE */; break; - case bar_enum_yes_cons: - foo_struct dummy_var = match_arg_6.payload.yes_cons; - temp_b_5 = 0 /* FALSE */; - break; - } - if (temp_b_5) { - option_2_enum temp_b_3 = {option_2_enum_some_2_cons, - {some_2_cons: 42.}}; - - } else { - option_2_enum temp_b_3 = {option_2_enum_none_2_cons, - {none_2_cons: NULL}}; - - } - break; - case option_2_enum_some_2_cons: - double x_3 = match_arg_5.payload.some_2_cons; - option_2_enum temp_b_3 = {option_2_enum_some_2_cons, - {some_2_cons: x_3}}; - break; - } - array_2_struct temp_b_6; - temp_b_6.content_field = catala_malloc(sizeof(array_2_struct)); - temp_b_6.content_field[0] = temp_b_3; - option_2_enum match_arg_7 = catala_handle_exceptions(temp_b_6); - switch (match_arg_7.code) { - case option_2_enum_none_2_cons: - if (0 /* FALSE */) { - option_2_enum temp_b_2 = {option_2_enum_none_2_cons, - {none_2_cons: NULL}}; - - } else { - option_2_enum temp_b_2 = {option_2_enum_none_2_cons, - {none_2_cons: NULL}}; - - } - break; - case option_2_enum_some_2_cons: - double x_4 = match_arg_7.payload.some_2_cons; - option_2_enum temp_b_2 = {option_2_enum_some_2_cons, - {some_2_cons: x_4}}; - break; - } - array_2_struct temp_b_7; - temp_b_7.content_field = catala_malloc(sizeof(array_2_struct)); - temp_b_7.content_field[0] = temp_b_2; - option_2_enum match_arg_8 = catala_handle_exceptions(temp_b_7); - switch (match_arg_8.code) { - case option_2_enum_none_2_cons: - if (1 /* TRUE */) { - option_2_enum temp_b_8; - array_2_struct temp_b_9; - temp_b_9.content_field = catala_malloc(sizeof(array_2_struct)); - - option_2_enum match_arg_9 = catala_handle_exceptions(temp_b_9); - switch (match_arg_9.code) { - case option_2_enum_none_2_cons: - if (1 /* TRUE */) { - double temp_b_10; - bar_enum match_arg_10 = a_1; - switch (match_arg_10.code) { - case bar_enum_no_cons: temp_b_10 = 0.; break; - case bar_enum_yes_cons: - foo_struct foo = match_arg_10.payload.yes_cons; - double temp_b_11; - if (foo.x_field) {temp_b_11 = 1.; } else {temp_b_11 = 0.; } - temp_b_10 = (foo.y_field + temp_b_11); - break; - } - option_2_enum temp_b_8 = {option_2_enum_some_2_cons, - {some_2_cons: temp_b_10}}; - - } else { - option_2_enum temp_b_8 = {option_2_enum_none_2_cons, - {none_2_cons: NULL}}; - - } - break; - case option_2_enum_some_2_cons: - double x_5 = match_arg_9.payload.some_2_cons; - option_2_enum temp_b_8 = {option_2_enum_some_2_cons, - {some_2_cons: x_5}}; - break; - } - array_2_struct temp_b_12; - temp_b_12.content_field = catala_malloc(sizeof(array_2_struct)); - temp_b_12.content_field[0] = temp_b_8; - option_2_enum match_arg_11 = catala_handle_exceptions(temp_b_12); - switch (match_arg_11.code) { - case option_2_enum_none_2_cons: - if (0 /* FALSE */) { - option_2_enum temp_b_1 = {option_2_enum_none_2_cons, - {none_2_cons: NULL}}; - - } else { - option_2_enum temp_b_1 = {option_2_enum_none_2_cons, - {none_2_cons: NULL}}; - - } - break; - case option_2_enum_some_2_cons: - double x_6 = match_arg_11.payload.some_2_cons; - option_2_enum temp_b_1 = {option_2_enum_some_2_cons, - {some_2_cons: x_6}}; - break; - } - - } else { - option_2_enum temp_b_1 = {option_2_enum_none_2_cons, - {none_2_cons: NULL}}; - - } - break; - case option_2_enum_some_2_cons: - double x_7 = match_arg_8.payload.some_2_cons; - option_2_enum temp_b_1 = {option_2_enum_some_2_cons, - {some_2_cons: x_7}}; - break; - } - option_2_enum match_arg_12 = temp_b_1; - switch (match_arg_12.code) { - case option_2_enum_none_2_cons: - catala_raise_fatal_error (catala_no_value, - "tests/backends/simple.catala_en", 12, 10, 12, 11); - break; - case option_2_enum_some_2_cons: - double arg_2 = match_arg_12.payload.some_2_cons; - temp_b = arg_2; - break; - } - double b; - b = temp_b; - array_3_struct temp_c; - option_3_enum temp_c_1; - option_3_enum temp_c_2; - array_4_struct temp_c_3; - temp_c_3.content_field = catala_malloc(sizeof(array_4_struct)); - - option_3_enum match_arg_13 = catala_handle_exceptions(temp_c_3); - switch (match_arg_13.code) { - case option_3_enum_none_3_cons: - if (1 /* TRUE */) { - array_3_struct temp_c_4; - temp_c_4.content_field = catala_malloc(sizeof(array_3_struct)); - temp_c_4.content_field[0] = b; - temp_c_4.content_field[1] = b; - option_3_enum temp_c_2 = {option_3_enum_some_3_cons, - {some_3_cons: temp_c_4}}; - - } else { - option_3_enum temp_c_2 = {option_3_enum_none_3_cons, - {none_3_cons: NULL}}; - - } - break; - case option_3_enum_some_3_cons: - array_3_struct x_8 = match_arg_13.payload.some_3_cons; - option_3_enum temp_c_2 = {option_3_enum_some_3_cons, - {some_3_cons: x_8}}; - break; - } - array_4_struct temp_c_5; - temp_c_5.content_field = catala_malloc(sizeof(array_4_struct)); - temp_c_5.content_field[0] = temp_c_2; - option_3_enum match_arg_14 = catala_handle_exceptions(temp_c_5); - switch (match_arg_14.code) { - case option_3_enum_none_3_cons: - if (0 /* FALSE */) { - option_3_enum temp_c_1 = {option_3_enum_none_3_cons, - {none_3_cons: NULL}}; - - } else { - option_3_enum temp_c_1 = {option_3_enum_none_3_cons, - {none_3_cons: NULL}}; - - } - break; - case option_3_enum_some_3_cons: - array_3_struct x_9 = match_arg_14.payload.some_3_cons; - option_3_enum temp_c_1 = {option_3_enum_some_3_cons, - {some_3_cons: x_9}}; - break; - } - option_3_enum match_arg_15 = temp_c_1; - switch (match_arg_15.code) { - case option_3_enum_none_3_cons: - catala_raise_fatal_error (catala_no_value, - "tests/backends/simple.catala_en", 13, 10, 13, 11); - break; - case option_3_enum_some_3_cons: - array_3_struct arg_3 = match_arg_15.payload.some_3_cons; - temp_c = arg_3; - break; - } - array_3_struct c; - c = temp_c; - baz_struct baz = { b, c }; - return baz; -} + Bar a1; + a1 = a2; + double b2; + option_2 b3; + option_2 b5; + option_2 b7; + ┌─[ERROR]─ +│ +│ Unexpected error: Not_found +│ +└─ +#return code 125# diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en index 16311a4c..62dab911 100644 --- a/tests/backends/python_name_clash.catala_en +++ b/tests/backends/python_name_clash.catala_en @@ -91,85 +91,89 @@ class BIn: def some_name(some_name_in:SomeNameIn): i = some_name_in.i_in - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + match_arg = handle_exceptions([], []) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if True: - temp_o = (i + integer_of_string("1")) + o3 = Eoption(Eoption_Code.ESome, (i + integer_of_string("1"))) else: - temp_o = None - else: - x = perhaps_none_arg - temp_o = x - perhaps_none_arg_1 = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=10, start_column=23, - end_line=10, end_column=28, law_headings=[] - )], - [temp_o] - ) - if perhaps_none_arg_1 is None: + o3 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + o3 = Eoption(Eoption_Code.ESome, x) + match_arg = handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=10, start_column=23, + end_line=10, end_column=28, law_headings=[])], + [o3] + ) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if False: - temp_o_1 = None + o2 = Eoption(Eoption_Code.ENone, Unit()) else: - temp_o_1 = None - else: - x_1 = perhaps_none_arg_1 - temp_o_1 = x_1 - perhaps_none_arg_2 = temp_o_1 - if perhaps_none_arg_2 is None: + o2 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + o2 = Eoption(Eoption_Code.ESome, x) + match_arg = o2 + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", start_line=7, start_column=10, end_line=7, end_column=11, law_headings=[])) - else: - arg = perhaps_none_arg_2 - temp_o_2 = arg - o = temp_o_2 + elif match_arg.code == Eoption_Code.ESome: + arg = match_arg.value + o1 = arg + o = o1 return SomeName(o = o) def b(b_in:BIn): - perhaps_none_arg_3 = handle_exceptions([], []) - if perhaps_none_arg_3 is None: + match_arg = handle_exceptions([], []) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if True: - temp_result = integer_of_string("1") + result3 = Eoption(Eoption_Code.ESome, integer_of_string("1")) else: - temp_result = None - else: - x_2 = perhaps_none_arg_3 - temp_result = x_2 - perhaps_none_arg_4 = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=16, start_column=33, - end_line=16, end_column=34, law_headings=[] - )], - [temp_result] - ) - if perhaps_none_arg_4 is None: + result3 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + result3 = Eoption(Eoption_Code.ESome, x) + match_arg = handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=16, start_column=33, + end_line=16, end_column=34, law_headings=[])], + [result3] + ) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if False: - temp_result_1 = None + result2 = Eoption(Eoption_Code.ENone, Unit()) else: - temp_result_1 = None - else: - x_3 = perhaps_none_arg_4 - temp_result_1 = x_3 - perhaps_none_arg_5 = temp_result_1 - if perhaps_none_arg_5 is None: + result2 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + result2 = Eoption(Eoption_Code.ESome, x) + match_arg = result2 + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", start_line=16, start_column=14, end_line=16, end_column=25, law_headings=[])) - else: - arg_1 = perhaps_none_arg_5 - temp_result_2 = arg_1 - result = some_name(SomeNameIn(i_in = temp_result_2)) - result_1 = SomeName(o = result.o) + elif match_arg.code == Eoption_Code.ESome: + arg = match_arg.value + result1 = arg + result = some_name(SomeNameIn(i_in = result1)) + result1 = SomeName(o = result.o) if True: - temp_some_name = result_1 + some_name2 = result1 else: - temp_some_name = result_1 - some_name_1 = temp_some_name - return B(some_name = some_name_1) + some_name2 = result1 + some_name1 = some_name2 + return B(some_name = some_name1) ``` The above should *not* show `some_name = temp_some_name`, but instead `some_name_1 = ...` diff --git a/tests/name_resolution/good/toplevel_defs.catala_en b/tests/name_resolution/good/toplevel_defs.catala_en index c32d6301..d0ad73e2 100644 --- a/tests/name_resolution/good/toplevel_defs.catala_en +++ b/tests/name_resolution/good/toplevel_defs.catala_en @@ -107,164 +107,164 @@ $ catala test-scope S4 ```catala-test-inline $ catala scalc -let glob1_1 = 44.12 +let glob1 = 44.12 -let glob3_1 (x_2: money) = return to_rat x_2 + 10. +let glob3 (x: money) = return to_rat x + 10. -let glob4_2 (x_3: money) (y_4: decimal) = return to_rat x_3 * y_4 + 10. +let glob4 (x: money) (y: decimal) = return to_rat x * y + 10. -let glob5_aux_3 = - decl x_6 : decimal; - x_6 = to_rat 2 * 3.; - decl y_7 : decimal; - y_7 = 1000.; - return x_6 * y_7 +let glob5_init = + decl x : decimal; + x = to_rat 2 * 3.; + decl y : decimal; + y = 1000.; + return x * y -let glob5_5 = glob5_aux_3 () +let glob5 = glob5_init () -let glob2_8 = A {"y": glob1_1 >= 30., "z": 123. * 17.} +let glob2 = A {"y": glob1 >= 30., "z": 123. * 17.} -let S2_4 (S2_in_9: S2_in) = - decl temp_a_11 : decimal; - decl temp_a_12 : option decimal; - decl temp_a_13 : option decimal; +let S2 (S2_in: S2_in) = + decl a1 : decimal; + decl a2 : option decimal; + decl a3 : option decimal; switch handle_exceptions []: - | ENone __14 → + | ENone _ → if true: - temp_a_13 = ESome glob3_1 ¤44.00 + 100. + a3 = ESome glob3 ¤44.00 + 100. else: - temp_a_13 = ENone () - | ESome x_15 → - temp_a_13 = ESome x_15; - switch handle_exceptions [temp_a_13]: - | ENone __16 → + a3 = ENone () + | ESome x → + a3 = ESome x; + switch handle_exceptions [a3]: + | ENone _ → if false: - temp_a_12 = ENone () + a2 = ENone () else: - temp_a_12 = ENone () - | ESome x_17 → - temp_a_12 = ESome x_17; - switch temp_a_12: - | ENone __18 → + a2 = ENone () + | ESome x → + a2 = ESome x; + switch a2: + | ENone _ → fatal NoValue - | ESome arg_19 → - temp_a_11 = arg_19; - decl a_10 : decimal; - a_10 = temp_a_11; - return S2 {"a": a_10} + | ESome arg → + a1 = arg; + decl a : decimal; + a = a1; + return S2 {"a": a} -let S3_5 (S3_in_20: S3_in) = - decl temp_a_22 : decimal; - decl temp_a_23 : option decimal; - decl temp_a_24 : option decimal; +let S3 (S3_in: S3_in) = + decl a1 : decimal; + decl a2 : option decimal; + decl a3 : option decimal; switch handle_exceptions []: - | ENone __25 → + | ENone _ → if true: - temp_a_24 = ESome 50. + glob4_2 ¤44.00 55. + a3 = ESome 50. + glob4 ¤44.00 55. else: - temp_a_24 = ENone () - | ESome x_26 → - temp_a_24 = ESome x_26; - switch handle_exceptions [temp_a_24]: - | ENone __27 → + a3 = ENone () + | ESome x → + a3 = ESome x; + switch handle_exceptions [a3]: + | ENone _ → if false: - temp_a_23 = ENone () + a2 = ENone () else: - temp_a_23 = ENone () - | ESome x_28 → - temp_a_23 = ESome x_28; - switch temp_a_23: - | ENone __29 → + a2 = ENone () + | ESome x → + a2 = ESome x; + switch a2: + | ENone _ → fatal NoValue - | ESome arg_30 → - temp_a_22 = arg_30; - decl a_21 : decimal; - a_21 = temp_a_22; - return S3 {"a": a_21} + | ESome arg → + a1 = arg; + decl a : decimal; + a = a1; + return S3 {"a": a} -let S4_6 (S4_in_31: S4_in) = - decl temp_a_33 : decimal; - decl temp_a_34 : option decimal; - decl temp_a_35 : option decimal; +let S4 (S4_in: S4_in) = + decl a1 : decimal; + decl a2 : option decimal; + decl a3 : option decimal; switch handle_exceptions []: - | ENone __36 → + | ENone _ → if true: - temp_a_35 = ESome glob5_5 + 1. + a3 = ESome glob5 + 1. else: - temp_a_35 = ENone () - | ESome x_37 → - temp_a_35 = ESome x_37; - switch handle_exceptions [temp_a_35]: - | ENone __38 → + a3 = ENone () + | ESome x → + a3 = ESome x; + switch handle_exceptions [a3]: + | ENone _ → if false: - temp_a_34 = ENone () + a2 = ENone () else: - temp_a_34 = ENone () - | ESome x_39 → - temp_a_34 = ESome x_39; - switch temp_a_34: - | ENone __40 → + a2 = ENone () + | ESome x → + a2 = ESome x; + switch a2: + | ENone _ → fatal NoValue - | ESome arg_41 → - temp_a_33 = arg_41; - decl a_32 : decimal; - a_32 = temp_a_33; - return S4 {"a": a_32} + | ESome arg → + a1 = arg; + decl a : decimal; + a = a1; + return S4 {"a": a} -let S_7 (S_in_42: S_in) = - decl temp_a_54 : decimal; - decl temp_a_55 : option decimal; - decl temp_a_56 : option decimal; +let S (S_in: S_in) = + decl a1 : decimal; + decl a2 : option decimal; + decl a3 : option decimal; switch handle_exceptions []: - | ENone __57 → + | ENone _ → if true: - temp_a_56 = ESome glob1_1 * glob1_1 + a3 = ESome glob1 * glob1 else: - temp_a_56 = ENone () - | ESome x_58 → - temp_a_56 = ESome x_58; - switch handle_exceptions [temp_a_56]: - | ENone __59 → + a3 = ENone () + | ESome x → + a3 = ESome x; + switch handle_exceptions [a3]: + | ENone _ → if false: - temp_a_55 = ENone () + a2 = ENone () else: - temp_a_55 = ENone () - | ESome x_60 → - temp_a_55 = ESome x_60; - switch temp_a_55: - | ENone __61 → + a2 = ENone () + | ESome x → + a2 = ESome x; + switch a2: + | ENone _ → fatal NoValue - | ESome arg_62 → - temp_a_54 = arg_62; - decl a_43 : decimal; - a_43 = temp_a_54; - decl temp_b_45 : A {y: bool; z: decimal}; - decl temp_b_46 : option A {y: bool; z: decimal}; - decl temp_b_47 : option A {y: bool; z: decimal}; + | ESome arg → + a1 = arg; + decl a : decimal; + a = a1; + decl b1 : A {y: bool; z: decimal}; + decl b2 : option A {y: bool; z: decimal}; + decl b3 : option A {y: bool; z: decimal}; switch handle_exceptions []: - | ENone __48 → + | ENone _ → if true: - temp_b_47 = ESome glob2_8 + b3 = ESome glob2 else: - temp_b_47 = ENone () - | ESome x_49 → - temp_b_47 = ESome x_49; - switch handle_exceptions [temp_b_47]: - | ENone __50 → + b3 = ENone () + | ESome x → + b3 = ESome x; + switch handle_exceptions [b3]: + | ENone _ → if false: - temp_b_46 = ENone () + b2 = ENone () else: - temp_b_46 = ENone () - | ESome x_51 → - temp_b_46 = ESome x_51; - switch temp_b_46: - | ENone __52 → + b2 = ENone () + | ESome x → + b2 = ESome x; + switch b2: + | ENone _ → fatal NoValue - | ESome arg_53 → - temp_b_45 = arg_53; - decl b_44 : A {y: bool; z: decimal}; - b_44 = temp_b_45; - return S {"a": a_43, "b": b_44} + | ESome arg → + b1 = arg; + decl b : A {y: bool; z: decimal}; + b = b1; + return S {"a": a, "b": b} ``` ```catala-test-inline @@ -427,18 +427,18 @@ glob1 = (decimal_of_string("44.12")) def glob3(x:Money): return (decimal_of_money(x) + decimal_of_string("10.")) -def glob4(x_1:Money, y:Decimal): - return ((decimal_of_money(x_1) * y) + decimal_of_string("10.")) +def glob4(x:Money, y:Decimal): + return ((decimal_of_money(x) * y) + decimal_of_string("10.")) -def glob5_aux(): - x_2 = (decimal_of_integer(integer_of_string("2")) * +def glob5_init(): + x = (decimal_of_integer(integer_of_string("2")) * decimal_of_string("3.")) - y_1 = decimal_of_string("1000.") - return (x_2 * y_1) + y = decimal_of_string("1000.") + return (x * y) -glob5 = (glob5_aux()) +glob5 = (glob5_init()) -glob2 = ( +glob6 = ( A(y = (glob1 >= decimal_of_string("30.")), z = (decimal_of_string("123.") * @@ -446,202 +446,216 @@ glob2 = ( ) def s2(s2_in:S2In): - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + match_arg = handle_exceptions([], []) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if True: - temp_a = (glob3(money_of_cents_string("4400")) + - decimal_of_string("100.")) + a3 = Eoption(Eoption_Code.ESome, + (glob3(money_of_cents_string("4400")) + + decimal_of_string("100."))) else: - temp_a = None - else: - x_3 = perhaps_none_arg - temp_a = x_3 - perhaps_none_arg_1 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=53, start_column=24, - end_line=53, end_column=43, - law_headings=["Test toplevel function defs"] - )], - [temp_a] - ) - if perhaps_none_arg_1 is None: + a3 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a3 = Eoption(Eoption_Code.ESome, x) + match_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=53, start_column=24, + end_line=53, end_column=43, + law_headings=["Test toplevel function defs"])], + [a3] + ) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if False: - temp_a_1 = None + a2 = Eoption(Eoption_Code.ENone, Unit()) else: - temp_a_1 = None - else: - x_4 = perhaps_none_arg_1 - temp_a_1 = x_4 - perhaps_none_arg_2 = temp_a_1 - if perhaps_none_arg_2 is None: + a2 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a2 = Eoption(Eoption_Code.ESome, x) + match_arg = a2 + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=50, start_column=10, end_line=50, end_column=11, law_headings=["Test toplevel function defs"])) - else: - arg = perhaps_none_arg_2 - temp_a_2 = arg - a = temp_a_2 + elif match_arg.code == Eoption_Code.ESome: + arg = match_arg.value + a1 = arg + a = a1 return S2(a = a) def s3(s3_in:S3In): - perhaps_none_arg_3 = handle_exceptions([], []) - if perhaps_none_arg_3 is None: + match_arg = handle_exceptions([], []) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if True: - temp_a_3 = (decimal_of_string("50.") + + a3 = Eoption(Eoption_Code.ESome, + (decimal_of_string("50.") + glob4(money_of_cents_string("4400"), - decimal_of_string("55."))) + decimal_of_string("55.")))) else: - temp_a_3 = None - else: - x_5 = perhaps_none_arg_3 - temp_a_3 = x_5 - perhaps_none_arg_4 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=74, start_column=24, - end_line=74, end_column=47, - law_headings=["Test function def with two args"] - )], - [temp_a_3] - ) - if perhaps_none_arg_4 is None: + a3 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a3 = Eoption(Eoption_Code.ESome, x) + match_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=74, start_column=24, + end_line=74, end_column=47, + law_headings=["Test function def with two args"])], + [a3] + ) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if False: - temp_a_4 = None + a2 = Eoption(Eoption_Code.ENone, Unit()) else: - temp_a_4 = None - else: - x_6 = perhaps_none_arg_4 - temp_a_4 = x_6 - perhaps_none_arg_5 = temp_a_4 - if perhaps_none_arg_5 is None: + a2 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a2 = Eoption(Eoption_Code.ESome, x) + match_arg = a2 + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=71, start_column=10, end_line=71, end_column=11, law_headings=["Test function def with two args"])) - else: - arg_1 = perhaps_none_arg_5 - temp_a_5 = arg_1 - a_1 = temp_a_5 - return S3(a = a_1) + elif match_arg.code == Eoption_Code.ESome: + arg = match_arg.value + a1 = arg + a = a1 + return S3(a = a) def s4(s4_in:S4In): - perhaps_none_arg_6 = handle_exceptions([], []) - if perhaps_none_arg_6 is None: + match_arg = handle_exceptions([], []) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if True: - temp_a_6 = (glob5 + decimal_of_string("1.")) + a3 = Eoption(Eoption_Code.ESome, + (glob5 + + decimal_of_string("1."))) else: - temp_a_6 = None - else: - x_7 = perhaps_none_arg_6 - temp_a_6 = x_7 - perhaps_none_arg_7 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=98, start_column=24, - end_line=98, end_column=34, - law_headings=["Test inline defs in toplevel defs"] - )], - [temp_a_6] - ) - if perhaps_none_arg_7 is None: + a3 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a3 = Eoption(Eoption_Code.ESome, x) + match_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=98, start_column=24, + end_line=98, end_column=34, + law_headings=["Test inline defs in toplevel defs"])], + [a3] + ) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if False: - temp_a_7 = None + a2 = Eoption(Eoption_Code.ENone, Unit()) else: - temp_a_7 = None - else: - x_8 = perhaps_none_arg_7 - temp_a_7 = x_8 - perhaps_none_arg_8 = temp_a_7 - if perhaps_none_arg_8 is None: + a2 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a2 = Eoption(Eoption_Code.ESome, x) + match_arg = a2 + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=95, start_column=10, end_line=95, end_column=11, law_headings=["Test inline defs in toplevel defs"])) - else: - arg_2 = perhaps_none_arg_8 - temp_a_8 = arg_2 - a_2 = temp_a_8 - return S4(a = a_2) + elif match_arg.code == Eoption_Code.ESome: + arg = match_arg.value + a1 = arg + a = a1 + return S4(a = a) -def s(s_in:SIn): - perhaps_none_arg_9 = handle_exceptions([], []) - if perhaps_none_arg_9 is None: +def s5(s_in:SIn): + match_arg = handle_exceptions([], []) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if True: - temp_a_9 = (glob1 * glob1) + a3 = Eoption(Eoption_Code.ESome, (glob1 * glob1)) else: - temp_a_9 = None - else: - x_9 = perhaps_none_arg_9 - temp_a_9 = x_9 - perhaps_none_arg_10 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=18, start_column=24, - end_line=18, end_column=37, - law_headings=["Test basic toplevel values defs"] - )], - [temp_a_9] - ) - if perhaps_none_arg_10 is None: + a3 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a3 = Eoption(Eoption_Code.ESome, x) + match_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=18, start_column=24, + end_line=18, end_column=37, + law_headings=["Test basic toplevel values defs"])], + [a3] + ) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if False: - temp_a_10 = None + a2 = Eoption(Eoption_Code.ENone, Unit()) else: - temp_a_10 = None - else: - x_10 = perhaps_none_arg_10 - temp_a_10 = x_10 - perhaps_none_arg_11 = temp_a_10 - if perhaps_none_arg_11 is None: + a2 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + a2 = Eoption(Eoption_Code.ESome, x) + match_arg = a2 + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=7, start_column=10, end_line=7, end_column=11, law_headings=["Test basic toplevel values defs"])) - else: - arg_3 = perhaps_none_arg_11 - temp_a_11 = arg_3 - a_3 = temp_a_11 - perhaps_none_arg_12 = handle_exceptions([], []) - if perhaps_none_arg_12 is None: + elif match_arg.code == Eoption_Code.ESome: + arg = match_arg.value + a1 = arg + a = a1 + match_arg = handle_exceptions([], []) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if True: - temp_b = glob2 + b3 = Eoption(Eoption_Code.ESome, glob6) else: - temp_b = None - else: - x_11 = perhaps_none_arg_12 - temp_b = x_11 - perhaps_none_arg_13 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=19, start_column=24, - end_line=19, end_column=29, - law_headings=["Test basic toplevel values defs"] - )], - [temp_b] - ) - if perhaps_none_arg_13 is None: + b3 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + b3 = Eoption(Eoption_Code.ESome, x) + match_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=19, start_column=24, + end_line=19, end_column=29, + law_headings=["Test basic toplevel values defs"])], + [b3] + ) + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value if False: - temp_b_1 = None + b2 = Eoption(Eoption_Code.ENone, Unit()) else: - temp_b_1 = None - else: - x_12 = perhaps_none_arg_13 - temp_b_1 = x_12 - perhaps_none_arg_14 = temp_b_1 - if perhaps_none_arg_14 is None: + b2 = Eoption(Eoption_Code.ENone, Unit()) + elif match_arg.code == Eoption_Code.ESome: + x = match_arg.value + b2 = Eoption(Eoption_Code.ESome, x) + match_arg = b2 + if match_arg.code == Eoption_Code.ENone: + _ = match_arg.value raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=8, start_column=10, end_line=8, end_column=11, law_headings=["Test basic toplevel values defs"])) - else: - arg_4 = perhaps_none_arg_14 - temp_b_2 = arg_4 - b = temp_b_2 - return S(a = a_3, b = b) + elif match_arg.code == Eoption_Code.ESome: + arg = match_arg.value + b1 = arg + b = b1 + return S(a = a, b = b) ``` diff --git a/tests/scope/good/nothing.catala_en b/tests/scope/good/nothing.catala_en index 88548eaa..a8463b5d 100644 --- a/tests/scope/good/nothing.catala_en +++ b/tests/scope/good/nothing.catala_en @@ -39,11 +39,11 @@ $ catala Scalc -s Foo2 -O -t │ 5 │ output bar content integer │ │ ‾‾‾ └─ Test -let Foo2_1 (Foo2_in_1: Foo2_in) = - decl temp_bar_3 : integer; +let Foo2 (Foo2_in: Foo2_in) = + decl bar1 : integer; fatal NoValue; - decl bar_2 : integer; - bar_2 = temp_bar_3; - return Foo2 {"bar": bar_2} + decl bar : integer; + bar = bar1; + return Foo2 {"bar": bar} ``` From 1b6da0b5720e9168aca1452e9ed40ad7e0bd737a Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 7 Aug 2024 17:44:39 +0200 Subject: [PATCH 2/9] reformat (renaming in scalc) --- compiler/driver.ml | 41 ++++++++------ compiler/driver.mli | 15 ++--- compiler/lcalc/to_ocaml.ml | 2 +- compiler/scalc/from_lcalc.ml | 98 ++++++++++++++++----------------- compiler/scalc/from_lcalc.mli | 3 +- compiler/scalc/print.ml | 4 +- compiler/scalc/to_c.ml | 65 ++++++++++++++++------ compiler/scalc/to_python.ml | 63 ++++++++++++++++----- compiler/shared_ast/program.ml | 32 ++++------- compiler/shared_ast/program.mli | 2 +- 10 files changed, 191 insertions(+), 134 deletions(-) diff --git a/compiler/driver.ml b/compiler/driver.ml index 862858ca..852707b1 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -230,9 +230,9 @@ module Passes = struct ~closure_conversion ~monomorphize_types ~renaming : - typed Lcalc.Ast.program * - Scopelang.Dependency.TVertex.t list * - Expr.Renaming.context option = + typed Lcalc.Ast.program + * Scopelang.Dependency.TVertex.t list + * Expr.Renaming.context option = let prg, type_ordering = dcalc options ~includes ~optimize ~check_invariants ~typed in @@ -303,25 +303,35 @@ module Passes = struct ~no_struct_literals ~monomorphize_types ~renaming : - Scalc.Ast.program * Scopelang.Dependency.TVertex.t list * Expr.Renaming.context = + Scalc.Ast.program + * Scopelang.Dependency.TVertex.t list + * Expr.Renaming.context = let prg, type_ordering, renaming_context = lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed ~closure_conversion ~monomorphize_types ~renaming in - let renaming_context = match renaming_context with - | None -> Expr.Renaming.get_ctx { - reserved = []; - sanitize_varname = Fun.id; - reset_context_for_closed_terms = true; - skip_constant_binders = true; - constant_binder_name = None; - } + let renaming_context = + match renaming_context with + | None -> + Expr.Renaming.get_ctx + { + reserved = []; + sanitize_varname = Fun.id; + reset_context_for_closed_terms = true; + skip_constant_binders = true; + constant_binder_name = None; + } | Some r -> r in debug_pass_name "scalc"; ( Scalc.From_lcalc.translate_program - ~config:{ keep_special_ops; dead_value_assignment; no_struct_literals; - renaming_context } + ~config: + { + keep_special_ops; + dead_value_assignment; + no_struct_literals; + renaming_context; + } prg, type_ordering, renaming_context ) @@ -963,8 +973,7 @@ module Commands = struct Passes.scalc options ~includes ~optimize ~check_invariants ~closure_conversion:true ~keep_special_ops:true ~dead_value_assignment:false ~no_struct_literals:true - ~monomorphize_types:true - ~renaming:(Some Scalc.To_c.renaming) + ~monomorphize_types:true ~renaming:(Some Scalc.To_c.renaming) in let output_file, with_output = get_output_format options ~ext:".c" output in Message.debug "Compiling program into C..."; diff --git a/compiler/driver.mli b/compiler/driver.mli index 47f2bfce..372b19e1 100644 --- a/compiler/driver.mli +++ b/compiler/driver.mli @@ -53,9 +53,10 @@ module Passes : sig typed:'m Shared_ast.mark -> closure_conversion:bool -> monomorphize_types:bool -> - renaming : Shared_ast.Program.renaming option -> - Shared_ast.typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list * - Shared_ast.Expr.Renaming.context option + renaming:Shared_ast.Program.renaming option -> + Shared_ast.typed Lcalc.Ast.program + * Scopelang.Dependency.TVertex.t list + * Shared_ast.Expr.Renaming.context option val scalc : Global.options -> @@ -67,10 +68,10 @@ module Passes : sig dead_value_assignment:bool -> no_struct_literals:bool -> monomorphize_types:bool -> - renaming: Shared_ast.Program.renaming option -> - Scalc.Ast.program * Scopelang.Dependency.TVertex.t list * - Shared_ast.Expr.Renaming.context - + renaming:Shared_ast.Program.renaming option -> + Scalc.Ast.program + * Scopelang.Dependency.TVertex.t list + * Shared_ast.Expr.Renaming.context end module Commands : sig diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index d9e5f68e..d1a783ba 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -133,7 +133,7 @@ let ocaml_keywords = let renaming = Program.renaming () ~reserved:ocaml_keywords - (* TODO: add catala runtime built-ins as reserved as well ? *) + (* TODO: add catala runtime built-ins as reserved as well ? *) ~reset_context_for_closed_terms:true ~skip_constant_binders:true ~constant_binder_name:(Some "_") ~namespaced_fields_constrs:true diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index c3326954..cbf65fe6 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -173,8 +173,7 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = let vars, body, ctxt = unmbind ctxt binder in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in let ctxt = - List.fold_left (register_fresh_arg ~pos:binder_pos) - ctxt vars_tau + List.fold_left (register_fresh_arg ~pos:binder_pos) ctxt vars_tau in let local_decls = List.fold_left @@ -336,9 +335,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = let vars, body, ctxt = unmbind ctxt binder in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in let ctxt = - List.fold_left - (register_fresh_arg ~pos:binder_pos) - ctxt vars_tau + List.fold_left (register_fresh_arg ~pos:binder_pos) ctxt vars_tau in let local_decls = List.map @@ -557,12 +554,9 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = RevBlock.rebuild e_stmts ~tail | _ -> . -let rec translate_scope_body_expr - ctx - (scope_expr : 'm L.expr scope_body_expr) : A.block = - let ctx = - { ctx with inside_definition_of = None } - in +let rec translate_scope_body_expr ctx (scope_expr : 'm L.expr scope_body_expr) : + A.block = + let ctx = { ctx with inside_definition_of = None } in match scope_expr with | Last e -> let block, new_e = translate_expr ctx e in @@ -572,9 +566,7 @@ let rec translate_scope_body_expr let let_var_id, ctx = register_fresh_var ctx1 let_var ~pos:scope_let.scope_let_pos in - let next = - translate_scope_body_expr ctx scope_let_next - in + let next = translate_scope_body_expr ctx scope_let_next in match scope_let.scope_let_kind with | Assertion -> translate_statements @@ -604,7 +596,7 @@ let rec translate_scope_body_expr scope_let.scope_let_pos ) :: next)) -let translate_program ~(config : translation_config) (p : 'm L.program): +let translate_program ~(config : translation_config) (p : 'm L.program) : A.program = let modules = List.fold_left @@ -632,8 +624,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program): } in let (_, rev_items), _vlist = - BoundList.fold_left - ~init:(ctxt, []) + BoundList.fold_left ~init:(ctxt, []) ~f:(fun (ctxt, rev_items) code_item var -> match code_item with | ScopeDef (name, body) -> @@ -646,13 +637,10 @@ let translate_program ~(config : translation_config) (p : 'm L.program): in let new_scope_body = translate_scope_body_expr - { ctxt with - context_name = Mark.remove (ScopeName.get_info name) } + { ctxt with context_name = Mark.remove (ScopeName.get_info name) } scope_body_expr in - let func_id, ctxt1 = - register_fresh_func ctxt1 var ~pos:input_pos - in + let func_id, ctxt1 = register_fresh_func ctxt1 var ~pos:input_pos in ( ctxt1, A.SScope { @@ -679,13 +667,14 @@ let translate_program ~(config : translation_config) (p : 'm L.program): let rargs_id, ctxt = List.fold_left2 (fun (rargs_id, ctxt) v ty -> - let pos = Mark.get ty in - let id, ctxt = register_fresh_var ctxt v ~pos in - ((id, pos), ty) :: rargs_id, ctxt) + let pos = Mark.get ty in + let id, ctxt = register_fresh_var ctxt v ~pos in + ((id, pos), ty) :: rargs_id, ctxt) ([], ctxt) args abs.tys in let ctxt = - { ctxt with + { + ctxt with context_name = Mark.remove (TopdefName.get_info name); } in @@ -695,7 +684,9 @@ let translate_program ~(config : translation_config) (p : 'm L.program): RevBlock.rebuild block ~tail:[A.SReturn (Mark.remove expr), Mark.get expr] in - let func_id, ctxt = register_fresh_func ctxt var ~pos:(Expr.mark_pos m) in + let func_id, ctxt = + register_fresh_func ctxt var ~pos:(Expr.mark_pos m) + in ( ctxt, A.SFunc { @@ -716,14 +707,16 @@ let translate_program ~(config : translation_config) (p : 'm L.program): (* Toplevel constant def *) let block, expr = let ctxt = - { ctxt with + { + ctxt with context_name = Mark.remove (TopdefName.get_info name); } in translate_expr ctxt expr in let var_id, ctxt = - register_fresh_var ctxt var ~pos:(Mark.get (TopdefName.get_info name)) + register_fresh_var ctxt var + ~pos:(Mark.get (TopdefName.get_info name)) in (* If the evaluation of the toplevel expr requires preliminary statements, we lift its computation into an auxiliary function *) @@ -738,26 +731,27 @@ let translate_program ~(config : translation_config) (p : 'm L.program): let func_id = A.FuncName.fresh (func_name, pos) in (* The list is being built in reverse order *) (* FIXME: find a better way than a function with no parameters... *) - A.SVar - { - var = var_id; - expr = A.EApp { f = EFunc func_id, pos; args = [] }, pos; - typ = topdef_ty; - } - :: A.SFunc - { - var = func_id; - func = - { - A.func_params = []; - A.func_body = - RevBlock.rebuild block - ~tail:[A.SReturn (Mark.remove expr), Mark.get expr]; - A.func_return_typ = topdef_ty; - }; - } - :: rev_items, - ctxt + ( A.SVar + { + var = var_id; + expr = A.EApp { f = EFunc func_id, pos; args = [] }, pos; + typ = topdef_ty; + } + :: A.SFunc + { + var = func_id; + func = + { + A.func_params = []; + A.func_body = + RevBlock.rebuild block + ~tail: + [A.SReturn (Mark.remove expr), Mark.get expr]; + A.func_return_typ = topdef_ty; + }; + } + :: rev_items, + ctxt ) in ( ctxt, (* No need to add func_id since the function will only be called @@ -765,4 +759,8 @@ let translate_program ~(config : translation_config) (p : 'm L.program): rev_items )) p.code_items in - { ctx = program_ctx; code_items = List.rev rev_items; module_name = p.module_name } + { + ctx = program_ctx; + code_items = List.rev rev_items; + module_name = p.module_name; + } diff --git a/compiler/scalc/from_lcalc.mli b/compiler/scalc/from_lcalc.mli index 4cd09de7..7ab7f417 100644 --- a/compiler/scalc/from_lcalc.mli +++ b/compiler/scalc/from_lcalc.mli @@ -36,5 +36,4 @@ type translation_config = { } val translate_program : - config:translation_config -> typed Lcalc.Ast.program -> - Ast.program + config:translation_config -> typed Lcalc.Ast.program -> Ast.program diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index a40737c6..03bf5fa2 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -22,11 +22,11 @@ let needs_parens (_e : expr) : bool = false let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit = VarName.format fmt v - (* Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v) *) +(* Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v) *) let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = FuncName.format fmt v - (* Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.id v) *) +(* Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.id v) *) let rec format_expr (decl_ctx : decl_ctx) diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index 517db4ae..77841ed7 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -22,20 +22,49 @@ module L = Lcalc.Ast open Ast let c_keywords = - [ "auto"; "break"; "case"; "char"; "const"; "continue"; "default"; - "do"; "double"; "else"; "enum"; "extern"; "float"; "for"; "goto"; - "if"; "inline"; "int"; "long"; "register"; "restrict"; "return"; - "short"; "signed"; "sizeof"; "static"; "struct"; "switch"; "typedef"; - "union"; "unsigned"; "void"; "volatile"; "while" ] + [ + "auto"; + "break"; + "case"; + "char"; + "const"; + "continue"; + "default"; + "do"; + "double"; + "else"; + "enum"; + "extern"; + "float"; + "for"; + "goto"; + "if"; + "inline"; + "int"; + "long"; + "register"; + "restrict"; + "return"; + "short"; + "signed"; + "sizeof"; + "static"; + "struct"; + "switch"; + "typedef"; + "union"; + "unsigned"; + "void"; + "volatile"; + "while"; + ] let renaming = Program.renaming () ~reserved:c_keywords - (* TODO: add catala runtime built-ins as reserved as well ? *) - ~reset_context_for_closed_terms:true - ~skip_constant_binders:true - ~constant_binder_name:None - ~namespaced_fields_constrs:false + (* TODO: add catala runtime built-ins as reserved as well ? *) + ~reset_context_for_closed_terms:true ~skip_constant_binders:true + ~constant_binder_name:None ~namespaced_fields_constrs:false module TypMap = Map.Make (struct type t = naked_typ @@ -102,8 +131,7 @@ let format_ctx ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") (fun fmt (struct_field, struct_field_type) -> Format.fprintf fmt "@[%a;@]" - (format_typ ctx (fun fmt -> - StructField.format fmt struct_field)) + (format_typ ctx (fun fmt -> StructField.format fmt struct_field)) struct_field_type)) fields StructName.format struct_name in @@ -251,8 +279,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : (fun fmt (_, e) -> Format.fprintf fmt "%a" (format_expression ctx) e)) (StructField.Map.bindings es) | EStructFieldAccess { e1; field; _ } -> - Format.fprintf fmt "%a.%a" (format_expression ctx) e1 - StructField.format field + Format.fprintf fmt "%a.%a" (format_expression ctx) e1 StructField.format + field | EInj { e1; cons; name = enum_name; _ } -> Format.fprintf fmt "{%a_%a,@ {%a: %a}}" EnumName.format enum_name EnumConstructor.format cons EnumConstructor.format cons @@ -380,7 +408,8 @@ let rec format_statement if not (Type.equal payload_var_typ (TLit TUnit, Pos.no_pos)) then Format.fprintf fmt "%a = %a.payload.%a;@ " (format_typ ctx (fun fmt -> VarName.format fmt payload_var_name)) - payload_var_typ VarName.format tmp_var EnumConstructor.format cons_name; + payload_var_typ VarName.format tmp_var EnumConstructor.format + cons_name; Format.fprintf fmt "%a@ break;@]" (format_block ctx) case_block) fmt cases; (* Do we want to add 'default' case with a failure ? *) @@ -447,9 +476,9 @@ let rec format_statement VarName.format exception_current (format_expression ctx) except VarName.format exception_current EnumName.format e_name EnumConstructor.format some_cons VarName.format exception_acc_var - EnumName.format e_name EnumConstructor.format some_cons VarName.format - exception_conflict VarName.format exception_acc_var VarName.format - exception_current) + EnumName.format e_name EnumConstructor.format some_cons + VarName.format exception_conflict VarName.format exception_acc_var + VarName.format exception_current) exceptions; Format.fprintf fmt "@[if (%a) {@,\ diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 26f33765..2391a4f1 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -113,17 +113,50 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit = let python_keywords = (* list taken from https://www.programiz.com/python-programming/keyword-list *) - [ "False"; "None"; "True"; "and"; "as"; "assert"; "async"; "await"; - "break"; "class"; "continue"; "def"; "del"; "elif"; "else"; - "except"; "finally"; "for"; "from"; "global"; "if"; "import"; "in"; - "is"; "lambda"; "nonlocal"; "not"; "or"; "pass"; "raise"; "return"; - "try"; "while"; "with"; "yield" ] -(* todo: reserved names should also include built-in types and everything exposed by the runtime. *) + [ + "False"; + "None"; + "True"; + "and"; + "as"; + "assert"; + "async"; + "await"; + "break"; + "class"; + "continue"; + "def"; + "del"; + "elif"; + "else"; + "except"; + "finally"; + "for"; + "from"; + "global"; + "if"; + "import"; + "in"; + "is"; + "lambda"; + "nonlocal"; + "not"; + "or"; + "pass"; + "raise"; + "return"; + "try"; + "while"; + "with"; + "yield"; + ] +(* todo: reserved names should also include built-in types and everything + exposed by the runtime. *) let renaming = Program.renaming () ~reserved:python_keywords - (* TODO: add catala runtime built-ins as reserved as well ? *) + (* TODO: add catala runtime built-ins as reserved as well ? *) ~reset_context_for_closed_terms:false ~skip_constant_binders:false ~constant_binder_name:None ~namespaced_fields_constrs:true ~f_struct:String.to_camel_case @@ -198,8 +231,8 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = (format_expression ctx) e)) (StructField.Map.bindings es) | EStructFieldAccess { e1; field; _ } -> - Format.fprintf fmt "%a.%a" (format_expression ctx) e1 - StructField.format field + Format.fprintf fmt "%a.%a" (format_expression ctx) e1 StructField.format + field | EInj { cons; name = e_name; _ } when EnumName.equal e_name Expr.option_enum && EnumConstructor.equal cons Expr.none_constr -> @@ -352,8 +385,8 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit (format_expression ctx) e1; Format.fprintf fmt "@[if %a is None:@ %a@]@," VarName.format tmp_var (format_block ctx) case_none; - Format.fprintf fmt "@[else:@ %a = %a@,%a@]" VarName.format case_some_var - VarName.format tmp_var (format_block ctx) case_some + Format.fprintf fmt "@[else:@ %a = %a@,%a@]" VarName.format + case_some_var VarName.format tmp_var (format_block ctx) case_some | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in let cases = @@ -369,7 +402,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt ({ case_block; payload_var_name; _ }, cons_name) -> Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" - VarName.format tmp_var (EnumName.format) e_name + VarName.format tmp_var EnumName.format e_name EnumConstructor.format cons_name VarName.format payload_var_name VarName.format tmp_var (format_block ctx) case_block)) cases @@ -482,14 +515,14 @@ let format_ctx @,\ \ def __str__(self) -> str:@,\ \ @[return \"{}({})\".format(self.code, self.value)@]" - (EnumName.format) enum_name + EnumName.format enum_name (Format.pp_print_list (fun fmt (i, enum_cons, _enum_cons_type) -> Format.fprintf fmt "%a = %d" EnumConstructor.format enum_cons i)) (List.mapi (fun i (x, y) -> i, x, y) (EnumConstructor.Map.bindings enum_cons)) - (EnumName.format) enum_name EnumName.format enum_name - EnumName.format enum_name + EnumName.format enum_name EnumName.format enum_name EnumName.format + enum_name in let is_in_type_ordering s = diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 33062c64..b76d1e84 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -114,8 +114,7 @@ let rename_ids ?(f_field = uncap) ?(f_enum = cap) ?(f_constr = cap) - p - = + p = let cfg = { Expr.Renaming.reserved; @@ -285,17 +284,16 @@ let rename_ids let code_items = Scope.rename_ids ctx p.code_items in { p with decl_ctx; code_items }, ctx -(* This first-class module wrapping is here to allow a polymorphic renaming function to be passed around *) +(* This first-class module wrapping is here to allow a polymorphic renaming + function to be passed around *) module type Renaming = sig - val apply: - 'e program -> - 'e program * Expr.Renaming.context + val apply : 'e program -> 'e program * Expr.Renaming.context end type renaming = (module Renaming) -let apply (module R: Renaming) = R.apply +let apply (module R : Renaming) = R.apply let renaming ~reserved @@ -308,21 +306,11 @@ let renaming ?f_field ?f_enum ?f_constr - () - = + () = let module M = struct let apply p = - rename_ids - ~reserved - ~reset_context_for_closed_terms - ~skip_constant_binders - ~constant_binder_name - ~namespaced_fields_constrs - ?f_var - ?f_struct - ?f_field - ?f_enum - ?f_constr - p + rename_ids ~reserved ~reset_context_for_closed_terms + ~skip_constant_binders ~constant_binder_name ~namespaced_fields_constrs + ?f_var ?f_struct ?f_field ?f_enum ?f_constr p end in - (module M: Renaming) + (module M : Renaming) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 7fd58a3b..41880a03 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -59,7 +59,7 @@ val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list type renaming -val apply: renaming -> 'e program -> 'e program * Expr.Renaming.context +val apply : renaming -> 'e program -> 'e program * Expr.Renaming.context val renaming : reserved:string list -> From 081e07378a0de36b58500528f1c3ee930bcfd4f7 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 7 Aug 2024 18:03:10 +0200 Subject: [PATCH 3/9] Renaming: move to its own module --- compiler/driver.ml | 15 +- compiler/driver.mli | 8 +- compiler/lcalc/to_ocaml.ml | 2 +- compiler/lcalc/to_ocaml.mli | 2 +- compiler/plugins/explain.ml | 6 +- compiler/scalc/from_lcalc.ml | 10 +- compiler/scalc/from_lcalc.mli | 2 +- compiler/scalc/to_c.ml | 2 +- compiler/scalc/to_c.mli | 2 +- compiler/scalc/to_python.ml | 2 +- compiler/scalc/to_python.mli | 2 +- compiler/shared_ast/expr.ml | 192 ------------ compiler/shared_ast/expr.mli | 54 +--- compiler/shared_ast/program.ml | 218 -------------- compiler/shared_ast/program.mli | 33 -- compiler/shared_ast/renaming.ml | 482 ++++++++++++++++++++++++++++++ compiler/shared_ast/renaming.mli | 105 +++++++ compiler/shared_ast/scope.ml | 58 ---- compiler/shared_ast/scope.mli | 5 - compiler/shared_ast/shared_ast.ml | 1 + 20 files changed, 615 insertions(+), 586 deletions(-) create mode 100644 compiler/shared_ast/renaming.ml create mode 100644 compiler/shared_ast/renaming.mli diff --git a/compiler/driver.ml b/compiler/driver.ml index 852707b1..0a0469d8 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -232,7 +232,7 @@ module Passes = struct ~renaming : typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list - * Expr.Renaming.context option = + * Renaming.context option = let prg, type_ordering = dcalc options ~includes ~optimize ~check_invariants ~typed in @@ -281,13 +281,13 @@ module Passes = struct match renaming with | None -> prg, type_ordering, None | Some renaming -> - let prg, ren_ctx = Program.apply renaming prg in + let prg, ren_ctx = Renaming.apply renaming prg in let type_ordering = let open Scopelang.Dependency.TVertex in List.map (function - | Struct s -> Struct (Expr.Renaming.struct_name ren_ctx s) - | Enum e -> Enum (Expr.Renaming.enum_name ren_ctx e)) + | Struct s -> Struct (Renaming.struct_name ren_ctx s) + | Enum e -> Enum (Renaming.enum_name ren_ctx e)) type_ordering in prg, type_ordering, Some ren_ctx @@ -303,9 +303,8 @@ module Passes = struct ~no_struct_literals ~monomorphize_types ~renaming : - Scalc.Ast.program - * Scopelang.Dependency.TVertex.t list - * Expr.Renaming.context = + Scalc.Ast.program * Scopelang.Dependency.TVertex.t list * Renaming.context + = let prg, type_ordering, renaming_context = lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed ~closure_conversion ~monomorphize_types ~renaming @@ -313,7 +312,7 @@ module Passes = struct let renaming_context = match renaming_context with | None -> - Expr.Renaming.get_ctx + Renaming.get_ctx { reserved = []; sanitize_varname = Fun.id; diff --git a/compiler/driver.mli b/compiler/driver.mli index 372b19e1..29a40832 100644 --- a/compiler/driver.mli +++ b/compiler/driver.mli @@ -53,10 +53,10 @@ module Passes : sig typed:'m Shared_ast.mark -> closure_conversion:bool -> monomorphize_types:bool -> - renaming:Shared_ast.Program.renaming option -> + renaming:Shared_ast.Renaming.t option -> Shared_ast.typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list - * Shared_ast.Expr.Renaming.context option + * Shared_ast.Renaming.context option val scalc : Global.options -> @@ -68,10 +68,10 @@ module Passes : sig dead_value_assignment:bool -> no_struct_literals:bool -> monomorphize_types:bool -> - renaming:Shared_ast.Program.renaming option -> + renaming:Shared_ast.Renaming.t option -> Scalc.Ast.program * Scopelang.Dependency.TVertex.t list - * Shared_ast.Expr.Renaming.context + * Shared_ast.Renaming.context end module Commands : sig diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index d1a783ba..02f238b7 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -131,7 +131,7 @@ let ocaml_keywords = ] let renaming = - Program.renaming () + Renaming.program () ~reserved:ocaml_keywords (* TODO: add catala runtime built-ins as reserved as well ? *) ~reset_context_for_closed_terms:true ~skip_constant_binders:true diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index 489343d7..9611ad32 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -17,7 +17,7 @@ open Catala_utils open Shared_ast -val renaming : Program.renaming +val renaming : Renaming.t (** Formats a lambda calculus program into a valid OCaml program *) diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index 068ced80..1ed7cc20 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -620,10 +620,10 @@ let program_to_graph let e = customize (Expr.unbox e) in let e = Expr.remove_logging_calls (Expr.unbox e) in let e = - Expr.Renaming.expr - (Expr.Renaming.get_ctx + Renaming.expr + (Renaming.get_ctx { - Expr.Renaming.reserved = []; + Renaming.reserved = []; sanitize_varname = String.to_snake_case; reset_context_for_closed_terms = false; skip_constant_binders = false; diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index cbf65fe6..997e8321 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -24,7 +24,7 @@ type translation_config = { keep_special_ops : bool; dead_value_assignment : bool; no_struct_literals : bool; - renaming_context : Expr.Renaming.context; + renaming_context : Renaming.context; } type 'm ctxt = { @@ -34,7 +34,7 @@ type 'm ctxt = { context_name : string; config : translation_config; program_ctx : A.ctx; - ren_ctx : Expr.Renaming.context; + ren_ctx : Renaming.context; } (* Expressions can spill out side effect, hence this function also returns a @@ -68,15 +68,15 @@ end let ( ++ ) = RevBlock.seq let unbind ctxt bnd = - let v, body, ren_ctx = Expr.Renaming.unbind_in ctxt.ren_ctx bnd in + let v, body, ren_ctx = Renaming.unbind_in ctxt.ren_ctx bnd in v, body, { ctxt with ren_ctx } let unmbind ctxt bnd = - let vs, body, ren_ctx = Expr.Renaming.unmbind_in ctxt.ren_ctx bnd in + let vs, body, ren_ctx = Renaming.unmbind_in ctxt.ren_ctx bnd in vs, body, { ctxt with ren_ctx } let get_name ctxt s = - let name, ren_ctx = Expr.Renaming.new_id ctxt.ren_ctx s in + let name, ren_ctx = Renaming.new_id ctxt.ren_ctx s in name, { ctxt with ren_ctx } let fresh_var ~pos ctxt name = diff --git a/compiler/scalc/from_lcalc.mli b/compiler/scalc/from_lcalc.mli index 7ab7f417..c7871441 100644 --- a/compiler/scalc/from_lcalc.mli +++ b/compiler/scalc/from_lcalc.mli @@ -32,7 +32,7 @@ type translation_config = { (** When [no_struct_literals] is true, the translation inserts a temporary variable to hold the initialization of struct literals. This matches what C89 expects. *) - renaming_context : Expr.Renaming.context; + renaming_context : Renaming.context; } val translate_program : diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index 77841ed7..784b0b97 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -60,7 +60,7 @@ let c_keywords = ] let renaming = - Program.renaming () + Renaming.program () ~reserved:c_keywords (* TODO: add catala runtime built-ins as reserved as well ? *) ~reset_context_for_closed_terms:true ~skip_constant_binders:true diff --git a/compiler/scalc/to_c.mli b/compiler/scalc/to_c.mli index 2b7c6853..efab8798 100644 --- a/compiler/scalc/to_c.mli +++ b/compiler/scalc/to_c.mli @@ -18,7 +18,7 @@ open Shared_ast -val renaming : Program.renaming +val renaming : Renaming.t val format_program : Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 2391a4f1..81e980f6 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -154,7 +154,7 @@ let python_keywords = exposed by the runtime. *) let renaming = - Program.renaming () + Renaming.program () ~reserved:python_keywords (* TODO: add catala runtime built-ins as reserved as well ? *) ~reset_context_for_closed_terms:false ~skip_constant_binders:false diff --git a/compiler/scalc/to_python.mli b/compiler/scalc/to_python.mli index d055d0ab..84988908 100644 --- a/compiler/scalc/to_python.mli +++ b/compiler/scalc/to_python.mli @@ -18,7 +18,7 @@ open Shared_ast -val renaming : Program.renaming +val renaming : Renaming.t val format_program : Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 5441d144..32d7a5e7 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -819,198 +819,6 @@ let remove_logging_calls e = in f e -module Renaming = struct - module DefaultBindlibCtxRename : Bindlib.Renaming = struct - (* This code is a copy-paste from Bindlib, they forgot to expose the default - implementation ! *) - type ctxt = int String.Map.t - - let empty_ctxt = String.Map.empty - - let split_name : string -> string * int = - fun name -> - let len = String.length name in - (* [i] is the index of the first first character of the suffix. *) - let i = - let is_digit c = '0' <= c && c <= '9' in - let first_digit = ref len in - let first_non_0 = ref len in - while !first_digit > 0 && is_digit name.[!first_digit - 1] do - decr first_digit; - if name.[!first_digit] <> '0' then first_non_0 := !first_digit - done; - !first_non_0 - in - if i = len then name, 0 - else String.sub name 0 i, int_of_string (String.sub name i (len - i)) - - let get_suffix : string -> int -> ctxt -> int * ctxt = - fun name suffix ctxt -> - let n = - try String.Map.find name ctxt with String.Map.Not_found _ -> -1 - in - let suffix = if suffix > n then suffix else n + 1 in - suffix, String.Map.add name suffix ctxt - - let merge_name : string -> int -> string = - fun prefix suffix -> - if suffix > 0 then prefix ^ string_of_int suffix else prefix - - let new_name : string -> ctxt -> string * ctxt = - fun name ctxt -> - let prefix, suffix = split_name name in - let suffix, ctxt = get_suffix prefix suffix ctxt in - merge_name prefix suffix, ctxt - - let reserve_name : string -> ctxt -> ctxt = - fun name ctxt -> - let prefix, suffix = split_name name in - try - let n = String.Map.find prefix ctxt in - if suffix <= n then ctxt else String.Map.add prefix suffix ctxt - with String.Map.Not_found _ -> String.Map.add prefix suffix ctxt - - let reset_context_for_closed_terms = false - let skip_constant_binders = false - let constant_binder_name = None - end - - module type BindlibCtxt = module type of Bindlib.Ctxt (DefaultBindlibCtxRename) - - type config = { - reserved : string list; - sanitize_varname : string -> string; - reset_context_for_closed_terms : bool; - skip_constant_binders : bool; - constant_binder_name : string option; - } - - type context = { - bindCtx : (module BindlibCtxt); - bcontext : DefaultBindlibCtxRename.ctxt; - vars : string -> string; - scopes : ScopeName.t -> ScopeName.t; - topdefs : TopdefName.t -> TopdefName.t; - structs : StructName.t -> StructName.t; - fields : StructField.t -> StructField.t; - enums : EnumName.t -> EnumName.t; - constrs : EnumConstructor.t -> EnumConstructor.t; - } - - let unbind_in ctx ?fname b = - let module BindCtx = (val ctx.bindCtx) in - match fname with - | Some fn -> - let name = fn (Bindlib.binder_name b) in - let v, bcontext = - BindCtx.new_var_in ctx.bcontext (fun v -> EVar v) name - in - let e = Bindlib.subst b (EVar v) in - v, e, { ctx with bcontext } - | None -> - let v, e, bcontext = BindCtx.unbind_in ctx.bcontext b in - v, e, { ctx with bcontext } - - let unmbind_in ctx ?fname b = - let module BindCtx = (val ctx.bindCtx) in - match fname with - | Some fn -> - let names = Array.map fn (Bindlib.mbinder_names b) in - let rvs, bcontext = - Array.fold_left - (fun (rvs, bcontext) n -> - let v, bcontext = BindCtx.new_var_in bcontext (fun v -> EVar v) n in - v :: rvs, bcontext) - ([], ctx.bcontext) names - in - let vs = Array.of_list (List.rev rvs) in - let e = Bindlib.msubst b (Array.map (fun v -> EVar v) vs) in - vs, e, { ctx with bcontext } - | None -> - let vs, e, bcontext = BindCtx.unmbind_in ctx.bcontext b in - vs, e, { ctx with bcontext } - - let set_rewriters ?scopes ?topdefs ?structs ?fields ?enums ?constrs ctx = - (fun ?(scopes = ctx.scopes) ?(topdefs = ctx.topdefs) - ?(structs = ctx.structs) ?(fields = ctx.fields) ?(enums = ctx.enums) - ?(constrs = ctx.constrs) () -> - { ctx with scopes; topdefs; structs; fields; enums; constrs }) - ?scopes ?topdefs ?structs ?fields ?enums ?constrs () - - let new_id ctx name = - let module BindCtx = (val ctx.bindCtx) in - let var, bcontext = - BindCtx.new_var_in ctx.bcontext (fun _ -> assert false) name - in - Bindlib.name_of var, { ctx with bcontext } - - let get_ctx cfg = - let module BindCtx = Bindlib.Ctxt (struct - include DefaultBindlibCtxRename - - let reset_context_for_closed_terms = cfg.reset_context_for_closed_terms - let skip_constant_binders = cfg.skip_constant_binders - let constant_binder_name = cfg.constant_binder_name - end) in - { - bindCtx = (module BindCtx); - bcontext = - List.fold_left - (fun ctx name -> DefaultBindlibCtxRename.reserve_name name ctx) - BindCtx.empty_ctxt cfg.reserved; - vars = cfg.sanitize_varname; - scopes = Fun.id; - topdefs = Fun.id; - structs = Fun.id; - fields = Fun.id; - enums = Fun.id; - constrs = Fun.id; - } - - let rec typ ctx = function - | TStruct n, m -> TStruct (ctx.structs n), m - | TEnum n, m -> TEnum (ctx.enums n), m - | ty -> Type.map (typ ctx) ty - - let rec expr : type k. context -> (k, 'm) gexpr -> (k, 'm) gexpr boxed = - fun ctx -> function - | EExternal { name = External_scope s, pos }, m -> - eexternal ~name:(External_scope (ctx.scopes s), pos) m - | EExternal { name = External_value d, pos }, m -> - eexternal ~name:(External_value (ctx.topdefs d), pos) m - | EAbs { binder; tys }, m -> - let vars, body, ctx = unmbind_in ctx ~fname:ctx.vars binder in - let body = expr ctx body in - let binder = bind vars body in - eabs binder (List.map (typ ctx) tys) m - | EStruct { name; fields }, m -> - estruct ~name:(ctx.structs name) - ~fields: - (StructField.Map.fold - (fun fld e -> StructField.Map.add (ctx.fields fld) (expr ctx e)) - fields StructField.Map.empty) - m - | EStructAccess { name; field; e }, m -> - estructaccess ~name:(ctx.structs name) ~field:(ctx.fields field) - ~e:(expr ctx e) m - | EInj { name; e; cons }, m -> - einj ~name:(ctx.enums name) ~cons:(ctx.constrs cons) ~e:(expr ctx e) m - | EMatch { name; e; cases }, m -> - ematch ~name:(ctx.enums name) - ~cases: - (EnumConstructor.Map.fold - (fun cons e -> - EnumConstructor.Map.add (ctx.constrs cons) (expr ctx e)) - cases EnumConstructor.Map.empty) - ~e:(expr ctx e) m - | e -> map ~typ:(typ ctx) ~f:(expr ctx) ~op:Fun.id e - - let scope_name ctx s = ctx.scopes s - let topdef_name ctx s = ctx.topdefs s - let struct_name ctx s = ctx.structs s - let enum_name ctx e = ctx.enums e -end - let format ppf e = Print.expr ~debug:false () ppf e let rec size : type a. (a, 't) gexpr -> int = diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 28207ae6..6b73ab97 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -393,59 +393,7 @@ val remove_logging_calls : (** Removes all calls to [Log] unary operators in the AST, replacing them by their argument. *) -(** {2 Renamings and formatting} *) - -module Renaming : sig - type config = { - reserved : string list; (** Use for keywords and built-ins *) - sanitize_varname : string -> string; (** Typically String.to_snake_case *) - reset_context_for_closed_terms : bool; (** See [Bindlib.Renaming] *) - skip_constant_binders : bool; (** See [Bindlib.Renaming] *) - constant_binder_name : string option; (** See [Bindlib.Renaming] *) - } - - type context - - val get_ctx : config -> context - - val unbind_in : - context -> - ?fname:(string -> string) -> - ('e, 'b) Bindlib.binder -> - ('e, _) Mark.ed Var.t * 'b * context - (* [fname] applies a transformation on the variable name (typically something - like [String.to_snake_case]). The result is advisory and a numerical suffix - may be appended or modified *) - - val unmbind_in : - context -> - ?fname:(string -> string) -> - ('e, 'b) Bindlib.mbinder -> - ('e, _) Mark.ed Var.t Array.t * 'b * context - - val new_id : context -> string -> string * context - - val set_rewriters : - ?scopes:(ScopeName.t -> ScopeName.t) -> - ?topdefs:(TopdefName.t -> TopdefName.t) -> - ?structs:(StructName.t -> StructName.t) -> - ?fields:(StructField.t -> StructField.t) -> - ?enums:(EnumName.t -> EnumName.t) -> - ?constrs:(EnumConstructor.t -> EnumConstructor.t) -> - context -> - context - - val typ : context -> typ -> typ - - val expr : context -> ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr - (** Disambiguates all variable names in [e], and renames structs, fields, - enums and constrs according to the given context configuration *) - - val scope_name : context -> ScopeName.t -> ScopeName.t - val topdef_name : context -> TopdefName.t -> TopdefName.t - val struct_name : context -> StructName.t -> StructName.t - val enum_name : context -> EnumName.t -> EnumName.t -end +(** {2 Formatting} *) val format : Format.formatter -> ('a, 'm) gexpr -> unit (** Simple printing without debug, use [Print.expr ()] instead to follow the diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index b76d1e84..7d92f909 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -96,221 +96,3 @@ let modules_to_list (mt : module_tree) = mtree acc in List.rev (aux [] mt) - -let cap s = String.to_ascii s |> String.capitalize_ascii -let uncap s = String.to_ascii s |> String.uncapitalize_ascii - -(* Todo? - handle separate namespaces ? (e.g. allow a field and var to have the - same name for backends that support it) - register module names as reserved - names *) -let rename_ids - ~reserved - ~reset_context_for_closed_terms - ~skip_constant_binders - ~constant_binder_name - ~namespaced_fields_constrs - ?(f_var = String.to_snake_case) - ?(f_struct = cap) - ?(f_field = uncap) - ?(f_enum = cap) - ?(f_constr = cap) - p = - let cfg = - { - Expr.Renaming.reserved; - sanitize_varname = f_var; - reset_context_for_closed_terms; - skip_constant_binders; - constant_binder_name; - } - in - let ctx = Expr.Renaming.get_ctx cfg in - (* Each module needs its separate ctx since resolution is qualified ; and name - resolution in a given module must be processed consistently independently - on the current context. *) - let ctx0 = ctx in - let module PathMap = Map.Make (Uid.Path) in - let pctxmap = PathMap.singleton [] ctx in - let pctxmap, structs_map, fields_map, ctx_structs = - (* Warning: the folding order matters here, if a module contains e.g. two - fields with the same name. This fold relies on UIDs, and is thus - dependent on the definition order. Another possibility would be to fold - lexicographically, but the result would be "less intuitive" *) - StructName.Map.fold - (fun name fields (pctxmap, structs_map, fields_map, ctx_structs) -> - let path = StructName.path name in - let str, pos = StructName.get_info name in - let pctxmap, ctx = - try pctxmap, PathMap.find path pctxmap - with PathMap.Not_found _ -> PathMap.add path ctx pctxmap, ctx - in - let id, ctx = Expr.Renaming.new_id ctx (f_struct str) in - let new_name = StructName.fresh path (id, pos) in - let ctx1, fields_map, ctx_fields = - StructField.Map.fold - (fun name ty (ctx, fields_map, ctx_fields) -> - let str, pos = StructField.get_info name in - let id, ctx = Expr.Renaming.new_id ctx (f_field str) in - let new_name = StructField.fresh (id, pos) in - ( ctx, - StructField.Map.add name new_name fields_map, - StructField.Map.add new_name ty ctx_fields )) - fields - ( (if namespaced_fields_constrs then ctx0 else ctx), - fields_map, - StructField.Map.empty ) - in - let ctx = if namespaced_fields_constrs then ctx else ctx1 in - ( PathMap.add path ctx pctxmap, - StructName.Map.add name new_name structs_map, - fields_map, - StructName.Map.add new_name ctx_fields ctx_structs )) - p.decl_ctx.ctx_structs - ( pctxmap, - StructName.Map.empty, - StructField.Map.empty, - StructName.Map.empty ) - in - let pctxmap, enums_map, constrs_map, ctx_enums = - EnumName.Map.fold - (fun name constrs (pctxmap, enums_map, constrs_map, ctx_enums) -> - let path = EnumName.path name in - let str, pos = EnumName.get_info name in - let pctxmap, ctx = - try pctxmap, PathMap.find path pctxmap - with Not_found -> PathMap.add path ctx pctxmap, ctx - in - let id, ctx = Expr.Renaming.new_id ctx (f_enum str) in - let new_name = EnumName.fresh path (id, pos) in - let ctx1, constrs_map, ctx_constrs = - EnumConstructor.Map.fold - (fun name ty (ctx, constrs_map, ctx_constrs) -> - let str, pos = EnumConstructor.get_info name in - let id, ctx = Expr.Renaming.new_id ctx (f_constr str) in - let new_name = EnumConstructor.fresh (id, pos) in - ( ctx, - EnumConstructor.Map.add name new_name constrs_map, - EnumConstructor.Map.add new_name ty ctx_constrs )) - constrs - ( (if namespaced_fields_constrs then ctx0 else ctx), - constrs_map, - EnumConstructor.Map.empty ) - in - let ctx = if namespaced_fields_constrs then ctx else ctx1 in - ( PathMap.add path ctx pctxmap, - EnumName.Map.add name new_name enums_map, - constrs_map, - EnumName.Map.add new_name ctx_constrs ctx_enums )) - p.decl_ctx.ctx_enums - ( pctxmap, - EnumName.Map.empty, - EnumConstructor.Map.empty, - EnumName.Map.empty ) - in - let pctxmap, scopes_map, ctx_scopes = - ScopeName.Map.fold - (fun name info (pctxmap, scopes_map, ctx_scopes) -> - let info = - { - in_struct_name = StructName.Map.find info.in_struct_name structs_map; - out_struct_name = - StructName.Map.find info.out_struct_name structs_map; - out_struct_fields = - ScopeVar.Map.map - (fun fld -> StructField.Map.find fld fields_map) - info.out_struct_fields; - } - in - let path = ScopeName.path name in - if path = [] then - (* Scopes / topdefs in the root module will be renamed through the - variables binding them in the code_items *) - ( pctxmap, - ScopeName.Map.add name name scopes_map, - ScopeName.Map.add name info ctx_scopes ) - else - let str, pos = ScopeName.get_info name in - let pctxmap, ctx = - try pctxmap, PathMap.find path pctxmap - with Not_found -> PathMap.add path ctx pctxmap, ctx - in - let id, ctx = Expr.Renaming.new_id ctx (f_var str) in - let new_name = ScopeName.fresh path (id, pos) in - ( PathMap.add path ctx pctxmap, - ScopeName.Map.add name new_name scopes_map, - ScopeName.Map.add new_name info ctx_scopes )) - p.decl_ctx.ctx_scopes - (pctxmap, ScopeName.Map.empty, ScopeName.Map.empty) - in - let pctxmap, topdefs_map, ctx_topdefs = - TopdefName.Map.fold - (fun name typ (pctxmap, topdefs_map, ctx_topdefs) -> - let path = TopdefName.path name in - if path = [] then - (* Topdefs / topdefs in the root module will be renamed through the - variables binding them in the code_items *) - ( pctxmap, - TopdefName.Map.add name name topdefs_map, - TopdefName.Map.add name typ ctx_topdefs ) - (* [typ] is rewritten later on *) - else - let str, pos = TopdefName.get_info name in - let pctxmap, ctx = - try pctxmap, PathMap.find path pctxmap - with Not_found -> PathMap.add path ctx pctxmap, ctx - in - let id, ctx = Expr.Renaming.new_id ctx (f_var str) in - let new_name = TopdefName.fresh path (id, pos) in - ( PathMap.add path ctx pctxmap, - TopdefName.Map.add name new_name topdefs_map, - TopdefName.Map.add new_name typ ctx_topdefs )) - p.decl_ctx.ctx_topdefs - (pctxmap, TopdefName.Map.empty, TopdefName.Map.empty) - in - let ctx = PathMap.find [] pctxmap in - let ctx = - Expr.Renaming.set_rewriters ctx - ~scopes:(fun n -> ScopeName.Map.find n scopes_map) - ~topdefs:(fun n -> TopdefName.Map.find n topdefs_map) - ~structs:(fun n -> StructName.Map.find n structs_map) - ~fields:(fun n -> StructField.Map.find n fields_map) - ~enums:(fun n -> EnumName.Map.find n enums_map) - ~constrs:(fun n -> EnumConstructor.Map.find n constrs_map) - in - let decl_ctx = - { p.decl_ctx with ctx_enums; ctx_structs; ctx_scopes; ctx_topdefs } - in - let decl_ctx = map_decl_ctx ~f:(Expr.Renaming.typ ctx) decl_ctx in - let code_items = Scope.rename_ids ctx p.code_items in - { p with decl_ctx; code_items }, ctx - -(* This first-class module wrapping is here to allow a polymorphic renaming - function to be passed around *) - -module type Renaming = sig - val apply : 'e program -> 'e program * Expr.Renaming.context -end - -type renaming = (module Renaming) - -let apply (module R : Renaming) = R.apply - -let renaming - ~reserved - ~reset_context_for_closed_terms - ~skip_constant_binders - ~constant_binder_name - ~namespaced_fields_constrs - ?f_var - ?f_struct - ?f_field - ?f_enum - ?f_constr - () = - let module M = struct - let apply p = - rename_ids ~reserved ~reset_context_for_closed_terms - ~skip_constant_binders ~constant_binder_name ~namespaced_fields_constrs - ?f_var ?f_struct ?f_field ?f_enum ?f_constr p - end in - (module M : Renaming) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 41880a03..071b7873 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -56,36 +56,3 @@ val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list (** Returns a list of used modules, in topological order ; the boolean indicates if the module is external *) - -type renaming - -val apply : renaming -> 'e program -> 'e program * Expr.Renaming.context - -val renaming : - reserved:string list -> - reset_context_for_closed_terms:bool -> - skip_constant_binders:bool -> - constant_binder_name:string option -> - namespaced_fields_constrs:bool -> - ?f_var:(string -> string) -> - ?f_struct:(string -> string) -> - ?f_field:(string -> string) -> - ?f_enum:(string -> string) -> - ?f_constr:(string -> string) -> - unit -> - renaming -(** Renames all idents (variables, types, struct and enum names, fields and - constructors) to dispel ambiguities in the target language. Names in - [reserved], typically keywords and built-ins, will be avoided ; the meaning - of the following three flags is described in [Bindlib.Renaming]. - - if [namespaced_fields_constrs] is true, then struct fields and enum - constructors can reuse names from other fields/constructors or other idents. - - The [f_*] optional arguments sanitize the different kinds of ids. The - default is what is used for OCaml: project to ASCII, capitalise structs, - enums (both modules in the backend) and constructors, lowercase fields, and - rewrite variables to snake case. - - In the returned program, it is safe to directly use `Bindlib.name_of` on - variables for printing. The same is true for `StructName.get_info` etc. *) diff --git a/compiler/shared_ast/renaming.ml b/compiler/shared_ast/renaming.ml new file mode 100644 index 00000000..a310fabc --- /dev/null +++ b/compiler/shared_ast/renaming.ml @@ -0,0 +1,482 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2024 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Catala_utils +open Definitions + +module DefaultBindlibCtxRename : Bindlib.Renaming = struct + (* This code is a copy-paste from Bindlib, they forgot to expose the default + implementation ! *) + type ctxt = int String.Map.t + + let empty_ctxt = String.Map.empty + + let split_name : string -> string * int = + fun name -> + let len = String.length name in + (* [i] is the index of the first first character of the suffix. *) + let i = + let is_digit c = '0' <= c && c <= '9' in + let first_digit = ref len in + let first_non_0 = ref len in + while !first_digit > 0 && is_digit name.[!first_digit - 1] do + decr first_digit; + if name.[!first_digit] <> '0' then first_non_0 := !first_digit + done; + !first_non_0 + in + if i = len then name, 0 + else String.sub name 0 i, int_of_string (String.sub name i (len - i)) + + let get_suffix : string -> int -> ctxt -> int * ctxt = + fun name suffix ctxt -> + let n = try String.Map.find name ctxt with String.Map.Not_found _ -> -1 in + let suffix = if suffix > n then suffix else n + 1 in + suffix, String.Map.add name suffix ctxt + + let merge_name : string -> int -> string = + fun prefix suffix -> + if suffix > 0 then prefix ^ string_of_int suffix else prefix + + let new_name : string -> ctxt -> string * ctxt = + fun name ctxt -> + let prefix, suffix = split_name name in + let suffix, ctxt = get_suffix prefix suffix ctxt in + merge_name prefix suffix, ctxt + + let reserve_name : string -> ctxt -> ctxt = + fun name ctxt -> + let prefix, suffix = split_name name in + try + let n = String.Map.find prefix ctxt in + if suffix <= n then ctxt else String.Map.add prefix suffix ctxt + with String.Map.Not_found _ -> String.Map.add prefix suffix ctxt + + let reset_context_for_closed_terms = false + let skip_constant_binders = false + let constant_binder_name = None +end + +module type BindlibCtxt = module type of Bindlib.Ctxt (DefaultBindlibCtxRename) + +type config = { + reserved : string list; + sanitize_varname : string -> string; + reset_context_for_closed_terms : bool; + skip_constant_binders : bool; + constant_binder_name : string option; +} + +type context = { + bindCtx : (module BindlibCtxt); + bcontext : DefaultBindlibCtxRename.ctxt; + vars : string -> string; + scopes : ScopeName.t -> ScopeName.t; + topdefs : TopdefName.t -> TopdefName.t; + structs : StructName.t -> StructName.t; + fields : StructField.t -> StructField.t; + enums : EnumName.t -> EnumName.t; + constrs : EnumConstructor.t -> EnumConstructor.t; +} + +let unbind_in ctx ?fname b = + let module BindCtx = (val ctx.bindCtx) in + match fname with + | Some fn -> + let name = fn (Bindlib.binder_name b) in + let v, bcontext = BindCtx.new_var_in ctx.bcontext (fun v -> EVar v) name in + let e = Bindlib.subst b (EVar v) in + v, e, { ctx with bcontext } + | None -> + let v, e, bcontext = BindCtx.unbind_in ctx.bcontext b in + v, e, { ctx with bcontext } + +let unmbind_in ctx ?fname b = + let module BindCtx = (val ctx.bindCtx) in + match fname with + | Some fn -> + let names = Array.map fn (Bindlib.mbinder_names b) in + let rvs, bcontext = + Array.fold_left + (fun (rvs, bcontext) n -> + let v, bcontext = BindCtx.new_var_in bcontext (fun v -> EVar v) n in + v :: rvs, bcontext) + ([], ctx.bcontext) names + in + let vs = Array.of_list (List.rev rvs) in + let e = Bindlib.msubst b (Array.map (fun v -> EVar v) vs) in + vs, e, { ctx with bcontext } + | None -> + let vs, e, bcontext = BindCtx.unmbind_in ctx.bcontext b in + vs, e, { ctx with bcontext } + +let set_rewriters ?scopes ?topdefs ?structs ?fields ?enums ?constrs ctx = + (fun ?(scopes = ctx.scopes) ?(topdefs = ctx.topdefs) ?(structs = ctx.structs) + ?(fields = ctx.fields) ?(enums = ctx.enums) ?(constrs = ctx.constrs) () -> + { ctx with scopes; topdefs; structs; fields; enums; constrs }) + ?scopes ?topdefs ?structs ?fields ?enums ?constrs () + +let new_id ctx name = + let module BindCtx = (val ctx.bindCtx) in + let var, bcontext = + BindCtx.new_var_in ctx.bcontext (fun _ -> assert false) name + in + Bindlib.name_of var, { ctx with bcontext } + +let get_ctx cfg = + let module BindCtx = Bindlib.Ctxt (struct + include DefaultBindlibCtxRename + + let reset_context_for_closed_terms = cfg.reset_context_for_closed_terms + let skip_constant_binders = cfg.skip_constant_binders + let constant_binder_name = cfg.constant_binder_name + end) in + { + bindCtx = (module BindCtx); + bcontext = + List.fold_left + (fun ctx name -> DefaultBindlibCtxRename.reserve_name name ctx) + BindCtx.empty_ctxt cfg.reserved; + vars = cfg.sanitize_varname; + scopes = Fun.id; + topdefs = Fun.id; + structs = Fun.id; + fields = Fun.id; + enums = Fun.id; + constrs = Fun.id; + } + +let rec typ ctx = function + | TStruct n, m -> TStruct (ctx.structs n), m + | TEnum n, m -> TEnum (ctx.enums n), m + | ty -> Type.map (typ ctx) ty + +(* {2 Handling expressions} *) + +let rec expr : type k. context -> (k, 'm) gexpr -> (k, 'm) gexpr boxed = + fun ctx -> function + | EExternal { name = External_scope s, pos }, m -> + Expr.eexternal ~name:(External_scope (ctx.scopes s), pos) m + | EExternal { name = External_value d, pos }, m -> + Expr.eexternal ~name:(External_value (ctx.topdefs d), pos) m + | EAbs { binder; tys }, m -> + let vars, body, ctx = unmbind_in ctx ~fname:ctx.vars binder in + let body = expr ctx body in + let binder = Expr.bind vars body in + Expr.eabs binder (List.map (typ ctx) tys) m + | EStruct { name; fields }, m -> + Expr.estruct ~name:(ctx.structs name) + ~fields: + (StructField.Map.fold + (fun fld e -> StructField.Map.add (ctx.fields fld) (expr ctx e)) + fields StructField.Map.empty) + m + | EStructAccess { name; field; e }, m -> + Expr.estructaccess ~name:(ctx.structs name) ~field:(ctx.fields field) + ~e:(expr ctx e) m + | EInj { name; e; cons }, m -> + Expr.einj ~name:(ctx.enums name) ~cons:(ctx.constrs cons) ~e:(expr ctx e) m + | EMatch { name; e; cases }, m -> + Expr.ematch ~name:(ctx.enums name) + ~cases: + (EnumConstructor.Map.fold + (fun cons e -> + EnumConstructor.Map.add (ctx.constrs cons) (expr ctx e)) + cases EnumConstructor.Map.empty) + ~e:(expr ctx e) m + | e -> Expr.map ~typ:(typ ctx) ~f:(expr ctx) ~op:Fun.id e + +let scope_name ctx s = ctx.scopes s +let topdef_name ctx s = ctx.topdefs s +let struct_name ctx s = ctx.structs s +let enum_name ctx e = ctx.enums e + +(* {2 Handling scopes} *) + +(** Maps carrying around a naming context, enriched at each [unbind] *) +let rec boundlist_map_ctx ~f ~fname ~last ~ctx = function + | Last l -> Bindlib.box_apply (fun l -> Last l) (last ctx l) + | Cons (item, next_bind) -> + let item = f ctx item in + let var, next, ctx = unbind_in ctx ~fname next_bind in + let next = boundlist_map_ctx ~f ~fname ~last ~ctx next in + let next_bind = Bindlib.bind_var var next in + Bindlib.box_apply2 + (fun item next_bind -> Cons (item, next_bind)) + item next_bind + +let rename_vars_in_lets ctx scope_body_expr = + boundlist_map_ctx scope_body_expr ~ctx ~fname:String.to_snake_case + ~last:(fun ctx e -> Expr.Box.lift (expr ctx e)) + ~f:(fun ctx scope_let -> + Bindlib.box_apply + (fun scope_let_expr -> + { + scope_let with + scope_let_expr; + scope_let_typ = typ ctx scope_let.scope_let_typ; + }) + (Expr.Box.lift (expr ctx scope_let.scope_let_expr))) + +let code_items ctx (scopes : 'e code_item_list) = + let f ctx = function + | ScopeDef (name, body) -> + let name = scope_name ctx name in + let scope_input_var, scope_lets, ctx = + unbind_in ctx ~fname:String.to_snake_case body.scope_body_expr + in + let scope_lets = rename_vars_in_lets ctx scope_lets in + let scope_body_expr = Bindlib.bind_var scope_input_var scope_lets in + Bindlib.box_apply + (fun scope_body_expr -> + let body = + { + scope_body_input_struct = + struct_name ctx body.scope_body_input_struct; + scope_body_output_struct = + struct_name ctx body.scope_body_output_struct; + scope_body_expr; + } + in + ScopeDef (name, body)) + scope_body_expr + | Topdef (name, ty, e) -> + Bindlib.box_apply + (fun e -> Topdef (name, typ ctx ty, e)) + (Expr.Box.lift (expr ctx e)) + in + Bindlib.unbox + @@ boundlist_map_ctx ~ctx ~f ~fname:String.to_snake_case + ~last:(fun _ctx -> Bindlib.box) + scopes + +let cap s = String.to_ascii s |> String.capitalize_ascii +let uncap s = String.to_ascii s |> String.uncapitalize_ascii + +(* Todo? - handle separate namespaces ? (e.g. allow a field and var to have the + same name for backends that support it) - register module names as reserved + names *) +let program + ~reserved + ~reset_context_for_closed_terms + ~skip_constant_binders + ~constant_binder_name + ~namespaced_fields_constrs + ?(f_var = String.to_snake_case) + ?(f_struct = cap) + ?(f_field = uncap) + ?(f_enum = cap) + ?(f_constr = cap) + p = + let cfg = + { + reserved; + sanitize_varname = f_var; + reset_context_for_closed_terms; + skip_constant_binders; + constant_binder_name; + } + in + let ctx = get_ctx cfg in + (* Each module needs its separate ctx since resolution is qualified ; and name + resolution in a given module must be processed consistently independently + on the current context. *) + let ctx0 = ctx in + let module PathMap = Map.Make (Uid.Path) in + let pctxmap = PathMap.singleton [] ctx in + let pctxmap, structs_map, fields_map, ctx_structs = + (* Warning: the folding order matters here, if a module contains e.g. two + fields with the same name. This fold relies on UIDs, and is thus + dependent on the definition order. Another possibility would be to fold + lexicographically, but the result would be "less intuitive" *) + StructName.Map.fold + (fun name fields (pctxmap, structs_map, fields_map, ctx_structs) -> + let path = StructName.path name in + let str, pos = StructName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with PathMap.Not_found _ -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = new_id ctx (f_struct str) in + let new_name = StructName.fresh path (id, pos) in + let ctx1, fields_map, ctx_fields = + StructField.Map.fold + (fun name ty (ctx, fields_map, ctx_fields) -> + let str, pos = StructField.get_info name in + let id, ctx = new_id ctx (f_field str) in + let new_name = StructField.fresh (id, pos) in + ( ctx, + StructField.Map.add name new_name fields_map, + StructField.Map.add new_name ty ctx_fields )) + fields + ( (if namespaced_fields_constrs then ctx0 else ctx), + fields_map, + StructField.Map.empty ) + in + let ctx = if namespaced_fields_constrs then ctx else ctx1 in + ( PathMap.add path ctx pctxmap, + StructName.Map.add name new_name structs_map, + fields_map, + StructName.Map.add new_name ctx_fields ctx_structs )) + p.decl_ctx.ctx_structs + ( pctxmap, + StructName.Map.empty, + StructField.Map.empty, + StructName.Map.empty ) + in + let pctxmap, enums_map, constrs_map, ctx_enums = + EnumName.Map.fold + (fun name constrs (pctxmap, enums_map, constrs_map, ctx_enums) -> + let path = EnumName.path name in + let str, pos = EnumName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = new_id ctx (f_enum str) in + let new_name = EnumName.fresh path (id, pos) in + let ctx1, constrs_map, ctx_constrs = + EnumConstructor.Map.fold + (fun name ty (ctx, constrs_map, ctx_constrs) -> + let str, pos = EnumConstructor.get_info name in + let id, ctx = new_id ctx (f_constr str) in + let new_name = EnumConstructor.fresh (id, pos) in + ( ctx, + EnumConstructor.Map.add name new_name constrs_map, + EnumConstructor.Map.add new_name ty ctx_constrs )) + constrs + ( (if namespaced_fields_constrs then ctx0 else ctx), + constrs_map, + EnumConstructor.Map.empty ) + in + let ctx = if namespaced_fields_constrs then ctx else ctx1 in + ( PathMap.add path ctx pctxmap, + EnumName.Map.add name new_name enums_map, + constrs_map, + EnumName.Map.add new_name ctx_constrs ctx_enums )) + p.decl_ctx.ctx_enums + ( pctxmap, + EnumName.Map.empty, + EnumConstructor.Map.empty, + EnumName.Map.empty ) + in + let pctxmap, scopes_map, ctx_scopes = + ScopeName.Map.fold + (fun name info (pctxmap, scopes_map, ctx_scopes) -> + let info = + { + in_struct_name = StructName.Map.find info.in_struct_name structs_map; + out_struct_name = + StructName.Map.find info.out_struct_name structs_map; + out_struct_fields = + ScopeVar.Map.map + (fun fld -> StructField.Map.find fld fields_map) + info.out_struct_fields; + } + in + let path = ScopeName.path name in + if path = [] then + (* Scopes / topdefs in the root module will be renamed through the + variables binding them in the code_items *) + ( pctxmap, + ScopeName.Map.add name name scopes_map, + ScopeName.Map.add name info ctx_scopes ) + else + let str, pos = ScopeName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = new_id ctx (f_var str) in + let new_name = ScopeName.fresh path (id, pos) in + ( PathMap.add path ctx pctxmap, + ScopeName.Map.add name new_name scopes_map, + ScopeName.Map.add new_name info ctx_scopes )) + p.decl_ctx.ctx_scopes + (pctxmap, ScopeName.Map.empty, ScopeName.Map.empty) + in + let pctxmap, topdefs_map, ctx_topdefs = + TopdefName.Map.fold + (fun name typ (pctxmap, topdefs_map, ctx_topdefs) -> + let path = TopdefName.path name in + if path = [] then + (* Topdefs / topdefs in the root module will be renamed through the + variables binding them in the code_items *) + ( pctxmap, + TopdefName.Map.add name name topdefs_map, + TopdefName.Map.add name typ ctx_topdefs ) + (* [typ] is rewritten later on *) + else + let str, pos = TopdefName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = new_id ctx (f_var str) in + let new_name = TopdefName.fresh path (id, pos) in + ( PathMap.add path ctx pctxmap, + TopdefName.Map.add name new_name topdefs_map, + TopdefName.Map.add new_name typ ctx_topdefs )) + p.decl_ctx.ctx_topdefs + (pctxmap, TopdefName.Map.empty, TopdefName.Map.empty) + in + let ctx = PathMap.find [] pctxmap in + let ctx = + set_rewriters ctx + ~scopes:(fun n -> ScopeName.Map.find n scopes_map) + ~topdefs:(fun n -> TopdefName.Map.find n topdefs_map) + ~structs:(fun n -> StructName.Map.find n structs_map) + ~fields:(fun n -> StructField.Map.find n fields_map) + ~enums:(fun n -> EnumName.Map.find n enums_map) + ~constrs:(fun n -> EnumConstructor.Map.find n constrs_map) + in + let decl_ctx = + { p.decl_ctx with ctx_enums; ctx_structs; ctx_scopes; ctx_topdefs } + in + let decl_ctx = Program.map_decl_ctx ~f:(typ ctx) decl_ctx in + let code_items = code_items ctx p.code_items in + { p with decl_ctx; code_items }, ctx + +(* This first-class module wrapping is here to allow a polymorphic renaming + function to be passed around *) + +module type Renaming = sig + val apply : 'e program -> 'e program * context +end + +type t = (module Renaming) + +let apply (module R : Renaming) = R.apply + +let program + ~reserved + ~reset_context_for_closed_terms + ~skip_constant_binders + ~constant_binder_name + ~namespaced_fields_constrs + ?f_var + ?f_struct + ?f_field + ?f_enum + ?f_constr + () = + let module M = struct + let apply p = + program ~reserved ~reset_context_for_closed_terms ~skip_constant_binders + ~constant_binder_name ~namespaced_fields_constrs ?f_var ?f_struct + ?f_field ?f_enum ?f_constr p + end in + (module M : Renaming) diff --git a/compiler/shared_ast/renaming.mli b/compiler/shared_ast/renaming.mli new file mode 100644 index 00000000..966be016 --- /dev/null +++ b/compiler/shared_ast/renaming.mli @@ -0,0 +1,105 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2024 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Catala_utils +open Definitions + +type config = { + reserved : string list; (** Use for keywords and built-ins *) + sanitize_varname : string -> string; (** Typically String.to_snake_case *) + reset_context_for_closed_terms : bool; (** See [Bindlib.Renaming] *) + skip_constant_binders : bool; (** See [Bindlib.Renaming] *) + constant_binder_name : string option; (** See [Bindlib.Renaming] *) +} + +type context + +val get_ctx : config -> context + +val unbind_in : + context -> + ?fname:(string -> string) -> + ('e, 'b) Bindlib.binder -> + ('e, _) Mark.ed Var.t * 'b * context +(* [fname] applies a transformation on the variable name (typically something + like [String.to_snake_case]). The result is advisory and a numerical suffix + may be appended or modified *) + +val unmbind_in : + context -> + ?fname:(string -> string) -> + ('e, 'b) Bindlib.mbinder -> + ('e, _) Mark.ed Var.t Array.t * 'b * context + +val new_id : context -> string -> string * context + +val set_rewriters : + ?scopes:(ScopeName.t -> ScopeName.t) -> + ?topdefs:(TopdefName.t -> TopdefName.t) -> + ?structs:(StructName.t -> StructName.t) -> + ?fields:(StructField.t -> StructField.t) -> + ?enums:(EnumName.t -> EnumName.t) -> + ?constrs:(EnumConstructor.t -> EnumConstructor.t) -> + context -> + context + +val typ : context -> typ -> typ + +val expr : context -> ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr +(** Disambiguates all variable names in [e], and renames structs, fields, enums + and constrs according to the given context configuration *) + +val scope_name : context -> ScopeName.t -> ScopeName.t +val topdef_name : context -> TopdefName.t -> TopdefName.t +val struct_name : context -> StructName.t -> StructName.t +val enum_name : context -> EnumName.t -> EnumName.t + +val code_items : + context -> ((_ any, 'm) gexpr as 'e) code_item_list -> 'e code_item_list + +type t +(** Enclosing of a polymorphic renaming function, to be used by [apply] *) + +val apply : t -> 'e program -> 'e program * context + +val program : + reserved:string list -> + reset_context_for_closed_terms:bool -> + skip_constant_binders:bool -> + constant_binder_name:string option -> + namespaced_fields_constrs:bool -> + ?f_var:(string -> string) -> + ?f_struct:(string -> string) -> + ?f_field:(string -> string) -> + ?f_enum:(string -> string) -> + ?f_constr:(string -> string) -> + unit -> + t +(** Renames all idents (variables, types, struct and enum names, fields and + constructors) to dispel ambiguities in the target language. Names in + [reserved], typically keywords and built-ins, will be avoided ; the meaning + of the following three flags is described in [Bindlib.Renaming]. + + if [namespaced_fields_constrs] is true, then struct fields and enum + constructors can reuse names from other fields/constructors or other idents. + + The [f_*] optional arguments sanitize the different kinds of ids. The + default is what is used for OCaml: project to ASCII, capitalise structs, + enums (both modules in the backend) and constructors, lowercase fields, and + rewrite variables to snake case. + + In the returned program, it is safe to directly use `Bindlib.name_of` on + variables for printing. The same is true for `StructName.get_info` etc. *) diff --git a/compiler/shared_ast/scope.ml b/compiler/shared_ast/scope.ml index bf43c64d..d7514bc3 100644 --- a/compiler/shared_ast/scope.ml +++ b/compiler/shared_ast/scope.ml @@ -146,61 +146,3 @@ let free_vars scopes = ~init:(fun _vlist -> Var.Set.empty) ~f:(fun item v acc -> Var.Set.union (Var.Set.remove v acc) (free_vars_item item)) - -(** Maps carrying around a naming context, enriched at each [unbind] *) -let rec boundlist_map_ctx ~f ~fname ~last ~ctx = function - | Last l -> Bindlib.box_apply (fun l -> Last l) (last ctx l) - | Cons (item, next_bind) -> - let item = f ctx item in - let var, next, ctx = Expr.Renaming.unbind_in ctx ~fname next_bind in - let next = boundlist_map_ctx ~f ~fname ~last ~ctx next in - let next_bind = Bindlib.bind_var var next in - Bindlib.box_apply2 - (fun item next_bind -> Cons (item, next_bind)) - item next_bind - -let rename_vars_in_lets ctx scope_body_expr = - boundlist_map_ctx scope_body_expr ~ctx ~fname:String.to_snake_case - ~last:(fun ctx e -> Expr.Box.lift (Expr.Renaming.expr ctx e)) - ~f:(fun ctx scope_let -> - Bindlib.box_apply - (fun scope_let_expr -> - { - scope_let with - scope_let_expr; - scope_let_typ = Expr.Renaming.typ ctx scope_let.scope_let_typ; - }) - (Expr.Box.lift (Expr.Renaming.expr ctx scope_let.scope_let_expr))) - -let rename_ids ctx (scopes : 'e code_item_list) = - let f ctx = function - | ScopeDef (name, body) -> - let name = Expr.Renaming.scope_name ctx name in - let scope_input_var, scope_lets, ctx = - Expr.Renaming.unbind_in ctx ~fname:String.to_snake_case - body.scope_body_expr - in - let scope_lets = rename_vars_in_lets ctx scope_lets in - let scope_body_expr = Bindlib.bind_var scope_input_var scope_lets in - Bindlib.box_apply - (fun scope_body_expr -> - let body = - { - scope_body_input_struct = - Expr.Renaming.struct_name ctx body.scope_body_input_struct; - scope_body_output_struct = - Expr.Renaming.struct_name ctx body.scope_body_output_struct; - scope_body_expr; - } - in - ScopeDef (name, body)) - scope_body_expr - | Topdef (name, ty, expr) -> - Bindlib.box_apply - (fun e -> Topdef (name, Expr.Renaming.typ ctx ty, e)) - (Expr.Box.lift (Expr.Renaming.expr ctx expr)) - in - Bindlib.unbox - @@ boundlist_map_ctx ~ctx ~f ~fname:String.to_snake_case - ~last:(fun _ctx -> Bindlib.box) - scopes diff --git a/compiler/shared_ast/scope.mli b/compiler/shared_ast/scope.mli index d63b06d1..30d14699 100644 --- a/compiler/shared_ast/scope.mli +++ b/compiler/shared_ast/scope.mli @@ -77,11 +77,6 @@ val input_type : typ -> Runtime.io_input Mark.pos -> typ this doesn't take thunking into account (thunking is added during the scopelang->dcalc translation) *) -val rename_ids : - Expr.Renaming.context -> - ((_ any, 'm) gexpr as 'e) code_item_list -> - 'e code_item_list - (** {2 Analysis and tests} *) val free_vars_body_expr : 'e scope_body_expr -> 'e Var.Set.t diff --git a/compiler/shared_ast/shared_ast.ml b/compiler/shared_ast/shared_ast.ml index 739066cc..c854c71d 100644 --- a/compiler/shared_ast/shared_ast.ml +++ b/compiler/shared_ast/shared_ast.ml @@ -23,6 +23,7 @@ module Expr = Expr module BoundList = BoundList module Scope = Scope module Program = Program +module Renaming = Renaming module Print = Print module Typing = Typing module Interpreter = Interpreter From f565e84daef8f6a65d6619817d525ced8a4dfae7 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 7 Aug 2024 18:33:22 +0200 Subject: [PATCH 4/9] Restore special handling of the option enum in scalc backends it's not very satisfying, but we make it pass through renaming for now. A better approach could be to make this special handling structural, or to have something more specific than an id to hold onto. --- compiler/shared_ast/renaming.ml | 66 +++-- tests/backends/python_name_clash.catala_en | 102 +++---- .../good/toplevel_defs.catala_en | 277 +++++++++--------- 3 files changed, 218 insertions(+), 227 deletions(-) diff --git a/compiler/shared_ast/renaming.ml b/compiler/shared_ast/renaming.ml index a310fabc..ba3113e2 100644 --- a/compiler/shared_ast/renaming.ml +++ b/compiler/shared_ast/renaming.ml @@ -340,33 +340,45 @@ let program let pctxmap, enums_map, constrs_map, ctx_enums = EnumName.Map.fold (fun name constrs (pctxmap, enums_map, constrs_map, ctx_enums) -> - let path = EnumName.path name in - let str, pos = EnumName.get_info name in - let pctxmap, ctx = - try pctxmap, PathMap.find path pctxmap - with Not_found -> PathMap.add path ctx pctxmap, ctx - in - let id, ctx = new_id ctx (f_enum str) in - let new_name = EnumName.fresh path (id, pos) in - let ctx1, constrs_map, ctx_constrs = - EnumConstructor.Map.fold - (fun name ty (ctx, constrs_map, ctx_constrs) -> - let str, pos = EnumConstructor.get_info name in - let id, ctx = new_id ctx (f_constr str) in - let new_name = EnumConstructor.fresh (id, pos) in - ( ctx, - EnumConstructor.Map.add name new_name constrs_map, - EnumConstructor.Map.add new_name ty ctx_constrs )) - constrs - ( (if namespaced_fields_constrs then ctx0 else ctx), - constrs_map, - EnumConstructor.Map.empty ) - in - let ctx = if namespaced_fields_constrs then ctx else ctx1 in - ( PathMap.add path ctx pctxmap, - EnumName.Map.add name new_name enums_map, - constrs_map, - EnumName.Map.add new_name ctx_constrs ctx_enums )) + if EnumName.equal name Expr.option_enum then + (* The option type shouldn't be renamed, it has special handling in + backends. FIXME: could the fact that it's special be detected + differently from id comparison ? Structure maybe, or a more + specific construct ? *) + ( pctxmap, + EnumName.Map.add name name enums_map, + EnumConstructor.Map.fold + (fun c _ constrs_map -> EnumConstructor.Map.add c c constrs_map) + Expr.option_enum_config constrs_map, + ctx_enums ) + else + let path = EnumName.path name in + let str, pos = EnumName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = new_id ctx (f_enum str) in + let new_name = EnumName.fresh path (id, pos) in + let ctx1, constrs_map, ctx_constrs = + EnumConstructor.Map.fold + (fun name ty (ctx, constrs_map, ctx_constrs) -> + let str, pos = EnumConstructor.get_info name in + let id, ctx = new_id ctx (f_constr str) in + let new_name = EnumConstructor.fresh (id, pos) in + ( ctx, + EnumConstructor.Map.add name new_name constrs_map, + EnumConstructor.Map.add new_name ty ctx_constrs )) + constrs + ( (if namespaced_fields_constrs then ctx0 else ctx), + constrs_map, + EnumConstructor.Map.empty ) + in + let ctx = if namespaced_fields_constrs then ctx else ctx1 in + ( PathMap.add path ctx pctxmap, + EnumName.Map.add name new_name enums_map, + constrs_map, + EnumName.Map.add new_name ctx_constrs ctx_enums )) p.decl_ctx.ctx_enums ( pctxmap, EnumName.Map.empty, diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en index 62dab911..00efc21f 100644 --- a/tests/backends/python_name_clash.catala_en +++ b/tests/backends/python_name_clash.catala_en @@ -91,81 +91,75 @@ class BIn: def some_name(some_name_in:SomeNameIn): i = some_name_in.i_in - match_arg = handle_exceptions([], []) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + perhaps_none_arg = handle_exceptions([], []) + if perhaps_none_arg is None: if True: - o3 = Eoption(Eoption_Code.ESome, (i + integer_of_string("1"))) + o3 = (i + integer_of_string("1")) else: - o3 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - o3 = Eoption(Eoption_Code.ESome, x) - match_arg = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=10, start_column=23, - end_line=10, end_column=28, law_headings=[])], - [o3] - ) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + o3 = None + else: + x = perhaps_none_arg + o3 = x + perhaps_none_arg = handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=10, start_column=23, + end_line=10, end_column=28, law_headings=[])], + [o3] + ) + if perhaps_none_arg is None: if False: - o2 = Eoption(Eoption_Code.ENone, Unit()) + o2 = None else: - o2 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - o2 = Eoption(Eoption_Code.ESome, x) - match_arg = o2 - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + o2 = None + else: + x = perhaps_none_arg + o2 = x + perhaps_none_arg = o2 + if perhaps_none_arg is None: raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", start_line=7, start_column=10, end_line=7, end_column=11, law_headings=[])) - elif match_arg.code == Eoption_Code.ESome: - arg = match_arg.value + else: + arg = perhaps_none_arg o1 = arg o = o1 return SomeName(o = o) def b(b_in:BIn): - match_arg = handle_exceptions([], []) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + perhaps_none_arg = handle_exceptions([], []) + if perhaps_none_arg is None: if True: - result3 = Eoption(Eoption_Code.ESome, integer_of_string("1")) + result3 = integer_of_string("1") else: - result3 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - result3 = Eoption(Eoption_Code.ESome, x) - match_arg = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=16, start_column=33, - end_line=16, end_column=34, law_headings=[])], - [result3] - ) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + result3 = None + else: + x = perhaps_none_arg + result3 = x + perhaps_none_arg = handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=16, start_column=33, + end_line=16, end_column=34, law_headings=[])], + [result3] + ) + if perhaps_none_arg is None: if False: - result2 = Eoption(Eoption_Code.ENone, Unit()) + result2 = None else: - result2 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - result2 = Eoption(Eoption_Code.ESome, x) - match_arg = result2 - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + result2 = None + else: + x = perhaps_none_arg + result2 = x + perhaps_none_arg = result2 + if perhaps_none_arg is None: raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", start_line=16, start_column=14, end_line=16, end_column=25, law_headings=[])) - elif match_arg.code == Eoption_Code.ESome: - arg = match_arg.value + else: + arg = perhaps_none_arg result1 = arg result = some_name(SomeNameIn(i_in = result1)) result1 = SomeName(o = result.o) diff --git a/tests/name_resolution/good/toplevel_defs.catala_en b/tests/name_resolution/good/toplevel_defs.catala_en index d0ad73e2..16091bff 100644 --- a/tests/name_resolution/good/toplevel_defs.catala_en +++ b/tests/name_resolution/good/toplevel_defs.catala_en @@ -446,215 +446,200 @@ glob6 = ( ) def s2(s2_in:S2In): - match_arg = handle_exceptions([], []) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + perhaps_none_arg = handle_exceptions([], []) + if perhaps_none_arg is None: if True: - a3 = Eoption(Eoption_Code.ESome, - (glob3(money_of_cents_string("4400")) + - decimal_of_string("100."))) + a3 = (glob3(money_of_cents_string("4400")) + + decimal_of_string("100.")) else: - a3 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a3 = Eoption(Eoption_Code.ESome, x) - match_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=53, start_column=24, - end_line=53, end_column=43, - law_headings=["Test toplevel function defs"])], - [a3] - ) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a3 = None + else: + x = perhaps_none_arg + a3 = x + perhaps_none_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=53, start_column=24, + end_line=53, end_column=43, + law_headings=["Test toplevel function defs"])], + [a3] + ) + if perhaps_none_arg is None: if False: - a2 = Eoption(Eoption_Code.ENone, Unit()) + a2 = None else: - a2 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a2 = Eoption(Eoption_Code.ESome, x) - match_arg = a2 - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a2 = None + else: + x = perhaps_none_arg + a2 = x + perhaps_none_arg = a2 + if perhaps_none_arg is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=50, start_column=10, end_line=50, end_column=11, law_headings=["Test toplevel function defs"])) - elif match_arg.code == Eoption_Code.ESome: - arg = match_arg.value + else: + arg = perhaps_none_arg a1 = arg a = a1 return S2(a = a) def s3(s3_in:S3In): - match_arg = handle_exceptions([], []) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + perhaps_none_arg = handle_exceptions([], []) + if perhaps_none_arg is None: if True: - a3 = Eoption(Eoption_Code.ESome, - (decimal_of_string("50.") + + a3 = (decimal_of_string("50.") + glob4(money_of_cents_string("4400"), - decimal_of_string("55.")))) + decimal_of_string("55."))) else: - a3 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a3 = Eoption(Eoption_Code.ESome, x) - match_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=74, start_column=24, - end_line=74, end_column=47, - law_headings=["Test function def with two args"])], - [a3] - ) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a3 = None + else: + x = perhaps_none_arg + a3 = x + perhaps_none_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=74, start_column=24, + end_line=74, end_column=47, + law_headings=["Test function def with two args"] + )], + [a3] + ) + if perhaps_none_arg is None: if False: - a2 = Eoption(Eoption_Code.ENone, Unit()) + a2 = None else: - a2 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a2 = Eoption(Eoption_Code.ESome, x) - match_arg = a2 - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a2 = None + else: + x = perhaps_none_arg + a2 = x + perhaps_none_arg = a2 + if perhaps_none_arg is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=71, start_column=10, end_line=71, end_column=11, law_headings=["Test function def with two args"])) - elif match_arg.code == Eoption_Code.ESome: - arg = match_arg.value + else: + arg = perhaps_none_arg a1 = arg a = a1 return S3(a = a) def s4(s4_in:S4In): - match_arg = handle_exceptions([], []) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + perhaps_none_arg = handle_exceptions([], []) + if perhaps_none_arg is None: if True: - a3 = Eoption(Eoption_Code.ESome, - (glob5 + - decimal_of_string("1."))) + a3 = (glob5 + decimal_of_string("1.")) else: - a3 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a3 = Eoption(Eoption_Code.ESome, x) - match_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=98, start_column=24, - end_line=98, end_column=34, - law_headings=["Test inline defs in toplevel defs"])], - [a3] - ) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a3 = None + else: + x = perhaps_none_arg + a3 = x + perhaps_none_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=98, start_column=24, + end_line=98, end_column=34, + law_headings=["Test inline defs in toplevel defs"] + )], + [a3] + ) + if perhaps_none_arg is None: if False: - a2 = Eoption(Eoption_Code.ENone, Unit()) + a2 = None else: - a2 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a2 = Eoption(Eoption_Code.ESome, x) - match_arg = a2 - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a2 = None + else: + x = perhaps_none_arg + a2 = x + perhaps_none_arg = a2 + if perhaps_none_arg is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=95, start_column=10, end_line=95, end_column=11, law_headings=["Test inline defs in toplevel defs"])) - elif match_arg.code == Eoption_Code.ESome: - arg = match_arg.value + else: + arg = perhaps_none_arg a1 = arg a = a1 return S4(a = a) def s5(s_in:SIn): - match_arg = handle_exceptions([], []) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + perhaps_none_arg = handle_exceptions([], []) + if perhaps_none_arg is None: if True: - a3 = Eoption(Eoption_Code.ESome, (glob1 * glob1)) + a3 = (glob1 * glob1) else: - a3 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a3 = Eoption(Eoption_Code.ESome, x) - match_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=18, start_column=24, - end_line=18, end_column=37, - law_headings=["Test basic toplevel values defs"])], - [a3] - ) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a3 = None + else: + x = perhaps_none_arg + a3 = x + perhaps_none_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=18, start_column=24, + end_line=18, end_column=37, + law_headings=["Test basic toplevel values defs"] + )], + [a3] + ) + if perhaps_none_arg is None: if False: - a2 = Eoption(Eoption_Code.ENone, Unit()) + a2 = None else: - a2 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - a2 = Eoption(Eoption_Code.ESome, x) - match_arg = a2 - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + a2 = None + else: + x = perhaps_none_arg + a2 = x + perhaps_none_arg = a2 + if perhaps_none_arg is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=7, start_column=10, end_line=7, end_column=11, law_headings=["Test basic toplevel values defs"])) - elif match_arg.code == Eoption_Code.ESome: - arg = match_arg.value + else: + arg = perhaps_none_arg a1 = arg a = a1 - match_arg = handle_exceptions([], []) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + perhaps_none_arg = handle_exceptions([], []) + if perhaps_none_arg is None: if True: - b3 = Eoption(Eoption_Code.ESome, glob6) + b3 = glob6 else: - b3 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - b3 = Eoption(Eoption_Code.ESome, x) - match_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=19, start_column=24, - end_line=19, end_column=29, - law_headings=["Test basic toplevel values defs"])], - [b3] - ) - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + b3 = None + else: + x = perhaps_none_arg + b3 = x + perhaps_none_arg = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=19, start_column=24, + end_line=19, end_column=29, + law_headings=["Test basic toplevel values defs"] + )], + [b3] + ) + if perhaps_none_arg is None: if False: - b2 = Eoption(Eoption_Code.ENone, Unit()) + b2 = None else: - b2 = Eoption(Eoption_Code.ENone, Unit()) - elif match_arg.code == Eoption_Code.ESome: - x = match_arg.value - b2 = Eoption(Eoption_Code.ESome, x) - match_arg = b2 - if match_arg.code == Eoption_Code.ENone: - _ = match_arg.value + b2 = None + else: + x = perhaps_none_arg + b2 = x + perhaps_none_arg = b2 + if perhaps_none_arg is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=8, start_column=10, end_line=8, end_column=11, law_headings=["Test basic toplevel values defs"])) - elif match_arg.code == Eoption_Code.ESome: - arg = match_arg.value + else: + arg = perhaps_none_arg b1 = arg b = b1 return S(a = a, b = b) From 14a378a33d57b5a8ab1889b065f471b9ebf0d421 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Thu, 8 Aug 2024 12:03:53 +0200 Subject: [PATCH 5/9] Translation to scalc: fix renaming in blocks Statements are often flattened, in which case their idents need to be conflict-free. We pass along the renaming context to handle this. --- compiler/scalc/from_lcalc.ml | 270 +++++++++++---------- compiler/scalc/to_python.ml | 3 +- compiler/shared_ast/program.ml | 1 - tests/backends/python_name_clash.catala_en | 6 +- 4 files changed, 146 insertions(+), 134 deletions(-) diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 997e8321..6cf25081 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -98,16 +98,16 @@ let register_fresh_arg ~pos ctxt (x, _) = ctxt let rec translate_expr_list ctxt args = - let stmts, args = + let stmts, args, ren_ctx = List.fold_left - (fun (args_stmts, new_args) arg -> - let arg_stmts, new_arg = translate_expr ctxt arg in - args_stmts ++ arg_stmts, new_arg :: new_args) - (RevBlock.empty, []) args + (fun (args_stmts, new_args, ren_ctx) arg -> + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + args_stmts ++ arg_stmts, new_arg :: new_args, ren_ctx) + (RevBlock.empty, [], ctxt.ren_ctx) args in - stmts, List.rev args + stmts, List.rev args, ren_ctx -and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = +and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr * Renaming.context = try match Mark.remove expr with | EVar v -> @@ -123,27 +123,27 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = Print.var_debug ppf v)) (Var.Map.keys ctxt.var_dict)) in - RevBlock.empty, (local_var, Expr.pos expr) + RevBlock.empty, (local_var, Expr.pos expr), ctxt.ren_ctx | EStruct { fields; name } -> if ctxt.config.no_struct_literals then (* In C89, struct literates have to be initialized at variable definition... *) raise (NotAnExpr { needs_a_local_decl = false }); - let args_stmts, new_args = + let args_stmts, new_args, ren_ctx = StructField.Map.fold - (fun field arg (args_stmts, new_args) -> - let arg_stmts, new_arg = translate_expr ctxt arg in - args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args) + (fun field arg (args_stmts, new_args, ren_ctx) -> + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args, ren_ctx) fields - (RevBlock.empty, StructField.Map.empty) + (RevBlock.empty, StructField.Map.empty, ctxt.ren_ctx) in - args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr) + args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr), ren_ctx | EInj { e = e1; cons; name } -> if ctxt.config.no_struct_literals then (* In C89, struct literates have to be initialized at variable definition... *) raise (NotAnExpr { needs_a_local_decl = false }); - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in ( e1_stmts, ( A.EInj { @@ -152,21 +152,23 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = name; expr_typ = Expr.maybe_ty (Mark.get expr); }, - Expr.pos expr ) ) + Expr.pos expr ), + ren_ctx ) | ETuple args -> - let args_stmts, new_args = translate_expr_list ctxt args in - args_stmts, (A.ETuple new_args, Expr.pos expr) + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in + args_stmts, (A.ETuple new_args, Expr.pos expr), ren_ctx | EStructAccess { e = e1; field; name } -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in ( e1_stmts, - (A.EStructFieldAccess { e1 = new_e1; field; name }, Expr.pos expr) ) + (A.EStructFieldAccess { e1 = new_e1; field; name }, Expr.pos expr), + ren_ctx) | ETupleAccess { e = e1; index; _ } -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in - e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr) + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in + e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr), ren_ctx | EAppOp { op; args; tys = _ } -> - let args_stmts, new_args = translate_expr_list ctxt args in + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) - args_stmts, (A.EAppOp { op; args = new_args }, Expr.pos expr) + args_stmts, (A.EAppOp { op; args = new_args }, Expr.pos expr), ren_ctx | EApp { f = EAbs { binder; tys }, binder_mark; args; tys = _ } -> (* This defines multiple local variables at the time *) let binder_pos = Expr.mark_pos binder_mark in @@ -190,39 +192,42 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = (Var.Map.find x ctxt.var_dict, binder_pos), tau, arg) vars_tau args in - let def_blocks = + let def_blocks, ren_ctx = List.fold_left - (fun acc (x, _tau, arg) -> + (fun (rblock, ren_ctx) (x, _tau, arg) -> let ctxt = { ctxt with inside_definition_of = Some (Mark.remove x); context_name = Mark.remove (A.VarName.get_info (Mark.remove x)); + ren_ctx; } in - let arg_stmts, new_arg = translate_expr ctxt arg in - RevBlock.append (acc ++ arg_stmts) + let arg_stmts, new_arg, ren_ctx = translate_expr ctxt arg in + RevBlock.append (rblock ++ arg_stmts) ( A.SLocalDef { name = x; expr = new_arg; typ = Expr.maybe_ty (Mark.get arg); }, - binder_pos )) - RevBlock.empty vars_args + binder_pos ), + ren_ctx) + (RevBlock.empty, ctxt.ren_ctx) vars_args in - let rest_of_expr_stmts, rest_of_expr = translate_expr ctxt body in - local_decls ++ def_blocks ++ rest_of_expr_stmts, rest_of_expr + let rest_of_expr_stmts, rest_of_expr, ren_ctx = translate_expr { ctxt with ren_ctx } body in + local_decls ++ def_blocks ++ rest_of_expr_stmts, rest_of_expr, ren_ctx | EApp { f; args; tys = _ } -> - let f_stmts, new_f = translate_expr ctxt f in - let args_stmts, new_args = translate_expr_list ctxt args in + let f_stmts, new_f, ren_ctx = translate_expr ctxt f in + let args_stmts, new_args, ren_ctx = translate_expr_list { ctxt with ren_ctx } args in (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) ( f_stmts ++ args_stmts, - (A.EApp { f = new_f; args = new_args }, Expr.pos expr) ) + (A.EApp { f = new_f; args = new_args }, Expr.pos expr), + ren_ctx ) | EArray args -> - let args_stmts, new_args = translate_expr_list ctxt args in - args_stmts, (A.EArray new_args, Expr.pos expr) - | ELit l -> RevBlock.empty, (A.ELit l, Expr.pos expr) + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in + args_stmts, (A.EArray new_args, Expr.pos expr), ren_ctx + | ELit l -> RevBlock.empty, (A.ELit l, Expr.pos expr), ctxt.ren_ctx | EExternal { name } -> let path, name = match Mark.remove name with @@ -233,7 +238,7 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = ( ModuleName.Map.find (List.hd (List.rev path)) ctxt.program_ctx.modules, Expr.pos expr ) in - RevBlock.empty, (EExternal { modname; name }, Expr.pos expr) + RevBlock.empty, (EExternal { modname; name }, Expr.pos expr), ctxt.ren_ctx | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | EFatalError _ -> raise (NotAnExpr { needs_a_local_decl = true }) | _ -> . @@ -253,7 +258,7 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = context_name = Mark.remove (A.VarName.get_info tmp_var); } in - let tmp_stmts = translate_statements ctxt expr in + let tmp_stmts, ren_ctx = translate_statements ctxt expr in ( (if needs_a_local_decl then RevBlock.make (( A.SLocalDecl @@ -264,17 +269,19 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = Expr.pos expr ) :: tmp_stmts) else RevBlock.make tmp_stmts), - (A.EVar tmp_var, Expr.pos expr) ) + (A.EVar tmp_var, Expr.pos expr), + ren_ctx ) -and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = +and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * Renaming.context = match Mark.remove block_expr with | EAssert e -> (* Assertions are always encapsulated in a unit-typed let binding *) - let e_stmts, new_e = translate_expr ctxt e in + let e_stmts, new_e, ren_ctx = translate_expr ctxt e in RevBlock.rebuild ~tail:[A.SAssert (Mark.remove new_e), Expr.pos block_expr] - e_stmts - | EFatalError err -> [SFatalError err, Expr.pos block_expr] + e_stmts, + ren_ctx + | EFatalError err -> [SFatalError err, Expr.pos block_expr], ctxt.ren_ctx (* | EAppOp * { op = Op.HandleDefaultOpt, _; tys = _; args = [exceptions; just; cons] } * when ctxt.config.keep_special_ops -> @@ -351,32 +358,31 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = (Var.Map.find x ctxt.var_dict, binder_pos), tau, arg) vars_tau args in - let def_blocks = - List.map - (fun (x, _tau, arg) -> + let def_blocks, ren_ctx = + List.fold_left + (fun (def_blocks, ren_ctx) (x, _tau, arg) -> let ctxt = { ctxt with inside_definition_of = Some (Mark.remove x); context_name = Mark.remove (A.VarName.get_info (Mark.remove x)); + ren_ctx; } in - let arg_stmts, new_arg = translate_expr ctxt arg in - RevBlock.rebuild arg_stmts - ~tail: - [ + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + RevBlock.append (def_blocks ++ arg_stmts) ( A.SLocalDef { name = x; expr = new_arg; typ = Expr.maybe_ty (Mark.get arg); }, - binder_pos ); - ]) - vars_args + binder_pos ), + ren_ctx) + (RevBlock.empty, ctxt.ren_ctx) vars_args in - let rest_of_block = translate_statements ctxt body in - local_decls @ List.flatten def_blocks @ rest_of_block + let rest_of_block, ren_ctx = translate_statements { ctxt with ren_ctx } body in + local_decls @ RevBlock.rebuild def_blocks ~tail:rest_of_block, ren_ctx | EAbs { binder; tys } -> let closure_name, ctxt = match ctxt.inside_definition_of with @@ -392,7 +398,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = { ctxt with inside_definition_of = None } vars_tau in - let new_body = translate_statements ctxt body in + let new_body, _ren_ctx = translate_statements ctxt body in [ ( A.SInnerFuncDef { @@ -413,9 +419,9 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = }; }, binder_pos ); - ] + ], ctxt.ren_ctx | EMatch { e = e1; cases; name } -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in let new_cases = EnumConstructor.Map.fold (fun _ arg new_args -> @@ -427,7 +433,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = let scalc_var, ctxt = register_fresh_var ctxt var ~pos:(Expr.pos arg) in - let new_arg = translate_statements ctxt body in + let new_arg, _ren_ctx = translate_statements ctxt body in { A.case_block = new_arg; payload_var_name = scalc_var; @@ -449,11 +455,12 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = switch_cases = new_args; }, Expr.pos block_expr ); - ] + ], + ren_ctx | EIfThenElse { cond; etrue; efalse } -> - let cond_stmts, s_cond = translate_expr ctxt cond in - let s_e_true = translate_statements ctxt etrue in - let s_e_false = translate_statements ctxt efalse in + let cond_stmts, s_cond, ren_ctx = translate_expr ctxt cond in + let s_e_true, _ = translate_statements ctxt etrue in + let s_e_false, _ = translate_statements ctxt efalse in RevBlock.rebuild cond_stmts ~tail: [ @@ -464,14 +471,14 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = else_block = s_e_false; }, Expr.pos block_expr ); - ] + ], + ren_ctx | EInj { e = e1; cons; name } when ctxt.config.no_struct_literals -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in let tmp_struct_var_name = match ctxt.inside_definition_of with - | None -> - failwith "should not happen" - (* [translate_expr] should create this [inside_definition_of]*) + | None -> assert false + (* [translate_expr] should create this [inside_definition_of]*) | Some x -> x, Expr.pos block_expr in let inj_expr = @@ -496,15 +503,16 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = Expr.pos block_expr ); }, Expr.pos block_expr ); - ] + ], + ren_ctx | EStruct { fields; name } when ctxt.config.no_struct_literals -> - let args_stmts, new_args = + let args_stmts, new_args, ren_ctx = StructField.Map.fold - (fun field arg (args_stmts, new_args) -> - let arg_stmts, new_arg = translate_expr ctxt arg in - args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args) + (fun field arg (args_stmts, new_args, ren_ctx) -> + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args, ren_ctx) fields - (RevBlock.empty, StructField.Map.empty) + (RevBlock.empty, StructField.Map.empty, ctxt.ren_ctx) in let struct_expr = A.EStruct { fields = new_args; name }, Expr.pos block_expr @@ -526,10 +534,11 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = typ = TStruct name, Expr.pos block_expr; }, Expr.pos block_expr ); - ] + ], + ren_ctx | ELit _ | EAppOp _ | EArray _ | EVar _ | EStruct _ | EInj _ | ETuple _ | ETupleAccess _ | EStructAccess _ | EExternal _ | EApp _ -> - let e_stmts, new_e = translate_expr ctxt block_expr in + let e_stmts, new_e, ren_ctx = translate_expr ctxt block_expr in let tail = match (e_stmts :> (A.stmt * Pos.t) list) with | (A.SRaiseEmpty, _) :: _ -> @@ -551,7 +560,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = Expr.pos block_expr ); ] in - RevBlock.rebuild e_stmts ~tail + RevBlock.rebuild e_stmts ~tail, + ren_ctx | _ -> . let rec translate_scope_body_expr ctx (scope_expr : 'm L.expr scope_body_expr) : @@ -559,59 +569,50 @@ let rec translate_scope_body_expr ctx (scope_expr : 'm L.expr scope_body_expr) : let ctx = { ctx with inside_definition_of = None } in match scope_expr with | Last e -> - let block, new_e = translate_expr ctx e in + let block, new_e, _ren_ctx = translate_expr ctx e in RevBlock.rebuild block ~tail:[A.SReturn (Mark.remove new_e), Mark.get new_e] - | Cons (scope_let, next_bnd) -> ( - let let_var, scope_let_next, ctx1 = unbind ctx next_bnd in + | Cons (scope_let, next_bnd) -> + let let_var, scope_let_next, ctx = unbind ctx next_bnd in let let_var_id, ctx = - register_fresh_var ctx1 let_var ~pos:scope_let.scope_let_pos + register_fresh_var ctx let_var ~pos:scope_let.scope_let_pos in - let next = translate_scope_body_expr ctx scope_let_next in - match scope_let.scope_let_kind with - | Assertion -> - translate_statements - { ctx with inside_definition_of = Some let_var_id } - scope_let.scope_let_expr - @ next - | _ -> - let let_expr_stmts, new_let_expr = - translate_expr - { ctx with inside_definition_of = Some let_var_id } - scope_let.scope_let_expr - in - RevBlock.rebuild let_expr_stmts - ~tail: - (( A.SLocalDecl - { - name = let_var_id, scope_let.scope_let_pos; - typ = scope_let.scope_let_typ; - }, - scope_let.scope_let_pos ) - :: ( A.SLocalDef - { - name = let_var_id, scope_let.scope_let_pos; - expr = new_let_expr; - typ = scope_let.scope_let_typ; - }, - scope_let.scope_let_pos ) - :: next)) + let statements, ren_ctx = + match scope_let.scope_let_kind with + | Assertion -> + let stmts, ren_ctx = + translate_statements + { ctx with inside_definition_of = Some let_var_id } + scope_let.scope_let_expr + in + RevBlock.make stmts, ren_ctx + | _ -> + let let_expr_stmts, new_let_expr, ren_ctx = + translate_expr + { ctx with inside_definition_of = Some let_var_id } + scope_let.scope_let_expr + in + let (+>) = RevBlock.append in + let_expr_stmts +> + ( A.SLocalDecl + { + name = let_var_id, scope_let.scope_let_pos; + typ = scope_let.scope_let_typ; + }, + scope_let.scope_let_pos ) +> + ( A.SLocalDef + { + name = let_var_id, scope_let.scope_let_pos; + expr = new_let_expr; + typ = scope_let.scope_let_typ; + }, + scope_let.scope_let_pos ), + ren_ctx + in + let tail = translate_scope_body_expr { ctx with ren_ctx } scope_let_next in + RevBlock.rebuild statements ~tail let translate_program ~(config : translation_config) (p : 'm L.program) : A.program = - let modules = - List.fold_left - (fun acc (m, _) -> - let vname = Mark.map (( ^ ) "Module_") (ModuleName.get_info m) in - (* The "Module_" prefix is a workaround name clashes for same-name - structs and modules, Python in particular mixes everything in one - namespaces. It can be removed once we have full clash-free variable - renaming in the Python backend (requiring all idents to go through - one stage of being bindlib vars) *) - ModuleName.Map.add m (A.VarName.fresh vname) acc) - ModuleName.Map.empty - (Program.modules_to_list p.decl_ctx.ctx_modules) - in - let program_ctx = { A.decl_ctx = p.decl_ctx; A.modules } in let ctxt = { func_dict = Var.Map.empty; @@ -619,10 +620,21 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : inside_definition_of = None; context_name = ""; config; - program_ctx; + program_ctx = { A.decl_ctx = p.decl_ctx; modules = ModuleName.Map.empty}; ren_ctx = config.renaming_context; } in + let modules, ctxt = + List.fold_left + (fun (modules, ctxt) (m, _) -> + let name, pos = ModuleName.get_info m in + let vname, ctxt = get_name ctxt name in + ModuleName.Map.add m (A.VarName.fresh (vname, pos)) modules, ctxt) + (ModuleName.Map.empty, ctxt) + (Program.modules_to_list p.decl_ctx.ctx_modules) + in + let program_ctx = { ctxt.program_ctx with A.modules } in + let ctxt = { ctxt with program_ctx } in let (_, rev_items), _vlist = BoundList.fold_left ~init:(ctxt, []) ~f:(fun (ctxt, rev_items) code_item var -> @@ -661,7 +673,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : :: rev_items ) | Topdef (name, topdef_ty, (EAbs abs, m)) -> (* Toplevel function def *) - let (block, expr), args_id = + let (block, expr, _ren_ctx), args_id = let args_a, expr, ctxt = unmbind ctxt abs.binder in let args = Array.to_list args_a in let rargs_id, ctxt = @@ -705,7 +717,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : :: rev_items ) | Topdef (name, topdef_ty, expr) -> (* Toplevel constant def *) - let block, expr = + let block, expr, _ren_ctx = let ctxt = { ctxt with diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 81e980f6..0769cc94 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -160,6 +160,7 @@ let renaming = ~reset_context_for_closed_terms:false ~skip_constant_binders:false ~constant_binder_name:None ~namespaced_fields_constrs:true ~f_struct:String.to_camel_case + ~f_enum:String.to_camel_case let typ_needs_parens (e : typ) : bool = match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false @@ -413,7 +414,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit let pos = Mark.get s in Format.fprintf fmt "@[if not (%a):@,\ - raise AssertionFailure(@[SourcePosition(@[filename=\"%s\",@ \ + raise AssertionFailed(@[SourcePosition(@[filename=\"%s\",@ \ start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \ law_headings=@[%a@])@])@]@]" (format_expression ctx) diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 7d92f909..f292cc61 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -15,7 +15,6 @@ License for the specific language governing permissions and limitations under the License. *) -open Catala_utils open Definitions let map_decl_ctx ~f ctx = diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en index 00efc21f..e72ef01d 100644 --- a/tests/backends/python_name_clash.catala_en +++ b/tests/backends/python_name_clash.catala_en @@ -162,11 +162,11 @@ def b(b_in:BIn): arg = perhaps_none_arg result1 = arg result = some_name(SomeNameIn(i_in = result1)) - result1 = SomeName(o = result.o) + result4 = SomeName(o = result.o) if True: - some_name2 = result1 + some_name2 = result4 else: - some_name2 = result1 + some_name2 = result4 some_name1 = some_name2 return B(some_name = some_name1) ``` From e9abbf9bd8be12b085c61bc6586f60e424151cb2 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Thu, 8 Aug 2024 15:06:03 +0200 Subject: [PATCH 6/9] Scalc change: switch only on variables matches can bind, but switches cannot, so we can assume the switch argument should always be bound to a name ; this allow the intermediate variable to be better renamed. --- compiler/scalc/ast.ml | 4 +- compiler/scalc/from_lcalc.ml | 25 ++- compiler/scalc/print.ml | 6 +- compiler/scalc/to_c.ml | 9 +- compiler/scalc/to_python.ml | 19 +- tests/backends/output/simple.c | 34 ++-- tests/backends/python_name_clash.catala_en | 64 +++--- .../good/toplevel_defs.catala_en | 184 +++++++++--------- 8 files changed, 169 insertions(+), 176 deletions(-) diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index b5908a2e..3ee7cf04 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -70,8 +70,8 @@ type stmt = | SFatalError of Runtime.error | SIfThenElse of { if_expr : expr; then_block : block; else_block : block } | SSwitch of { - switch_expr : expr; - switch_expr_typ : typ; + switch_var : VarName.t; + switch_var_typ : typ; enum_name : EnumName.t; switch_cases : switch_case list; } diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 6cf25081..300991e7 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -421,7 +421,23 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R binder_pos ); ], ctxt.ren_ctx | EMatch { e = e1; cases; name } -> + let typ = Expr.maybe_ty (Mark.get e1) in let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in + let ctxt = { ctxt with ren_ctx } in + let e1_stmts, switch_var, ctxt = + match new_e1 with + | A.EVar v, _ -> e1_stmts, v, ctxt + | _ -> + let v, ctxt = fresh_var ctxt ctxt.context_name ~pos:(Expr.pos e1) in + RevBlock.append e1_stmts + ( A.SLocalInit + { name = v, Expr.pos e1; + expr = new_e1; + typ }, + Expr.pos e1 ), + v, + ctxt + in let new_cases = EnumConstructor.Map.fold (fun _ arg new_args -> @@ -443,20 +459,19 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R | _ -> assert false) cases [] in - let new_args = List.rev new_cases in RevBlock.rebuild e1_stmts ~tail: [ ( A.SSwitch { - switch_expr = new_e1; - switch_expr_typ = Expr.maybe_ty (Mark.get e1); + switch_var; + switch_var_typ = typ; enum_name = name; - switch_cases = new_args; + switch_cases = List.rev new_cases; }, Expr.pos block_expr ); ], - ren_ctx + ctxt.ren_ctx | EIfThenElse { cond; etrue; efalse } -> let cond_stmts, s_cond, ren_ctx = translate_expr ctxt cond in let s_e_true, _ = translate_statements ctxt etrue in diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 03bf5fa2..d7086ce4 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -169,12 +169,12 @@ let rec format_statement Format.fprintf fmt "@[%a %a@]" Print.keyword "assert" (format_expr decl_ctx ~debug) (naked_expr, Mark.get stmt) - | SSwitch { switch_expr = e_switch; enum_name = enum; switch_cases = arms; _ } + | SSwitch { switch_var = v_switch; enum_name = enum; switch_cases = arms; _ } -> let cons = EnumName.Map.find enum decl_ctx.ctx_enums in Format.fprintf fmt "@[%a @[%a@]%a@,@]%a" Print.keyword "switch" - (format_expr decl_ctx ~debug) - e_switch Print.punctuation ":" + format_var_name v_switch + Print.punctuation ":" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt ((case, _), switch_case_data) -> diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index 784b0b97..9da9e8be 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -389,18 +389,15 @@ let rec format_statement Format.fprintf fmt "@[@[if (%a) {@]@,%a@,@;<1 -2>} else {@,%a@,@;<1 -2>}@]" (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 - | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> + | SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ } -> let cases = List.map2 (fun x (cons, _) -> x, cons) cases (EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums)) in - let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in - Format.fprintf fmt "@[%a %a = %a;@]@," EnumName.format e_name - VarName.format tmp_var (format_expression ctx) e1; Format.pp_open_vbox fmt 2; - Format.fprintf fmt "@[switch (%a.code) {@]@," VarName.format tmp_var; + Format.fprintf fmt "@[switch (%a.code) {@]@," VarName.format switch_var; Format.pp_print_list (fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) -> Format.fprintf fmt "@[case %a_%a:@ " EnumName.format e_name @@ -408,7 +405,7 @@ let rec format_statement if not (Type.equal payload_var_typ (TLit TUnit, Pos.no_pos)) then Format.fprintf fmt "%a = %a.payload.%a;@ " (format_typ ctx (fun fmt -> VarName.format fmt payload_var_name)) - payload_var_typ VarName.format tmp_var EnumConstructor.format + payload_var_typ VarName.format switch_var EnumConstructor.format cons_name; Format.fprintf fmt "%a@ break;@]" (format_block ctx) case_block) fmt cases; diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 0769cc94..8abd1f79 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -370,7 +370,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 | SSwitch { - switch_expr = e1; + switch_var; enum_name = e_name; switch_cases = [ @@ -381,14 +381,11 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit } when EnumName.equal e_name Expr.option_enum -> (* We translate the option type with an overloading by Python's [None] *) - let tmp_var = VarName.fresh ("perhaps_none_arg", Pos.no_pos) in - Format.fprintf fmt "@[%a = %a@]@," VarName.format tmp_var - (format_expression ctx) e1; - Format.fprintf fmt "@[if %a is None:@ %a@]@," VarName.format tmp_var + Format.fprintf fmt "@[if %a is None:@ %a@]@," VarName.format switch_var (format_block ctx) case_none; Format.fprintf fmt "@[else:@ %a = %a@,%a@]" VarName.format - case_some_var VarName.format tmp_var (format_block ctx) case_some - | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> + case_some_var VarName.format switch_var (format_block ctx) case_some + | SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ } -> let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in let cases = List.map2 @@ -396,16 +393,14 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit cases (EnumConstructor.Map.bindings cons_map) in - let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in - Format.fprintf fmt "%a = %a@\n@[if %a@]" VarName.format tmp_var - (format_expression ctx) e1 + Format.fprintf fmt "@[if %a@]" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt ({ case_block; payload_var_name; _ }, cons_name) -> Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" - VarName.format tmp_var EnumName.format e_name + VarName.format switch_var EnumName.format e_name EnumConstructor.format cons_name VarName.format payload_var_name - VarName.format tmp_var (format_block ctx) case_block)) + VarName.format switch_var (format_block ctx) case_block)) cases | SReturn e1 -> Format.fprintf fmt "@[return %a@]" (format_expression ctx) diff --git a/tests/backends/output/simple.c b/tests/backends/output/simple.c index c037f057..156bb04d 100644 --- a/tests/backends/output/simple.c +++ b/tests/backends/output/simple.c @@ -110,8 +110,8 @@ Baz baz(Baz_in baz_in) { Array_1 a4; a4.content_field = catala_malloc(sizeof(Array_1)); a4.content_field[0] = code(env, NULL); - Option_1 match_arg = catala_handle_exceptions(a4); - switch (match_arg.code) { + Option_1 a3 = catala_handle_exceptions(a4); + switch (a3.code) { case Option_1_None_1: if (1 /* TRUE */) { Bar a3; @@ -120,8 +120,8 @@ Baz baz(Baz_in baz_in) { Array_1 a7; a7.content_field = catala_malloc(sizeof(Array_1)); - Option_1 match_arg = catala_handle_exceptions(a7); - switch (match_arg.code) { + Option_1 a8 = catala_handle_exceptions(a7); + switch (a8.code) { case Option_1_None_1: if (1 /* TRUE */) { Bar a6 = {Bar_No, {No: NULL}}; @@ -133,15 +133,15 @@ Baz baz(Baz_in baz_in) { } break; case Option_1_Some_1: - Bar x1 = match_arg.payload.Some_1; + Bar x1 = a8.payload.Some_1; option_1 a6 = {Option_1_Some_1, {Some_1: x1}}; break; } Array_1 a5; a5.content_field = catala_malloc(sizeof(Array_1)); a5.content_field[0] = a6; - Option_1 match_arg = catala_handle_exceptions(a5); - switch (match_arg.code) { + Option_1 a9 = catala_handle_exceptions(a5); + switch (a9.code) { case Option_1_None_1: if (0 /* FALSE */) { option_1 a4 = {Option_1_None_1, {None_1: NULL}}; @@ -152,20 +152,16 @@ Baz baz(Baz_in baz_in) { } break; case Option_1_Some_1: - Bar x1 = match_arg.payload.Some_1; + Bar x1 = a9.payload.Some_1; option_1 a4 = {Option_1_Some_1, {Some_1: x1}}; break; } - Option_1 match_arg = a4; - switch (match_arg.code) { + switch (a4.code) { case Option_1_None_1: catala_raise_fatal_error (catala_no_value, "tests/backends/simple.catala_en", 11, 11, 11, 12); break; - case Option_1_Some_1: - Bar arg = match_arg.payload.Some_1; - a3 = arg; - break; + case Option_1_Some_1: Bar arg = a4.payload.Some_1; a3 = arg; break; } option_1 a3 = {Option_1_Some_1, {Some_1: a3}}; @@ -175,20 +171,16 @@ Baz baz(Baz_in baz_in) { } break; case Option_1_Some_1: - Bar x1 = match_arg.payload.Some_1; + Bar x1 = a3.payload.Some_1; option_1 a3 = {Option_1_Some_1, {Some_1: x1}}; break; } - Option_1 match_arg = a3; - switch (match_arg.code) { + switch (a3.code) { case Option_1_None_1: catala_raise_fatal_error (catala_no_value, "tests/backends/simple.catala_en", 11, 11, 11, 12); break; - case Option_1_Some_1: - Bar arg = match_arg.payload.Some_1; - a2 = arg; - break; + case Option_1_Some_1: Bar arg = a3.payload.Some_1; a2 = arg; break; } Bar a1; a1 = a2; diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en index e72ef01d..5423b029 100644 --- a/tests/backends/python_name_clash.catala_en +++ b/tests/backends/python_name_clash.catala_en @@ -91,82 +91,80 @@ class BIn: def some_name(some_name_in:SomeNameIn): i = some_name_in.i_in - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + o4 = handle_exceptions([], []) + if o4 is None: if True: o3 = (i + integer_of_string("1")) else: o3 = None else: - x = perhaps_none_arg + x = o4 o3 = x - perhaps_none_arg = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=10, start_column=23, - end_line=10, end_column=28, law_headings=[])], - [o3] - ) - if perhaps_none_arg is None: + o5 = handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=10, start_column=23, end_line=10, end_column=28, + law_headings=[])], + [o3] + ) + if o5 is None: if False: o2 = None else: o2 = None else: - x = perhaps_none_arg + x = o5 o2 = x - perhaps_none_arg = o2 - if perhaps_none_arg is None: + if o2 is None: raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", start_line=7, start_column=10, end_line=7, end_column=11, law_headings=[])) else: - arg = perhaps_none_arg + arg = o2 o1 = arg o = o1 return SomeName(o = o) def b(b_in:BIn): - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + result4 = handle_exceptions([], []) + if result4 is None: if True: result3 = integer_of_string("1") else: result3 = None else: - x = perhaps_none_arg + x = result4 result3 = x - perhaps_none_arg = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=16, start_column=33, - end_line=16, end_column=34, law_headings=[])], - [result3] - ) - if perhaps_none_arg is None: + result5 = handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=16, start_column=33, + end_line=16, end_column=34, law_headings=[])], + [result3] + ) + if result5 is None: if False: result2 = None else: result2 = None else: - x = perhaps_none_arg + x = result5 result2 = x - perhaps_none_arg = result2 - if perhaps_none_arg is None: + if result2 is None: raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", start_line=16, start_column=14, end_line=16, end_column=25, law_headings=[])) else: - arg = perhaps_none_arg + arg = result2 result1 = arg result = some_name(SomeNameIn(i_in = result1)) - result4 = SomeName(o = result.o) + result6 = SomeName(o = result.o) if True: - some_name2 = result4 + some_name2 = result6 else: - some_name2 = result4 + some_name2 = result6 some_name1 = some_name2 return B(some_name = some_name1) ``` diff --git a/tests/name_resolution/good/toplevel_defs.catala_en b/tests/name_resolution/good/toplevel_defs.catala_en index 16091bff..9232b6a7 100644 --- a/tests/name_resolution/good/toplevel_defs.catala_en +++ b/tests/name_resolution/good/toplevel_defs.catala_en @@ -128,7 +128,8 @@ let S2 (S2_in: S2_in) = decl a1 : decimal; decl a2 : option decimal; decl a3 : option decimal; - switch handle_exceptions []: + a4 : option decimal = handle_exceptions []; + switch a4: | ENone _ → if true: a3 = ESome glob3 ¤44.00 + 100. @@ -136,7 +137,8 @@ let S2 (S2_in: S2_in) = a3 = ENone () | ESome x → a3 = ESome x; - switch handle_exceptions [a3]: + a5 : option decimal = handle_exceptions [a3]; + switch a5: | ENone _ → if false: a2 = ENone () @@ -157,7 +159,8 @@ let S3 (S3_in: S3_in) = decl a1 : decimal; decl a2 : option decimal; decl a3 : option decimal; - switch handle_exceptions []: + a4 : option decimal = handle_exceptions []; + switch a4: | ENone _ → if true: a3 = ESome 50. + glob4 ¤44.00 55. @@ -165,7 +168,8 @@ let S3 (S3_in: S3_in) = a3 = ENone () | ESome x → a3 = ESome x; - switch handle_exceptions [a3]: + a5 : option decimal = handle_exceptions [a3]; + switch a5: | ENone _ → if false: a2 = ENone () @@ -186,7 +190,8 @@ let S4 (S4_in: S4_in) = decl a1 : decimal; decl a2 : option decimal; decl a3 : option decimal; - switch handle_exceptions []: + a4 : option decimal = handle_exceptions []; + switch a4: | ENone _ → if true: a3 = ESome glob5 + 1. @@ -194,7 +199,8 @@ let S4 (S4_in: S4_in) = a3 = ENone () | ESome x → a3 = ESome x; - switch handle_exceptions [a3]: + a5 : option decimal = handle_exceptions [a3]; + switch a5: | ENone _ → if false: a2 = ENone () @@ -215,7 +221,8 @@ let S (S_in: S_in) = decl a1 : decimal; decl a2 : option decimal; decl a3 : option decimal; - switch handle_exceptions []: + a4 : option decimal = handle_exceptions []; + switch a4: | ENone _ → if true: a3 = ESome glob1 * glob1 @@ -223,7 +230,8 @@ let S (S_in: S_in) = a3 = ENone () | ESome x → a3 = ESome x; - switch handle_exceptions [a3]: + a5 : option decimal = handle_exceptions [a3]; + switch a5: | ENone _ → if false: a2 = ENone () @@ -241,7 +249,8 @@ let S (S_in: S_in) = decl b1 : A {y: bool; z: decimal}; decl b2 : option A {y: bool; z: decimal}; decl b3 : option A {y: bool; z: decimal}; - switch handle_exceptions []: + b4 : option A {y: bool; z: decimal} = handle_exceptions []; + switch b4: | ENone _ → if true: b3 = ESome glob2 @@ -249,7 +258,8 @@ let S (S_in: S_in) = b3 = ENone () | ESome x → b3 = ESome x; - switch handle_exceptions [b3]: + b5 : option A {y: bool; z: decimal} = handle_exceptions [b3]; + switch b5: | ENone _ → if false: b2 = ENone () @@ -446,48 +456,46 @@ glob6 = ( ) def s2(s2_in:S2In): - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + a4 = handle_exceptions([], []) + if a4 is None: if True: a3 = (glob3(money_of_cents_string("4400")) + decimal_of_string("100.")) else: a3 = None else: - x = perhaps_none_arg + x = a4 a3 = x - perhaps_none_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=53, start_column=24, - end_line=53, end_column=43, - law_headings=["Test toplevel function defs"])], - [a3] - ) - if perhaps_none_arg is None: + a5 = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=53, start_column=24, end_line=53, end_column=43, + law_headings=["Test toplevel function defs"])], + [a3] + ) + if a5 is None: if False: a2 = None else: a2 = None else: - x = perhaps_none_arg + x = a5 a2 = x - perhaps_none_arg = a2 - if perhaps_none_arg is None: + if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=50, start_column=10, end_line=50, end_column=11, law_headings=["Test toplevel function defs"])) else: - arg = perhaps_none_arg + arg = a2 a1 = arg a = a1 return S2(a = a) def s3(s3_in:S3In): - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + a4 = handle_exceptions([], []) + if a4 is None: if True: a3 = (decimal_of_string("50.") + glob4(money_of_cents_string("4400"), @@ -495,151 +503,139 @@ def s3(s3_in:S3In): else: a3 = None else: - x = perhaps_none_arg + x = a4 a3 = x - perhaps_none_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=74, start_column=24, - end_line=74, end_column=47, - law_headings=["Test function def with two args"] - )], - [a3] - ) - if perhaps_none_arg is None: + a5 = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=74, start_column=24, end_line=74, end_column=47, + law_headings=["Test function def with two args"])], + [a3] + ) + if a5 is None: if False: a2 = None else: a2 = None else: - x = perhaps_none_arg + x = a5 a2 = x - perhaps_none_arg = a2 - if perhaps_none_arg is None: + if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=71, start_column=10, end_line=71, end_column=11, law_headings=["Test function def with two args"])) else: - arg = perhaps_none_arg + arg = a2 a1 = arg a = a1 return S3(a = a) def s4(s4_in:S4In): - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + a4 = handle_exceptions([], []) + if a4 is None: if True: a3 = (glob5 + decimal_of_string("1.")) else: a3 = None else: - x = perhaps_none_arg + x = a4 a3 = x - perhaps_none_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=98, start_column=24, - end_line=98, end_column=34, - law_headings=["Test inline defs in toplevel defs"] - )], - [a3] - ) - if perhaps_none_arg is None: + a5 = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=98, start_column=24, end_line=98, end_column=34, + law_headings=["Test inline defs in toplevel defs"])], + [a3] + ) + if a5 is None: if False: a2 = None else: a2 = None else: - x = perhaps_none_arg + x = a5 a2 = x - perhaps_none_arg = a2 - if perhaps_none_arg is None: + if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=95, start_column=10, end_line=95, end_column=11, law_headings=["Test inline defs in toplevel defs"])) else: - arg = perhaps_none_arg + arg = a2 a1 = arg a = a1 return S4(a = a) def s5(s_in:SIn): - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + a4 = handle_exceptions([], []) + if a4 is None: if True: a3 = (glob1 * glob1) else: a3 = None else: - x = perhaps_none_arg + x = a4 a3 = x - perhaps_none_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=18, start_column=24, - end_line=18, end_column=37, - law_headings=["Test basic toplevel values defs"] - )], - [a3] - ) - if perhaps_none_arg is None: + a5 = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=18, start_column=24, end_line=18, end_column=37, + law_headings=["Test basic toplevel values defs"])], + [a3] + ) + if a5 is None: if False: a2 = None else: a2 = None else: - x = perhaps_none_arg + x = a5 a2 = x - perhaps_none_arg = a2 - if perhaps_none_arg is None: + if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=7, start_column=10, end_line=7, end_column=11, law_headings=["Test basic toplevel values defs"])) else: - arg = perhaps_none_arg + arg = a2 a1 = arg a = a1 - perhaps_none_arg = handle_exceptions([], []) - if perhaps_none_arg is None: + b4 = handle_exceptions([], []) + if b4 is None: if True: b3 = glob6 else: b3 = None else: - x = perhaps_none_arg + x = b4 b3 = x - perhaps_none_arg = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=19, start_column=24, - end_line=19, end_column=29, - law_headings=["Test basic toplevel values defs"] - )], - [b3] - ) - if perhaps_none_arg is None: + b5 = handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=19, start_column=24, end_line=19, end_column=29, + law_headings=["Test basic toplevel values defs"])], + [b3] + ) + if b5 is None: if False: b2 = None else: b2 = None else: - x = perhaps_none_arg + x = b5 b2 = x - perhaps_none_arg = b2 - if perhaps_none_arg is None: + if b2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", start_line=8, start_column=10, end_line=8, end_column=11, law_headings=["Test basic toplevel values defs"])) else: - arg = perhaps_none_arg + arg = b2 b1 = arg b = b1 return S(a = a, b = b) From 5d61963a93a0b735d9419b43ef05c676f5169d9a Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Thu, 8 Aug 2024 15:51:52 +0200 Subject: [PATCH 7/9] Reformat --- compiler/scalc/from_lcalc.ml | 551 ++++++++++++++++++----------------- compiler/scalc/print.ml | 3 +- compiler/scalc/to_c.ml | 3 +- compiler/scalc/to_python.ml | 3 +- 4 files changed, 288 insertions(+), 272 deletions(-) diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 300991e7..26104250 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -37,13 +37,6 @@ type 'm ctxt = { ren_ctx : Renaming.context; } -(* Expressions can spill out side effect, hence this function also returns a - list of statements to be prepended before the expression is evaluated *) - -exception NotAnExpr of { needs_a_local_decl : bool } -(** Contains the LocalDecl of the temporary variable that will be defined by the - next block is it's here *) - (** Blocks are constructed as reverse ordered lists. This module abstracts this and avoids confusion in ordering of statements (also opening the opportunity for more optimisations) *) @@ -101,48 +94,59 @@ let rec translate_expr_list ctxt args = let stmts, args, ren_ctx = List.fold_left (fun (args_stmts, new_args, ren_ctx) arg -> - let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + let arg_stmts, new_arg, ren_ctx = + translate_expr { ctxt with ren_ctx } arg + in args_stmts ++ arg_stmts, new_arg :: new_args, ren_ctx) - (RevBlock.empty, [], ctxt.ren_ctx) args + (RevBlock.empty, [], ctxt.ren_ctx) + args in stmts, List.rev args, ren_ctx -and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr * Renaming.context = - try - match Mark.remove expr with - | EVar v -> - let local_var = - try A.EVar (Var.Map.find v ctxt.var_dict) - with Var.Map.Not_found _ -> ( - try A.EFunc (Var.Map.find v ctxt.func_dict) - with Var.Map.Not_found _ -> - Message.error ~pos:(Expr.pos expr) - "Var not found in lambda→scalc: %a@\nknown: @[%a@]@\n" - Print.var_debug v - (Format.pp_print_list ~pp_sep:Format.pp_print_space (fun ppf v -> - Print.var_debug ppf v)) - (Var.Map.keys ctxt.var_dict)) - in - RevBlock.empty, (local_var, Expr.pos expr), ctxt.ren_ctx - | EStruct { fields; name } -> - if ctxt.config.no_struct_literals then - (* In C89, struct literates have to be initialized at variable - definition... *) - raise (NotAnExpr { needs_a_local_decl = false }); +and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : + RevBlock.t * A.expr * Renaming.context = + match Mark.remove expr with + | EVar v -> + let local_var = + try A.EVar (Var.Map.find v ctxt.var_dict) + with Var.Map.Not_found _ -> ( + try A.EFunc (Var.Map.find v ctxt.func_dict) + with Var.Map.Not_found _ -> + Message.error ~pos:(Expr.pos expr) + "Var not found in lambda→scalc: %a@\nknown: @[%a@]@\n" + Print.var_debug v + (Format.pp_print_list ~pp_sep:Format.pp_print_space (fun ppf v -> + Print.var_debug ppf v)) + (Var.Map.keys ctxt.var_dict)) + in + RevBlock.empty, (local_var, Expr.pos expr), ctxt.ren_ctx + | EStruct { fields; name } -> + if ctxt.config.no_struct_literals then + (* In C89, struct literates have to be initialized at variable + definition... *) + spill_expr ~needs_a_local_decl:false ctxt expr + else let args_stmts, new_args, ren_ctx = StructField.Map.fold (fun field arg (args_stmts, new_args, ren_ctx) -> - let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in - args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args, ren_ctx) + let arg_stmts, new_arg, ren_ctx = + translate_expr { ctxt with ren_ctx } arg + in + ( args_stmts ++ arg_stmts, + StructField.Map.add field new_arg new_args, + ren_ctx )) fields (RevBlock.empty, StructField.Map.empty, ctxt.ren_ctx) in - args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr), ren_ctx - | EInj { e = e1; cons; name } -> - if ctxt.config.no_struct_literals then - (* In C89, struct literates have to be initialized at variable - definition... *) - raise (NotAnExpr { needs_a_local_decl = false }); + ( args_stmts, + (A.EStruct { fields = new_args; name }, Expr.pos expr), + ren_ctx ) + | EInj { e = e1; cons; name } -> + if ctxt.config.no_struct_literals then + (* In C89, struct literates have to be initialized at variable + definition... *) + spill_expr ~needs_a_local_decl:false ctxt expr + else let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in ( e1_stmts, ( A.EInj @@ -154,57 +158,57 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr * R }, Expr.pos expr ), ren_ctx ) - | ETuple args -> - let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in - args_stmts, (A.ETuple new_args, Expr.pos expr), ren_ctx - | EStructAccess { e = e1; field; name } -> - let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in - ( e1_stmts, - (A.EStructFieldAccess { e1 = new_e1; field; name }, Expr.pos expr), - ren_ctx) - | ETupleAccess { e = e1; index; _ } -> - let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in - e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr), ren_ctx - | EAppOp { op; args; tys = _ } -> - let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in - (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) - args_stmts, (A.EAppOp { op; args = new_args }, Expr.pos expr), ren_ctx - | EApp { f = EAbs { binder; tys }, binder_mark; args; tys = _ } -> - (* This defines multiple local variables at the time *) - let binder_pos = Expr.mark_pos binder_mark in - let vars, body, ctxt = unmbind ctxt binder in - let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in - let ctxt = - List.fold_left (register_fresh_arg ~pos:binder_pos) ctxt vars_tau - in - let local_decls = - List.fold_left - (fun acc (x, tau) -> - RevBlock.append acc - ( A.SLocalDecl - { name = Var.Map.find x ctxt.var_dict, binder_pos; typ = tau }, - binder_pos )) - RevBlock.empty vars_tau - in - let vars_args = - List.map2 - (fun (x, tau) arg -> - (Var.Map.find x ctxt.var_dict, binder_pos), tau, arg) - vars_tau args - in - let def_blocks, ren_ctx = - List.fold_left - (fun (rblock, ren_ctx) (x, _tau, arg) -> - let ctxt = - { - ctxt with - inside_definition_of = Some (Mark.remove x); - context_name = Mark.remove (A.VarName.get_info (Mark.remove x)); - ren_ctx; - } - in - let arg_stmts, new_arg, ren_ctx = translate_expr ctxt arg in - RevBlock.append (rblock ++ arg_stmts) + | ETuple args -> + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in + args_stmts, (A.ETuple new_args, Expr.pos expr), ren_ctx + | EStructAccess { e = e1; field; name } -> + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in + ( e1_stmts, + (A.EStructFieldAccess { e1 = new_e1; field; name }, Expr.pos expr), + ren_ctx ) + | ETupleAccess { e = e1; index; _ } -> + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in + e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr), ren_ctx + | EAppOp { op; args; tys = _ } -> + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in + (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) + args_stmts, (A.EAppOp { op; args = new_args }, Expr.pos expr), ren_ctx + | EApp { f = EAbs { binder; tys }, binder_mark; args; tys = _ } -> + (* This defines multiple local variables at the time *) + let binder_pos = Expr.mark_pos binder_mark in + let vars, body, ctxt = unmbind ctxt binder in + let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in + let ctxt = + List.fold_left (register_fresh_arg ~pos:binder_pos) ctxt vars_tau + in + let local_decls = + List.fold_left + (fun acc (x, tau) -> + RevBlock.append acc + ( A.SLocalDecl + { name = Var.Map.find x ctxt.var_dict, binder_pos; typ = tau }, + binder_pos )) + RevBlock.empty vars_tau + in + let vars_args = + List.map2 + (fun (x, tau) arg -> + (Var.Map.find x ctxt.var_dict, binder_pos), tau, arg) + vars_tau args + in + let def_blocks, ren_ctx = + List.fold_left + (fun (rblock, ren_ctx) (x, _tau, arg) -> + let ctxt = + { + ctxt with + inside_definition_of = Some (Mark.remove x); + context_name = Mark.remove (A.VarName.get_info (Mark.remove x)); + ren_ctx; + } + in + let arg_stmts, new_arg, ren_ctx = translate_expr ctxt arg in + ( RevBlock.append (rblock ++ arg_stmts) ( A.SLocalDef { name = x; @@ -212,75 +216,82 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr * R typ = Expr.maybe_ty (Mark.get arg); }, binder_pos ), - ren_ctx) - (RevBlock.empty, ctxt.ren_ctx) vars_args - in - let rest_of_expr_stmts, rest_of_expr, ren_ctx = translate_expr { ctxt with ren_ctx } body in - local_decls ++ def_blocks ++ rest_of_expr_stmts, rest_of_expr, ren_ctx - | EApp { f; args; tys = _ } -> - let f_stmts, new_f, ren_ctx = translate_expr ctxt f in - let args_stmts, new_args, ren_ctx = translate_expr_list { ctxt with ren_ctx } args in - (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) - ( f_stmts ++ args_stmts, - (A.EApp { f = new_f; args = new_args }, Expr.pos expr), - ren_ctx ) - | EArray args -> - let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in - args_stmts, (A.EArray new_args, Expr.pos expr), ren_ctx - | ELit l -> RevBlock.empty, (A.ELit l, Expr.pos expr), ctxt.ren_ctx - | EExternal { name } -> - let path, name = - match Mark.remove name with - | External_value name -> TopdefName.(path name, get_info name) - | External_scope name -> ScopeName.(path name, get_info name) - in - let modname = - ( ModuleName.Map.find (List.hd (List.rev path)) ctxt.program_ctx.modules, - Expr.pos expr ) - in - RevBlock.empty, (EExternal { modname; name }, Expr.pos expr), ctxt.ren_ctx - | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | EFatalError _ -> - raise (NotAnExpr { needs_a_local_decl = true }) - | _ -> . - with NotAnExpr { needs_a_local_decl } -> - let tmp_var, ctxt = - let name = - match ctxt.inside_definition_of with - | None -> ctxt.context_name - | Some v -> A.VarName.to_string v - in - fresh_var ctxt name ~pos:(Expr.pos expr) + ren_ctx )) + (RevBlock.empty, ctxt.ren_ctx) + vars_args in - let ctxt = - { - ctxt with - inside_definition_of = Some tmp_var; - context_name = Mark.remove (A.VarName.get_info tmp_var); - } + let rest_of_expr_stmts, rest_of_expr, ren_ctx = + translate_expr { ctxt with ren_ctx } body in - let tmp_stmts, ren_ctx = translate_statements ctxt expr in - ( (if needs_a_local_decl then - RevBlock.make - (( A.SLocalDecl - { - name = tmp_var, Expr.pos expr; - typ = Expr.maybe_ty (Mark.get expr); - }, - Expr.pos expr ) - :: tmp_stmts) - else RevBlock.make tmp_stmts), - (A.EVar tmp_var, Expr.pos expr), + local_decls ++ def_blocks ++ rest_of_expr_stmts, rest_of_expr, ren_ctx + | EApp { f; args; tys = _ } -> + let f_stmts, new_f, ren_ctx = translate_expr ctxt f in + let args_stmts, new_args, ren_ctx = + translate_expr_list { ctxt with ren_ctx } args + in + (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) + ( f_stmts ++ args_stmts, + (A.EApp { f = new_f; args = new_args }, Expr.pos expr), ren_ctx ) + | EArray args -> + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in + args_stmts, (A.EArray new_args, Expr.pos expr), ren_ctx + | ELit l -> RevBlock.empty, (A.ELit l, Expr.pos expr), ctxt.ren_ctx + | EExternal { name } -> + let path, name = + match Mark.remove name with + | External_value name -> TopdefName.(path name, get_info name) + | External_scope name -> ScopeName.(path name, get_info name) + in + let modname = + ( ModuleName.Map.find (List.hd (List.rev path)) ctxt.program_ctx.modules, + Expr.pos expr ) + in + RevBlock.empty, (EExternal { modname; name }, Expr.pos expr), ctxt.ren_ctx + | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | EFatalError _ -> + spill_expr ~needs_a_local_decl:true ctxt expr + | _ -> . -and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * Renaming.context = +and spill_expr ~needs_a_local_decl ctxt expr = + let tmp_var, ctxt = + let name = + match ctxt.inside_definition_of with + | None -> ctxt.context_name + | Some v -> A.VarName.to_string v + in + fresh_var ctxt name ~pos:(Expr.pos expr) + in + let ctxt = + { + ctxt with + inside_definition_of = Some tmp_var; + context_name = Mark.remove (A.VarName.get_info tmp_var); + } + in + let tmp_stmts, ren_ctx = translate_statements ctxt expr in + ( (if needs_a_local_decl then + RevBlock.make + (( A.SLocalDecl + { + name = tmp_var, Expr.pos expr; + typ = Expr.maybe_ty (Mark.get expr); + }, + Expr.pos expr ) + :: tmp_stmts) + else RevBlock.make tmp_stmts), + (A.EVar tmp_var, Expr.pos expr), + ren_ctx ) + +and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : + A.block * Renaming.context = match Mark.remove block_expr with | EAssert e -> (* Assertions are always encapsulated in a unit-typed let binding *) let e_stmts, new_e, ren_ctx = translate_expr ctxt e in - RevBlock.rebuild - ~tail:[A.SAssert (Mark.remove new_e), Expr.pos block_expr] - e_stmts, - ren_ctx + ( RevBlock.rebuild + ~tail:[A.SAssert (Mark.remove new_e), Expr.pos block_expr] + e_stmts, + ren_ctx ) | EFatalError err -> [SFatalError err, Expr.pos block_expr], ctxt.ren_ctx (* | EAppOp * { op = Op.HandleDefaultOpt, _; tys = _; args = [exceptions; just; cons] } @@ -369,19 +380,24 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R ren_ctx; } in - let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in - RevBlock.append (def_blocks ++ arg_stmts) - ( A.SLocalDef - { - name = x; - expr = new_arg; - typ = Expr.maybe_ty (Mark.get arg); - }, - binder_pos ), - ren_ctx) - (RevBlock.empty, ctxt.ren_ctx) vars_args + let arg_stmts, new_arg, ren_ctx = + translate_expr { ctxt with ren_ctx } arg + in + ( RevBlock.append (def_blocks ++ arg_stmts) + ( A.SLocalDef + { + name = x; + expr = new_arg; + typ = Expr.maybe_ty (Mark.get arg); + }, + binder_pos ), + ren_ctx )) + (RevBlock.empty, ctxt.ren_ctx) + vars_args + in + let rest_of_block, ren_ctx = + translate_statements { ctxt with ren_ctx } body in - let rest_of_block, ren_ctx = translate_statements { ctxt with ren_ctx } body in local_decls @ RevBlock.rebuild def_blocks ~tail:rest_of_block, ren_ctx | EAbs { binder; tys } -> let closure_name, ctxt = @@ -399,27 +415,28 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R vars_tau in let new_body, _ren_ctx = translate_statements ctxt body in - [ - ( A.SInnerFuncDef - { - name = closure_name, binder_pos; - func = - { - func_params = - List.map - (fun (var, tau) -> - (Var.Map.find var ctxt.var_dict, binder_pos), tau) - vars_tau; - func_body = new_body; - func_return_typ = - (match Expr.maybe_ty (Mark.get block_expr) with - | TArrow (_, t2), _ -> t2 - | TAny, pos_any -> TAny, pos_any - | _ -> assert false); - }; - }, - binder_pos ); - ], ctxt.ren_ctx + ( [ + ( A.SInnerFuncDef + { + name = closure_name, binder_pos; + func = + { + func_params = + List.map + (fun (var, tau) -> + (Var.Map.find var ctxt.var_dict, binder_pos), tau) + vars_tau; + func_body = new_body; + func_return_typ = + (match Expr.maybe_ty (Mark.get block_expr) with + | TArrow (_, t2), _ -> t2 + | TAny, pos_any -> TAny, pos_any + | _ -> assert false); + }; + }, + binder_pos ); + ], + ctxt.ren_ctx ) | EMatch { e = e1; cases; name } -> let typ = Expr.maybe_ty (Mark.get e1) in let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in @@ -429,14 +446,11 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R | A.EVar v, _ -> e1_stmts, v, ctxt | _ -> let v, ctxt = fresh_var ctxt ctxt.context_name ~pos:(Expr.pos e1) in - RevBlock.append e1_stmts - ( A.SLocalInit - { name = v, Expr.pos e1; - expr = new_e1; - typ }, - Expr.pos e1 ), - v, - ctxt + ( RevBlock.append e1_stmts + ( A.SLocalInit { name = v, Expr.pos e1; expr = new_e1; typ }, + Expr.pos e1 ), + v, + ctxt ) in let new_cases = EnumConstructor.Map.fold @@ -459,35 +473,35 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R | _ -> assert false) cases [] in - RevBlock.rebuild e1_stmts - ~tail: - [ - ( A.SSwitch - { - switch_var; - switch_var_typ = typ; - enum_name = name; - switch_cases = List.rev new_cases; - }, - Expr.pos block_expr ); - ], - ctxt.ren_ctx + ( RevBlock.rebuild e1_stmts + ~tail: + [ + ( A.SSwitch + { + switch_var; + switch_var_typ = typ; + enum_name = name; + switch_cases = List.rev new_cases; + }, + Expr.pos block_expr ); + ], + ctxt.ren_ctx ) | EIfThenElse { cond; etrue; efalse } -> let cond_stmts, s_cond, ren_ctx = translate_expr ctxt cond in let s_e_true, _ = translate_statements ctxt etrue in let s_e_false, _ = translate_statements ctxt efalse in - RevBlock.rebuild cond_stmts - ~tail: - [ - ( A.SIfThenElse - { - if_expr = s_cond; - then_block = s_e_true; - else_block = s_e_false; - }, - Expr.pos block_expr ); - ], - ren_ctx + ( RevBlock.rebuild cond_stmts + ~tail: + [ + ( A.SIfThenElse + { + if_expr = s_cond; + then_block = s_e_true; + else_block = s_e_false; + }, + Expr.pos block_expr ); + ], + ren_ctx ) | EInj { e = e1; cons; name } when ctxt.config.no_struct_literals -> let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in let tmp_struct_var_name = @@ -506,26 +520,30 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R }, Expr.pos block_expr ) in - RevBlock.rebuild e1_stmts - ~tail: - [ - ( A.SLocalInit - { - name = tmp_struct_var_name; - expr = inj_expr; - typ = - ( Mark.remove (Expr.maybe_ty (Mark.get block_expr)), - Expr.pos block_expr ); - }, - Expr.pos block_expr ); - ], - ren_ctx + ( RevBlock.rebuild e1_stmts + ~tail: + [ + ( A.SLocalInit + { + name = tmp_struct_var_name; + expr = inj_expr; + typ = + ( Mark.remove (Expr.maybe_ty (Mark.get block_expr)), + Expr.pos block_expr ); + }, + Expr.pos block_expr ); + ], + ren_ctx ) | EStruct { fields; name } when ctxt.config.no_struct_literals -> let args_stmts, new_args, ren_ctx = StructField.Map.fold (fun field arg (args_stmts, new_args, ren_ctx) -> - let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in - args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args, ren_ctx) + let arg_stmts, new_arg, ren_ctx = + translate_expr { ctxt with ren_ctx } arg + in + ( args_stmts ++ arg_stmts, + StructField.Map.add field new_arg new_args, + ren_ctx )) fields (RevBlock.empty, StructField.Map.empty, ctxt.ren_ctx) in @@ -539,18 +557,18 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R (* [translate_expr] should create this [inside_definition_of]*) | Some x -> x, Expr.pos block_expr in - RevBlock.rebuild args_stmts - ~tail: - [ - ( A.SLocalInit - { - name = tmp_struct_var_name; - expr = struct_expr; - typ = TStruct name, Expr.pos block_expr; - }, - Expr.pos block_expr ); - ], - ren_ctx + ( RevBlock.rebuild args_stmts + ~tail: + [ + ( A.SLocalInit + { + name = tmp_struct_var_name; + expr = struct_expr; + typ = TStruct name, Expr.pos block_expr; + }, + Expr.pos block_expr ); + ], + ren_ctx ) | ELit _ | EAppOp _ | EArray _ | EVar _ | EStruct _ | EInj _ | ETuple _ | ETupleAccess _ | EStructAccess _ | EExternal _ | EApp _ -> let e_stmts, new_e, ren_ctx = translate_expr ctxt block_expr in @@ -575,8 +593,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R Expr.pos block_expr ); ] in - RevBlock.rebuild e_stmts ~tail, - ren_ctx + RevBlock.rebuild e_stmts ~tail, ren_ctx | _ -> . let rec translate_scope_body_expr ctx (scope_expr : 'm L.expr scope_body_expr) : @@ -606,22 +623,22 @@ let rec translate_scope_body_expr ctx (scope_expr : 'm L.expr scope_body_expr) : { ctx with inside_definition_of = Some let_var_id } scope_let.scope_let_expr in - let (+>) = RevBlock.append in - let_expr_stmts +> - ( A.SLocalDecl - { - name = let_var_id, scope_let.scope_let_pos; - typ = scope_let.scope_let_typ; - }, - scope_let.scope_let_pos ) +> - ( A.SLocalDef - { - name = let_var_id, scope_let.scope_let_pos; - expr = new_let_expr; - typ = scope_let.scope_let_typ; - }, - scope_let.scope_let_pos ), - ren_ctx + let ( +> ) = RevBlock.append in + ( let_expr_stmts + +> ( A.SLocalDecl + { + name = let_var_id, scope_let.scope_let_pos; + typ = scope_let.scope_let_typ; + }, + scope_let.scope_let_pos ) + +> ( A.SLocalDef + { + name = let_var_id, scope_let.scope_let_pos; + expr = new_let_expr; + typ = scope_let.scope_let_typ; + }, + scope_let.scope_let_pos ), + ren_ctx ) in let tail = translate_scope_body_expr { ctx with ren_ctx } scope_let_next in RevBlock.rebuild statements ~tail @@ -635,16 +652,16 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : inside_definition_of = None; context_name = ""; config; - program_ctx = { A.decl_ctx = p.decl_ctx; modules = ModuleName.Map.empty}; + program_ctx = { A.decl_ctx = p.decl_ctx; modules = ModuleName.Map.empty }; ren_ctx = config.renaming_context; } in let modules, ctxt = List.fold_left (fun (modules, ctxt) (m, _) -> - let name, pos = ModuleName.get_info m in - let vname, ctxt = get_name ctxt name in - ModuleName.Map.add m (A.VarName.fresh (vname, pos)) modules, ctxt) + let name, pos = ModuleName.get_info m in + let vname, ctxt = get_name ctxt name in + ModuleName.Map.add m (A.VarName.fresh (vname, pos)) modules, ctxt) (ModuleName.Map.empty, ctxt) (Program.modules_to_list p.decl_ctx.ctx_modules) in diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index d7086ce4..a3703cdc 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -173,8 +173,7 @@ let rec format_statement -> let cons = EnumName.Map.find enum decl_ctx.ctx_enums in Format.fprintf fmt "@[%a @[%a@]%a@,@]%a" Print.keyword "switch" - format_var_name v_switch - Print.punctuation ":" + format_var_name v_switch Print.punctuation ":" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt ((case, _), switch_case_data) -> diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index 9da9e8be..20313146 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -397,7 +397,8 @@ let rec format_statement (EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums)) in Format.pp_open_vbox fmt 2; - Format.fprintf fmt "@[switch (%a.code) {@]@," VarName.format switch_var; + Format.fprintf fmt "@[switch (%a.code) {@]@," VarName.format + switch_var; Format.pp_print_list (fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) -> Format.fprintf fmt "@[case %a_%a:@ " EnumName.format e_name diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 8abd1f79..189271e5 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -159,8 +159,7 @@ let renaming = (* TODO: add catala runtime built-ins as reserved as well ? *) ~reset_context_for_closed_terms:false ~skip_constant_binders:false ~constant_binder_name:None ~namespaced_fields_constrs:true - ~f_struct:String.to_camel_case - ~f_enum:String.to_camel_case + ~f_struct:String.to_camel_case ~f_enum:String.to_camel_case let typ_needs_parens (e : typ) : bool = match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false From 9fa5f91e3aa19a6c3d562ae419f7a82c13c5eb1d Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Fri, 9 Aug 2024 11:10:47 +0200 Subject: [PATCH 8/9] Python printer: add some parens to be safe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I can't find where the line cut triggering the error at https://github.com/CatalaLang/catala/actions/runs/10304111306/job/28522272547?pr=666 came from: ``` /home/ocaml/french-law/_python_venv/lib/python3.12/site-packages/french_law/Aides_logement.py:27785: error: invalid syntax [syntax] ``` the file at this point contains: ``` def traitement_aide_finale_montee_en_charge_saint_pierre_miquelon1( aide_finale4:Money): → traitement_aide_finale_montee_en_charge_saint_pierre_miquelon4 = handle_exceptions( [], [] ) ``` This workaround adds parens after `=`, which ensures the syntax will be correct. --- compiler/scalc/to_python.ml | 6 +- tests/backends/python_name_clash.catala_en | 74 ++++---- .../good/toplevel_defs.catala_en | 175 +++++++++--------- 3 files changed, 129 insertions(+), 126 deletions(-) diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 189271e5..0a02f0f8 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -356,7 +356,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit assert false (* We don't need to declare variables in Python *) | SLocalDef { name = v; expr = e; _ } | SLocalInit { name = v; expr = e; _ } -> - Format.fprintf fmt "@[%a = %a@]" VarName.format (Mark.remove v) + Format.fprintf fmt "@[%a = (%a)@]" VarName.format (Mark.remove v) (format_expression ctx) e | STryWEmpty { try_block = try_b; with_block = catch_b } -> Format.fprintf fmt "@[try:@ %a@]@,@[except Empty:@ %a@]" @@ -420,9 +420,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit and format_block ctx (fmt : Format.formatter) (b : block) : unit = Format.pp_open_vbox fmt 0; - Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@,") - (format_statement ctx) fmt + Format.pp_print_list (format_statement ctx) fmt (List.filter (fun s -> match Mark.remove s with SLocalDecl _ -> false | _ -> true) b); diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en index 5423b029..d3d31afc 100644 --- a/tests/backends/python_name_clash.catala_en +++ b/tests/backends/python_name_clash.catala_en @@ -90,31 +90,31 @@ class BIn: def some_name(some_name_in:SomeNameIn): - i = some_name_in.i_in - o4 = handle_exceptions([], []) + i = (some_name_in.i_in) + o4 = (handle_exceptions([], [])) if o4 is None: if True: - o3 = (i + integer_of_string("1")) + o3 = ((i + integer_of_string("1"))) else: - o3 = None + o3 = (None) else: x = o4 - o3 = x - o5 = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=10, start_column=23, end_line=10, end_column=28, - law_headings=[])], - [o3] - ) + o3 = (x) + o5 = (handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=10, start_column=23, + end_line=10, end_column=28, law_headings=[])], + [o3] + )) if o5 is None: if False: - o2 = None + o2 = (None) else: - o2 = None + o2 = (None) else: x = o5 - o2 = x + o2 = (x) if o2 is None: raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", @@ -122,35 +122,35 @@ def some_name(some_name_in:SomeNameIn): end_line=7, end_column=11, law_headings=[])) else: arg = o2 - o1 = arg - o = o1 + o1 = (arg) + o = (o1) return SomeName(o = o) def b(b_in:BIn): - result4 = handle_exceptions([], []) + result4 = (handle_exceptions([], [])) if result4 is None: if True: - result3 = integer_of_string("1") + result3 = (integer_of_string("1")) else: - result3 = None + result3 = (None) else: x = result4 - result3 = x - result5 = handle_exceptions( - [SourcePosition( - filename="tests/backends/python_name_clash.catala_en", - start_line=16, start_column=33, - end_line=16, end_column=34, law_headings=[])], - [result3] - ) + result3 = (x) + result5 = (handle_exceptions( + [SourcePosition( + filename="tests/backends/python_name_clash.catala_en", + start_line=16, start_column=33, + end_line=16, end_column=34, law_headings=[])], + [result3] + )) if result5 is None: if False: - result2 = None + result2 = (None) else: - result2 = None + result2 = (None) else: x = result5 - result2 = x + result2 = (x) if result2 is None: raise NoValue(SourcePosition( filename="tests/backends/python_name_clash.catala_en", @@ -158,14 +158,14 @@ def b(b_in:BIn): end_line=16, end_column=25, law_headings=[])) else: arg = result2 - result1 = arg - result = some_name(SomeNameIn(i_in = result1)) - result6 = SomeName(o = result.o) + result1 = (arg) + result = (some_name(SomeNameIn(i_in = result1))) + result6 = (SomeName(o = result.o)) if True: - some_name2 = result6 + some_name2 = (result6) else: - some_name2 = result6 - some_name1 = some_name2 + some_name2 = (result6) + some_name1 = (some_name2) return B(some_name = some_name1) ``` The above should *not* show `some_name = temp_some_name`, but instead `some_name_1 = ...` diff --git a/tests/name_resolution/good/toplevel_defs.catala_en b/tests/name_resolution/good/toplevel_defs.catala_en index 9232b6a7..c0249c91 100644 --- a/tests/name_resolution/good/toplevel_defs.catala_en +++ b/tests/name_resolution/good/toplevel_defs.catala_en @@ -441,9 +441,9 @@ def glob4(x:Money, y:Decimal): return ((decimal_of_money(x) * y) + decimal_of_string("10.")) def glob5_init(): - x = (decimal_of_integer(integer_of_string("2")) * - decimal_of_string("3.")) - y = decimal_of_string("1000.") + x = ((decimal_of_integer(integer_of_string("2")) * + decimal_of_string("3."))) + y = (decimal_of_string("1000.")) return (x * y) glob5 = (glob5_init()) @@ -456,31 +456,32 @@ glob6 = ( ) def s2(s2_in:S2In): - a4 = handle_exceptions([], []) + a4 = (handle_exceptions([], [])) if a4 is None: if True: - a3 = (glob3(money_of_cents_string("4400")) + - decimal_of_string("100.")) + a3 = ((glob3(money_of_cents_string("4400")) + + decimal_of_string("100."))) else: - a3 = None + a3 = (None) else: x = a4 - a3 = x - a5 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=53, start_column=24, end_line=53, end_column=43, - law_headings=["Test toplevel function defs"])], - [a3] - ) + a3 = (x) + a5 = (handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=53, start_column=24, + end_line=53, end_column=43, + law_headings=["Test toplevel function defs"])], + [a3] + )) if a5 is None: if False: - a2 = None + a2 = (None) else: - a2 = None + a2 = (None) else: x = a5 - a2 = x + a2 = (x) if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", @@ -489,37 +490,38 @@ def s2(s2_in:S2In): law_headings=["Test toplevel function defs"])) else: arg = a2 - a1 = arg - a = a1 + a1 = (arg) + a = (a1) return S2(a = a) def s3(s3_in:S3In): - a4 = handle_exceptions([], []) + a4 = (handle_exceptions([], [])) if a4 is None: if True: - a3 = (decimal_of_string("50.") + + a3 = ((decimal_of_string("50.") + glob4(money_of_cents_string("4400"), - decimal_of_string("55."))) + decimal_of_string("55.")))) else: - a3 = None + a3 = (None) else: x = a4 - a3 = x - a5 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=74, start_column=24, end_line=74, end_column=47, - law_headings=["Test function def with two args"])], - [a3] - ) + a3 = (x) + a5 = (handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=74, start_column=24, + end_line=74, end_column=47, + law_headings=["Test function def with two args"])], + [a3] + )) if a5 is None: if False: - a2 = None + a2 = (None) else: - a2 = None + a2 = (None) else: x = a5 - a2 = x + a2 = (x) if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", @@ -528,35 +530,36 @@ def s3(s3_in:S3In): law_headings=["Test function def with two args"])) else: arg = a2 - a1 = arg - a = a1 + a1 = (arg) + a = (a1) return S3(a = a) def s4(s4_in:S4In): - a4 = handle_exceptions([], []) + a4 = (handle_exceptions([], [])) if a4 is None: if True: - a3 = (glob5 + decimal_of_string("1.")) + a3 = ((glob5 + decimal_of_string("1."))) else: - a3 = None + a3 = (None) else: x = a4 - a3 = x - a5 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=98, start_column=24, end_line=98, end_column=34, - law_headings=["Test inline defs in toplevel defs"])], - [a3] - ) + a3 = (x) + a5 = (handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=98, start_column=24, + end_line=98, end_column=34, + law_headings=["Test inline defs in toplevel defs"])], + [a3] + )) if a5 is None: if False: - a2 = None + a2 = (None) else: - a2 = None + a2 = (None) else: x = a5 - a2 = x + a2 = (x) if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", @@ -565,35 +568,36 @@ def s4(s4_in:S4In): law_headings=["Test inline defs in toplevel defs"])) else: arg = a2 - a1 = arg - a = a1 + a1 = (arg) + a = (a1) return S4(a = a) def s5(s_in:SIn): - a4 = handle_exceptions([], []) + a4 = (handle_exceptions([], [])) if a4 is None: if True: - a3 = (glob1 * glob1) + a3 = ((glob1 * glob1)) else: - a3 = None + a3 = (None) else: x = a4 - a3 = x - a5 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=18, start_column=24, end_line=18, end_column=37, - law_headings=["Test basic toplevel values defs"])], - [a3] - ) + a3 = (x) + a5 = (handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=18, start_column=24, + end_line=18, end_column=37, + law_headings=["Test basic toplevel values defs"])], + [a3] + )) if a5 is None: if False: - a2 = None + a2 = (None) else: - a2 = None + a2 = (None) else: x = a5 - a2 = x + a2 = (x) if a2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", @@ -602,32 +606,33 @@ def s5(s_in:SIn): law_headings=["Test basic toplevel values defs"])) else: arg = a2 - a1 = arg - a = a1 - b4 = handle_exceptions([], []) + a1 = (arg) + a = (a1) + b4 = (handle_exceptions([], [])) if b4 is None: if True: - b3 = glob6 + b3 = (glob6) else: - b3 = None + b3 = (None) else: x = b4 - b3 = x - b5 = handle_exceptions( - [SourcePosition( - filename="tests/name_resolution/good/toplevel_defs.catala_en", - start_line=19, start_column=24, end_line=19, end_column=29, - law_headings=["Test basic toplevel values defs"])], - [b3] - ) + b3 = (x) + b5 = (handle_exceptions( + [SourcePosition( + filename="tests/name_resolution/good/toplevel_defs.catala_en", + start_line=19, start_column=24, + end_line=19, end_column=29, + law_headings=["Test basic toplevel values defs"])], + [b3] + )) if b5 is None: if False: - b2 = None + b2 = (None) else: - b2 = None + b2 = (None) else: x = b5 - b2 = x + b2 = (x) if b2 is None: raise NoValue(SourcePosition( filename="tests/name_resolution/good/toplevel_defs.catala_en", @@ -636,7 +641,7 @@ def s5(s_in:SIn): law_headings=["Test basic toplevel values defs"])) else: arg = b2 - b1 = arg - b = b1 + b1 = (arg) + b = (b1) return S(a = a, b = b) ``` From 8e91dcb281a982f3da257ec2d8d76010ebe53360 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 28 Aug 2024 14:09:02 +0200 Subject: [PATCH 9/9] Apply suggestions from code review Thanks @vincent-botbol Co-authored-by: vbot --- compiler/scalc/print.ml | 2 -- 1 file changed, 2 deletions(-) diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index a3703cdc..fd80fe69 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -22,11 +22,9 @@ let needs_parens (_e : expr) : bool = false let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit = VarName.format fmt v -(* Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v) *) let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = FuncName.format fmt v -(* Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.id v) *) let rec format_expr (decl_ctx : decl_ctx)