From f71db385d5592f4df1c3347b24182c8385c13b0d Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Fri, 5 Apr 2024 14:32:34 +0200 Subject: [PATCH] Python backend: workaround func/var name clash --- compiler/scalc/to_python.ml | 31 ++-- tests/backends/python_name_clash.catala_en | 158 +++++++++++++++++++++ 2 files changed, 176 insertions(+), 13 deletions(-) create mode 100644 tests/backends/python_name_clash.catala_en diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 8e963d3d..17eb0984 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -134,12 +134,14 @@ module IntMap = Map.Make (struct let format ppf i = Format.pp_print_int ppf i end) -let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = +let clean_name (s : string) : string = s |> String.to_snake_case |> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") |> avoid_keywords - |> Format.pp_print_string fmt + +let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = + Format.pp_print_string fmt (clean_name s) (** For each `VarName.t` defined by its string and then by its hash, we keep track of which local integer id we've given it. This is used to keep @@ -156,18 +158,10 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit = | Some ids -> ( match IntMap.find_opt hash ids with | None -> - let max_id = - snd - (List.hd - (List.fast_sort - (fun (_, x) (_, y) -> Int.compare y x) - (IntMap.bindings ids))) - in + let id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in string_counter_map := - StringMap.add v_str - (IntMap.add hash (max_id + 1) ids) - !string_counter_map; - max_id + 1 + StringMap.add v_str (IntMap.add hash id ids) !string_counter_map; + id | Some local_id -> local_id) | None -> string_counter_map := @@ -632,6 +626,16 @@ let format_ctx (e, EnumName.Map.find e ctx.decl_ctx.ctx_enums)) (type_ordering @ scope_structs) +(* FIXME: this is an ugly (and partial) workaround, Python basically has one + namespace and we reserve the name to avoid clashes between func ids and + variable ids. *) +let reserve_func_name = function + | SVar _ -> () + | SFunc { var = v; _ } | SScope { scope_body_var = v; _ } -> + let v_str = clean_name (Mark.remove (FuncName.get_info v)) in + string_counter_map := + StringMap.add v_str (IntMap.singleton (-1) 0) !string_counter_map + let format_code_item ctx fmt = function | SVar { var; expr; typ = _ } -> Format.fprintf fmt "@[%a = (@,%a@,@])@," format_var var @@ -651,6 +655,7 @@ let format_program (fmt : Format.formatter) (p : Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = + List.iter reserve_func_name p.code_items; Format.pp_open_vbox fmt 0; let header = [ diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en new file mode 100644 index 00000000..1a5108ce --- /dev/null +++ b/tests/backends/python_name_clash.catala_en @@ -0,0 +1,158 @@ +This test exposes a name clash between the scope function (`ScopeName,` +rewritten to `scope_name`) and the scope variable `scope_name`. + +```catala +declaration scope SomeName: + input i content integer + output o content integer + +scope SomeName: + definition o equals i + 1 + +declaration scope B: + output some_name scope SomeName + +scope B: + definition some_name.i equals 1 +``` + +```catala-test-inline +$ catala python +# This file has been generated by the Catala compiler, do not edit! + +from catala.runtime import * +from typing import Any, List, Callable, Tuple +from enum import Enum + +class SomeName: + def __init__(self, o: Integer) -> None: + self.o = o + + def __eq__(self, other: object) -> bool: + if isinstance(other, SomeName): + return (self.o == other.o) + else: + return False + + def __ne__(self, other: object) -> bool: + return not (self == other) + + def __str__(self) -> str: + return "SomeName(o={})".format(self.o) + +class B: + def __init__(self, some_name: SomeName) -> None: + self.some_name = some_name + + def __eq__(self, other: object) -> bool: + if isinstance(other, B): + return (self.some_name == other.some_name) + else: + return False + + def __ne__(self, other: object) -> bool: + return not (self == other) + + def __str__(self) -> str: + return "B(some_name={})".format(self.some_name) + +class SomeNameIn: + def __init__(self, i_in: Integer) -> None: + self.i_in = i_in + + def __eq__(self, other: object) -> bool: + if isinstance(other, SomeNameIn): + return (self.i_in == other.i_in) + else: + return False + + def __ne__(self, other: object) -> bool: + return not (self == other) + + def __str__(self) -> str: + return "SomeNameIn(i_in={})".format(self.i_in) + +class BIn: + def __init__(self, ) -> None: + pass + + def __eq__(self, other: object) -> bool: + if isinstance(other, BIn): + return (True) + else: + return False + + def __ne__(self, other: object) -> bool: + return not (self == other) + + def __str__(self) -> str: + return "BIn()".format() + + +def some_name(some_name_in:SomeNameIn): + i = some_name_in.i_in + try: + def temp_o(_:Unit): + raise EmptyError + def temp_o_1(_:Unit): + return False + def temp_o_2(_:Unit): + def temp_o_3(_:Unit): + return (i + integer_of_string("1")) + def temp_o_4(_:Unit): + return True + return handle_default(SourcePosition(filename="tests/backends/python_name_clash.catala_en", + start_line=7, start_column=10, + end_line=7, end_column=11, + law_headings=[]), [], temp_o_4, temp_o_3) + temp_o_5 = handle_default(SourcePosition(filename="tests/backends/python_name_clash.catala_en", + start_line=7, start_column=10, + end_line=7, end_column=11, + law_headings=[]), [temp_o_2], temp_o_1, + temp_o) + except EmptyError: + temp_o_5 = dead_value + raise NoValueProvided(SourcePosition(filename="tests/backends/python_name_clash.catala_en", + start_line=7, start_column=10, + end_line=7, end_column=11, + law_headings=[])) + o = temp_o_5 + return SomeName(o = o) + +def b(b_in:BIn): + try: + def temp_result(_:Unit): + raise EmptyError + def temp_result_1(_:Unit): + return False + def temp_result_2(_:Unit): + def temp_result_3(_:Unit): + return integer_of_string("1") + def temp_result_4(_:Unit): + return True + return handle_default(SourcePosition(filename="tests/backends/python_name_clash.catala_en", + start_line=16, start_column=14, + end_line=16, end_column=25, + law_headings=[]), [], temp_result_4, + temp_result_3) + temp_result_5 = handle_default(SourcePosition(filename="tests/backends/python_name_clash.catala_en", + start_line=16, start_column=14, + end_line=16, end_column=25, + law_headings=[]), [temp_result_2], + temp_result_1, temp_result) + except EmptyError: + temp_result_5 = dead_value + raise NoValueProvided(SourcePosition(filename="tests/backends/python_name_clash.catala_en", + start_line=16, start_column=14, + end_line=16, end_column=25, + law_headings=[])) + result = some_name(SomeNameIn(i_in = temp_result_5)) + result_1 = SomeName(o = result.o) + if True: + temp_some_name = result_1 + else: + temp_some_name = result_1 + some_name_1 = temp_some_name + return B(some_name = some_name_1) +``` +The above should *not* show `some_name = temp_some_name`, but instead `some_name_1 = ...`