diff --git a/compiler/driver.ml b/compiler/driver.ml index bacac301..bac1b235 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -880,8 +880,8 @@ module Commands = struct @@ fun fmt -> match ex_scope_opt with | Some scope -> - let scope_uid = get_scope_uid prg.decl_ctx scope in - Scalc.Print.format_item ~debug:options.Cli.debug prg.decl_ctx fmt + let scope_uid = get_scope_uid prg.ctx.decl_ctx scope in + Scalc.Print.format_item ~debug:options.Cli.debug prg.ctx.decl_ctx fmt (List.find (function | Scalc.Ast.SScope { scope_body_name; _ } -> @@ -889,7 +889,7 @@ module Commands = struct | _ -> false) prg.code_items); Format.pp_print_newline fmt () - | None -> Scalc.Print.format_program prg.decl_ctx fmt prg + | None -> Scalc.Print.format_program fmt prg let scalc_cmd = Cmd.v diff --git a/compiler/plugin.ml b/compiler/plugin.ml index a2e4ff71..e125fab1 100644 --- a/compiler/plugin.ml +++ b/compiler/plugin.ml @@ -28,13 +28,13 @@ let register info term = let list () = Hashtbl.to_seq_values backend_plugins |> List.of_seq let names () = Hashtbl.to_seq_keys backend_plugins |> List.of_seq - let load_failures = Hashtbl.create 17 let print_failures () = if Hashtbl.length load_failures > 0 then Message.emit_warning "Some plugins could not be loaded:@,%a" - (Format.pp_print_seq (fun ppf -> Format.fprintf ppf " - %s")) (Hashtbl.to_seq_values load_failures) + (Format.pp_print_seq (fun ppf -> Format.fprintf ppf " - %s")) + (Hashtbl.to_seq_values load_failures) let load_file f = try diff --git a/compiler/plugin.mli b/compiler/plugin.mli index 1e94dba3..f216623f 100644 --- a/compiler/plugin.mli +++ b/compiler/plugin.mli @@ -43,4 +43,5 @@ val load_dir : string -> unit (** Load all plugins found in the given directory *) val print_failures : unit -> unit -(** Dynlink errors may be silenced at startup time if not in --debug mode, this prints them as warnings *) +(** Dynlink errors may be silenced at startup time if not in --debug mode, this + prints them as warnings *) diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index 20a90fc4..9d023772 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -29,7 +29,7 @@ module FuncName = module VarName = Uid.Gen (struct - let style = Ocolor_types.(Fg (C4 hi_green)) + let style = Ocolor_types.Default_fg end) () @@ -62,6 +62,7 @@ and naked_expr = | ELit of lit | EApp of { f : expr; args : expr list } | EAppOp of { op : operator; args : expr list } + | EExternal of { modname : VarName.t Mark.pos; name : string Mark.pos } type stmt = | SInnerFuncDef of { name : VarName.t Mark.pos; func : func } @@ -114,4 +115,10 @@ type code_item = | SFunc of { var : FuncName.t; func : func } | SScope of scope_body -type program = { decl_ctx : decl_ctx; code_items : code_item list } +type ctx = { decl_ctx : decl_ctx; modules : VarName.t ModuleName.Map.t } + +type program = { + ctx : ctx; + code_items : code_item list; + module_name : ModuleName.t option; +} diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 66fd5d1f..b4b6de1e 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -32,6 +32,7 @@ type 'm ctxt = { inside_definition_of : A.VarName.t option; context_name : string; config : translation_config; + program_ctx : A.ctx; } let unthunk e = @@ -65,7 +66,11 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr = (Var.Map.keys ctxt.var_dict)) in [], (local_var, Expr.pos expr) - | EStruct { fields; name } when not ctxt.config.no_struct_literals -> + | EStruct { fields; name } -> + if ctxt.config.no_struct_literals then + (* In C89, struct literates have to be initialized at variable + definition... *) + raise (NotAnExpr { needs_a_local_decl = false }); let args_stmts, new_args = StructField.Map.fold (fun field arg (args_stmts, new_args) -> @@ -76,11 +81,11 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr = in let args_stmts = List.rev args_stmts in args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr) - | EStruct _ when ctxt.config.no_struct_literals -> - (* In C89, struct literates have to be initialized at variable - definition... *) - raise (NotAnExpr { needs_a_local_decl = false }) - | EInj { e = e1; cons; name } when not ctxt.config.no_struct_literals -> + | EInj { e = e1; cons; name } -> + if ctxt.config.no_struct_literals then + (* In C89, struct literates have to be initialized at variable + definition... *) + raise (NotAnExpr { needs_a_local_decl = false }); let e1_stmts, new_e1 = translate_expr ctxt e1 in ( e1_stmts, ( A.EInj @@ -91,10 +96,6 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr = expr_typ = Expr.maybe_ty (Mark.get expr); }, Expr.pos expr ) ) - | EInj _ when ctxt.config.no_struct_literals -> - (* In C89, struct literates have to be initialized at variable - definition... *) - raise (NotAnExpr { needs_a_local_decl = false }) | ETuple args -> let args_stmts, new_args = List.fold_left @@ -212,7 +213,20 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr = let new_args = List.rev new_args in args_stmts, (A.EArray new_args, Expr.pos expr) | ELit l -> [], (A.ELit l, Expr.pos expr) - | _ -> raise (NotAnExpr { needs_a_local_decl = true }) + | EExternal { name } -> + let path, name = + match Mark.remove name with + | External_value name -> TopdefName.(path name, get_info name) + | External_scope name -> ScopeName.(path name, get_info name) + in + let modname = + ( ModuleName.Map.find (List.hd (List.rev path)) ctxt.program_ctx.modules, + Expr.pos expr ) + in + [], (EExternal { modname; name }, Expr.pos expr) + | ECatch _ | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | ERaise _ -> + raise (NotAnExpr { needs_a_local_decl = true }) + | _ -> . with NotAnExpr { needs_a_local_decl } -> let tmp_var = A.VarName.fresh @@ -542,8 +556,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = }, Expr.pos block_expr ); ] - | _ -> ( - Message.emit_debug "E: %a" Expr.format block_expr; + | ELit _ | EAppOp _ | EArray _ | EVar _ | EStruct _ | EInj _ | ETuple _ + | ETupleAccess _ | EStructAccess _ | EExternal _ | EApp _ -> ( let e_stmts, new_e = translate_expr ctxt block_expr in e_stmts @ @@ -566,27 +580,28 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = }), Expr.pos block_expr ); ]) + | _ -> . let rec translate_scope_body_expr ~(config : translation_config) (scope_name : ScopeName.t) - (decl_ctx : decl_ctx) + (program_ctx : A.ctx) (var_dict : ('m L.expr, A.VarName.t) Var.Map.t) (func_dict : ('m L.expr, A.FuncName.t) Var.Map.t) (scope_expr : 'm L.expr scope_body_expr) : A.block = + let ctx = + { + func_dict; + var_dict; + inside_definition_of = None; + context_name = Mark.remove (ScopeName.get_info scope_name); + config; + program_ctx; + } + in match scope_expr with | Last e -> - let block, new_e = - translate_expr - { - func_dict; - var_dict; - inside_definition_of = None; - context_name = Mark.remove (ScopeName.get_info scope_name); - config; - } - e - in + let block, new_e = translate_expr ctx e in block @ [A.SReturn (Mark.remove new_e), Mark.get new_e] | Cons (scope_let, next_bnd) -> let let_var, scope_let_next = Bindlib.unbind next_bnd in @@ -597,24 +612,12 @@ let rec translate_scope_body_expr (match scope_let.scope_let_kind with | Assertion -> translate_statements - { - func_dict; - var_dict; - inside_definition_of = Some let_var_id; - context_name = Mark.remove (ScopeName.get_info scope_name); - config; - } + { ctx with inside_definition_of = Some let_var_id } scope_let.scope_let_expr | _ -> let let_expr_stmts, new_let_expr = translate_expr - { - func_dict; - var_dict; - inside_definition_of = Some let_var_id; - context_name = Mark.remove (ScopeName.get_info scope_name); - config; - } + { ctx with inside_definition_of = Some let_var_id } scope_let.scope_let_expr in let_expr_stmts @@ -633,11 +636,19 @@ let rec translate_scope_body_expr }, scope_let.scope_let_pos ); ]) - @ translate_scope_body_expr ~config scope_name decl_ctx new_var_dict + @ translate_scope_body_expr ~config scope_name program_ctx new_var_dict func_dict scope_let_next let translate_program ~(config : translation_config) (p : 'm L.program) : A.program = + let modules = + List.fold_left + (fun acc m -> + ModuleName.Map.add m (A.VarName.fresh (ModuleName.get_info m)) acc) + ModuleName.Map.empty + (Program.modules_to_list p.decl_ctx.ctx_modules) + in + let ctx = { A.decl_ctx = p.decl_ctx; A.modules } in let (_, _, rev_items), () = BoundList.fold_left ~f:(fun (func_dict, var_dict, rev_items) code_item var -> @@ -654,8 +665,8 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : Var.Map.add scope_input_var scope_input_var_id var_dict in let new_scope_body = - translate_scope_body_expr ~config name p.decl_ctx var_dict_local - func_dict scope_body_expr + translate_scope_body_expr ~config name ctx var_dict_local func_dict + scope_body_expr in let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in ( Var.Map.add var func_id func_dict, @@ -700,6 +711,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : inside_definition_of = None; context_name = Mark.remove (TopdefName.get_info name); config; + program_ctx = ctx; } in translate_expr ctxt expr @@ -735,6 +747,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : inside_definition_of = None; context_name = Mark.remove (TopdefName.get_info name); config; + program_ctx = ctx; } in translate_expr ctxt expr @@ -778,4 +791,4 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : ~init:(Var.Map.empty, Var.Map.empty, []) p.code_items in - { decl_ctx = p.decl_ctx; code_items = List.rev rev_items } + { ctx; code_items = List.rev rev_items; module_name = p.module_name } diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index e7f1cb64..6c0c9069 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -21,10 +21,10 @@ open Ast let needs_parens (_e : expr) : bool = false let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit = - Format.fprintf fmt "%a_%s" VarName.format v (string_of_int (VarName.hash v)) + Format.fprintf fmt "%a_%d" VarName.format v (VarName.hash v) let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = - Format.fprintf fmt "%a_%s" FuncName.format v (string_of_int (FuncName.hash v)) + Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.hash v) let rec format_expr (decl_ctx : decl_ctx) @@ -99,6 +99,9 @@ let rec format_expr ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) args + | EExternal { modname; name } -> + Format.fprintf fmt "%a.%s" format_var_name (Mark.remove modname) + (Mark.remove name) let rec format_statement (decl_ctx : decl_ctx) @@ -226,15 +229,22 @@ let format_item decl_ctx ?debug ppf def = Format.pp_close_box ppf (); Format.pp_print_cut ppf () -let format_program decl_ctx ?debug ppf prg = +let format_program ?debug ppf prg = let decl_ctx = + (* TODO: this is redundant with From_dcalc.add_option_type (which is already + applied in avoid_exceptions mode) *) { - decl_ctx with + prg.ctx.decl_ctx with ctx_enums = EnumName.Map.add Expr.option_enum Expr.option_enum_config - decl_ctx.ctx_enums; + prg.ctx.decl_ctx.ctx_enums; } in Format.pp_open_vbox ppf 0; + ModuleName.Map.iter + (fun m var -> + Format.fprintf ppf "%a %a = %a@," Print.keyword "module" format_var_name + var ModuleName.format m) + prg.ctx.modules; Format.pp_print_list (format_item decl_ctx ?debug) ppf prg.code_items; Format.pp_close_box ppf () diff --git a/compiler/scalc/print.mli b/compiler/scalc/print.mli index 66cf7a19..c5ab9bee 100644 --- a/compiler/scalc/print.mli +++ b/compiler/scalc/print.mli @@ -21,5 +21,4 @@ val format_item : Ast.code_item -> unit -val format_program : - Shared_ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.program -> unit +val format_program : ?debug:bool -> Format.formatter -> Ast.program -> unit diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index 66b3aace..8064b6dd 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -385,6 +385,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : args | ETuple _ | ETupleAccess _ -> Message.raise_internal_error "Tuple compilation to R unimplemented!" + | EExternal _ -> failwith "TODO" let typ_is_array (ctx : decl_ctx) (typ : typ) = match Mark.remove typ with @@ -604,26 +605,28 @@ let format_program %a@,\ %a@,\ @]" - (format_ctx type_ordering) p.decl_ctx + (format_ctx type_ordering) p.ctx.decl_ctx (Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt code_item -> match code_item with | SVar { var; expr; typ } -> Format.fprintf fmt "@[%a = %a;@]" - (format_typ p.decl_ctx (fun fmt -> format_var fmt var)) + (format_typ p.ctx.decl_ctx (fun fmt -> format_var fmt var)) typ - (format_expression p.decl_ctx) + (format_expression p.ctx.decl_ctx) expr | SFunc { var; func } | SScope { scope_body_var = var; scope_body_func = func; _ } -> let { func_params; func_body; func_return_typ } = func in Format.fprintf fmt "@[%a(%a) {@,%a@]@,}" - (format_typ p.decl_ctx (fun fmt -> format_func_name fmt var)) + (format_typ p.ctx.decl_ctx (fun fmt -> format_func_name fmt var)) func_return_typ (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (var, typ) -> - (format_typ p.decl_ctx (fun fmt -> + (format_typ p.ctx.decl_ctx (fun fmt -> format_var fmt (Mark.remove var))) fmt typ)) - func_params (format_block p.decl_ctx) func_body)) + func_params + (format_block p.ctx.decl_ctx) + func_body)) p.code_items diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 9cd70873..1606a1a9 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -126,68 +126,13 @@ let avoid_keywords (s : string) : string = then s ^ "_" else s -let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = - Format.fprintf fmt "%s" - (avoid_keywords - (String.to_camel_case - (String.to_ascii (Format.asprintf "%a" StructName.format v)))) +module StringMap = String.Map -let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit - = - Format.fprintf fmt "%s" - (avoid_keywords - (String.to_ascii (Format.asprintf "%a" StructField.format v))) +module IntMap = Map.Make (struct + include Int -let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = - Format.fprintf fmt "%s" - (avoid_keywords - (String.to_camel_case - (String.to_ascii (Format.asprintf "%a" EnumName.format v)))) - -let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : - unit = - Format.fprintf fmt "%s" - (avoid_keywords - (String.to_ascii (Format.asprintf "%a" EnumConstructor.format v))) - -let typ_needs_parens (e : typ) : bool = - match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false - -let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = - let format_typ = format_typ in - let format_typ_with_parens (fmt : Format.formatter) (t : typ) = - if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t - else Format.fprintf fmt "%a" format_typ t - in - match Mark.remove typ with - | TLit TUnit -> Format.fprintf fmt "Unit" - | TLit TMoney -> Format.fprintf fmt "Money" - | TLit TInt -> Format.fprintf fmt "Integer" - | TLit TRat -> Format.fprintf fmt "Decimal" - | TLit TDate -> Format.fprintf fmt "Date" - | TLit TDuration -> Format.fprintf fmt "Duration" - | TLit TBool -> Format.fprintf fmt "bool" - | TTuple ts -> - Format.fprintf fmt "Tuple[%a]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") - (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) - ts - | TStruct s -> Format.fprintf fmt "%a" format_struct_name s - | TOption some_typ -> - (* We translate the option type with an overloading by Python's [None] *) - Format.fprintf fmt "Optional[%a]" format_typ some_typ - | TDefault t -> format_typ fmt t - | TEnum e -> Format.fprintf fmt "%a" format_enum_name e - | TArrow (t1, t2) -> - Format.fprintf fmt "Callable[[%a], %a]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - format_typ_with_parens) - t1 format_typ_with_parens t2 - | TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1 - | TAny -> Format.fprintf fmt "Any" - | TClosureEnv -> failwith "unimplemented!" + let format ppf i = Format.pp_print_int ppf i +end) let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = s @@ -198,14 +143,6 @@ let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = |> avoid_keywords |> Format.fprintf fmt "%s" -module StringMap = String.Map - -module IntMap = Map.Make (struct - include Int - - let format ppf i = Format.pp_print_int ppf i -end) - (** For each `VarName.t` defined by its string and then by its hash, we keep track of which local integer id we've given it. This is used to keep variable naming with low indices rather than one global counter for all @@ -244,6 +181,76 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit = else if local_id = 0 then format_name_cleaned fmt v_str else Format.fprintf fmt "%a_%d" format_name_cleaned v_str local_id +let format_path ctx fmt p = + match List.rev p with + | [] -> () + | m :: _ -> + format_var fmt (ModuleName.Map.find m ctx.modules); + Format.pp_print_char fmt '.' + +let format_struct_name ctx (fmt : Format.formatter) (v : StructName.t) : unit = + format_path ctx fmt (StructName.path v); + Format.pp_print_string fmt + (avoid_keywords + (String.to_camel_case + (String.to_ascii (Mark.remove (StructName.get_info v))))) + +let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit + = + Format.pp_print_string fmt + (avoid_keywords (String.to_ascii (StructField.to_string v))) + +let format_enum_name ctx (fmt : Format.formatter) (v : EnumName.t) : unit = + format_path ctx fmt (EnumName.path v); + Format.pp_print_string fmt + (avoid_keywords + (String.to_camel_case + (String.to_ascii (Mark.remove (EnumName.get_info v))))) + +let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : + unit = + Format.pp_print_string fmt + (avoid_keywords (String.to_ascii (EnumConstructor.to_string v))) + +let typ_needs_parens (e : typ) : bool = + match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false + +let rec format_typ ctx (fmt : Format.formatter) (typ : typ) : unit = + let format_typ = format_typ ctx in + let format_typ_with_parens (fmt : Format.formatter) (t : typ) = + if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t + else Format.fprintf fmt "%a" format_typ t + in + match Mark.remove typ with + | TLit TUnit -> Format.fprintf fmt "Unit" + | TLit TMoney -> Format.fprintf fmt "Money" + | TLit TInt -> Format.fprintf fmt "Integer" + | TLit TRat -> Format.fprintf fmt "Decimal" + | TLit TDate -> Format.fprintf fmt "Date" + | TLit TDuration -> Format.fprintf fmt "Duration" + | TLit TBool -> Format.fprintf fmt "bool" + | TTuple ts -> + Format.fprintf fmt "Tuple[%a]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) + ts + | TStruct s -> Format.fprintf fmt "%a" (format_struct_name ctx) s + | TOption some_typ -> + (* We translate the option type with an overloading by Python's [None] *) + Format.fprintf fmt "Optional[%a]" format_typ some_typ + | TDefault t -> format_typ fmt t + | TEnum e -> Format.fprintf fmt "%a" (format_enum_name ctx) e + | TArrow (t1, t2) -> + Format.fprintf fmt "Callable[[%a], %a]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + format_typ_with_parens) + t1 format_typ_with_parens t2 + | TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1 + | TAny -> Format.fprintf fmt "Any" + | TClosureEnv -> failwith "unimplemented!" + let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = let v_str = Mark.remove (FuncName.get_info v) in format_name_cleaned fmt v_str @@ -270,13 +277,12 @@ let format_exception (fmt : Format.formatter) (exc : except Mark.pos) : unit = (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) -let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : - unit = +let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = match Mark.remove e with | EVar v -> format_var fmt v | EFunc f -> format_func_name fmt f | EStruct { fields = es; name = s } -> - Format.fprintf fmt "%a(%a)" format_struct_name s + Format.fprintf fmt "%a(%a)" (format_struct_name ctx) s (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (struct_field, e) -> @@ -297,8 +303,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : (* We translate the option type with an overloading by Python's [None] *) format_expression ctx fmt e | EInj { e1 = e; cons; name = enum_name; _ } -> - Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name - format_enum_name enum_name format_enum_cons_name cons + Format.fprintf fmt "%a(%a_Code.%a,@ %a)" (format_enum_name ctx) enum_name + (format_enum_name ctx) enum_name format_enum_cons_name cons (format_expression ctx) e | EArray es -> Format.fprintf fmt "[%a]" @@ -402,11 +408,12 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : es | ETupleAccess { e1; index } -> Format.fprintf fmt "%a[%d]" (format_expression ctx) e1 index + | EExternal { modname; name } -> + Format.fprintf fmt "%a.%a" format_var (Mark.remove modname) + format_name_cleaned (Mark.remove name) -let rec format_statement - (ctx : decl_ctx) - (fmt : Format.formatter) - (s : stmt Mark.pos) : unit = +let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit + = match Mark.remove s with | SInnerFuncDef { name; func = { func_params; func_body; _ } } -> Format.fprintf fmt "@[def %a(%a):@\n%a@]" format_var @@ -414,8 +421,8 @@ let rec format_statement (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt (var, typ) -> - Format.fprintf fmt "%a:%a" format_var (Mark.remove var) format_typ - typ)) + Format.fprintf fmt "%a:%a" format_var (Mark.remove var) + (format_typ ctx) typ)) func_params (format_block ctx) func_body | SLocalDecl _ -> assert false (* We don't need to declare variables in Python *) @@ -458,7 +465,7 @@ let rec format_statement (format_block ctx) case_none format_var case_some_var format_var tmp_var (format_block ctx) case_some | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> - let cons_map = EnumName.Map.find e_name ctx.ctx_enums in + let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in let cases = List.map2 (fun x (cons, _) -> x, cons) @@ -472,9 +479,9 @@ let rec format_statement ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt ({ case_block; payload_var_name; _ }, cons_name) -> Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" - format_var tmp_var format_enum_name e_name format_enum_cons_name - cons_name format_var payload_var_name format_var tmp_var - (format_block ctx) case_block)) + format_var tmp_var (format_enum_name ctx) e_name + format_enum_cons_name cons_name format_var payload_var_name + format_var tmp_var (format_block ctx) case_block)) cases | SReturn e1 -> Format.fprintf fmt "@[return %a@]" (format_expression ctx) @@ -493,7 +500,7 @@ let rec format_statement (Pos.get_law_info pos) | SSpecialOp _ -> failwith "should not happen" -and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit = +and format_block ctx (fmt : Format.formatter) (b : block) : unit = Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (format_statement ctx) fmt @@ -504,7 +511,7 @@ and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit = let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) - (ctx : decl_ctx) : unit = + ctx : unit = let format_struct_decl fmt (struct_name, struct_fields) = let fields = StructField.Map.bindings struct_fields in Format.fprintf fmt @@ -522,13 +529,13 @@ let format_ctx \ return not (self == other)@\n\ @\n\ \ def __str__(self) -> str:@\n\ - \ @[return \"%a(%a)\".format(%a)@]" format_struct_name + \ @[return \"%a(%a)\".format(%a)@]" (format_struct_name ctx) struct_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt (struct_field, struct_field_type) -> Format.fprintf fmt "%a: %a" format_struct_field_name struct_field - format_typ struct_field_type)) + (format_typ ctx) struct_field_type)) fields (if StructField.Map.is_empty struct_fields then fun fmt _ -> Format.fprintf fmt " pass" @@ -538,7 +545,7 @@ let format_ctx (fun fmt (struct_field, _) -> Format.fprintf fmt " self.%a = %a" format_struct_field_name struct_field format_struct_field_name struct_field)) - fields format_struct_name struct_name + fields (format_struct_name ctx) struct_name (if not (StructField.Map.is_empty struct_fields) then Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ") @@ -546,7 +553,7 @@ let format_ctx Format.fprintf fmt "self.%a == other.%a" format_struct_field_name struct_field format_struct_field_name struct_field) else fun fmt _ -> Format.fprintf fmt "True") - fields format_struct_name struct_name + fields (format_struct_name ctx) struct_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",") (fun fmt (struct_field, _) -> @@ -585,7 +592,7 @@ let format_ctx @\n\ \ def __str__(self) -> str:@\n\ \ @[return \"{}({})\".format(self.code, self.value)@]" - format_enum_name enum_name + (format_enum_name ctx) enum_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (i, enum_cons, _enum_cons_type) -> @@ -593,8 +600,8 @@ let format_ctx (List.mapi (fun i (x, y) -> i, x, y) (EnumConstructor.Map.bindings enum_cons)) - format_enum_name enum_name format_enum_name enum_name format_enum_name - enum_name + (format_enum_name ctx) enum_name (format_enum_name ctx) enum_name + (format_enum_name ctx) enum_name in let is_in_type_ordering s = @@ -611,50 +618,58 @@ let format_ctx (StructName.Map.bindings (StructName.Map.filter (fun s _ -> not (is_in_type_ordering s)) - ctx.ctx_structs)) + ctx.decl_ctx.ctx_structs)) in List.iter (fun struct_or_enum -> match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> - Format.fprintf fmt "%a@\n@\n" format_struct_decl - (s, StructName.Map.find s ctx.ctx_structs) + if StructName.path s = [] then + Format.fprintf fmt "%a@\n@\n" format_struct_decl + (s, StructName.Map.find s ctx.decl_ctx.ctx_structs) | Scopelang.Dependency.TVertex.Enum e -> - Format.fprintf fmt "%a@\n@\n" format_enum_decl - (e, EnumName.Map.find e ctx.ctx_enums)) + if EnumName.path e = [] then + Format.fprintf fmt "%a@\n@\n" format_enum_decl + (e, EnumName.Map.find e ctx.decl_ctx.ctx_enums)) (type_ordering @ scope_structs) +let format_code_item ctx fmt = function + | SVar { var; expr; typ = _ } -> + Format.fprintf fmt "@[%a = (@,%a@,@])@," format_var var + (format_expression ctx) expr + | SFunc { var; func } + | SScope { scope_body_var = var; scope_body_func = func; _ } -> + let { Ast.func_params; Ast.func_body; _ } = func in + Format.fprintf fmt "@[def %a(%a):@\n%a@]@," format_func_name var + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (var, typ) -> + Format.fprintf fmt "%a:%a" format_var (Mark.remove var) + (format_typ ctx) typ)) + func_params (format_block ctx) func_body + let format_program (fmt : Format.formatter) (p : Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = - (* We disable the style flag in order to enjoy formatting from the - pretty-printers of Dcalc and Lcalc but without the color terminal - markers. *) - Format.fprintf fmt - "@[# This file has been generated by the Catala compiler, do not edit!@,\ - @,\ - from catala.runtime import *@,\ - from typing import Any, List, Callable, Tuple@,\ - from enum import Enum@,\ - @,\ - @[%a@]@,\ - @,\ - %a@]@?" - (format_ctx type_ordering) p.decl_ctx - (Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function - | SVar { var; expr; typ = _ } -> - Format.fprintf fmt "@[%a = (@,%a@,@])@," format_var var - (format_expression p.decl_ctx) - expr - | SFunc { var; func } - | SScope { scope_body_var = var; scope_body_func = func; _ } -> - let { Ast.func_params; Ast.func_body; _ } = func in - Format.fprintf fmt "@[def %a(%a):@\n%a@]@," format_func_name var - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") - (fun fmt (var, typ) -> - Format.fprintf fmt "%a:%a" format_var (Mark.remove var) - format_typ typ)) - func_params (format_block p.decl_ctx) func_body)) - p.code_items + Format.pp_open_vbox fmt 0; + let header = + [ + "# This file has been generated by the Catala compiler, do not edit!"; + ""; + "from catala.runtime import *"; + "from typing import Any, List, Callable, Tuple"; + "from enum import Enum"; + ""; + ] + in + Format.pp_print_list Format.pp_print_string fmt header; + ModuleName.Map.iter + (fun m v -> + Format.fprintf fmt "import %a as %a@," ModuleName.format m format_var v) + p.ctx.modules; + Format.pp_print_cut fmt (); + format_ctx type_ordering fmt p.ctx; + Format.pp_print_cut fmt (); + Format.pp_print_list (format_code_item p.ctx) fmt p.code_items; + Format.pp_print_flush fmt () diff --git a/compiler/scalc/to_r.ml b/compiler/scalc/to_r.ml index bc0131ab..f56ceb07 100644 --- a/compiler/scalc/to_r.ml +++ b/compiler/scalc/to_r.ml @@ -373,6 +373,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : args | ETuple _ | ETupleAccess _ -> Message.raise_internal_error "Tuple compilation to R unimplemented!" + | EExternal _ -> failwith "TODO" let rec format_statement (ctx : decl_ctx) @@ -562,11 +563,11 @@ let format_program @[%a@]@,\ @,\ %a@]@?" - (format_ctx type_ordering) p.decl_ctx + (format_ctx type_ordering) p.ctx.decl_ctx (Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function | SVar { var; expr; typ = _ } -> Format.fprintf fmt "@[%a <- (@,%a@,@])@," format_var var - (format_expression p.decl_ctx) + (format_expression p.ctx.decl_ctx) expr | SFunc { var; func } | SScope { scope_body_var = var; scope_body_func = func; _ } -> @@ -578,5 +579,7 @@ let format_program (fun fmt (var, typ) -> Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var) format_typ typ)) - func_params (format_block p.decl_ctx) func_body)) + func_params + (format_block p.ctx.decl_ctx) + func_body)) p.code_items diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 7824706d..1a510bda 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -1051,16 +1051,7 @@ let load_runtime_modules prg = obj_file Format.pp_print_text (Dynlink.error_message dl_err) in - let modules_list_topo = - let rec aux acc (M mtree) = - ModuleName.Map.fold - (fun mname sub acc -> - if List.exists (ModuleName.equal mname) acc then acc - else mname :: aux acc sub) - mtree acc - in - List.rev (aux [] prg.decl_ctx.ctx_modules) - in + let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in if modules_list_topo <> [] then Message.emit_debug "Loading shared modules... %a" (Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format) diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 6ce67cd8..d3a7f1f1 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -85,3 +85,13 @@ let to_expr p main_scope = let res = Scope.unfold p.decl_ctx p.code_items main_scope in Expr.Box.assert_closed (Expr.Box.lift res); res + +let modules_to_list (mt : module_tree) = + let rec aux acc (M mtree) = + ModuleName.Map.fold + (fun mname sub acc -> + if List.exists (ModuleName.equal mname) acc then acc + else mname :: aux acc sub) + mtree acc + in + List.rev (aux [] mt) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index a702e2b4..54a95047 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -52,3 +52,6 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed function. *) val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body + +val modules_to_list : module_tree -> ModuleName.t list +(** Returns a list of used modules, in topological order *) diff --git a/tests/name_resolution/good/toplevel_defs.catala_en b/tests/name_resolution/good/toplevel_defs.catala_en index ff8a507a..edee69b8 100644 --- a/tests/name_resolution/good/toplevel_defs.catala_en +++ b/tests/name_resolution/good/toplevel_defs.catala_en @@ -394,7 +394,6 @@ class S4In: return "S4In()".format() - glob1 = (decimal_of_string("44.12")) def glob3(x:Money):