From 14a378a33d57b5a8ab1889b065f471b9ebf0d421 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Thu, 8 Aug 2024 12:03:53 +0200 Subject: [PATCH] Translation to scalc: fix renaming in blocks Statements are often flattened, in which case their idents need to be conflict-free. We pass along the renaming context to handle this. --- compiler/scalc/from_lcalc.ml | 270 +++++++++++---------- compiler/scalc/to_python.ml | 3 +- compiler/shared_ast/program.ml | 1 - tests/backends/python_name_clash.catala_en | 6 +- 4 files changed, 146 insertions(+), 134 deletions(-) diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 997e8321..6cf25081 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -98,16 +98,16 @@ let register_fresh_arg ~pos ctxt (x, _) = ctxt let rec translate_expr_list ctxt args = - let stmts, args = + let stmts, args, ren_ctx = List.fold_left - (fun (args_stmts, new_args) arg -> - let arg_stmts, new_arg = translate_expr ctxt arg in - args_stmts ++ arg_stmts, new_arg :: new_args) - (RevBlock.empty, []) args + (fun (args_stmts, new_args, ren_ctx) arg -> + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + args_stmts ++ arg_stmts, new_arg :: new_args, ren_ctx) + (RevBlock.empty, [], ctxt.ren_ctx) args in - stmts, List.rev args + stmts, List.rev args, ren_ctx -and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = +and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr * Renaming.context = try match Mark.remove expr with | EVar v -> @@ -123,27 +123,27 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = Print.var_debug ppf v)) (Var.Map.keys ctxt.var_dict)) in - RevBlock.empty, (local_var, Expr.pos expr) + RevBlock.empty, (local_var, Expr.pos expr), ctxt.ren_ctx | EStruct { fields; name } -> if ctxt.config.no_struct_literals then (* In C89, struct literates have to be initialized at variable definition... *) raise (NotAnExpr { needs_a_local_decl = false }); - let args_stmts, new_args = + let args_stmts, new_args, ren_ctx = StructField.Map.fold - (fun field arg (args_stmts, new_args) -> - let arg_stmts, new_arg = translate_expr ctxt arg in - args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args) + (fun field arg (args_stmts, new_args, ren_ctx) -> + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args, ren_ctx) fields - (RevBlock.empty, StructField.Map.empty) + (RevBlock.empty, StructField.Map.empty, ctxt.ren_ctx) in - args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr) + args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr), ren_ctx | EInj { e = e1; cons; name } -> if ctxt.config.no_struct_literals then (* In C89, struct literates have to be initialized at variable definition... *) raise (NotAnExpr { needs_a_local_decl = false }); - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in ( e1_stmts, ( A.EInj { @@ -152,21 +152,23 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = name; expr_typ = Expr.maybe_ty (Mark.get expr); }, - Expr.pos expr ) ) + Expr.pos expr ), + ren_ctx ) | ETuple args -> - let args_stmts, new_args = translate_expr_list ctxt args in - args_stmts, (A.ETuple new_args, Expr.pos expr) + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in + args_stmts, (A.ETuple new_args, Expr.pos expr), ren_ctx | EStructAccess { e = e1; field; name } -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in ( e1_stmts, - (A.EStructFieldAccess { e1 = new_e1; field; name }, Expr.pos expr) ) + (A.EStructFieldAccess { e1 = new_e1; field; name }, Expr.pos expr), + ren_ctx) | ETupleAccess { e = e1; index; _ } -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in - e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr) + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in + e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr), ren_ctx | EAppOp { op; args; tys = _ } -> - let args_stmts, new_args = translate_expr_list ctxt args in + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) - args_stmts, (A.EAppOp { op; args = new_args }, Expr.pos expr) + args_stmts, (A.EAppOp { op; args = new_args }, Expr.pos expr), ren_ctx | EApp { f = EAbs { binder; tys }, binder_mark; args; tys = _ } -> (* This defines multiple local variables at the time *) let binder_pos = Expr.mark_pos binder_mark in @@ -190,39 +192,42 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = (Var.Map.find x ctxt.var_dict, binder_pos), tau, arg) vars_tau args in - let def_blocks = + let def_blocks, ren_ctx = List.fold_left - (fun acc (x, _tau, arg) -> + (fun (rblock, ren_ctx) (x, _tau, arg) -> let ctxt = { ctxt with inside_definition_of = Some (Mark.remove x); context_name = Mark.remove (A.VarName.get_info (Mark.remove x)); + ren_ctx; } in - let arg_stmts, new_arg = translate_expr ctxt arg in - RevBlock.append (acc ++ arg_stmts) + let arg_stmts, new_arg, ren_ctx = translate_expr ctxt arg in + RevBlock.append (rblock ++ arg_stmts) ( A.SLocalDef { name = x; expr = new_arg; typ = Expr.maybe_ty (Mark.get arg); }, - binder_pos )) - RevBlock.empty vars_args + binder_pos ), + ren_ctx) + (RevBlock.empty, ctxt.ren_ctx) vars_args in - let rest_of_expr_stmts, rest_of_expr = translate_expr ctxt body in - local_decls ++ def_blocks ++ rest_of_expr_stmts, rest_of_expr + let rest_of_expr_stmts, rest_of_expr, ren_ctx = translate_expr { ctxt with ren_ctx } body in + local_decls ++ def_blocks ++ rest_of_expr_stmts, rest_of_expr, ren_ctx | EApp { f; args; tys = _ } -> - let f_stmts, new_f = translate_expr ctxt f in - let args_stmts, new_args = translate_expr_list ctxt args in + let f_stmts, new_f, ren_ctx = translate_expr ctxt f in + let args_stmts, new_args, ren_ctx = translate_expr_list { ctxt with ren_ctx } args in (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) ( f_stmts ++ args_stmts, - (A.EApp { f = new_f; args = new_args }, Expr.pos expr) ) + (A.EApp { f = new_f; args = new_args }, Expr.pos expr), + ren_ctx ) | EArray args -> - let args_stmts, new_args = translate_expr_list ctxt args in - args_stmts, (A.EArray new_args, Expr.pos expr) - | ELit l -> RevBlock.empty, (A.ELit l, Expr.pos expr) + let args_stmts, new_args, ren_ctx = translate_expr_list ctxt args in + args_stmts, (A.EArray new_args, Expr.pos expr), ren_ctx + | ELit l -> RevBlock.empty, (A.ELit l, Expr.pos expr), ctxt.ren_ctx | EExternal { name } -> let path, name = match Mark.remove name with @@ -233,7 +238,7 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = ( ModuleName.Map.find (List.hd (List.rev path)) ctxt.program_ctx.modules, Expr.pos expr ) in - RevBlock.empty, (EExternal { modname; name }, Expr.pos expr) + RevBlock.empty, (EExternal { modname; name }, Expr.pos expr), ctxt.ren_ctx | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | EFatalError _ -> raise (NotAnExpr { needs_a_local_decl = true }) | _ -> . @@ -253,7 +258,7 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = context_name = Mark.remove (A.VarName.get_info tmp_var); } in - let tmp_stmts = translate_statements ctxt expr in + let tmp_stmts, ren_ctx = translate_statements ctxt expr in ( (if needs_a_local_decl then RevBlock.make (( A.SLocalDecl @@ -264,17 +269,19 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = Expr.pos expr ) :: tmp_stmts) else RevBlock.make tmp_stmts), - (A.EVar tmp_var, Expr.pos expr) ) + (A.EVar tmp_var, Expr.pos expr), + ren_ctx ) -and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = +and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * Renaming.context = match Mark.remove block_expr with | EAssert e -> (* Assertions are always encapsulated in a unit-typed let binding *) - let e_stmts, new_e = translate_expr ctxt e in + let e_stmts, new_e, ren_ctx = translate_expr ctxt e in RevBlock.rebuild ~tail:[A.SAssert (Mark.remove new_e), Expr.pos block_expr] - e_stmts - | EFatalError err -> [SFatalError err, Expr.pos block_expr] + e_stmts, + ren_ctx + | EFatalError err -> [SFatalError err, Expr.pos block_expr], ctxt.ren_ctx (* | EAppOp * { op = Op.HandleDefaultOpt, _; tys = _; args = [exceptions; just; cons] } * when ctxt.config.keep_special_ops -> @@ -351,32 +358,31 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = (Var.Map.find x ctxt.var_dict, binder_pos), tau, arg) vars_tau args in - let def_blocks = - List.map - (fun (x, _tau, arg) -> + let def_blocks, ren_ctx = + List.fold_left + (fun (def_blocks, ren_ctx) (x, _tau, arg) -> let ctxt = { ctxt with inside_definition_of = Some (Mark.remove x); context_name = Mark.remove (A.VarName.get_info (Mark.remove x)); + ren_ctx; } in - let arg_stmts, new_arg = translate_expr ctxt arg in - RevBlock.rebuild arg_stmts - ~tail: - [ + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + RevBlock.append (def_blocks ++ arg_stmts) ( A.SLocalDef { name = x; expr = new_arg; typ = Expr.maybe_ty (Mark.get arg); }, - binder_pos ); - ]) - vars_args + binder_pos ), + ren_ctx) + (RevBlock.empty, ctxt.ren_ctx) vars_args in - let rest_of_block = translate_statements ctxt body in - local_decls @ List.flatten def_blocks @ rest_of_block + let rest_of_block, ren_ctx = translate_statements { ctxt with ren_ctx } body in + local_decls @ RevBlock.rebuild def_blocks ~tail:rest_of_block, ren_ctx | EAbs { binder; tys } -> let closure_name, ctxt = match ctxt.inside_definition_of with @@ -392,7 +398,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = { ctxt with inside_definition_of = None } vars_tau in - let new_body = translate_statements ctxt body in + let new_body, _ren_ctx = translate_statements ctxt body in [ ( A.SInnerFuncDef { @@ -413,9 +419,9 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = }; }, binder_pos ); - ] + ], ctxt.ren_ctx | EMatch { e = e1; cases; name } -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in let new_cases = EnumConstructor.Map.fold (fun _ arg new_args -> @@ -427,7 +433,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = let scalc_var, ctxt = register_fresh_var ctxt var ~pos:(Expr.pos arg) in - let new_arg = translate_statements ctxt body in + let new_arg, _ren_ctx = translate_statements ctxt body in { A.case_block = new_arg; payload_var_name = scalc_var; @@ -449,11 +455,12 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = switch_cases = new_args; }, Expr.pos block_expr ); - ] + ], + ren_ctx | EIfThenElse { cond; etrue; efalse } -> - let cond_stmts, s_cond = translate_expr ctxt cond in - let s_e_true = translate_statements ctxt etrue in - let s_e_false = translate_statements ctxt efalse in + let cond_stmts, s_cond, ren_ctx = translate_expr ctxt cond in + let s_e_true, _ = translate_statements ctxt etrue in + let s_e_false, _ = translate_statements ctxt efalse in RevBlock.rebuild cond_stmts ~tail: [ @@ -464,14 +471,14 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = else_block = s_e_false; }, Expr.pos block_expr ); - ] + ], + ren_ctx | EInj { e = e1; cons; name } when ctxt.config.no_struct_literals -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in + let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in let tmp_struct_var_name = match ctxt.inside_definition_of with - | None -> - failwith "should not happen" - (* [translate_expr] should create this [inside_definition_of]*) + | None -> assert false + (* [translate_expr] should create this [inside_definition_of]*) | Some x -> x, Expr.pos block_expr in let inj_expr = @@ -496,15 +503,16 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = Expr.pos block_expr ); }, Expr.pos block_expr ); - ] + ], + ren_ctx | EStruct { fields; name } when ctxt.config.no_struct_literals -> - let args_stmts, new_args = + let args_stmts, new_args, ren_ctx = StructField.Map.fold - (fun field arg (args_stmts, new_args) -> - let arg_stmts, new_arg = translate_expr ctxt arg in - args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args) + (fun field arg (args_stmts, new_args, ren_ctx) -> + let arg_stmts, new_arg, ren_ctx = translate_expr { ctxt with ren_ctx } arg in + args_stmts ++ arg_stmts, StructField.Map.add field new_arg new_args, ren_ctx) fields - (RevBlock.empty, StructField.Map.empty) + (RevBlock.empty, StructField.Map.empty, ctxt.ren_ctx) in let struct_expr = A.EStruct { fields = new_args; name }, Expr.pos block_expr @@ -526,10 +534,11 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = typ = TStruct name, Expr.pos block_expr; }, Expr.pos block_expr ); - ] + ], + ren_ctx | ELit _ | EAppOp _ | EArray _ | EVar _ | EStruct _ | EInj _ | ETuple _ | ETupleAccess _ | EStructAccess _ | EExternal _ | EApp _ -> - let e_stmts, new_e = translate_expr ctxt block_expr in + let e_stmts, new_e, ren_ctx = translate_expr ctxt block_expr in let tail = match (e_stmts :> (A.stmt * Pos.t) list) with | (A.SRaiseEmpty, _) :: _ -> @@ -551,7 +560,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = Expr.pos block_expr ); ] in - RevBlock.rebuild e_stmts ~tail + RevBlock.rebuild e_stmts ~tail, + ren_ctx | _ -> . let rec translate_scope_body_expr ctx (scope_expr : 'm L.expr scope_body_expr) : @@ -559,59 +569,50 @@ let rec translate_scope_body_expr ctx (scope_expr : 'm L.expr scope_body_expr) : let ctx = { ctx with inside_definition_of = None } in match scope_expr with | Last e -> - let block, new_e = translate_expr ctx e in + let block, new_e, _ren_ctx = translate_expr ctx e in RevBlock.rebuild block ~tail:[A.SReturn (Mark.remove new_e), Mark.get new_e] - | Cons (scope_let, next_bnd) -> ( - let let_var, scope_let_next, ctx1 = unbind ctx next_bnd in + | Cons (scope_let, next_bnd) -> + let let_var, scope_let_next, ctx = unbind ctx next_bnd in let let_var_id, ctx = - register_fresh_var ctx1 let_var ~pos:scope_let.scope_let_pos + register_fresh_var ctx let_var ~pos:scope_let.scope_let_pos in - let next = translate_scope_body_expr ctx scope_let_next in - match scope_let.scope_let_kind with - | Assertion -> - translate_statements - { ctx with inside_definition_of = Some let_var_id } - scope_let.scope_let_expr - @ next - | _ -> - let let_expr_stmts, new_let_expr = - translate_expr - { ctx with inside_definition_of = Some let_var_id } - scope_let.scope_let_expr - in - RevBlock.rebuild let_expr_stmts - ~tail: - (( A.SLocalDecl - { - name = let_var_id, scope_let.scope_let_pos; - typ = scope_let.scope_let_typ; - }, - scope_let.scope_let_pos ) - :: ( A.SLocalDef - { - name = let_var_id, scope_let.scope_let_pos; - expr = new_let_expr; - typ = scope_let.scope_let_typ; - }, - scope_let.scope_let_pos ) - :: next)) + let statements, ren_ctx = + match scope_let.scope_let_kind with + | Assertion -> + let stmts, ren_ctx = + translate_statements + { ctx with inside_definition_of = Some let_var_id } + scope_let.scope_let_expr + in + RevBlock.make stmts, ren_ctx + | _ -> + let let_expr_stmts, new_let_expr, ren_ctx = + translate_expr + { ctx with inside_definition_of = Some let_var_id } + scope_let.scope_let_expr + in + let (+>) = RevBlock.append in + let_expr_stmts +> + ( A.SLocalDecl + { + name = let_var_id, scope_let.scope_let_pos; + typ = scope_let.scope_let_typ; + }, + scope_let.scope_let_pos ) +> + ( A.SLocalDef + { + name = let_var_id, scope_let.scope_let_pos; + expr = new_let_expr; + typ = scope_let.scope_let_typ; + }, + scope_let.scope_let_pos ), + ren_ctx + in + let tail = translate_scope_body_expr { ctx with ren_ctx } scope_let_next in + RevBlock.rebuild statements ~tail let translate_program ~(config : translation_config) (p : 'm L.program) : A.program = - let modules = - List.fold_left - (fun acc (m, _) -> - let vname = Mark.map (( ^ ) "Module_") (ModuleName.get_info m) in - (* The "Module_" prefix is a workaround name clashes for same-name - structs and modules, Python in particular mixes everything in one - namespaces. It can be removed once we have full clash-free variable - renaming in the Python backend (requiring all idents to go through - one stage of being bindlib vars) *) - ModuleName.Map.add m (A.VarName.fresh vname) acc) - ModuleName.Map.empty - (Program.modules_to_list p.decl_ctx.ctx_modules) - in - let program_ctx = { A.decl_ctx = p.decl_ctx; A.modules } in let ctxt = { func_dict = Var.Map.empty; @@ -619,10 +620,21 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : inside_definition_of = None; context_name = ""; config; - program_ctx; + program_ctx = { A.decl_ctx = p.decl_ctx; modules = ModuleName.Map.empty}; ren_ctx = config.renaming_context; } in + let modules, ctxt = + List.fold_left + (fun (modules, ctxt) (m, _) -> + let name, pos = ModuleName.get_info m in + let vname, ctxt = get_name ctxt name in + ModuleName.Map.add m (A.VarName.fresh (vname, pos)) modules, ctxt) + (ModuleName.Map.empty, ctxt) + (Program.modules_to_list p.decl_ctx.ctx_modules) + in + let program_ctx = { ctxt.program_ctx with A.modules } in + let ctxt = { ctxt with program_ctx } in let (_, rev_items), _vlist = BoundList.fold_left ~init:(ctxt, []) ~f:(fun (ctxt, rev_items) code_item var -> @@ -661,7 +673,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : :: rev_items ) | Topdef (name, topdef_ty, (EAbs abs, m)) -> (* Toplevel function def *) - let (block, expr), args_id = + let (block, expr, _ren_ctx), args_id = let args_a, expr, ctxt = unmbind ctxt abs.binder in let args = Array.to_list args_a in let rargs_id, ctxt = @@ -705,7 +717,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : :: rev_items ) | Topdef (name, topdef_ty, expr) -> (* Toplevel constant def *) - let block, expr = + let block, expr, _ren_ctx = let ctxt = { ctxt with diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 81e980f6..0769cc94 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -160,6 +160,7 @@ let renaming = ~reset_context_for_closed_terms:false ~skip_constant_binders:false ~constant_binder_name:None ~namespaced_fields_constrs:true ~f_struct:String.to_camel_case + ~f_enum:String.to_camel_case let typ_needs_parens (e : typ) : bool = match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false @@ -413,7 +414,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit let pos = Mark.get s in Format.fprintf fmt "@[if not (%a):@,\ - raise AssertionFailure(@[SourcePosition(@[filename=\"%s\",@ \ + raise AssertionFailed(@[SourcePosition(@[filename=\"%s\",@ \ start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \ law_headings=@[%a@])@])@]@]" (format_expression ctx) diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 7d92f909..f292cc61 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -15,7 +15,6 @@ License for the specific language governing permissions and limitations under the License. *) -open Catala_utils open Definitions let map_decl_ctx ~f ctx = diff --git a/tests/backends/python_name_clash.catala_en b/tests/backends/python_name_clash.catala_en index 00efc21f..e72ef01d 100644 --- a/tests/backends/python_name_clash.catala_en +++ b/tests/backends/python_name_clash.catala_en @@ -162,11 +162,11 @@ def b(b_in:BIn): arg = perhaps_none_arg result1 = arg result = some_name(SomeNameIn(i_in = result1)) - result1 = SomeName(o = result.o) + result4 = SomeName(o = result.o) if True: - some_name2 = result1 + some_name2 = result4 else: - some_name2 = result1 + some_name2 = result4 some_name1 = some_name2 return B(some_name = some_name1) ```