diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index a51e6cf4..d377fc69 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -103,6 +103,7 @@ module Gen_qualified (_ : Style) () : sig val fresh : Path.t -> MarkedString.info -> t val path : t -> Path.t val get_info : t -> MarkedString.info + val hash : strip:Path.t -> t -> Hash.t - (* [strip] strips that prefix from the start of the path before hashing *) + (** [strip] strips that prefix from the start of the path before hashing *) end diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 8761ca4a..53bdf73b 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -941,8 +941,11 @@ let translate_program (prgm : 'm S.program) : 'm Ast.program = (* the resulting expression is the list of definitions of all the scopes, ending with the top-level scope. The decl_ctx is filled in left-to-right order, then the chained scopes aggregated from the right. *) - let rec translate_defs = function - | [] -> Bindlib.box (Last ()) + let rec translate_defs vlist = function + | [] -> + Bindlib.box_apply + (fun vl -> Last vl) + (Bindlib.box_rev_list (List.map Bindlib.box_var vlist)) | def :: next -> let dvar, def = match def with @@ -971,13 +974,13 @@ let translate_program (prgm : 'm S.program) : 'm Ast.program = (fun body -> ScopeDef (scope_name, body)) scope_body ) in - let scope_next = translate_defs next in + let scope_next = translate_defs (dvar :: vlist) next in let next_bind = Bindlib.bind_var dvar scope_next in Bindlib.box_apply2 (fun item next_bind -> Cons (item, next_bind)) def next_bind in - let items = translate_defs defs_ordering in + let items = translate_defs [] defs_ordering in Expr.Box.assert_closed items; { code_items = Bindlib.unbox items; diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 961861f9..277ea829 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -440,7 +440,7 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = Bindlib.box_apply (fun e -> Topdef (name, (TAny, Mark.get ty), e)) (Expr.Box.lift new_expr) )) - ~last:(fun _ () -> (), Bindlib.box ()) + ~last:(fun _ vlist -> (), Scope.map_last_item ~varf:Fun.id vlist) ~init:Var.Map.empty p.code_items in (* Now we need to further tweak [decl_ctx] because some of the user-defined @@ -612,7 +612,8 @@ let rec hoist_closures_code_item_list (code_items : (lcalc, 'm) gexpr code_item_list) : (lcalc, 'm) gexpr code_item_list Bindlib.box = match code_items with - | Last () -> Bindlib.box (Last ()) + | Last vlist -> + Bindlib.box_apply (fun l -> Last l) (Scope.map_last_item ~varf:Fun.id vlist) | Cons (code_item, next_code_items) -> let code_item_var, next_code_items = Bindlib.unbind next_code_items in let hoisted_closures, new_code_item = diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index fe6dd665..a98bbf19 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -130,38 +130,23 @@ let ocaml_keywords = "Oper"; ] -let ocaml_keywords_set = String.Set.of_list ocaml_keywords - -let avoid_keywords (s : string) : string = - if String.Set.mem s ocaml_keywords_set then s ^ "_user" else s -(* Fixme: this could cause clashes if the user program contains both e.g. [new] - and [new_user] *) - -let ppclean fmt str = - str |> String.to_ascii |> avoid_keywords |> Format.pp_print_string fmt - -let ppsnake fmt str = - str - |> String.to_ascii - |> String.to_snake_case - |> avoid_keywords - |> Format.pp_print_string fmt - let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = (match StructName.path v with | [] -> () | path -> - ppclean fmt (Uid.Path.to_string path); + Uid.Path.format fmt path; Format.pp_print_char fmt '.'); - ppsnake fmt (Mark.remove (StructName.get_info v)) + assert ( + let n = Mark.remove (StructName.get_info v) in + n = String.capitalize_ascii n); + Format.pp_print_string fmt (Mark.remove (StructName.get_info v)) let format_to_module_name (fmt : Format.formatter) (name : [< `Ename of EnumName.t | `Sname of StructName.t ]) = - ppclean fmt - (match name with - | `Ename v -> EnumName.to_string v - | `Sname v -> StructName.to_string v) + match name with + | `Ename v -> EnumName.format fmt v + | `Sname v -> StructName.format fmt v let format_struct_field_name (fmt : Format.formatter) @@ -171,20 +156,16 @@ let format_struct_field_name format_to_module_name fmt (`Sname sname); Format.pp_print_char fmt '.') sname_opt; - ppclean fmt (StructField.to_string v) + StructField.format fmt v let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = - (match EnumName.path v with - | [] -> () - | path -> - ppclean fmt (Uid.Path.to_string path); - Format.pp_print_char fmt '.'); - ppsnake fmt (Mark.remove (EnumName.get_info v)) + EnumName.format fmt v let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : unit = - ppclean fmt (EnumConstructor.to_string v) + EnumConstructor.format fmt v +(* TODO: these names should be properly registered before renaming *) let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = match Mark.remove ty with | TLit TUnit -> Format.pp_print_string fmt "embed_unit" @@ -195,16 +176,12 @@ let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = | TLit TDate -> Format.pp_print_string fmt "embed_date" | TLit TDuration -> Format.pp_print_string fmt "embed_duration" | TStruct s_name -> - Format.fprintf fmt "%a%sembed_%a" ppclean - (Uid.Path.to_string (StructName.path s_name)) - (if StructName.path s_name = [] then "" else ".") - ppsnake + Format.fprintf fmt "%aembed_%a" Uid.Path.format (StructName.path s_name) + Format.pp_print_string (Uid.MarkedString.to_string (StructName.get_info s_name)) | TEnum e_name -> - Format.fprintf fmt "%a%sembed_%a" ppclean - (Uid.Path.to_string (EnumName.path e_name)) - (if EnumName.path e_name = [] then "" else ".") - ppsnake + Format.fprintf fmt "%aembed_%a" Uid.Path.format (EnumName.path e_name) + Format.pp_print_string (Uid.MarkedString.to_string (EnumName.get_info e_name)) | TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty | _ -> Format.pp_print_string fmt "unembeddable" @@ -243,20 +220,7 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = | TClosureEnv -> Format.fprintf fmt "Obj.t" let format_var_str (fmt : Format.formatter) (v : string) : unit = - let lowercase_name = String.to_snake_case (String.to_ascii v) in - let lowercase_name = - Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") - ~subst:(fun _ -> "_dot_") - lowercase_name - in - let lowercase_name = String.to_ascii lowercase_name in - if - List.mem lowercase_name ["handle_default"; "handle_default_opt"] - (* O_O *) - || String.begins_with_uppercase v - then Format.pp_print_string fmt lowercase_name - else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name - else Format.fprintf fmt "%s_" lowercase_name + Format.pp_print_string fmt v let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = format_var_str fmt (Bindlib.name_of v) @@ -561,15 +525,9 @@ let format_ctx Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) (type_ordering @ scope_structs) -let rename_vars e = - Expr.( - unbox - (rename_vars ~exclude:ocaml_keywords ~reset_context_for_closed_terms:true - ~skip_constant_binders:true ~constant_binder_name:(Some "_") e)) - let format_expr ctx fmt e = Format.pp_open_vbox fmt 0; - format_expr ctx fmt (rename_vars e); + format_expr ctx fmt e; Format.pp_close_box fmt () let format_scope_body_expr @@ -594,7 +552,7 @@ let format_code_items (code_items : 'm Ast.expr code_item_list) : ('m Ast.expr Var.t * 'm Ast.expr code_item) String.Map.t = Format.pp_open_vbox fmt 0; - let var_bindings, () = + let var_bindings, _ = BoundList.fold_left ~f:(fun bnd item var -> match item with @@ -761,14 +719,9 @@ let format_module_registration Format.pp_print_newline fmt () let header = - {ocaml| -(** This file has been generated by the Catala compiler, do not edit! *) - -open Runtime_ocaml.Runtime - -[@@@ocaml.warning "-4-26-27-32-41-42"] - -|ocaml} + "(** This file has been generated by the Catala compiler, do not edit! *)\n\n\ + open Runtime_ocaml.Runtime\n\n\ + [@@@ocaml.warning \"-4-26-27-32-41-42\"]\n\n" let format_program (fmt : Format.formatter) @@ -777,6 +730,22 @@ let format_program ~(hashf : Hash.t -> Hash.full) (p : 'm Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = + let p, ren_ctx = + Program.rename_ids p + ~reserved:ocaml_keywords + (* TODO: add catala runtime built-ins as reserved as well ? *) + ~reset_context_for_closed_terms:true ~skip_constant_binders:true + ~constant_binder_name:(Some "_") + in + let type_ordering = + let open Scopelang.Dependency.TVertex in + List.map + (function + | Struct s -> Struct (Expr.Renaming.struct_name ren_ctx s) + | Enum e -> Enum (Expr.Renaming.enum_name ren_ctx e)) + type_ordering + in + (* Print.program fmt p; *) Format.pp_open_vbox fmt 0; Format.pp_print_string fmt header; check_and_reexport_used_modules fmt ~hashf diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index 85f49490..abee6f03 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -19,7 +19,6 @@ open Shared_ast (** Formats a lambda calculus program into a valid OCaml program *) -val avoid_keywords : string -> string val typ_needs_parens : typ -> bool (* val needs_parens : 'm expr -> bool *) diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index b7fb1e74..784af796 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -31,7 +31,6 @@ module To_jsoo = struct StructField.to_string v |> String.to_camel_case |> String.uncapitalize_ascii - |> avoid_keywords |> Format.pp_print_string ppf (* Supersedes [To_ocaml.format_struct_name], which can refer to enums from @@ -40,7 +39,6 @@ module To_jsoo = struct StructName.to_string name |> String.map (function '.' -> '_' | c -> c) |> String.to_snake_case - |> avoid_keywords |> Format.pp_print_string ppf (* Supersedes [To_ocaml.format_enum_name], which can refer to enums from other @@ -49,7 +47,6 @@ module To_jsoo = struct EnumName.to_string name |> String.map (function '.' -> '_' | c -> c) |> String.to_snake_case - |> avoid_keywords |> Format.pp_print_string ppf let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = @@ -160,7 +157,6 @@ module To_jsoo = struct |> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") |> String.uncapitalize_ascii - |> avoid_keywords in if List.mem lowercase_name ["handle_default"; "handle_default_opt"] @@ -388,41 +384,43 @@ module To_jsoo = struct (_ctx : decl_ctx) (fmt : Format.formatter) (scopes : 'e code_item_list) = - BoundList.iter - ~f:(fun var code_item -> - match code_item with - | Topdef _ -> () - | ScopeDef (_name, body) -> - let fmt_fun_call fmt _ = - Format.fprintf fmt - "@[@[execute_or_throw_error@ (@[fun () ->@ %a@ \ - |> %a_of_js@ |> %a@ |> %a_to_js@])@]@]" - fmt_input_struct_name body fmt_input_struct_name body format_var - var fmt_output_struct_name body - in - Format.fprintf fmt - "@\n@\n@[let %a@ (%a : %a Js.t)@ : %a Js.t =@\n%a@]@\n" - format_var var fmt_input_struct_name body fmt_input_struct_name body - fmt_output_struct_name body fmt_fun_call ()) - scopes + ignore + @@ BoundList.iter + ~f:(fun var code_item -> + match code_item with + | Topdef _ -> () + | ScopeDef (_name, body) -> + let fmt_fun_call fmt _ = + Format.fprintf fmt + "@[@[execute_or_throw_error@ (@[fun () ->@ \ + %a@ |> %a_of_js@ |> %a@ |> %a_to_js@])@]@]" + fmt_input_struct_name body fmt_input_struct_name body + format_var var fmt_output_struct_name body + in + Format.fprintf fmt + "@\n@\n@[let %a@ (%a : %a Js.t)@ : %a Js.t =@\n%a@]@\n" + format_var var fmt_input_struct_name body fmt_input_struct_name + body fmt_output_struct_name body fmt_fun_call ()) + scopes let format_scopes_to_callbacks (_ctx : decl_ctx) (fmt : Format.formatter) (scopes : 'e code_item_list) : unit = - BoundList.iter - ~f:(fun var code_item -> - match code_item with - | Topdef _ -> () - | ScopeDef (_name, body) -> - let fmt_meth_name fmt _ = - Format.fprintf fmt "method %a : (%a Js.t -> %a Js.t) Js.callback" - format_var_camel_case var fmt_input_struct_name body - fmt_output_struct_name body - in - Format.fprintf fmt "@,@[%a =@ Js.wrap_callback@ %a@]@," - fmt_meth_name () format_var var) - scopes + ignore + @@ BoundList.iter + ~f:(fun var code_item -> + match code_item with + | Topdef _ -> () + | ScopeDef (_name, body) -> + let fmt_meth_name fmt _ = + Format.fprintf fmt "method %a : (%a Js.t -> %a Js.t) Js.callback" + format_var_camel_case var fmt_input_struct_name body + fmt_output_struct_name body + in + Format.fprintf fmt "@,@[%a =@ Js.wrap_callback@ %a@]@," + fmt_meth_name () format_var var) + scopes let format_program (fmt : Format.formatter) diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index afed754b..fa194ac5 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -436,7 +436,7 @@ let result_level base_vars = let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : ('t, 'm) gexpr * Env.t = let ctx = prg.decl_ctx in - let (all_env, scopes), () = + let (all_env, scopes), _ = BoundList.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) ~f:(fun (env, scopes) item v -> match item with @@ -611,7 +611,7 @@ let program_to_graph Expr.map_marks ~f:(fun m -> Custom { pos = Expr.mark_pos m; custom = { conditions = [] } }) in - let (all_env, scopes), () = + let (all_env, scopes), _ = BoundList.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) ~f:(fun (env, scopes) item v -> match item with @@ -619,7 +619,17 @@ let program_to_graph let e = Scope.to_expr ctx body in let e = customize (Expr.unbox e) in let e = Expr.remove_logging_calls (Expr.unbox e) in - let e = Expr.rename_vars (Expr.unbox e) in + let e = + Expr.Renaming.expr + (Expr.Renaming.get_ctx + { + Expr.Renaming.reserved = []; + reset_context_for_closed_terms = false; + skip_constant_binders = false; + constant_binder_name = None; + }) + (Expr.unbox e) + in ( Env.add (Var.translate v) (Expr.unbox e) env env, ScopeName.Map.add name (v, body.scope_body_input_struct) scopes ) | Topdef (_, _, e) -> diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index d187e59e..bdfc0e2b 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -35,7 +35,6 @@ module To_json = struct Format.asprintf "%a" StructField.format v |> String.to_ascii |> String.to_snake_case - |> avoid_keywords |> to_camel_case in Format.fprintf fmt "%s" s diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index 52d549c9..5f7e69d5 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -228,7 +228,7 @@ let rec lazy_eval : let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : ('t, 'm) gexpr * 'm Env.t = let ctx = prg.decl_ctx in - let (all_env, scopes), () = + let (all_env, scopes), _ = BoundList.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) ~f:(fun (env, scopes) item v -> match item with diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index a5d08406..41ced7ef 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -631,7 +631,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : (Program.modules_to_list p.decl_ctx.ctx_modules) in let ctx = { A.decl_ctx = p.decl_ctx; A.modules } in - let (_, _, rev_items), () = + let (_, _, rev_items), _vlist = BoundList.fold_left ~f:(fun (func_dict, var_dict, rev_items) code_item var -> match code_item with diff --git a/compiler/shared_ast/boundList.ml b/compiler/shared_ast/boundList.ml index ac08591f..cafaad73 100644 --- a/compiler/shared_ast/boundList.ml +++ b/compiler/shared_ast/boundList.ml @@ -21,7 +21,7 @@ type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list = | Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder let rec to_seq = function - | Last () -> Seq.empty + | Last _ -> Seq.empty | Cons (item, next_bind) -> fun () -> let v, next = Bindlib.unbind next_bind in diff --git a/compiler/shared_ast/boundList.mli b/compiler/shared_ast/boundList.mli index 2bd1c524..bd89e119 100644 --- a/compiler/shared_ast/boundList.mli +++ b/compiler/shared_ast/boundList.mli @@ -30,7 +30,9 @@ type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list = | Last of 'last | Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder -val to_seq : (((_, _) gexpr as 'e), 'elt, unit) t -> ('e Var.t * 'elt) Seq.t +val to_seq : (((_, _) gexpr as 'e), 'elt, _) t -> ('e Var.t * 'elt) Seq.t +(** Note that the boundlist terminator is ignored in the resulting sequence *) + val last : (_, _, 'a) t -> 'a val iter : f:('e Var.t -> 'elt -> unit) -> ('e, 'elt, 'last) t -> 'last val find : f:('elt -> 'a option) -> (_, 'elt, _) t -> 'a diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 83c00386..e02cf9d1 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -645,7 +645,12 @@ type 'e code_item = | ScopeDef of ScopeName.t * 'e scope_body | Topdef of TopdefName.t * typ * 'e -type 'e code_item_list = ('e, 'e code_item, unit) bound_list +type 'e code_item_list = ('e, 'e code_item, 'naked_e list) bound_list + constraint 'e = ('naked_e, _) Mark.ed +(* The bound_list terminator is a naked expression list that is not part of the + program: it contains the list of exported variables, so that Bindlib + correctly understands these variables as being used *) + type struct_ctx = typ StructField.Map.t StructName.Map.t type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 9f8a0d3f..c7ce9dec 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -72,39 +72,26 @@ module Box = struct let lift_scope_vars = LiftScopeVars.lift_box - module Ren = struct - module Set = Set.Make (String) - - type ctxt = Set.t - - let skip_constant_binders = true - let reset_context_for_closed_terms = true - let constant_binder_name = None - let empty_ctxt = Set.empty - let reserve_name n s = Set.add n s - let new_name n s = n, Set.add n s - end - - module Ctx = Bindlib.Ctxt (Ren) - - let fv b = Ren.Set.elements (Ctx.free_vars b) - let assert_closed b = - match fv b with - | [] -> () - | [h] -> + if not (Bindlib.is_closed b) then + (* This is a bit convoluted, but we just want to extract the free + variables names for debug *) + let module Ctx = Bindlib.Ctxt (struct + type ctxt = String.Set.t + + let skip_constant_binders = true + let reset_context_for_closed_terms = true + let constant_binder_name = None + let empty_ctxt = String.Set.empty + let reserve_name n s = String.Set.add n s + let new_name n s = n, String.Set.add n s + end) in Message.error ~internal:true - "The boxed term is not closed the variable %s is free in the global \ - context" - h - | l -> - Message.error ~internal:true - "The boxed term is not closed the variables %a is free in the global \ - context" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.pp_print_string fmt "; ") + "The boxed term is not closed, these variables are free in it:@ \ + @[%a@]" + (Format.pp_print_list ~pp_sep:Format.pp_print_space Format.pp_print_string) - l + (String.Set.elements (Ctx.free_vars b)) end let bind vars e = Bindlib.bind_mvar vars (Box.lift e) @@ -828,84 +815,194 @@ let remove_logging_calls e = in f e -module DefaultBindlibCtxRename = struct - (* This code is a copy-paste from Bindlib, they forgot to expose the default - implementation ! *) - type ctxt = int String.Map.t +module Renaming = struct + module DefaultBindlibCtxRename : Bindlib.Renaming = struct + (* This code is a copy-paste from Bindlib, they forgot to expose the default + implementation ! *) + type ctxt = int String.Map.t - let empty_ctxt = String.Map.empty + let empty_ctxt = String.Map.empty - let split_name : string -> string * int = - fun name -> - let len = String.length name in - (* [i] is the index of the first first character of the suffix. *) - let i = - let is_digit c = '0' <= c && c <= '9' in - let first_digit = ref len in - let first_non_0 = ref len in - while !first_digit > 0 && is_digit name.[!first_digit - 1] do - decr first_digit; - if name.[!first_digit] <> '0' then first_non_0 := !first_digit - done; - !first_non_0 + let split_name : string -> string * int = + fun name -> + let len = String.length name in + (* [i] is the index of the first first character of the suffix. *) + let i = + let is_digit c = '0' <= c && c <= '9' in + let first_digit = ref len in + let first_non_0 = ref len in + while !first_digit > 0 && is_digit name.[!first_digit - 1] do + decr first_digit; + if name.[!first_digit] <> '0' then first_non_0 := !first_digit + done; + !first_non_0 + in + if i = len then name, 0 + else String.sub name 0 i, int_of_string (String.sub name i (len - i)) + + let get_suffix : string -> int -> ctxt -> int * ctxt = + fun name suffix ctxt -> + let n = + try String.Map.find name ctxt with String.Map.Not_found _ -> -1 + in + let suffix = if suffix > n then suffix else n + 1 in + suffix, String.Map.add name suffix ctxt + + let merge_name : string -> int -> string = + fun prefix suffix -> + if suffix > 0 then prefix ^ string_of_int suffix else prefix + + let new_name : string -> ctxt -> string * ctxt = + fun name ctxt -> + let prefix, suffix = split_name name in + let suffix, ctxt = get_suffix prefix suffix ctxt in + merge_name prefix suffix, ctxt + + let reserve_name : string -> ctxt -> ctxt = + fun name ctxt -> + let prefix, suffix = split_name name in + try + let n = String.Map.find prefix ctxt in + if suffix <= n then ctxt else String.Map.add prefix suffix ctxt + with String.Map.Not_found _ -> String.Map.add prefix suffix ctxt + + let reset_context_for_closed_terms = false + let skip_constant_binders = false + let constant_binder_name = None + end + + module type BindlibCtxt = module type of Bindlib.Ctxt (DefaultBindlibCtxRename) + + type config = { + reserved : string list; + reset_context_for_closed_terms : bool; + skip_constant_binders : bool; + constant_binder_name : string option; + } + + type context = { + bindCtx : (module BindlibCtxt); + bcontext : DefaultBindlibCtxRename.ctxt; + scopes : ScopeName.t -> ScopeName.t; + topdefs : TopdefName.t -> TopdefName.t; + structs : StructName.t -> StructName.t; + fields : StructField.t -> StructField.t; + enums : EnumName.t -> EnumName.t; + constrs : EnumConstructor.t -> EnumConstructor.t; + } + + let unbind_in ctx ?fname b = + let module BindCtx = (val ctx.bindCtx) in + match fname with + | Some fn -> + let name = fn (Bindlib.binder_name b) in + let v, bcontext = + BindCtx.new_var_in ctx.bcontext (fun v -> EVar v) name + in + let e = Bindlib.subst b (EVar v) in + v, e, { ctx with bcontext } + | None -> + let v, e, bcontext = BindCtx.unbind_in ctx.bcontext b in + v, e, { ctx with bcontext } + + let unmbind_in ctx ?fname b = + let module BindCtx = (val ctx.bindCtx) in + match fname with + | Some fn -> + let names = Array.map fn (Bindlib.mbinder_names b) in + let rvs, bcontext = + Array.fold_left + (fun (rvs, bcontext) n -> + let v, bcontext = BindCtx.new_var_in bcontext (fun v -> EVar v) n in + v :: rvs, bcontext) + ([], ctx.bcontext) names + in + let vs = Array.of_list (List.rev rvs) in + let e = Bindlib.msubst b (Array.map (fun v -> EVar v) vs) in + vs, e, { ctx with bcontext } + | None -> + let vs, e, bcontext = BindCtx.unmbind_in ctx.bcontext b in + vs, e, { ctx with bcontext } + + let set_rewriters ?scopes ?topdefs ?structs ?fields ?enums ?constrs ctx = + (fun ?(scopes = ctx.scopes) ?(topdefs = ctx.topdefs) + ?(structs = ctx.structs) ?(fields = ctx.fields) ?(enums = ctx.enums) + ?(constrs = ctx.constrs) () -> + { ctx with scopes; topdefs; structs; fields; enums; constrs }) + ?scopes ?topdefs ?structs ?fields ?enums ?constrs () + + let new_id ctx name = + let module BindCtx = (val ctx.bindCtx) in + let var, bcontext = + BindCtx.new_var_in ctx.bcontext (fun _ -> assert false) name in - if i = len then name, 0 - else String.sub name 0 i, int_of_string (String.sub name i (len - i)) + Bindlib.name_of var, { ctx with bcontext } - let get_suffix : string -> int -> ctxt -> int * ctxt = - fun name suffix ctxt -> - let n = try String.Map.find name ctxt with String.Map.Not_found _ -> -1 in - let suffix = if suffix > n then suffix else n + 1 in - suffix, String.Map.add name suffix ctxt + let get_ctx cfg = + let module BindCtx = Bindlib.Ctxt (struct + include DefaultBindlibCtxRename - let merge_name : string -> int -> string = - fun prefix suffix -> - if suffix > 0 then prefix ^ string_of_int suffix else prefix + let reset_context_for_closed_terms = cfg.reset_context_for_closed_terms + let skip_constant_binders = cfg.skip_constant_binders + let constant_binder_name = cfg.constant_binder_name + end) in + { + bindCtx = (module BindCtx); + bcontext = + List.fold_left + (fun ctx name -> DefaultBindlibCtxRename.reserve_name name ctx) + BindCtx.empty_ctxt cfg.reserved; + scopes = Fun.id; + topdefs = Fun.id; + structs = Fun.id; + fields = Fun.id; + enums = Fun.id; + constrs = Fun.id; + } - let new_name : string -> ctxt -> string * ctxt = - fun name ctxt -> - let prefix, suffix = split_name name in - let suffix, ctxt = get_suffix prefix suffix ctxt in - merge_name prefix suffix, ctxt + let rec typ ctx = function + | TStruct n, m -> TStruct (ctx.structs n), m + | TEnum n, m -> TEnum (ctx.enums n), m + | ty -> Type.map (typ ctx) ty - let reserve_name : string -> ctxt -> ctxt = - fun name ctxt -> - let prefix, suffix = split_name name in - try - let n = String.Map.find prefix ctxt in - if suffix <= n then ctxt else String.Map.add prefix suffix ctxt - with String.Map.Not_found _ -> String.Map.add prefix suffix ctxt -end - -let rename_vars - ?(exclude = ([] : string list)) - ?(reset_context_for_closed_terms = false) - ?(skip_constant_binders = false) - ?(constant_binder_name = None) - e = - let module BindCtx = Bindlib.Ctxt (struct - include DefaultBindlibCtxRename - - let reset_context_for_closed_terms = reset_context_for_closed_terms - let skip_constant_binders = skip_constant_binders - let constant_binder_name = constant_binder_name - end) in - let rec aux : type a. BindCtx.ctxt -> (a, 't) gexpr -> (a, 't) gexpr boxed = - fun ctx e -> - match e with + let rec expr : type k. context -> (k, 'm) gexpr -> (k, 'm) gexpr boxed = + fun ctx -> function + | EExternal { name = External_scope s, pos }, m -> + eexternal ~name:(External_scope (ctx.scopes s), pos) m + | EExternal { name = External_value d, pos }, m -> + eexternal ~name:(External_value (ctx.topdefs d), pos) m | EAbs { binder; tys }, m -> - let vars, body, ctx = BindCtx.unmbind_in ctx binder in - let body = aux ctx body in + let vars, body, ctx = unmbind_in ctx ~fname:String.to_snake_case binder in + let body = expr ctx body in let binder = bind vars body in - eabs binder tys m - | e -> map ~f:(aux ctx) ~op:Fun.id e - in - let ctx = - List.fold_left - (fun ctx name -> DefaultBindlibCtxRename.reserve_name name ctx) - BindCtx.empty_ctxt exclude - in - aux ctx e + eabs binder (List.map (typ ctx) tys) m + | EStruct { name; fields }, m -> + estruct ~name:(ctx.structs name) + ~fields: + (StructField.Map.fold + (fun fld e -> StructField.Map.add (ctx.fields fld) (expr ctx e)) + fields StructField.Map.empty) + m + | EStructAccess { name; field; e }, m -> + estructaccess ~name:(ctx.structs name) ~field:(ctx.fields field) + ~e:(expr ctx e) m + | EInj { name; e; cons }, m -> + einj ~name:(ctx.enums name) ~cons:(ctx.constrs cons) ~e:(expr ctx e) m + | EMatch { name; e; cases }, m -> + ematch ~name:(ctx.enums name) + ~cases: + (EnumConstructor.Map.fold + (fun cons e -> + EnumConstructor.Map.add (ctx.constrs cons) (expr ctx e)) + cases EnumConstructor.Map.empty) + ~e:(expr ctx e) m + | e -> map ~typ:(typ ctx) ~f:(expr ctx) ~op:Fun.id e + + let scope_name ctx s = ctx.scopes s + let topdef_name ctx s = ctx.topdefs s + let struct_name ctx s = ctx.structs s + let enum_name ctx e = ctx.enums e +end let format ppf e = Print.expr ~debug:false () ppf e diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 4ebf78ab..1d534a7b 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -393,16 +393,52 @@ val remove_logging_calls : (** Removes all calls to [Log] unary operators in the AST, replacing them by their argument. *) -val rename_vars : - ?exclude:string list -> - ?reset_context_for_closed_terms:bool -> - ?skip_constant_binders:bool -> - ?constant_binder_name:string option -> - ('a, 'm) gexpr -> - ('a, 'm) boxed_gexpr -(** Disambiguates all variable names in [e]. [exclude] will blacklist the given - names (useful for keywords or built-in names) ; the other flags behave as - defined in the bindlib documentation for module type [Rename] *) +(** {2 Renamings and formatting} *) + +module Renaming : sig + type config = { + reserved : string list; (** Use for keywords and built-ins *) + reset_context_for_closed_terms : bool; (** See [Bindlib.Rename] *) + skip_constant_binders : bool; (** See [Bindlib.Rename] *) + constant_binder_name : string option; (** See [Bindlib.Rename] *) + } + + type context + + val get_ctx : config -> context + + val unbind_in : + context -> + ?fname:(string -> string) -> + ('e, 'b) Bindlib.binder -> + ('e, _) Mark.ed Var.t * 'b * context + (* [fname] applies a transformation on the variable name (typically something + like [String.to_snake_case]). The result is advisory and a numerical suffix + may be appended or modified *) + + val new_id : context -> string -> string * context + + val set_rewriters : + ?scopes:(ScopeName.t -> ScopeName.t) -> + ?topdefs:(TopdefName.t -> TopdefName.t) -> + ?structs:(StructName.t -> StructName.t) -> + ?fields:(StructField.t -> StructField.t) -> + ?enums:(EnumName.t -> EnumName.t) -> + ?constrs:(EnumConstructor.t -> EnumConstructor.t) -> + context -> + context + + val typ : context -> typ -> typ + + val expr : context -> ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr + (** Disambiguates all variable names in [e], and renames structs, fields, + enums and constrs according to the given context configuration *) + + val scope_name : context -> ScopeName.t -> ScopeName.t + val topdef_name : context -> TopdefName.t -> TopdefName.t + val struct_name : context -> StructName.t -> StructName.t + val enum_name : context -> EnumName.t -> EnumName.t +end val format : Format.formatter -> ('a, 'm) gexpr -> unit (** Simple printing without debug, use [Print.expr ()] instead to follow the @@ -496,9 +532,6 @@ module Box : sig 'm mark -> ('a, 'm) boxed_gexpr - val fv : 'b Bindlib.box -> string list - (** [fv] return the list of free variables from a boxed term. *) - val assert_closed : 'b Bindlib.box -> unit (** [assert_closed b] check there is no free variables in then [b] boxed term. It raises an internal error if it not the case, printing all free diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 753a1c74..2dd81cd3 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -15,6 +15,7 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils open Definitions let map_decl_ctx ~f ctx = @@ -35,7 +36,7 @@ let map_exprs ?typ ~f ~varf { code_items; decl_ctx; lang; module_name } = { code_items; decl_ctx; lang; module_name }) (Scope.map_exprs ?typ ~f ~varf code_items) in - assert (Bindlib.is_closed boxed_prg); + Expr.Box.assert_closed boxed_prg; Bindlib.unbox boxed_prg let fold_left ~f ~init { code_items; _ } = @@ -46,7 +47,7 @@ let fold_exprs ~f ~init prg = Scope.fold_exprs ~f ~init prg.code_items let fold_right ~f ~init { code_items; _ } = BoundList.fold_right ~f:(fun e _ acc -> f e acc) - ~init:(fun () -> init) + ~init:(fun _vlist -> init) code_items let empty_ctx = @@ -95,3 +96,170 @@ let modules_to_list (mt : module_tree) = mtree acc in List.rev (aux [] mt) + +(* Todo? - add handling for specific naming constraints (automatically convert + to camel/snake-case, etc.) - register module names as reserved names *) +let rename_ids + ~reserved + ~reset_context_for_closed_terms + ~skip_constant_binders + ~constant_binder_name + p = + let cap s = String.to_camel_case s in + let uncap s = String.to_snake_case s in + let cfg = + { + Expr.Renaming.reserved; + reset_context_for_closed_terms; + skip_constant_binders; + constant_binder_name; + } + in + let ctx = Expr.Renaming.get_ctx cfg in + (* Each module needs its separate ctx since resolution is qualified ; and name + resolution in a given module must be processed consistently independently + on the current context. *) + let module PathMap = Map.Make (Uid.Path) in + let pctxmap = PathMap.singleton [] ctx in + let pctxmap, structs_map, fields_map, ctx_structs = + StructName.Map.fold + (fun name fields (pctxmap, structs_map, fields_map, ctx_structs) -> + let path = StructName.path name in + let str, pos = StructName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with PathMap.Not_found _ -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (cap str) in + let new_name = StructName.fresh path (id, pos) in + let ctx, fields_map, ctx_fields = + StructField.Map.fold + (fun name ty (ctx, fields_map, ctx_fields) -> + let str, pos = StructField.get_info name in + let id, ctx = Expr.Renaming.new_id ctx (uncap str) in + let new_name = StructField.fresh (id, pos) in + ( ctx, + StructField.Map.add name new_name fields_map, + StructField.Map.add new_name ty ctx_fields )) + fields + (ctx, fields_map, StructField.Map.empty) + in + ( PathMap.add path ctx pctxmap, + StructName.Map.add name new_name structs_map, + fields_map, + StructName.Map.add new_name ctx_fields ctx_structs )) + p.decl_ctx.ctx_structs + ( pctxmap, + StructName.Map.empty, + StructField.Map.empty, + StructName.Map.empty ) + in + let pctxmap, enums_map, constrs_map, ctx_enums = + EnumName.Map.fold + (fun name constrs (pctxmap, enums_map, constrs_map, ctx_enums) -> + let path = EnumName.path name in + let str, pos = EnumName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (cap str) in + let new_name = EnumName.fresh path (id, pos) in + let ctx, constrs_map, ctx_constrs = + EnumConstructor.Map.fold + (fun name ty (ctx, constrs_map, ctx_constrs) -> + let str, pos = EnumConstructor.get_info name in + let id, ctx = Expr.Renaming.new_id ctx (cap str) in + let new_name = EnumConstructor.fresh (id, pos) in + ( ctx, + EnumConstructor.Map.add name new_name constrs_map, + EnumConstructor.Map.add new_name ty ctx_constrs )) + constrs + (ctx, constrs_map, EnumConstructor.Map.empty) + in + ( PathMap.add path ctx pctxmap, + EnumName.Map.add name new_name enums_map, + constrs_map, + EnumName.Map.add new_name ctx_constrs ctx_enums )) + p.decl_ctx.ctx_enums + ( pctxmap, + EnumName.Map.empty, + EnumConstructor.Map.empty, + EnumName.Map.empty ) + in + let pctxmap, scopes_map, ctx_scopes = + ScopeName.Map.fold + (fun name info (pctxmap, scopes_map, ctx_scopes) -> + let info = + { + in_struct_name = StructName.Map.find info.in_struct_name structs_map; + out_struct_name = + StructName.Map.find info.out_struct_name structs_map; + out_struct_fields = + ScopeVar.Map.map + (fun fld -> StructField.Map.find fld fields_map) + info.out_struct_fields; + } + in + let path = ScopeName.path name in + if path = [] then + (* Scopes / topdefs in the root module will be renamed through the + variables binding them in the code_items *) + ( pctxmap, + ScopeName.Map.add name name scopes_map, + ScopeName.Map.add name info ctx_scopes ) + else + let str, pos = ScopeName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (uncap str) in + let new_name = ScopeName.fresh path (id, pos) in + ( PathMap.add path ctx pctxmap, + ScopeName.Map.add name new_name scopes_map, + ScopeName.Map.add new_name info ctx_scopes )) + p.decl_ctx.ctx_scopes + (pctxmap, ScopeName.Map.empty, ScopeName.Map.empty) + in + let pctxmap, topdefs_map, ctx_topdefs = + TopdefName.Map.fold + (fun name typ (pctxmap, topdefs_map, ctx_topdefs) -> + let path = TopdefName.path name in + if path = [] then + (* Topdefs / topdefs in the root module will be renamed through the + variables binding them in the code_items *) + ( pctxmap, + TopdefName.Map.add name name topdefs_map, + TopdefName.Map.add name typ ctx_topdefs ) + (* [typ] is rewritten later on *) + else + let str, pos = TopdefName.get_info name in + let pctxmap, ctx = + try pctxmap, PathMap.find path pctxmap + with Not_found -> PathMap.add path ctx pctxmap, ctx + in + let id, ctx = Expr.Renaming.new_id ctx (uncap str) in + let new_name = TopdefName.fresh path (id, pos) in + ( PathMap.add path ctx pctxmap, + TopdefName.Map.add name new_name topdefs_map, + TopdefName.Map.add new_name typ ctx_topdefs )) + p.decl_ctx.ctx_topdefs + (pctxmap, TopdefName.Map.empty, TopdefName.Map.empty) + in + let ctx = PathMap.find [] pctxmap in + let ctx = + Expr.Renaming.set_rewriters ctx + ~scopes:(fun n -> ScopeName.Map.find n scopes_map) + ~topdefs:(fun n -> TopdefName.Map.find n topdefs_map) + ~structs:(fun n -> StructName.Map.find n structs_map) + ~fields:(fun n -> StructField.Map.find n fields_map) + ~enums:(fun n -> EnumName.Map.find n enums_map) + ~constrs:(fun n -> EnumConstructor.Map.find n constrs_map) + in + let decl_ctx = + { p.decl_ctx with ctx_enums; ctx_structs; ctx_scopes; ctx_topdefs } + in + let decl_ctx = map_decl_ctx ~f:(Expr.Renaming.typ ctx) decl_ctx in + let code_items = Scope.rename_ids ctx p.code_items in + { p with decl_ctx; code_items }, ctx diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 071b7873..98e669e3 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -56,3 +56,18 @@ val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list (** Returns a list of used modules, in topological order ; the boolean indicates if the module is external *) + +val rename_ids : + reserved:string list -> + reset_context_for_closed_terms:bool -> + skip_constant_binders:bool -> + constant_binder_name:string option -> + ('a, 't) gexpr program -> + ('a, 't) gexpr program * Expr.Renaming.context +(** Renames all idents (variables, types, struct and enum names, fields and + constructors) to dispel ambiguities in the target language. Names in + [reserved], typically keywords and built-ins, will be avoided ; the meaning + of the flags is described in [Bindlib.Renaming]. + + In the returned program, it is safe to directly use `Bindlib.name_of` on + variables for printing. The same is true for `StructName.get_info` etc. *) diff --git a/compiler/shared_ast/scope.ml b/compiler/shared_ast/scope.ml index 59bcbcb8..bf43c64d 100644 --- a/compiler/shared_ast/scope.ml +++ b/compiler/shared_ast/scope.ml @@ -39,6 +39,12 @@ let map_exprs_in_lets : (f scope_let.scope_let_expr) )) scope_body_expr +let map_last_item ~varf last = + Bindlib.box_list + @@ List.map + (function EVar v -> Bindlib.box_var (varf v) | _ -> assert false) + last + let map_exprs ?(typ = Fun.id) ~f ~varf scopes = let f v = function | ScopeDef (name, body) -> @@ -58,7 +64,7 @@ let map_exprs ?(typ = Fun.id) ~f ~varf scopes = (fun e -> Topdef (name, typ ty, e)) (Expr.Box.lift (f expr)) ) in - BoundList.map ~f ~last:Bindlib.box scopes + BoundList.map ~f ~last:(map_last_item ~varf) scopes let fold_exprs ~f ~init scopes = let f acc def _ = @@ -116,7 +122,7 @@ let unfold (ctx : decl_ctx) (s : 'e code_item_list) (main_scope : ScopeName.t) : | None, ScopeDef (name, body) when ScopeName.equal name main_scope -> Some (Expr.make_var v (get_body_mark body)) | r, _ -> r) - ~bottom:(fun () -> function Some v -> v | None -> raise Not_found) + ~bottom:(fun _vlist -> function Some v -> v | None -> raise Not_found) ~up:(fun var item next -> let e, typ = match item with @@ -137,6 +143,64 @@ let free_vars_item = function let free_vars scopes = BoundList.fold_right scopes - ~init:(fun () -> Var.Set.empty) + ~init:(fun _vlist -> Var.Set.empty) ~f:(fun item v acc -> Var.Set.union (Var.Set.remove v acc) (free_vars_item item)) + +(** Maps carrying around a naming context, enriched at each [unbind] *) +let rec boundlist_map_ctx ~f ~fname ~last ~ctx = function + | Last l -> Bindlib.box_apply (fun l -> Last l) (last ctx l) + | Cons (item, next_bind) -> + let item = f ctx item in + let var, next, ctx = Expr.Renaming.unbind_in ctx ~fname next_bind in + let next = boundlist_map_ctx ~f ~fname ~last ~ctx next in + let next_bind = Bindlib.bind_var var next in + Bindlib.box_apply2 + (fun item next_bind -> Cons (item, next_bind)) + item next_bind + +let rename_vars_in_lets ctx scope_body_expr = + boundlist_map_ctx scope_body_expr ~ctx ~fname:String.to_snake_case + ~last:(fun ctx e -> Expr.Box.lift (Expr.Renaming.expr ctx e)) + ~f:(fun ctx scope_let -> + Bindlib.box_apply + (fun scope_let_expr -> + { + scope_let with + scope_let_expr; + scope_let_typ = Expr.Renaming.typ ctx scope_let.scope_let_typ; + }) + (Expr.Box.lift (Expr.Renaming.expr ctx scope_let.scope_let_expr))) + +let rename_ids ctx (scopes : 'e code_item_list) = + let f ctx = function + | ScopeDef (name, body) -> + let name = Expr.Renaming.scope_name ctx name in + let scope_input_var, scope_lets, ctx = + Expr.Renaming.unbind_in ctx ~fname:String.to_snake_case + body.scope_body_expr + in + let scope_lets = rename_vars_in_lets ctx scope_lets in + let scope_body_expr = Bindlib.bind_var scope_input_var scope_lets in + Bindlib.box_apply + (fun scope_body_expr -> + let body = + { + scope_body_input_struct = + Expr.Renaming.struct_name ctx body.scope_body_input_struct; + scope_body_output_struct = + Expr.Renaming.struct_name ctx body.scope_body_output_struct; + scope_body_expr; + } + in + ScopeDef (name, body)) + scope_body_expr + | Topdef (name, ty, expr) -> + Bindlib.box_apply + (fun e -> Topdef (name, Expr.Renaming.typ ctx ty, e)) + (Expr.Box.lift (Expr.Renaming.expr ctx expr)) + in + Bindlib.unbox + @@ boundlist_map_ctx ~ctx ~f ~fname:String.to_snake_case + ~last:(fun _ctx -> Bindlib.box) + scopes diff --git a/compiler/shared_ast/scope.mli b/compiler/shared_ast/scope.mli index 696e04de..d63b06d1 100644 --- a/compiler/shared_ast/scope.mli +++ b/compiler/shared_ast/scope.mli @@ -47,6 +47,14 @@ val map_exprs : (** This is the main map visitor for all the expressions inside all the scopes of the program. *) +val map_last_item : + varf:(('a, 'm) naked_gexpr Bindlib.var -> 'e2 Bindlib.var) -> + ('a, 'm) naked_gexpr list -> + 'e2 list Bindlib.box + +(** Helper function to handle the [code_item_list] terminator when manually + mapping on [code_item_list] *) + val fold_exprs : f:('acc -> 'expr -> typ -> 'acc) -> init:'acc -> 'expr code_item_list -> 'acc @@ -69,6 +77,11 @@ val input_type : typ -> Runtime.io_input Mark.pos -> typ this doesn't take thunking into account (thunking is added during the scopelang->dcalc translation) *) +val rename_ids : + Expr.Renaming.context -> + ((_ any, 'm) gexpr as 'e) code_item_list -> + 'e code_item_list + (** {2 Analysis and tests} *) val free_vars_body_expr : 'e scope_body_expr -> 'e Var.Set.t diff --git a/compiler/shared_ast/type.ml b/compiler/shared_ast/type.ml index a792e489..9cc553bf 100644 --- a/compiler/shared_ast/type.ml +++ b/compiler/shared_ast/type.ml @@ -93,6 +93,21 @@ let rec compare ty1 ty2 = | TClosureEnv, _ -> -1 | _, TClosureEnv -> 1 +let map f ty = + Mark.map + (function + | TLit l -> TLit l + | TTuple tl -> TTuple (List.map f tl) + | TStruct n -> TStruct n + | TEnum n -> TEnum n + | TOption ty -> TOption (f ty) + | TArrow (tl, ty) -> TArrow (List.map f tl, f ty) + | TArray ty -> TArray (f ty) + | TDefault ty -> TDefault (f ty) + | TAny -> TAny + | TClosureEnv -> TClosureEnv) + ty + let rec hash ~strip ty = let open Hash.Op in match Mark.remove ty with diff --git a/compiler/shared_ast/type.mli b/compiler/shared_ast/type.mli index 5d026da7..fa8c29c1 100644 --- a/compiler/shared_ast/type.mli +++ b/compiler/shared_ast/type.mli @@ -26,6 +26,9 @@ val equal : t -> t -> bool val equal_list : t list -> t list -> bool val compare : t -> t -> int +val map : (t -> t) -> t -> t +(** Shallow mapping on types *) + val hash : strip:Uid.Path.t -> t -> Hash.t (** The [strip] argument strips the given leading path components in included identifiers before hashing *) diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 5b0c02a3..cef73e1f 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -1040,7 +1040,7 @@ let scope_body ctx env body = let scopes ctx env = BoundList.fold_map ~init:env - ~last:(fun ctx () -> ctx, Bindlib.box ()) + ~last:(fun ctx el -> ctx, Scope.map_last_item ~varf:Var.translate el) ~f:(fun env var item -> match item with | A.ScopeDef (name, body) -> diff --git a/compiler/verification/conditions.ml b/compiler/verification/conditions.ml index 5c121a1c..5e261e50 100644 --- a/compiler/verification/conditions.ml +++ b/compiler/verification/conditions.ml @@ -382,7 +382,7 @@ let generate_verification_conditions_code_items (decl_ctx : decl_ctx) (code_items : 'm expr code_item_list) (s : ScopeName.t option) : verification_condition list = - let conditions, () = + let conditions, _ = BoundList.fold_left ~f:(fun vcs item _ -> match item with diff --git a/tests/modules/good/output/mod_def.ml b/tests/modules/good/output/mod_def.ml index 443b777d..1e95d261 100644 --- a/tests/modules/good/output/mod_def.ml +++ b/tests/modules/good/output/mod_def.ml @@ -1,4 +1,3 @@ - (** This file has been generated by the Catala compiler, do not edit! *) open Runtime_ocaml.Runtime @@ -20,13 +19,13 @@ module Str1 = struct type t = {fld1: Enum1.t; fld2: integer} end -module S_in = struct +module SIn = struct type t = unit end -let s (s_in: S_in.t) : S.t = - let sr_: money = +let s (s_in: SIn.t) : S.t = + let sr1: money = match (match (handle_exceptions @@ -45,19 +44,19 @@ let s (s_in: S_in.t) : S.t = ( if true then (Eoption.ESome (money_of_cents_string "100000")) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_))|])) + | Eoption.ESome x -> (Eoption.ESome x))|])) with | Eoption.ENone _ -> ( if false then (Eoption.ENone ()) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome x -> (Eoption.ESome x)) with | Eoption.ENone _ -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/modules/good/mod_def.catala_en"; start_line=16; start_column=10; end_line=16; end_column=12; law_headings=["Test modules + inclusions 1"]}]))) - | Eoption.ESome arg_ -> arg_ in - let e1_: Enum1.t = + | Eoption.ESome arg -> arg in + let e2: Enum1.t = match (match (handle_exceptions @@ -75,34 +74,34 @@ let s (s_in: S_in.t) : S.t = | Eoption.ENone _ -> ( if true then (Eoption.ESome (Enum1.Maybe ())) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_))|])) + | Eoption.ESome x -> (Eoption.ESome x))|])) with | Eoption.ENone _ -> ( if false then (Eoption.ENone ()) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome x -> (Eoption.ESome x)) with | Eoption.ENone _ -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/modules/good/mod_def.catala_en"; start_line=17; start_column=10; end_line=17; end_column=12; law_headings=["Test modules + inclusions 1"]}]))) - | Eoption.ESome arg_ -> arg_ in - {S.sr = sr_; S.e1 = e1_} + | Eoption.ESome arg -> arg in + {S.sr = sr1; S.e1 = e2} -let half_ : integer -> decimal = - fun (x_: integer) -> +let half : integer -> decimal = + fun (x: integer) -> o_div_int_int {filename="tests/modules/good/mod_def.catala_en"; start_line=21; start_column=14; end_line=21; end_column=15; - law_headings=["Test modules + inclusions 1"]} x_ (integer_of_string + law_headings=["Test modules + inclusions 1"]} x (integer_of_string "2") -let maybe_ : Enum1.t -> Enum1.t = - fun (_: Enum1.t) -> Enum1.Maybe () +let maybe : Enum1.t -> Enum1.t = + fun (x: Enum1.t) -> Enum1.Maybe () let () = Runtime_ocaml.Runtime.register_module "Mod_def" [ "S", Obj.repr s; - "half", Obj.repr half_; - "maybe", Obj.repr maybe_ ] + "half", Obj.repr half; + "maybe", Obj.repr maybe ] "CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX" diff --git a/tests/name_resolution/good/let_in2.catala_en b/tests/name_resolution/good/let_in2.catala_en index 09eca7af..2d20264d 100644 --- a/tests/name_resolution/good/let_in2.catala_en +++ b/tests/name_resolution/good/let_in2.catala_en @@ -34,7 +34,6 @@ $ catala test-scope S ```catala-test-inline $ catala ocaml - (** This file has been generated by the Catala compiler, do not edit! *) open Runtime_ocaml.Runtime @@ -46,20 +45,20 @@ module S = struct type t = {a: bool} end -module S_in = struct +module SIn = struct type t = {a_in: unit -> (bool) Eoption.t} end -let s (s_in: S_in.t) : S.t = - let a_: unit -> (bool) Eoption.t = s_in.S_in.a_in in - let a_: bool = +let s (s_in: SIn.t) : S.t = + let a1: unit -> (bool) Eoption.t = s_in.SIn.a_in in + let a2: bool = match (match (handle_exceptions [|{filename="tests/name_resolution/good/let_in2.catala_en"; start_line=7; start_column=18; end_line=7; end_column=19; - law_headings=["Article"]}|] ([|(a_ ())|])) + law_headings=["Article"]}|] ([|(a1 ())|])) with | Eoption.ENone _ -> ( if true then @@ -78,36 +77,36 @@ let s (s_in: S_in.t) : S.t = end_line=13; end_column=6; law_headings=["Article"]}|] ([||])) with - | Eoption.ENone _ -> + | Eoption.ENone _1 -> ( if true then - (Eoption.ESome (let a_ : bool = false + (Eoption.ESome (let a2 : bool = false in - (let a_ : bool = (o_or a_ true) + (let a3 : bool = (o_or a2 true) in - a_))) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_))|])) + a3))) else (Eoption.ENone ())) + | Eoption.ESome x -> (Eoption.ESome x))|])) with - | Eoption.ENone _ -> + | Eoption.ENone _1 -> ( if false then (Eoption.ENone ()) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome x -> (Eoption.ESome x)) with - | Eoption.ENone _ -> (raise + | Eoption.ENone _1 -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/name_resolution/good/let_in2.catala_en"; start_line=7; start_column=18; end_line=7; end_column=19; law_headings= ["Article"]}]))) - | Eoption.ESome arg_ -> arg_)) else (Eoption.ENone ())) - | Eoption.ESome x_ -> (Eoption.ESome x_)) + | Eoption.ESome arg -> arg)) else (Eoption.ENone ())) + | Eoption.ESome x -> (Eoption.ESome x)) with | Eoption.ENone _ -> (raise (Runtime_ocaml.Runtime.Error (NoValue, [{filename="tests/name_resolution/good/let_in2.catala_en"; start_line=7; start_column=18; end_line=7; end_column=19; law_headings=["Article"]}]))) - | Eoption.ESome arg_ -> arg_ in - {S.a = a_} + | Eoption.ESome arg -> arg in + {S.a = a2} let () = Runtime_ocaml.Runtime.register_module "Let_in2" diff --git a/tests/scope/good/191_fix_record_name_confusion.catala_en b/tests/scope/good/191_fix_record_name_confusion.catala_en index 7a4a8450..f0ce665c 100644 --- a/tests/scope/good/191_fix_record_name_confusion.catala_en +++ b/tests/scope/good/191_fix_record_name_confusion.catala_en @@ -29,7 +29,6 @@ $ catala Typecheck --check-invariants ```catala-test-inline $ catala OCaml -O - (** This file has been generated by the Catala compiler, do not edit! *) open Runtime_ocaml.Runtime @@ -42,26 +41,26 @@ module ScopeA = struct end module ScopeB = struct - type t = {a: bool} + type t = {a1: bool} end -module ScopeA_in = struct +module ScopeAIn = struct type t = unit end -module ScopeB_in = struct +module ScopeBIn = struct type t = unit end -let scope_a (scope_a_in: ScopeA_in.t) : ScopeA.t = - let a_: bool = true in - {ScopeA.a = a_} +let scope_a (scope_a_in: ScopeAIn.t) : ScopeA.t = + let a2: bool = true in + {ScopeA.a = a2} -let scope_b (scope_b_in: ScopeB_in.t) : ScopeB.t = - let scope_a_: ScopeA.t = {ScopeA.a = ((scope_a (())).ScopeA.a)} in - let a_: bool = scope_a_.ScopeA.a in - {ScopeB.a = a_} +let scope_b (scope_b_in: ScopeBIn.t) : ScopeB.t = + let scope_a1: ScopeA.t = {ScopeA.a = ((scope_a (())).ScopeA.a)} in + let a2: bool = scope_a1.ScopeA.a in + {ScopeB.a1 = a2} let entry_scopes = [