From b9156bb60ed7233e25de70ccb0a38ae24446ec97 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Mon, 5 Aug 2024 17:08:36 +0200 Subject: [PATCH] Implement safe renaming of idents for backend printing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously we had some heuristics in the backends trying to achieve this with a lot of holes ; this should be much more solid, relying on `Bindlib` to do the correct renamings. **Note1**: it's not plugged into the backends other than OCaml at the moment. **Note2**: the related, obsolete heuristics haven't been cleaned out yet **Note3**: we conservatively suppose a single namespace at the moment. This is required for e.g. Python, but it forces vars named like struct fields to be renamed, which is more verbose in e.g. OCaml. The renaming engine could be improved to support different namespaces, with a way to select how to route the different kinds of identifiers into them. Similarly, customisation for what needs to be uppercase or lowercase is not available yet. **Note4**: besides excluding keywords, we should also be careful to exclude (or namespace): - the idents used in the runtime (e.g. `o_add_int_int`) - the dynamically generated idents (e.g. `embed_*`) **Note5**: module names themselves aren't handled yet. The reason is that they must be discoverable by the user, and even need to match the filenames, etc. In other words, imagine that `Mod` is a keyword in the target language. You can't rename a module called `Mod` to `Mod1` without knowing the whole module context, because that would destroy the mapping for a module already called `Mod1`. A reliable solution would be to translate all module names to e.g. `CatalaModule_*`, which we can assume will never conflict with any built-in, and forbid idents starting with that prefix. We may also want to restrict their names to ASCII ? Currently we use a projection, but what if I have two modules called `Là` and `La` ? --- compiler/catala_utils/uid.mli | 3 +- compiler/dcalc/from_scopelang.ml | 11 +- compiler/lcalc/closure_conversion.ml | 5 +- compiler/lcalc/to_ocaml.ml | 107 +++---- compiler/lcalc/to_ocaml.mli | 1 - compiler/plugins/api_web.ml | 66 ++-- compiler/plugins/explain.ml | 16 +- compiler/plugins/json_schema.ml | 1 - compiler/plugins/lazy_interp.ml | 2 +- compiler/scalc/from_lcalc.ml | 2 +- compiler/shared_ast/boundList.ml | 2 +- compiler/shared_ast/boundList.mli | 4 +- compiler/shared_ast/definitions.ml | 7 +- compiler/shared_ast/expr.ml | 295 ++++++++++++------ compiler/shared_ast/expr.mli | 59 +++- compiler/shared_ast/program.ml | 172 +++++++++- compiler/shared_ast/program.mli | 15 + compiler/shared_ast/scope.ml | 70 ++++- compiler/shared_ast/scope.mli | 13 + compiler/shared_ast/type.ml | 15 + compiler/shared_ast/type.mli | 3 + compiler/shared_ast/typing.ml | 2 +- compiler/verification/conditions.ml | 2 +- tests/modules/good/output/mod_def.ml | 37 ++- tests/name_resolution/good/let_in2.catala_en | 35 +-- .../191_fix_record_name_confusion.catala_en | 21 +- 26 files changed, 679 insertions(+), 287 deletions(-) diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index a51e6cf4..d377fc69 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -103,6 +103,7 @@ module Gen_qualified (_ : Style) () : sig val fresh : Path.t -> MarkedString.info -> t val path : t -> Path.t val get_info : t -> MarkedString.info + val hash : strip:Path.t -> t -> Hash.t - (* [strip] strips that prefix from the start of the path before hashing *) + (** [strip] strips that prefix from the start of the path before hashing *) end diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 8761ca4a..53bdf73b 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -941,8 +941,11 @@ let translate_program (prgm : 'm S.program) : 'm Ast.program = (* the resulting expression is the list of definitions of all the scopes, ending with the top-level scope. The decl_ctx is filled in left-to-right order, then the chained scopes aggregated from the right. *) - let rec translate_defs = function - | [] -> Bindlib.box (Last ()) + let rec translate_defs vlist = function + | [] -> + Bindlib.box_apply + (fun vl -> Last vl) + (Bindlib.box_rev_list (List.map Bindlib.box_var vlist)) | def :: next -> let dvar, def = match def with @@ -971,13 +974,13 @@ let translate_program (prgm : 'm S.program) : 'm Ast.program = (fun body -> ScopeDef (scope_name, body)) scope_body ) in - let scope_next = translate_defs next in + let scope_next = translate_defs (dvar :: vlist) next in let next_bind = Bindlib.bind_var dvar scope_next in Bindlib.box_apply2 (fun item next_bind -> Cons (item, next_bind)) def next_bind in - let items = translate_defs defs_ordering in + let items = translate_defs [] defs_ordering in Expr.Box.assert_closed items; { code_items = Bindlib.unbox items; diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 961861f9..277ea829 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -440,7 +440,7 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = Bindlib.box_apply (fun e -> Topdef (name, (TAny, Mark.get ty), e)) (Expr.Box.lift new_expr) )) - ~last:(fun _ () -> (), Bindlib.box ()) + ~last:(fun _ vlist -> (), Scope.map_last_item ~varf:Fun.id vlist) ~init:Var.Map.empty p.code_items in (* Now we need to further tweak [decl_ctx] because some of the user-defined @@ -612,7 +612,8 @@ let rec hoist_closures_code_item_list (code_items : (lcalc, 'm) gexpr code_item_list) : (lcalc, 'm) gexpr code_item_list Bindlib.box = match code_items with - | Last () -> Bindlib.box (Last ()) + | Last vlist -> + Bindlib.box_apply (fun l -> Last l) (Scope.map_last_item ~varf:Fun.id vlist) | Cons (code_item, next_code_items) -> let code_item_var, next_code_items = Bindlib.unbind next_code_items in let hoisted_closures, new_code_item = diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index fe6dd665..a98bbf19 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -130,38 +130,23 @@ let ocaml_keywords = "Oper"; ] -let ocaml_keywords_set = String.Set.of_list ocaml_keywords - -let avoid_keywords (s : string) : string = - if String.Set.mem s ocaml_keywords_set then s ^ "_user" else s -(* Fixme: this could cause clashes if the user program contains both e.g. [new] - and [new_user] *) - -let ppclean fmt str = - str |> String.to_ascii |> avoid_keywords |> Format.pp_print_string fmt - -let ppsnake fmt str = - str - |> String.to_ascii - |> String.to_snake_case - |> avoid_keywords - |> Format.pp_print_string fmt - let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = (match StructName.path v with | [] -> () | path -> - ppclean fmt (Uid.Path.to_string path); + Uid.Path.format fmt path; Format.pp_print_char fmt '.'); - ppsnake fmt (Mark.remove (StructName.get_info v)) + assert ( + let n = Mark.remove (StructName.get_info v) in + n = String.capitalize_ascii n); + Format.pp_print_string fmt (Mark.remove (StructName.get_info v)) let format_to_module_name (fmt : Format.formatter) (name : [< `Ename of EnumName.t | `Sname of StructName.t ]) = - ppclean fmt - (match name with - | `Ename v -> EnumName.to_string v - | `Sname v -> StructName.to_string v) + match name with + | `Ename v -> EnumName.format fmt v + | `Sname v -> StructName.format fmt v let format_struct_field_name (fmt : Format.formatter) @@ -171,20 +156,16 @@ let format_struct_field_name format_to_module_name fmt (`Sname sname); Format.pp_print_char fmt '.') sname_opt; - ppclean fmt (StructField.to_string v) + StructField.format fmt v let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = - (match EnumName.path v with - | [] -> () - | path -> - ppclean fmt (Uid.Path.to_string path); - Format.pp_print_char fmt '.'); - ppsnake fmt (Mark.remove (EnumName.get_info v)) + EnumName.format fmt v let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : unit = - ppclean fmt (EnumConstructor.to_string v) + EnumConstructor.format fmt v +(* TODO: these names should be properly registered before renaming *) let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = match Mark.remove ty with | TLit TUnit -> Format.pp_print_string fmt "embed_unit" @@ -195,16 +176,12 @@ let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = | TLit TDate -> Format.pp_print_string fmt "embed_date" | TLit TDuration -> Format.pp_print_string fmt "embed_duration" | TStruct s_name -> - Format.fprintf fmt "%a%sembed_%a" ppclean - (Uid.Path.to_string (StructName.path s_name)) - (if StructName.path s_name = [] then "" else ".") - ppsnake + Format.fprintf fmt "%aembed_%a" Uid.Path.format (StructName.path s_name) + Format.pp_print_string (Uid.MarkedString.to_string (StructName.get_info s_name)) | TEnum e_name -> - Format.fprintf fmt "%a%sembed_%a" ppclean - (Uid.Path.to_string (EnumName.path e_name)) - (if EnumName.path e_name = [] then "" else ".") - ppsnake + Format.fprintf fmt "%aembed_%a" Uid.Path.format (EnumName.path e_name) + Format.pp_print_string (Uid.MarkedString.to_string (EnumName.get_info e_name)) | TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty | _ -> Format.pp_print_string fmt "unembeddable" @@ -243,20 +220,7 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = | TClosureEnv -> Format.fprintf fmt "Obj.t" let format_var_str (fmt : Format.formatter) (v : string) : unit = - let lowercase_name = String.to_snake_case (String.to_ascii v) in - let lowercase_name = - Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") - ~subst:(fun _ -> "_dot_") - lowercase_name - in - let lowercase_name = String.to_ascii lowercase_name in - if - List.mem lowercase_name ["handle_default"; "handle_default_opt"] - (* O_O *) - || String.begins_with_uppercase v - then Format.pp_print_string fmt lowercase_name - else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name - else Format.fprintf fmt "%s_" lowercase_name + Format.pp_print_string fmt v let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = format_var_str fmt (Bindlib.name_of v) @@ -561,15 +525,9 @@ let format_ctx Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) (type_ordering @ scope_structs) -let rename_vars e = - Expr.( - unbox - (rename_vars ~exclude:ocaml_keywords ~reset_context_for_closed_terms:true - ~skip_constant_binders:true ~constant_binder_name:(Some "_") e)) - let format_expr ctx fmt e = Format.pp_open_vbox fmt 0; - format_expr ctx fmt (rename_vars e); + format_expr ctx fmt e; Format.pp_close_box fmt () let format_scope_body_expr @@ -594,7 +552,7 @@ let format_code_items (code_items : 'm Ast.expr code_item_list) : ('m Ast.expr Var.t * 'm Ast.expr code_item) String.Map.t = Format.pp_open_vbox fmt 0; - let var_bindings, () = + let var_bindings, _ = BoundList.fold_left ~f:(fun bnd item var -> match item with @@ -761,14 +719,9 @@ let format_module_registration Format.pp_print_newline fmt () let header = - {ocaml| -(** This file has been generated by the Catala compiler, do not edit! *) - -open Runtime_ocaml.Runtime - -[@@@ocaml.warning "-4-26-27-32-41-42"] - -|ocaml} + "(** This file has been generated by the Catala compiler, do not edit! *)\n\n\ + open Runtime_ocaml.Runtime\n\n\ + [@@@ocaml.warning \"-4-26-27-32-41-42\"]\n\n" let format_program (fmt : Format.formatter) @@ -777,6 +730,22 @@ let format_program ~(hashf : Hash.t -> Hash.full) (p : 'm Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = + let p, ren_ctx = + Program.rename_ids p + ~reserved:ocaml_keywords + (* TODO: add catala runtime built-ins as reserved as well ? *) + ~reset_context_for_closed_terms:true ~skip_constant_binders:true + ~constant_binder_name:(Some "_") + in + let type_ordering = + let open Scopelang.Dependency.TVertex in + List.map + (function + | Struct s -> Struct (Expr.Renaming.struct_name ren_ctx s) + | Enum e -> Enum (Expr.Renaming.enum_name ren_ctx e)) + type_ordering + in + (* Print.program fmt p; *) Format.pp_open_vbox fmt 0; Format.pp_print_string fmt header; check_and_reexport_used_modules fmt ~hashf diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index 85f49490..abee6f03 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -19,7 +19,6 @@ open Shared_ast (** Formats a lambda calculus program into a valid OCaml program *) -val avoid_keywords : string -> string val typ_needs_parens : typ -> bool (* val needs_parens : 'm expr -> bool *) diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index b7fb1e74..784af796 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -31,7 +31,6 @@ module To_jsoo = struct StructField.to_string v |> String.to_camel_case |> String.uncapitalize_ascii - |> avoid_keywords |> Format.pp_print_string ppf (* Supersedes [To_ocaml.format_struct_name], which can refer to enums from @@ -40,7 +39,6 @@ module To_jsoo = struct StructName.to_string name |> String.map (function '.' -> '_' | c -> c) |> String.to_snake_case - |> avoid_keywords |> Format.pp_print_string ppf (* Supersedes [To_ocaml.format_enum_name], which can refer to enums from other @@ -49,7 +47,6 @@ module To_jsoo = struct EnumName.to_string name |> String.map (function '.' -> '_' | c -> c) |> String.to_snake_case - |> avoid_keywords |> Format.pp_print_string ppf let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = @@ -160,7 +157,6 @@ module To_jsoo = struct |> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") |> String.uncapitalize_ascii - |> avoid_keywords in if List.mem lowercase_name ["handle_default"; "handle_default_opt"] @@ -388,41 +384,43 @@ module To_jsoo = struct (_ctx : decl_ctx) (fmt : Format.formatter) (scopes : 'e code_item_list) = - BoundList.iter - ~f:(fun var code_item -> - match code_item with - | Topdef _ -> () - | ScopeDef (_name, body) -> - let fmt_fun_call fmt _ = - Format.fprintf fmt - "@[@[execute_or_throw_error@ (@[fun () ->@ %a@ \ - |> %a_of_js@ |> %a@ |> %a_to_js@])@]@]" - fmt_input_struct_name body fmt_input_struct_name body format_var - var fmt_output_struct_name body - in - Format.fprintf fmt - "@\n@\n@[let %a@ (%a : %a Js.t)@ : %a Js.t =@\n%a@]@\n" - format_var var fmt_input_struct_name body fmt_input_struct_name body - fmt_output_struct_name body fmt_fun_call ()) - scopes + ignore + @@ BoundList.iter + ~f:(fun var code_item -> + match code_item with + | Topdef _ -> () + | ScopeDef (_name, body) -> + let fmt_fun_call fmt _ = + Format.fprintf fmt + "@[@[execute_or_throw_error@ (@[fun () ->@ \ + %a@ |> %a_of_js@ |> %a@ |> %a_to_js@])@]@]" + fmt_input_struct_name body fmt_input_struct_name body + format_var var fmt_output_struct_name body + in + Format.fprintf fmt + "@\n@\n@[let %a@ (%a : %a Js.t)@ : %a Js.t =@\n%a@]@\n" + format_var var fmt_input_struct_name body fmt_input_struct_name + body fmt_output_struct_name body fmt_fun_call ()) + scopes let format_scopes_to_callbacks (_ctx : decl_ctx) (fmt : Format.formatter) (scopes : 'e code_item_list) : unit = - BoundList.iter - ~f:(fun var code_item -> - match code_item with - | Topdef _ -> () - | ScopeDef (_name, body) -> - let fmt_meth_name fmt _ = - Format.fprintf fmt "method %a : (%a Js.t -> %a Js.t) Js.callback" - format_var_camel_case var fmt_input_struct_name body - fmt_output_struct_name body - in - Format.fprintf fmt "@,@[%a =@ Js.wrap_callback@ %a@]@," - fmt_meth_name () format_var var) - scopes + ignore + @@ BoundList.iter + ~f:(fun var code_item -> + match code_item with + | Topdef _ -> () + | ScopeDef (_name, body) -> + let fmt_meth_name fmt _ = + Format.fprintf fmt "method %a : (%a Js.t -> %a Js.t) Js.callback" + format_var_camel_case var fmt_input_struct_name body + fmt_output_struct_name body + in + Format.fprintf fmt "@,@[%a =@ Js.wrap_callback@ %a@]@," + fmt_meth_name () format_var var) + scopes let format_program (fmt : Format.formatter) diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index afed754b..fa194ac5 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -436,7 +436,7 @@ let result_level base_vars = let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : ('t, 'm) gexpr * Env.t = let ctx = prg.decl_ctx in - let (all_env, scopes), () = + let (all_env, scopes), _ = BoundList.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) ~f:(fun (env, scopes) item v -> match item with @@ -611,7 +611,7 @@ let program_to_graph Expr.map_marks ~f:(fun m -> Custom { pos = Expr.mark_pos m; custom = { conditions = [] } }) in - let (all_env, scopes), () = + let (all_env, scopes), _ = BoundList.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) ~f:(fun (env, scopes) item v -> match item with @@ -619,7 +619,17 @@ let program_to_graph let e = Scope.to_expr ctx body in let e = customize (Expr.unbox e) in let e = Expr.remove_logging_calls (Expr.unbox e) in - let e = Expr.rename_vars (Expr.unbox e) in + let e = + Expr.Renaming.expr + (Expr.Renaming.get_ctx + { + Expr.Renaming.reserved = []; + reset_context_for_closed_terms = false; + skip_constant_binders = false; + constant_binder_name = None; + }) + (Expr.unbox e) + in ( Env.add (Var.translate v) (Expr.unbox e) env env, ScopeName.Map.add name (v, body.scope_body_input_struct) scopes ) | Topdef (_, _, e) -> diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index d187e59e..bdfc0e2b 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -35,7 +35,6 @@ module To_json = struct Format.asprintf "%a" StructField.format v |> String.to_ascii |> String.to_snake_case - |> avoid_keywords |> to_camel_case in Format.fprintf fmt "%s" s diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index 52d549c9..5f7e69d5 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -228,7 +228,7 @@ let rec lazy_eval : let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : ('t, 'm) gexpr * 'm Env.t = let ctx = prg.decl_ctx in - let (all_env, scopes), () = + let (all_env, scopes), _ = BoundList.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) ~f:(fun (env, scopes) item v -> match item with diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index a5d08406..41ced7ef 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -631,7 +631,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : (Program.modules_to_list p.decl_ctx.ctx_modules) in let ctx = { A.decl_ctx = p.decl_ctx; A.modules } in - let (_, _, rev_items), () = + let (_, _, rev_items), _vlist = BoundList.fold_left ~f:(fun (func_dict, var_dict, rev_items) code_item var -> match code_item with diff --git a/compiler/shared_ast/boundList.ml b/compiler/shared_ast/boundList.ml index ac08591f..cafaad73 100644 --- a/compiler/shared_ast/boundList.ml +++ b/compiler/shared_ast/boundList.ml @@ -21,7 +21,7 @@ type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list = | Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder let rec to_seq = function - | Last () -> Seq.empty + | Last _ -> Seq.empty | Cons (item, next_bind) -> fun () -> let v, next = Bindlib.unbind next_bind in diff --git a/compiler/shared_ast/boundList.mli b/compiler/shared_ast/boundList.mli index 2bd1c524..bd89e119 100644 --- a/compiler/shared_ast/boundList.mli +++ b/compiler/shared_ast/boundList.mli @@ -30,7 +30,9 @@ type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list = | Last of 'last | Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder -val to_seq : (((_, _) gexpr as 'e), 'elt, unit) t -> ('e Var.t * 'elt) Seq.t +val to_seq : (((_, _) gexpr as 'e), 'elt, _) t -> ('e Var.t * 'elt) Seq.t +(** Note that the boundlist terminator is ignored in the resulting sequence *) + val last : (_, _, 'a) t -> 'a val iter : f:('e Var.t -> 'elt -> unit) -> ('e, 'elt, 'last) t -> 'last val find : f:('elt -> 'a option) -> (_, 'elt, _) t -> 'a diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 83c00386..e02cf9d1 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -645,7 +645,12 @@ type 'e code_item = | ScopeDef of ScopeName.t * 'e scope_body | Topdef of TopdefName.t * typ * 'e -type 'e code_item_list = ('e, 'e code_item, unit) bound_list +type 'e code_item_list = ('e, 'e code_item, 'naked_e list) bound_list + constraint 'e = ('naked_e, _) Mark.ed +(* The bound_list terminator is a naked expression list that is not part of the + program: it contains the list of exported variables, so that Bindlib + correctly understands these variables as being used *) + type struct_ctx = typ StructField.Map.t StructName.Map.t type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 9f8a0d3f..c7ce9dec 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -72,39 +72,26 @@ module Box = struct let lift_scope_vars = LiftScopeVars.lift_box - module Ren = struct - module Set = Set.Make (String) - - type ctxt = Set.t - - let skip_constant_binders = true - let reset_context_for_closed_terms = true - let constant_binder_name = None - let empty_ctxt = Set.empty - let reserve_name n s = Set.add n s - let new_name n s = n, Set.add n s - end - - module Ctx = Bindlib.Ctxt (Ren) - - let fv b = Ren.Set.elements (Ctx.free_vars b) - let assert_closed b = - match fv b with - | [] -> () - | [h] -> + if not (Bindlib.is_closed b) then + (* This is a bit convoluted, but we just want to extract the free + variables names for debug *) + let module Ctx = Bindlib.Ctxt (struct + type ctxt = String.Set.t + + let skip_constant_binders = true + let reset_context_for_closed_terms = true + let constant_binder_name = None + let empty_ctxt = String.Set.empty + let reserve_name n s = String.Set.add n s + let new_name n s = n, String.Set.add n s + end) in Message.error ~internal:true - "The boxed term is not closed the variable %s is free in the global \ - context" - h - | l -> - Message.error ~internal:true - "The boxed term is not closed the variables %a is free in the global \ - context" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.pp_print_string fmt "; ") + "The boxed term is not closed, these variables are free in it:@ \ + @[%a@]" + (Format.pp_print_list ~pp_sep:Format.pp_print_space Format.pp_print_string) - l + (String.Set.elements (Ctx.free_vars b)) end let bind vars e = Bindlib.bind_mvar vars (Box.lift e) @@ -828,84 +815,194 @@ let remove_logging_calls e = in f e -module DefaultBindlibCtxRename = struct - (* This code is a copy-paste from Bindlib, they forgot to expose the default - implementation ! *) - type ctxt = int String.Map.t +module Renaming = struct + module DefaultBindlibCtxRename : Bindlib.Renaming = struct + (* This code is a copy-paste from Bindlib, they forgot to expose the default + implementation ! *) + type ctxt = int String.Map.t - let empty_ctxt = String.Map.empty + let empty_ctxt = String.Map.empty - let split_name : string -> string * int = - fun name -> - let len = String.length name in - (* [i] is the index of the first first character of the suffix. *) - let i = - let is_digit c = '0' <= c && c <= '9' in - let first_digit = ref len in - let first_non_0 = ref len in - while !first_digit > 0 && is_digit name.[!first_digit - 1] do - decr first_digit; - if name.[!first_digit] <> '0' then first_non_0 := !first_digit - done; - !first_non_0 + let split_name : string -> string * int = + fun name -> + let len = String.length name in + (* [i] is the index of the first first character of the suffix. *) + let i = + let is_digit c = '0' <= c && c <= '9' in + let first_digit = ref len in + let first_non_0 = ref len in + while !first_digit > 0 && is_digit name.[!first_digit - 1] do + decr first_digit; + if name.[!first_digit] <> '0' then first_non_0 := !first_digit + done; + !first_non_0 + in + if i = len then name, 0 + else String.sub name 0 i, int_of_string (String.sub name i (len - i)) + + let get_suffix : string -> int -> ctxt -> int * ctxt = + fun name suffix ctxt -> + let n = + try String.Map.find name ctxt with String.Map.Not_found _ -> -1 + in + let suffix = if suffix > n then suffix else n + 1 in + suffix, String.Map.add name suffix ctxt + + let merge_name : string -> int -> string = + fun prefix suffix -> + if suffix > 0 then prefix ^ string_of_int suffix else prefix + + let new_name : string -> ctxt -> string * ctxt = + fun name ctxt -> + let prefix, suffix = split_name name in + let suffix, ctxt = get_suffix prefix suffix ctxt in + merge_name prefix suffix, ctxt + + let reserve_name : string -> ctxt -> ctxt = + fun name ctxt -> + let prefix, suffix = split_name name in + try + let n = String.Map.find prefix ctxt in + if suffix <= n then ctxt else String.Map.add prefix suffix ctxt + with String.Map.Not_found _ -> String.Map.add prefix suffix ctxt + + let reset_context_for_closed_terms = false + let skip_constant_binders = false + let constant_binder_name = None + end + + module type BindlibCtxt = module type of Bindlib.Ctxt (DefaultBindlibCtxRename) + + type config = { + reserved : string list; + reset_context_for_closed_terms : bool; + skip_constant_binders : bool; + constant_binder_name : string option; + } + + type context = { + bindCtx : (module BindlibCtxt); + bcontext : DefaultBindlibCtxRename.ctxt; + scopes : ScopeName.t -> ScopeName.t; + topdefs : TopdefName.t -> TopdefName.t; + structs : StructName.t -> StructName.t; + fields : StructField.t -> StructField.t; + enums : EnumName.t -> EnumName.t; + constrs : EnumConstructor.t -> EnumConstructor.t; + } + + let unbind_in ctx ?fname b = + let module BindCtx = (val ctx.bindCtx) in + match fname with + | Some fn -> + let name = fn (Bindlib.binder_name b) in + let v, bcontext = + BindCtx.new_var_in ctx.bcontext (fun v -> EVar v) name + in + let e = Bindlib.subst b (EVar v) in + v, e, { ctx with bcontext } + | None -> + let v, e, bcontext = BindCtx.unbind_in ctx.bcontext b in + v, e, { ctx with bcontext } + + let unmbind_in ctx ?fname b = + let module BindCtx = (val ctx.bindCtx) in + match fname with + | Some fn -> + let names = Array.map fn (Bindlib.mbinder_names b) in + let rvs, bcontext = + Array.fold_left + (fun (rvs, bcontext) n -> + let v, bcontext = BindCtx.new_var_in bcontext (fun v -> EVar v) n in + v :: rvs, bcontext) + ([], ctx.bcontext) names + in + let vs = Array.of_list (List.rev rvs) in + let e = Bindlib.msubst b (Array.map (fun v -> EVar v) vs) in + vs, e, { ctx with bcontext } + | None -> + let vs, e, bcontext = BindCtx.unmbind_in ctx.bcontext b in + vs, e, { ctx with bcontext } + + let set_rewriters ?scopes ?topdefs ?structs ?fields ?enums ?constrs ctx = + (fun ?(scopes = ctx.scopes) ?(topdefs = ctx.topdefs) + ?(structs = ctx.structs) ?(fields = ctx.fields) ?(enums = ctx.enums) + ?(constrs = ctx.constrs) () -> + { ctx with scopes; topdefs; structs; fields; enums; constrs }) + ?scopes ?topdefs ?structs ?fields ?enums ?constrs () + + let new_id ctx name = + let module BindCtx = (val ctx.bindCtx) in + let var, bcontext = + BindCtx.new_var_in ctx.bcontext (fun _ -> assert false) name in - if i = len then name, 0 - else String.sub name 0 i, int_of_string (String.sub name i (len - i)) + Bindlib.name_of var, { ctx with bcontext } - let get_suffix : string -> int -> ctxt -> int * ctxt = - fun name suffix ctxt -> - let n = try String.Map.find name ctxt with String.Map.Not_found _ -> -1 in - let suffix = if suffix > n then suffix else n + 1 in - suffix, String.Map.add name suffix ctxt + let get_ctx cfg = + let module BindCtx = Bindlib.Ctxt (struct + include DefaultBindlibCtxRename - let merge_name : string -> int -> string = - fun prefix suffix -> - if suffix > 0 then prefix ^ string_of_int suffix else prefix + let reset_context_for_closed_terms = cfg.reset_context_for_closed_terms + let skip_constant_binders = cfg.skip_constant_binders + let constant_binder_name = cfg.constant_binder_name + end) in + { + bindCtx = (module BindCtx); + bcontext = + List.fold_left + (fun ctx name -> DefaultBindlibCtxRename.reserve_name name ctx) + BindCtx.empty_ctxt cfg.reserved; + scopes = Fun.id; + topdefs = Fun.id; + structs = Fun.id; + fields = Fun.id; + enums = Fun.id; + constrs = Fun.id; + } - let new_name : string -> ctxt -> string * ctxt = - fun name ctxt -> - let prefix, suffix = split_name name in - let suffix, ctxt = get_suffix prefix suffix ctxt in - merge_name prefix suffix, ctxt + let rec typ ctx = function + | TStruct n, m -> TStruct (ctx.structs n), m + | TEnum n, m -> TEnum (ctx.enums n), m + | ty -> Type.map (typ ctx) ty - let reserve_name : string -> ctxt -> ctxt = - fun name ctxt -> - let prefix, suffix = split_name name in - try - let n = String.Map.find prefix ctxt in - if suffix <= n then ctxt else String.Map.add prefix suffix ctxt - with String.Map.Not_found _ -> String.Map.add prefix suffix ctxt -end - -let rename_vars - ?(exclude = ([] : string list)) - ?(reset_context_for_closed_terms = false) - ?(skip_constant_binders = false) - ?(constant_binder_name = None) - e = - let module BindCtx = Bindlib.Ctxt (struct - include DefaultBindlibCtxRename - - let reset_context_for_closed_terms = reset_context_for_closed_terms - let skip_constant_binders = skip_constant_binders - let constant_binder_name = constant_binder_name - end) in - let rec aux : type a. BindCtx.ctxt -> (a, 't) gexpr -> (a, 't) gexpr boxed = - fun ctx e -> - match e with + let rec expr : type k. context -> (k, 'm) gexpr -> (k, 'm) gexpr boxed = + fun ctx -> function + | EExternal { name = External_scope s, pos }, m -> + eexternal ~name:(External_scope (ctx.scopes s), pos) m + | EExternal { name = External_value d, pos }, m -> + eexternal ~name:(External_value (ctx.topdefs d), pos) m | EAbs { binder; tys }, m -> - let vars, body, ctx = BindCtx.unmbind_in ctx binder in - let body = aux ctx body in + let vars, body, ctx = unmbind_in ctx ~fname:String.to_snake_case binder in + let body = expr ctx body in let binder = bind vars body in - eabs binder tys m - | e -> map ~f:(aux ctx) ~op:Fun.id e - in - let ctx = - List.fold_left - (fun ctx name -> DefaultBindlibCtxRename.reserve_name name ctx) - BindCtx.empty_ctxt exclude - in - aux ctx e + eabs binder (List.map (typ ctx) tys) m + | EStruct { name; fields }, m -> + estruct ~name:(ctx.structs name) + ~fields: + (StructField.Map.fold + (fun fld e -> StructField.Map.add (ctx.fields fld) (expr ctx e)) + fields StructField.Map.empty) + m + | EStructAccess { name; field; e }, m -> + estructaccess ~name:(ctx.structs name) ~field:(ctx.fields field) + ~e:(expr ctx e) m + | EInj { name; e; cons }, m -> + einj ~name:(ctx.enums name) ~cons:(ctx.constrs cons) ~e:(expr ctx e) m + | EMatch { name; e; cases }, m -> + ematch ~name:(ctx.enums name) + ~cases: + (EnumConstructor.Map.fold + (fun cons e -> + EnumConstructor.Map.add (ctx.constrs cons) (expr ctx e)) + cases EnumConstructor.Map.empty) + ~e:(expr ctx e) m + | e -> map ~typ:(typ ctx) ~f:(expr ctx) ~op:Fun.id e + + let scope_name ctx s = ctx.scopes s + let topdef_name ctx s = ctx.topdefs s + let struct_name ctx s = ctx.structs s + let enum_name ctx e = ctx.enums e +end let format ppf e = Print.expr ~debug:false () ppf e diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 4ebf78ab..1d534a7b 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -393,16 +393,52 @@ val remove_logging_calls : (** Removes all calls to [Log] unary operators in the AST, replacing them by their argument. *) -val rename_vars : - ?exclude:string list -> - ?reset_context_for_closed_terms:bool -> - ?skip_constant_binders:bool -> - ?constant_binder_name:string option -> - ('a, 'm) gexpr -> - ('a, 'm) boxed_gexpr -(** Disambiguates all variable names in [e]. [exclude] will blacklist the given - names (useful for keywords or built-in names) ; the other flags behave as - defined in the bindlib documentation for module type [Rename] *) +(** {2 Renamings and formatting} *) + +module Renaming : sig + type config = { + reserved : string list; (** Use for keywords and built-ins *) + reset_context_for_closed_terms : bool; (** See [Bindlib.Rename] *) + skip_constant_binders : bool; (** See [Bindlib.Rename] *) + constant_binder_name : string option; (** See [Bindlib.Rename] *) + } + + type context + + val get_ctx : config -> context + + val unbind_in : + context -> + ?fname:(string -> string) -> + ('e, 'b) Bindlib.binder -> + ('e, _) Mark.ed Var.t * 'b * context + (* [fname] applies a transformation on the variable name (typically something + like [String.to_snake_case]). The result is advisory and a numerical suffix + may be appended or modified *) + + val new_id : context -> string -> string * context + + val set_rewriters : + ?scopes:(ScopeName.t -> ScopeName.t) -> + ?topdefs:(TopdefName.t -> TopdefName.t) -> + ?structs:(StructName.t -> StructName.t) -> + ?fields:(StructField.t -> StructField.t) -> + ?enums:(EnumName.t -> EnumName.t) -> + ?constrs:(EnumConstructor.t -> EnumConstructor.t) -> + context -> + context + + val typ : context -> typ -> typ + + val expr : context -> ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr + (** Disambiguates all variable names in [e], and renames structs, fields, + enums and constrs according to the given context configuration *) + + val scope_name : context -> ScopeName.t -> ScopeName.t + val topdef_name : context -> TopdefName.t -> TopdefName.t + val struct_name : context -> StructName.t -> StructName.t + val enum_name : context -> EnumName.t -> EnumName.t +end val format : Format.formatter -> ('a, 'm) gexpr -> unit (** Simple printing without debug, use [Print.expr ()] instead to follow the @@ -496,9 +532,6 @@ module Box : sig 'm mark -> ('a, 'm) boxed_gexpr - val fv : 'b Bindlib.box -> string list - (** [fv] return the list of free variables from a boxed term. *) - val assert_closed : 'b Bindlib.box -> unit (** [assert_closed b] check there is no free variables in then [b] boxed term. It raises an internal error if it not the case, printing all free diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 753a1c74..2dd81cd3 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -15,6 +15,7 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils open Definitions let map_decl_ctx ~f ctx = @@ -35,7 +36,7 @@ let map_exprs ?typ ~f ~varf { code_items; decl_ctx; lang; module_name } = { code_items; decl_ctx; lang; module_name }) (Scope.map_exprs ?typ ~f ~varf code_items) in - assert (Bindlib.is_closed boxed_prg); + Expr.Box.assert_closed boxed_prg; Bindlib.unbox boxed_prg let fold_left ~f ~init { code_items; _ } = @@ -46,7 +47,7 @@ let fold_exprs ~f ~init prg = Scope.fold_exprs ~f ~init prg.code_items let fold_right ~f ~init { code_items; _ } = BoundList.fold_right ~f:(fun e _ acc -> f e acc) - ~init:(fun () -> init) + ~init:(fun _vlist -> init) code_items let empty_ctx = @@ -95,3 +96,170 @@ let modules_to_list (mt : module_tree) = mtree acc in List.rev (aux [] mt) + +(* Todo? - add handling for specific naming constraints (automatically convert + to camel/snake-case, etc.) - register module names as reserved names *) +let rename_ids + ~reserved + ~reset_context_for_closed_terms + ~skip_constant_binders + ~constant_binder_name + p = + let cap s = String.to_camel_case s in + let uncap s = String.to_snake_case s in + let cfg = + { + Expr.Renaming.reserved; + reset_context_for_closed_terms; + skip_constant_binders; + constant_binder_name; + } + in + let ctx = Expr.Renaming.get_ctx cfg in + (* Each module needs its separate ctx since resolution is qualified ; and name + resolution in a given module must be processed consistently independently + on the current context. *) + let module PathMap = Map.Make (Uid.Path) in + let pctxmap = PathMap.singleton [] ctx in + let pctxmap, structs_map, fields_map, ctx_structs = + StructName.Map.fold + (fun name fields (pctxmap, structs_map, fields_map, ctx_structs) -> + let path = StructName.path name in + let str, pos = StructName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with PathMap.Not_found _ -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (cap str) in + let new_name = StructName.fresh path (id, pos) in + let ctx, fields_map, ctx_fields = + StructField.Map.fold + (fun name ty (ctx, fields_map, ctx_fields) -> + let str, pos = StructField.get_info name in + let id, ctx = Expr.Renaming.new_id ctx (uncap str) in + let new_name = StructField.fresh (id, pos) in + ( ctx, + StructField.Map.add name new_name fields_map, + StructField.Map.add new_name ty ctx_fields )) + fields + (ctx, fields_map, StructField.Map.empty) + in + ( PathMap.add path ctx pctxmap, + StructName.Map.add name new_name structs_map, + fields_map, + StructName.Map.add new_name ctx_fields ctx_structs )) + p.decl_ctx.ctx_structs + ( pctxmap, + StructName.Map.empty, + StructField.Map.empty, + StructName.Map.empty ) + in + let pctxmap, enums_map, constrs_map, ctx_enums = + EnumName.Map.fold + (fun name constrs (pctxmap, enums_map, constrs_map, ctx_enums) -> + let path = EnumName.path name in + let str, pos = EnumName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (cap str) in + let new_name = EnumName.fresh path (id, pos) in + let ctx, constrs_map, ctx_constrs = + EnumConstructor.Map.fold + (fun name ty (ctx, constrs_map, ctx_constrs) -> + let str, pos = EnumConstructor.get_info name in + let id, ctx = Expr.Renaming.new_id ctx (cap str) in + let new_name = EnumConstructor.fresh (id, pos) in + ( ctx, + EnumConstructor.Map.add name new_name constrs_map, + EnumConstructor.Map.add new_name ty ctx_constrs )) + constrs + (ctx, constrs_map, EnumConstructor.Map.empty) + in + ( PathMap.add path ctx pctxmap, + EnumName.Map.add name new_name enums_map, + constrs_map, + EnumName.Map.add new_name ctx_constrs ctx_enums )) + p.decl_ctx.ctx_enums + ( pctxmap, + EnumName.Map.empty, + EnumConstructor.Map.empty, + EnumName.Map.empty ) + in + let pctxmap, scopes_map, ctx_scopes = + ScopeName.Map.fold + (fun name info (pctxmap, scopes_map, ctx_scopes) -> + let info = + { + in_struct_name = StructName.Map.find info.in_struct_name structs_map; + out_struct_name = + StructName.Map.find info.out_struct_name structs_map; + out_struct_fields = + ScopeVar.Map.map + (fun fld -> StructField.Map.find fld fields_map) + info.out_struct_fields; + } + in + let path = ScopeName.path name in + if path = [] then + (* Scopes / topdefs in the root module will be renamed through the + variables binding them in the code_items *) + ( pctxmap, + ScopeName.Map.add name name scopes_map, + ScopeName.Map.add name info ctx_scopes ) + else + let str, pos = ScopeName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (uncap str) in + let new_name = ScopeName.fresh path (id, pos) in + ( PathMap.add path ctx pctxmap, + ScopeName.Map.add name new_name scopes_map, + ScopeName.Map.add new_name info ctx_scopes )) + p.decl_ctx.ctx_scopes + (pctxmap, ScopeName.Map.empty, ScopeName.Map.empty) + in + let pctxmap, topdefs_map, ctx_topdefs = + TopdefName.Map.fold + (fun name typ (pctxmap, topdefs_map, ctx_topdefs) -> + let path = TopdefName.path name in + if path = [] then + (* Topdefs / topdefs in the root module will be renamed through the + variables binding them in the code_items *) + ( pctxmap, + TopdefName.Map.add name name topdefs_map, + TopdefName.Map.add name typ ctx_topdefs ) + (* [typ] is rewritten later on *) + else + let str, pos = TopdefName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (uncap str) in + let new_name = TopdefName.fresh path (id, pos) in + ( PathMap.add path ctx pctxmap, + TopdefName.Map.add name new_name topdefs_map, + TopdefName.Map.add new_name typ ctx_topdefs )) + p.decl_ctx.ctx_topdefs + (pctxmap, TopdefName.Map.empty, TopdefName.Map.empty) + in + let ctx = PathMap.find [] pctxmap in + let ctx = + Expr.Renaming.set_rewriters ctx + ~scopes:(fun n -> ScopeName.Map.find n scopes_map) + ~topdefs:(fun n -> TopdefName.Map.find n topdefs_map) + ~structs:(fun n -> StructName.Map.find n structs_map) + ~fields:(fun n -> StructField.Map.find n fields_map) + ~enums:(fun n -> EnumName.Map.find n enums_map) + ~constrs:(fun n -> EnumConstructor.Map.find n constrs_map) + in + let decl_ctx = + { p.decl_ctx with ctx_enums; ctx_structs; ctx_scopes; ctx_topdefs } + in + let decl_ctx = map_decl_ctx ~f:(Expr.Renaming.typ ctx) decl_ctx in + let code_items = Scope.rename_ids ctx p.code_items in + { p with decl_ctx; code_items }, ctx diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 071b7873..98e669e3 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -56,3 +56,18 @@ val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list (** Returns a list of used modules, in topological order ; the boolean indicates if the module is external *) + +val rename_ids : + reserved:string list -> + reset_context_for_closed_terms:bool -> + skip_constant_binders:bool -> + constant_binder_name:string option -> + ('a, 't) gexpr program -> + ('a, 't) gexpr program * Expr.Renaming.context +(** Renames all idents (variables, types, struct and enum names, fields and + constructors) to dispel ambiguities in the target language. Names in + [reserved], typically keywords and built-ins, will be avoided ; the meaning + of the flags is described in [Bindlib.Renaming]. + + In the returned program, it is safe to directly use `Bindlib.name_of` on + variables for printing. The same is true for `StructName.get_info` etc. *) diff --git a/compiler/shared_ast/scope.ml b/compiler/shared_ast/scope.ml index 59bcbcb8..bf43c64d 100644 --- a/compiler/shared_ast/scope.ml +++ b/compiler/shared_ast/scope.ml @@ -39,6 +39,12 @@ let map_exprs_in_lets : (f scope_let.scope_let_expr) )) scope_body_expr +let map_last_item ~varf last = + Bindlib.box_list + @@ List.map + (function EVar v -> Bindlib.box_var (varf v) | _ -> assert false) + last + let map_exprs ?(typ = Fun.id) ~f ~varf scopes = let f v = function | ScopeDef (name, body) -> @@ -58,7 +64,7 @@ let map_exprs ?(typ = Fun.id) ~f ~varf scopes = (fun e -> Topdef (name, typ ty, e)) (Expr.Box.lift (f expr)) ) in - BoundList.map ~f ~last:Bindlib.box scopes + BoundList.map ~f ~last:(map_last_item ~varf) scopes let fold_exprs ~f ~init scopes = let f acc def _ = @@ -116,7 +122,7 @@ let unfold (ctx : decl_ctx) (s : 'e code_item_list) (main_scope : ScopeName.t) : | None, ScopeDef (name, body) when ScopeName.equal name main_scope -> Some (Expr.make_var v (get_body_mark body)) | r, _ -> r) - ~bottom:(fun () -> function Some v -> v | None -> raise Not_found) + ~bottom:(fun _vlist -> function Some v -> v | None -> raise Not_found) ~up:(fun var item next -> let e, typ = match item with @@ -137,6 +143,64 @@ let free_vars_item = function let free_vars scopes = BoundList.fold_right scopes - ~init:(fun () -> Var.Set.empty) + ~init:(fun _vlist -> Var.Set.empty) ~f:(fun item v acc -> Var.Set.union (Var.Set.remove v acc) (free_vars_item item)) + +(** Maps carrying around a naming context, enriched at each [unbind] *) +let rec boundlist_map_ctx ~f ~fname ~last ~ctx = function + | Last l -> Bindlib.box_apply (fun l -> Last l) (last ctx l) + | Cons (item, next_bind) -> + let item = f ctx item in + let var, next, ctx = Expr.Renaming.unbind_in ctx ~fname next_bind in + let next = boundlist_map_ctx ~f ~fname ~last ~ctx next in + let next_bind = Bindlib.bind_var var next in + Bindlib.box_apply2 + (fun item next_bind -> Cons (item, next_bind)) + item next_bind + +let rename_vars_in_lets ctx scope_body_expr = + boundlist_map_ctx scope_body_expr ~ctx ~fname:String.to_snake_case + ~last:(fun ctx e -> Expr.Box.lift (Expr.Renaming.expr ctx e)) + ~f:(fun ctx scope_let -> + Bindlib.box_apply + (fun scope_let_expr -> + { + scope_let with + scope_let_expr; + scope_let_typ = Expr.Renaming.typ ctx scope_let.scope_let_typ; + }) + (Expr.Box.lift (Expr.Renaming.expr ctx scope_let.scope_let_expr))) + +let rename_ids ctx (scopes : 'e code_item_list) = + let f ctx = function + | ScopeDef (name, body) -> + let name = Expr.Renaming.scope_name ctx name in + let scope_input_var, scope_lets, ctx = + Expr.Renaming.unbind_in ctx ~fname:String.to_snake_case + body.scope_body_expr + in + let scope_lets = rename_vars_in_lets ctx scope_lets in + let scope_body_expr = Bindlib.bind_var scope_input_var scope_lets in + Bindlib.box_apply + (fun scope_body_expr -> + let body = + { + scope_body_input_struct = + Expr.Renaming.struct_name ctx body.scope_body_input_struct; + scope_body_output_struct = + Expr.Renaming.struct_name ctx body.scope_body_output_struct; + scope_body_expr; + } + in + ScopeDef (name, body)) + scope_body_expr + | Topdef (name, ty, expr) -> + Bindlib.box_apply + (fun e -> Topdef (name, Expr.Renaming.typ ctx ty, e)) + (Expr.Box.lift (Expr.Renaming.expr ctx expr)) + in + Bindlib.unbox + @@ boundlist_map_ctx ~ctx ~f ~fname:String.to_snake_case + ~last:(fun _ctx -> Bindlib.box) + scopes diff --git a/compiler/shared_ast/scope.mli b/compiler/shared_ast/scope.mli index 696e04de..d63b06d1 100644 --- a/compiler/shared_ast/scope.mli +++ b/compiler/shared_ast/scope.mli @@ -47,6 +47,14 @@ val map_exprs : (** This is the main map visitor for all the expressions inside all the scopes of the program. *) +val map_last_item : + varf:(('a, 'm) naked_gexpr Bindlib.var -> 'e2 Bindlib.var) -> + ('a, 'm) naked_gexpr list -> + 'e2 list Bindlib.box + +(** Helper function to handle the [code_item_list] terminator when manually + mapping on [code_item_list] *) + val fold_exprs : f:('acc -> 'expr -> typ -> 'acc) -> init:'acc -> 'expr code_item_list -> 'acc @@ -69,6 +77,11 @@ val input_type : typ -> Runtime.io_input Mark.pos -> typ this doesn't take thunking into account (thunking is added during the scopelang->dcalc translation) *) +val rename_ids : + Expr.Renaming.context -> + ((_ any, 'm) gexpr as 'e) code_item_list -> + 'e code_item_list + (** {2 Analysis and tests} *) val free_vars_body_expr : 'e scope_body_expr -> 'e Var.Set.t diff --git a/compiler/shared_ast/type.ml b/compiler/shared_ast/type.ml index a792e489..9cc553bf 100644 --- a/compiler/shared_ast/type.ml +++ b/compiler/shared_ast/type.ml @@ -93,6 +93,21 @@ let rec compare ty1 ty2 = | TClosureEnv, _ -> -1 | _, TClosureEnv -> 1 +let map f ty = + Mark.map + (function + | TLit l -> TLit l + | TTuple tl -> TTuple (List.map f tl) + | TStruct n -> TStruct n + | TEnum n -> TEnum n + | TOption ty -> TOption (f ty) + | TArrow (tl, ty) -> TArrow (List.map f tl, f ty) + | TArray ty -> TArray (f ty) + | TDefault ty -> TDefault (f ty) + | TAny -> TAny + | TClosureEnv -> TClosureEnv) + ty + let rec hash ~strip ty = let open Hash.Op in match Mark.remove ty with diff --git a/compiler/shared_ast/type.mli b/compiler/shared_ast/type.mli index 5d026da7..fa8c29c1 100644 --- a/compiler/shared_ast/type.mli +++ b/compiler/shared_ast/type.mli @@ -26,6 +26,9 @@ val equal : t -> t -> bool val equal_list : t list -> t list -> bool val compare : t -> t -> int +val map : (t -> t) -> t -> t +(** Shallow mapping on types *) + val hash : strip:Uid.Path.t -> t -> Hash.t (** The [strip] argument strips the given leading path components in included identifiers before hashing *) diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 5b0c02a3..cef73e1f 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -1040,7 +1040,7 @@ let scope_body ctx env body = let scopes ctx env = BoundList.fold_map ~init:env - ~last:(fun ctx () -> ctx, Bindlib.box ()) + ~last:(fun ctx el -> ctx, Scope.map_last_item ~varf:Var.translate el) ~f:(fun env var item -> match item with | A.ScopeDef (name, body) -> diff --git a/compiler/verification/conditions.ml b/compiler/verification/conditions.ml index 5c121a1c..5e261e50 100644 --- a/compiler/verification/conditions.ml +++ b/compiler/verification/conditions.ml @@ -382,7 +382,7 @@ let generate_verification_conditions_code_items (decl_ctx : decl_ctx) (code_items : 'm expr code_item_list) (s : ScopeName.t option) : verification_condition list = - let conditions, () = + let conditions, _ = BoundList.fold_left ~f:(fun vcs item _ -> match item with diff --git a/tests/modules/good/output/mod_def.ml b/tests/modules/good/output/mod_def.ml index 443b777d..1e95d261 100644 --- a/tests/modules/good/output/mod_def.ml +++ b/tests/modules/good/output/mod_def.ml @@ -1,4 +1,3 @@ - (** This file has been generated by the Catala compiler, do not edit! *) open Runtime_ocaml.Runtime @@ -20,13 +19,13 @@ module Str1 = struct type t = {fld1: Enum1.t; fld2: integer} end -module S_in = struct +module SIn = struct type t = unit end -let s (s_in: S_in.t) : S.t = - let sr_: money = +let s (s_in: SIn.t) : S.t = + let sr1: money = match (match (handle_exceptions @@ -45,19 +44,19 @@ let s (s_in: S_in.t) : S.t = ( if true then (Eoption.ESome (money_of_cents_string "100000")) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_))|])) + | Eoption.ESome x -> (Eoption.ESome x))|])) with | Eoption.ENone _ -> ( if false then (Eoption.ENone ()) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome x -> (Eoption.ESome x)) with | Eoption.ENone _ -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/modules/good/mod_def.catala_en"; start_line=16; start_column=10; end_line=16; end_column=12; law_headings=["Test modules + inclusions 1"]}]))) - | Eoption.ESome arg_ -> arg_ in - let e1_: Enum1.t = + | Eoption.ESome arg -> arg in + let e2: Enum1.t = match (match (handle_exceptions @@ -75,34 +74,34 @@ let s (s_in: S_in.t) : S.t = | Eoption.ENone _ -> ( if true then (Eoption.ESome (Enum1.Maybe ())) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_))|])) + | Eoption.ESome x -> (Eoption.ESome x))|])) with | Eoption.ENone _ -> ( if false then (Eoption.ENone ()) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome x -> (Eoption.ESome x)) with | Eoption.ENone _ -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/modules/good/mod_def.catala_en"; start_line=17; start_column=10; end_line=17; end_column=12; law_headings=["Test modules + inclusions 1"]}]))) - | Eoption.ESome arg_ -> arg_ in - {S.sr = sr_; S.e1 = e1_} + | Eoption.ESome arg -> arg in + {S.sr = sr1; S.e1 = e2} -let half_ : integer -> decimal = - fun (x_: integer) -> +let half : integer -> decimal = + fun (x: integer) -> o_div_int_int {filename="tests/modules/good/mod_def.catala_en"; start_line=21; start_column=14; end_line=21; end_column=15; - law_headings=["Test modules + inclusions 1"]} x_ (integer_of_string + law_headings=["Test modules + inclusions 1"]} x (integer_of_string "2") -let maybe_ : Enum1.t -> Enum1.t = - fun (_: Enum1.t) -> Enum1.Maybe () +let maybe : Enum1.t -> Enum1.t = + fun (x: Enum1.t) -> Enum1.Maybe () let () = Runtime_ocaml.Runtime.register_module "Mod_def" [ "S", Obj.repr s; - "half", Obj.repr half_; - "maybe", Obj.repr maybe_ ] + "half", Obj.repr half; + "maybe", Obj.repr maybe ] "CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX" diff --git a/tests/name_resolution/good/let_in2.catala_en b/tests/name_resolution/good/let_in2.catala_en index 09eca7af..2d20264d 100644 --- a/tests/name_resolution/good/let_in2.catala_en +++ b/tests/name_resolution/good/let_in2.catala_en @@ -34,7 +34,6 @@ $ catala test-scope S ```catala-test-inline $ catala ocaml - (** This file has been generated by the Catala compiler, do not edit! *) open Runtime_ocaml.Runtime @@ -46,20 +45,20 @@ module S = struct type t = {a: bool} end -module S_in = struct +module SIn = struct type t = {a_in: unit -> (bool) Eoption.t} end -let s (s_in: S_in.t) : S.t = - let a_: unit -> (bool) Eoption.t = s_in.S_in.a_in in - let a_: bool = +let s (s_in: SIn.t) : S.t = + let a1: unit -> (bool) Eoption.t = s_in.SIn.a_in in + let a2: bool = match (match (handle_exceptions [|{filename="tests/name_resolution/good/let_in2.catala_en"; start_line=7; start_column=18; end_line=7; end_column=19; - law_headings=["Article"]}|] ([|(a_ ())|])) + law_headings=["Article"]}|] ([|(a1 ())|])) with | Eoption.ENone _ -> ( if true then @@ -78,36 +77,36 @@ let s (s_in: S_in.t) : S.t = end_line=13; end_column=6; law_headings=["Article"]}|] ([||])) with - | Eoption.ENone _ -> + | Eoption.ENone _1 -> ( if true then - (Eoption.ESome (let a_ : bool = false + (Eoption.ESome (let a2 : bool = false in - (let a_ : bool = (o_or a_ true) + (let a3 : bool = (o_or a2 true) in - a_))) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_))|])) + a3))) else (Eoption.ENone ())) + | Eoption.ESome x -> (Eoption.ESome x))|])) with - | Eoption.ENone _ -> + | Eoption.ENone _1 -> ( if false then (Eoption.ENone ()) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome x -> (Eoption.ESome x)) with - | Eoption.ENone _ -> (raise + | Eoption.ENone _1 -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/name_resolution/good/let_in2.catala_en"; start_line=7; start_column=18; end_line=7; end_column=19; law_headings= ["Article"]}]))) - | Eoption.ESome arg_ -> arg_)) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome arg -> arg)) else (Eoption.ENone ())) + | Eoption.ESome x -> (Eoption.ESome x)) with | Eoption.ENone _ -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/name_resolution/good/let_in2.catala_en"; start_line=7; start_column=18; end_line=7; end_column=19; law_headings=["Article"]}]))) - | Eoption.ESome arg_ -> arg_ in - {S.a = a_} + | Eoption.ESome arg -> arg in + {S.a = a2} let () = Runtime_ocaml.Runtime.register_module "Let_in2" diff --git a/tests/scope/good/191_fix_record_name_confusion.catala_en b/tests/scope/good/191_fix_record_name_confusion.catala_en index 7a4a8450..f0ce665c 100644 --- a/tests/scope/good/191_fix_record_name_confusion.catala_en +++ b/tests/scope/good/191_fix_record_name_confusion.catala_en @@ -29,7 +29,6 @@ $ catala Typecheck --check-invariants ```catala-test-inline $ catala OCaml -O - (** This file has been generated by the Catala compiler, do not edit! *) open Runtime_ocaml.Runtime @@ -42,26 +41,26 @@ module ScopeA = struct end module ScopeB = struct - type t = {a: bool} + type t = {a1: bool} end -module ScopeA_in = struct +module ScopeAIn = struct type t = unit end -module ScopeB_in = struct +module ScopeBIn = struct type t = unit end -let scope_a (scope_a_in: ScopeA_in.t) : ScopeA.t = - let a_: bool = true in - {ScopeA.a = a_} +let scope_a (scope_a_in: ScopeAIn.t) : ScopeA.t = + let a2: bool = true in + {ScopeA.a = a2} -let scope_b (scope_b_in: ScopeB_in.t) : ScopeB.t = - let scope_a_: ScopeA.t = {ScopeA.a = ((scope_a (())).ScopeA.a)} in - let a_: bool = scope_a_.ScopeA.a in - {ScopeB.a = a_} +let scope_b (scope_b_in: ScopeBIn.t) : ScopeB.t = + let scope_a1: ScopeA.t = {ScopeA.a = ((scope_a (())).ScopeA.a)} in + let a2: bool = scope_a1.ScopeA.a in + {ScopeB.a1 = a2} let entry_scopes = [