mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Add externals to scalc, working test with Python backend
This commit is contained in:
parent
589833bca7
commit
97ae62384e
@ -880,8 +880,8 @@ module Commands = struct
|
||||
@@ fun fmt ->
|
||||
match ex_scope_opt with
|
||||
| Some scope ->
|
||||
let scope_uid = get_scope_uid prg.decl_ctx scope in
|
||||
Scalc.Print.format_item ~debug:options.Cli.debug prg.decl_ctx fmt
|
||||
let scope_uid = get_scope_uid prg.ctx.decl_ctx scope in
|
||||
Scalc.Print.format_item ~debug:options.Cli.debug prg.ctx.decl_ctx fmt
|
||||
(List.find
|
||||
(function
|
||||
| Scalc.Ast.SScope { scope_body_name; _ } ->
|
||||
@ -889,7 +889,7 @@ module Commands = struct
|
||||
| _ -> false)
|
||||
prg.code_items);
|
||||
Format.pp_print_newline fmt ()
|
||||
| None -> Scalc.Print.format_program prg.decl_ctx fmt prg
|
||||
| None -> Scalc.Print.format_program fmt prg
|
||||
|
||||
let scalc_cmd =
|
||||
Cmd.v
|
||||
|
@ -28,13 +28,13 @@ let register info term =
|
||||
|
||||
let list () = Hashtbl.to_seq_values backend_plugins |> List.of_seq
|
||||
let names () = Hashtbl.to_seq_keys backend_plugins |> List.of_seq
|
||||
|
||||
let load_failures = Hashtbl.create 17
|
||||
|
||||
let print_failures () =
|
||||
if Hashtbl.length load_failures > 0 then
|
||||
Message.emit_warning "Some plugins could not be loaded:@,%a"
|
||||
(Format.pp_print_seq (fun ppf -> Format.fprintf ppf " - %s")) (Hashtbl.to_seq_values load_failures)
|
||||
(Format.pp_print_seq (fun ppf -> Format.fprintf ppf " - %s"))
|
||||
(Hashtbl.to_seq_values load_failures)
|
||||
|
||||
let load_file f =
|
||||
try
|
||||
|
@ -43,4 +43,5 @@ val load_dir : string -> unit
|
||||
(** Load all plugins found in the given directory *)
|
||||
|
||||
val print_failures : unit -> unit
|
||||
(** Dynlink errors may be silenced at startup time if not in --debug mode, this prints them as warnings *)
|
||||
(** Dynlink errors may be silenced at startup time if not in --debug mode, this
|
||||
prints them as warnings *)
|
||||
|
@ -29,7 +29,7 @@ module FuncName =
|
||||
module VarName =
|
||||
Uid.Gen
|
||||
(struct
|
||||
let style = Ocolor_types.(Fg (C4 hi_green))
|
||||
let style = Ocolor_types.Default_fg
|
||||
end)
|
||||
()
|
||||
|
||||
@ -62,6 +62,7 @@ and naked_expr =
|
||||
| ELit of lit
|
||||
| EApp of { f : expr; args : expr list }
|
||||
| EAppOp of { op : operator; args : expr list }
|
||||
| EExternal of { modname : VarName.t Mark.pos; name : string Mark.pos }
|
||||
|
||||
type stmt =
|
||||
| SInnerFuncDef of { name : VarName.t Mark.pos; func : func }
|
||||
@ -114,4 +115,10 @@ type code_item =
|
||||
| SFunc of { var : FuncName.t; func : func }
|
||||
| SScope of scope_body
|
||||
|
||||
type program = { decl_ctx : decl_ctx; code_items : code_item list }
|
||||
type ctx = { decl_ctx : decl_ctx; modules : VarName.t ModuleName.Map.t }
|
||||
|
||||
type program = {
|
||||
ctx : ctx;
|
||||
code_items : code_item list;
|
||||
module_name : ModuleName.t option;
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ type 'm ctxt = {
|
||||
inside_definition_of : A.VarName.t option;
|
||||
context_name : string;
|
||||
config : translation_config;
|
||||
program_ctx : A.ctx;
|
||||
}
|
||||
|
||||
let unthunk e =
|
||||
@ -65,7 +66,11 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
|
||||
(Var.Map.keys ctxt.var_dict))
|
||||
in
|
||||
[], (local_var, Expr.pos expr)
|
||||
| EStruct { fields; name } when not ctxt.config.no_struct_literals ->
|
||||
| EStruct { fields; name } ->
|
||||
if ctxt.config.no_struct_literals then
|
||||
(* In C89, struct literates have to be initialized at variable
|
||||
definition... *)
|
||||
raise (NotAnExpr { needs_a_local_decl = false });
|
||||
let args_stmts, new_args =
|
||||
StructField.Map.fold
|
||||
(fun field arg (args_stmts, new_args) ->
|
||||
@ -76,11 +81,11 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
|
||||
in
|
||||
let args_stmts = List.rev args_stmts in
|
||||
args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr)
|
||||
| EStruct _ when ctxt.config.no_struct_literals ->
|
||||
| EInj { e = e1; cons; name } ->
|
||||
if ctxt.config.no_struct_literals then
|
||||
(* In C89, struct literates have to be initialized at variable
|
||||
definition... *)
|
||||
raise (NotAnExpr { needs_a_local_decl = false })
|
||||
| EInj { e = e1; cons; name } when not ctxt.config.no_struct_literals ->
|
||||
raise (NotAnExpr { needs_a_local_decl = false });
|
||||
let e1_stmts, new_e1 = translate_expr ctxt e1 in
|
||||
( e1_stmts,
|
||||
( A.EInj
|
||||
@ -91,10 +96,6 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
|
||||
expr_typ = Expr.maybe_ty (Mark.get expr);
|
||||
},
|
||||
Expr.pos expr ) )
|
||||
| EInj _ when ctxt.config.no_struct_literals ->
|
||||
(* In C89, struct literates have to be initialized at variable
|
||||
definition... *)
|
||||
raise (NotAnExpr { needs_a_local_decl = false })
|
||||
| ETuple args ->
|
||||
let args_stmts, new_args =
|
||||
List.fold_left
|
||||
@ -212,7 +213,20 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
|
||||
let new_args = List.rev new_args in
|
||||
args_stmts, (A.EArray new_args, Expr.pos expr)
|
||||
| ELit l -> [], (A.ELit l, Expr.pos expr)
|
||||
| _ -> raise (NotAnExpr { needs_a_local_decl = true })
|
||||
| EExternal { name } ->
|
||||
let path, name =
|
||||
match Mark.remove name with
|
||||
| External_value name -> TopdefName.(path name, get_info name)
|
||||
| External_scope name -> ScopeName.(path name, get_info name)
|
||||
in
|
||||
let modname =
|
||||
( ModuleName.Map.find (List.hd (List.rev path)) ctxt.program_ctx.modules,
|
||||
Expr.pos expr )
|
||||
in
|
||||
[], (EExternal { modname; name }, Expr.pos expr)
|
||||
| ECatch _ | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | ERaise _ ->
|
||||
raise (NotAnExpr { needs_a_local_decl = true })
|
||||
| _ -> .
|
||||
with NotAnExpr { needs_a_local_decl } ->
|
||||
let tmp_var =
|
||||
A.VarName.fresh
|
||||
@ -542,8 +556,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
},
|
||||
Expr.pos block_expr );
|
||||
]
|
||||
| _ -> (
|
||||
Message.emit_debug "E: %a" Expr.format block_expr;
|
||||
| ELit _ | EAppOp _ | EArray _ | EVar _ | EStruct _ | EInj _ | ETuple _
|
||||
| ETupleAccess _ | EStructAccess _ | EExternal _ | EApp _ -> (
|
||||
let e_stmts, new_e = translate_expr ctxt block_expr in
|
||||
e_stmts
|
||||
@
|
||||
@ -566,27 +580,28 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
}),
|
||||
Expr.pos block_expr );
|
||||
])
|
||||
| _ -> .
|
||||
|
||||
let rec translate_scope_body_expr
|
||||
~(config : translation_config)
|
||||
(scope_name : ScopeName.t)
|
||||
(decl_ctx : decl_ctx)
|
||||
(program_ctx : A.ctx)
|
||||
(var_dict : ('m L.expr, A.VarName.t) Var.Map.t)
|
||||
(func_dict : ('m L.expr, A.FuncName.t) Var.Map.t)
|
||||
(scope_expr : 'm L.expr scope_body_expr) : A.block =
|
||||
match scope_expr with
|
||||
| Last e ->
|
||||
let block, new_e =
|
||||
translate_expr
|
||||
let ctx =
|
||||
{
|
||||
func_dict;
|
||||
var_dict;
|
||||
inside_definition_of = None;
|
||||
context_name = Mark.remove (ScopeName.get_info scope_name);
|
||||
config;
|
||||
program_ctx;
|
||||
}
|
||||
e
|
||||
in
|
||||
match scope_expr with
|
||||
| Last e ->
|
||||
let block, new_e = translate_expr ctx e in
|
||||
block @ [A.SReturn (Mark.remove new_e), Mark.get new_e]
|
||||
| Cons (scope_let, next_bnd) ->
|
||||
let let_var, scope_let_next = Bindlib.unbind next_bnd in
|
||||
@ -597,24 +612,12 @@ let rec translate_scope_body_expr
|
||||
(match scope_let.scope_let_kind with
|
||||
| Assertion ->
|
||||
translate_statements
|
||||
{
|
||||
func_dict;
|
||||
var_dict;
|
||||
inside_definition_of = Some let_var_id;
|
||||
context_name = Mark.remove (ScopeName.get_info scope_name);
|
||||
config;
|
||||
}
|
||||
{ ctx with inside_definition_of = Some let_var_id }
|
||||
scope_let.scope_let_expr
|
||||
| _ ->
|
||||
let let_expr_stmts, new_let_expr =
|
||||
translate_expr
|
||||
{
|
||||
func_dict;
|
||||
var_dict;
|
||||
inside_definition_of = Some let_var_id;
|
||||
context_name = Mark.remove (ScopeName.get_info scope_name);
|
||||
config;
|
||||
}
|
||||
{ ctx with inside_definition_of = Some let_var_id }
|
||||
scope_let.scope_let_expr
|
||||
in
|
||||
let_expr_stmts
|
||||
@ -633,11 +636,19 @@ let rec translate_scope_body_expr
|
||||
},
|
||||
scope_let.scope_let_pos );
|
||||
])
|
||||
@ translate_scope_body_expr ~config scope_name decl_ctx new_var_dict
|
||||
@ translate_scope_body_expr ~config scope_name program_ctx new_var_dict
|
||||
func_dict scope_let_next
|
||||
|
||||
let translate_program ~(config : translation_config) (p : 'm L.program) :
|
||||
A.program =
|
||||
let modules =
|
||||
List.fold_left
|
||||
(fun acc m ->
|
||||
ModuleName.Map.add m (A.VarName.fresh (ModuleName.get_info m)) acc)
|
||||
ModuleName.Map.empty
|
||||
(Program.modules_to_list p.decl_ctx.ctx_modules)
|
||||
in
|
||||
let ctx = { A.decl_ctx = p.decl_ctx; A.modules } in
|
||||
let (_, _, rev_items), () =
|
||||
BoundList.fold_left
|
||||
~f:(fun (func_dict, var_dict, rev_items) code_item var ->
|
||||
@ -654,8 +665,8 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
|
||||
Var.Map.add scope_input_var scope_input_var_id var_dict
|
||||
in
|
||||
let new_scope_body =
|
||||
translate_scope_body_expr ~config name p.decl_ctx var_dict_local
|
||||
func_dict scope_body_expr
|
||||
translate_scope_body_expr ~config name ctx var_dict_local func_dict
|
||||
scope_body_expr
|
||||
in
|
||||
let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in
|
||||
( Var.Map.add var func_id func_dict,
|
||||
@ -700,6 +711,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
|
||||
inside_definition_of = None;
|
||||
context_name = Mark.remove (TopdefName.get_info name);
|
||||
config;
|
||||
program_ctx = ctx;
|
||||
}
|
||||
in
|
||||
translate_expr ctxt expr
|
||||
@ -735,6 +747,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
|
||||
inside_definition_of = None;
|
||||
context_name = Mark.remove (TopdefName.get_info name);
|
||||
config;
|
||||
program_ctx = ctx;
|
||||
}
|
||||
in
|
||||
translate_expr ctxt expr
|
||||
@ -778,4 +791,4 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
|
||||
~init:(Var.Map.empty, Var.Map.empty, [])
|
||||
p.code_items
|
||||
in
|
||||
{ decl_ctx = p.decl_ctx; code_items = List.rev rev_items }
|
||||
{ ctx; code_items = List.rev rev_items; module_name = p.module_name }
|
||||
|
@ -21,10 +21,10 @@ open Ast
|
||||
let needs_parens (_e : expr) : bool = false
|
||||
|
||||
let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit =
|
||||
Format.fprintf fmt "%a_%s" VarName.format v (string_of_int (VarName.hash v))
|
||||
Format.fprintf fmt "%a_%d" VarName.format v (VarName.hash v)
|
||||
|
||||
let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
|
||||
Format.fprintf fmt "%a_%s" FuncName.format v (string_of_int (FuncName.hash v))
|
||||
Format.fprintf fmt "@{<green>%a_%d@}" FuncName.format v (FuncName.hash v)
|
||||
|
||||
let rec format_expr
|
||||
(decl_ctx : decl_ctx)
|
||||
@ -99,6 +99,9 @@ let rec format_expr
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
format_with_parens)
|
||||
args
|
||||
| EExternal { modname; name } ->
|
||||
Format.fprintf fmt "%a.%s" format_var_name (Mark.remove modname)
|
||||
(Mark.remove name)
|
||||
|
||||
let rec format_statement
|
||||
(decl_ctx : decl_ctx)
|
||||
@ -226,15 +229,22 @@ let format_item decl_ctx ?debug ppf def =
|
||||
Format.pp_close_box ppf ();
|
||||
Format.pp_print_cut ppf ()
|
||||
|
||||
let format_program decl_ctx ?debug ppf prg =
|
||||
let format_program ?debug ppf prg =
|
||||
let decl_ctx =
|
||||
(* TODO: this is redundant with From_dcalc.add_option_type (which is already
|
||||
applied in avoid_exceptions mode) *)
|
||||
{
|
||||
decl_ctx with
|
||||
prg.ctx.decl_ctx with
|
||||
ctx_enums =
|
||||
EnumName.Map.add Expr.option_enum Expr.option_enum_config
|
||||
decl_ctx.ctx_enums;
|
||||
prg.ctx.decl_ctx.ctx_enums;
|
||||
}
|
||||
in
|
||||
Format.pp_open_vbox ppf 0;
|
||||
ModuleName.Map.iter
|
||||
(fun m var ->
|
||||
Format.fprintf ppf "%a %a = %a@," Print.keyword "module" format_var_name
|
||||
var ModuleName.format m)
|
||||
prg.ctx.modules;
|
||||
Format.pp_print_list (format_item decl_ctx ?debug) ppf prg.code_items;
|
||||
Format.pp_close_box ppf ()
|
||||
|
@ -21,5 +21,4 @@ val format_item :
|
||||
Ast.code_item ->
|
||||
unit
|
||||
|
||||
val format_program :
|
||||
Shared_ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.program -> unit
|
||||
val format_program : ?debug:bool -> Format.formatter -> Ast.program -> unit
|
||||
|
@ -385,6 +385,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
args
|
||||
| ETuple _ | ETupleAccess _ ->
|
||||
Message.raise_internal_error "Tuple compilation to R unimplemented!"
|
||||
| EExternal _ -> failwith "TODO"
|
||||
|
||||
let typ_is_array (ctx : decl_ctx) (typ : typ) =
|
||||
match Mark.remove typ with
|
||||
@ -604,26 +605,28 @@ let format_program
|
||||
%a@,\
|
||||
%a@,\
|
||||
@]"
|
||||
(format_ctx type_ordering) p.decl_ctx
|
||||
(format_ctx type_ordering) p.ctx.decl_ctx
|
||||
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt code_item ->
|
||||
match code_item with
|
||||
| SVar { var; expr; typ } ->
|
||||
Format.fprintf fmt "@[<v 2>%a = %a;@]"
|
||||
(format_typ p.decl_ctx (fun fmt -> format_var fmt var))
|
||||
(format_typ p.ctx.decl_ctx (fun fmt -> format_var fmt var))
|
||||
typ
|
||||
(format_expression p.decl_ctx)
|
||||
(format_expression p.ctx.decl_ctx)
|
||||
expr
|
||||
| SFunc { var; func }
|
||||
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
|
||||
let { func_params; func_body; func_return_typ } = func in
|
||||
Format.fprintf fmt "@[<v 2>%a(%a) {@,%a@]@,}"
|
||||
(format_typ p.decl_ctx (fun fmt -> format_func_name fmt var))
|
||||
(format_typ p.ctx.decl_ctx (fun fmt -> format_func_name fmt var))
|
||||
func_return_typ
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt (var, typ) ->
|
||||
(format_typ p.decl_ctx (fun fmt ->
|
||||
(format_typ p.ctx.decl_ctx (fun fmt ->
|
||||
format_var fmt (Mark.remove var)))
|
||||
fmt typ))
|
||||
func_params (format_block p.decl_ctx) func_body))
|
||||
func_params
|
||||
(format_block p.ctx.decl_ctx)
|
||||
func_body))
|
||||
p.code_items
|
||||
|
@ -126,68 +126,13 @@ let avoid_keywords (s : string) : string =
|
||||
then s ^ "_"
|
||||
else s
|
||||
|
||||
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(String.to_camel_case
|
||||
(String.to_ascii (Format.asprintf "%a" StructName.format v))))
|
||||
module StringMap = String.Map
|
||||
|
||||
let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
|
||||
=
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(String.to_ascii (Format.asprintf "%a" StructField.format v)))
|
||||
module IntMap = Map.Make (struct
|
||||
include Int
|
||||
|
||||
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(String.to_camel_case
|
||||
(String.to_ascii (Format.asprintf "%a" EnumName.format v))))
|
||||
|
||||
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
|
||||
unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(String.to_ascii (Format.asprintf "%a" EnumConstructor.format v)))
|
||||
|
||||
let typ_needs_parens (e : typ) : bool =
|
||||
match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false
|
||||
|
||||
let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
||||
let format_typ = format_typ in
|
||||
let format_typ_with_parens (fmt : Format.formatter) (t : typ) =
|
||||
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
|
||||
else Format.fprintf fmt "%a" format_typ t
|
||||
in
|
||||
match Mark.remove typ with
|
||||
| TLit TUnit -> Format.fprintf fmt "Unit"
|
||||
| TLit TMoney -> Format.fprintf fmt "Money"
|
||||
| TLit TInt -> Format.fprintf fmt "Integer"
|
||||
| TLit TRat -> Format.fprintf fmt "Decimal"
|
||||
| TLit TDate -> Format.fprintf fmt "Date"
|
||||
| TLit TDuration -> Format.fprintf fmt "Duration"
|
||||
| TLit TBool -> Format.fprintf fmt "bool"
|
||||
| TTuple ts ->
|
||||
Format.fprintf fmt "Tuple[%a]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
(fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t))
|
||||
ts
|
||||
| TStruct s -> Format.fprintf fmt "%a" format_struct_name s
|
||||
| TOption some_typ ->
|
||||
(* We translate the option type with an overloading by Python's [None] *)
|
||||
Format.fprintf fmt "Optional[%a]" format_typ some_typ
|
||||
| TDefault t -> format_typ fmt t
|
||||
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "Callable[[%a], %a]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
format_typ_with_parens)
|
||||
t1 format_typ_with_parens t2
|
||||
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
|
||||
| TAny -> Format.fprintf fmt "Any"
|
||||
| TClosureEnv -> failwith "unimplemented!"
|
||||
let format ppf i = Format.pp_print_int ppf i
|
||||
end)
|
||||
|
||||
let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
|
||||
s
|
||||
@ -198,14 +143,6 @@ let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
|
||||
|> avoid_keywords
|
||||
|> Format.fprintf fmt "%s"
|
||||
|
||||
module StringMap = String.Map
|
||||
|
||||
module IntMap = Map.Make (struct
|
||||
include Int
|
||||
|
||||
let format ppf i = Format.pp_print_int ppf i
|
||||
end)
|
||||
|
||||
(** For each `VarName.t` defined by its string and then by its hash, we keep
|
||||
track of which local integer id we've given it. This is used to keep
|
||||
variable naming with low indices rather than one global counter for all
|
||||
@ -244,6 +181,76 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
|
||||
else if local_id = 0 then format_name_cleaned fmt v_str
|
||||
else Format.fprintf fmt "%a_%d" format_name_cleaned v_str local_id
|
||||
|
||||
let format_path ctx fmt p =
|
||||
match List.rev p with
|
||||
| [] -> ()
|
||||
| m :: _ ->
|
||||
format_var fmt (ModuleName.Map.find m ctx.modules);
|
||||
Format.pp_print_char fmt '.'
|
||||
|
||||
let format_struct_name ctx (fmt : Format.formatter) (v : StructName.t) : unit =
|
||||
format_path ctx fmt (StructName.path v);
|
||||
Format.pp_print_string fmt
|
||||
(avoid_keywords
|
||||
(String.to_camel_case
|
||||
(String.to_ascii (Mark.remove (StructName.get_info v)))))
|
||||
|
||||
let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
|
||||
=
|
||||
Format.pp_print_string fmt
|
||||
(avoid_keywords (String.to_ascii (StructField.to_string v)))
|
||||
|
||||
let format_enum_name ctx (fmt : Format.formatter) (v : EnumName.t) : unit =
|
||||
format_path ctx fmt (EnumName.path v);
|
||||
Format.pp_print_string fmt
|
||||
(avoid_keywords
|
||||
(String.to_camel_case
|
||||
(String.to_ascii (Mark.remove (EnumName.get_info v)))))
|
||||
|
||||
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
|
||||
unit =
|
||||
Format.pp_print_string fmt
|
||||
(avoid_keywords (String.to_ascii (EnumConstructor.to_string v)))
|
||||
|
||||
let typ_needs_parens (e : typ) : bool =
|
||||
match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false
|
||||
|
||||
let rec format_typ ctx (fmt : Format.formatter) (typ : typ) : unit =
|
||||
let format_typ = format_typ ctx in
|
||||
let format_typ_with_parens (fmt : Format.formatter) (t : typ) =
|
||||
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
|
||||
else Format.fprintf fmt "%a" format_typ t
|
||||
in
|
||||
match Mark.remove typ with
|
||||
| TLit TUnit -> Format.fprintf fmt "Unit"
|
||||
| TLit TMoney -> Format.fprintf fmt "Money"
|
||||
| TLit TInt -> Format.fprintf fmt "Integer"
|
||||
| TLit TRat -> Format.fprintf fmt "Decimal"
|
||||
| TLit TDate -> Format.fprintf fmt "Date"
|
||||
| TLit TDuration -> Format.fprintf fmt "Duration"
|
||||
| TLit TBool -> Format.fprintf fmt "bool"
|
||||
| TTuple ts ->
|
||||
Format.fprintf fmt "Tuple[%a]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
(fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t))
|
||||
ts
|
||||
| TStruct s -> Format.fprintf fmt "%a" (format_struct_name ctx) s
|
||||
| TOption some_typ ->
|
||||
(* We translate the option type with an overloading by Python's [None] *)
|
||||
Format.fprintf fmt "Optional[%a]" format_typ some_typ
|
||||
| TDefault t -> format_typ fmt t
|
||||
| TEnum e -> Format.fprintf fmt "%a" (format_enum_name ctx) e
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "Callable[[%a], %a]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
format_typ_with_parens)
|
||||
t1 format_typ_with_parens t2
|
||||
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
|
||||
| TAny -> Format.fprintf fmt "Any"
|
||||
| TClosureEnv -> failwith "unimplemented!"
|
||||
|
||||
let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
|
||||
let v_str = Mark.remove (FuncName.get_info v) in
|
||||
format_name_cleaned fmt v_str
|
||||
@ -270,13 +277,12 @@ let format_exception (fmt : Format.formatter) (exc : except Mark.pos) : unit =
|
||||
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
|
||||
(Pos.get_law_info pos)
|
||||
|
||||
let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
unit =
|
||||
let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit =
|
||||
match Mark.remove e with
|
||||
| EVar v -> format_var fmt v
|
||||
| EFunc f -> format_func_name fmt f
|
||||
| EStruct { fields = es; name = s } ->
|
||||
Format.fprintf fmt "%a(%a)" format_struct_name s
|
||||
Format.fprintf fmt "%a(%a)" (format_struct_name ctx) s
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt (struct_field, e) ->
|
||||
@ -297,8 +303,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
(* We translate the option type with an overloading by Python's [None] *)
|
||||
format_expression ctx fmt e
|
||||
| EInj { e1 = e; cons; name = enum_name; _ } ->
|
||||
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name
|
||||
format_enum_name enum_name format_enum_cons_name cons
|
||||
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" (format_enum_name ctx) enum_name
|
||||
(format_enum_name ctx) enum_name format_enum_cons_name cons
|
||||
(format_expression ctx) e
|
||||
| EArray es ->
|
||||
Format.fprintf fmt "[%a]"
|
||||
@ -402,11 +408,12 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
es
|
||||
| ETupleAccess { e1; index } ->
|
||||
Format.fprintf fmt "%a[%d]" (format_expression ctx) e1 index
|
||||
| EExternal { modname; name } ->
|
||||
Format.fprintf fmt "%a.%a" format_var (Mark.remove modname)
|
||||
format_name_cleaned (Mark.remove name)
|
||||
|
||||
let rec format_statement
|
||||
(ctx : decl_ctx)
|
||||
(fmt : Format.formatter)
|
||||
(s : stmt Mark.pos) : unit =
|
||||
let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
|
||||
=
|
||||
match Mark.remove s with
|
||||
| SInnerFuncDef { name; func = { func_params; func_body; _ } } ->
|
||||
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_var
|
||||
@ -414,8 +421,8 @@ let rec format_statement
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
(fun fmt (var, typ) ->
|
||||
Format.fprintf fmt "%a:%a" format_var (Mark.remove var) format_typ
|
||||
typ))
|
||||
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
|
||||
(format_typ ctx) typ))
|
||||
func_params (format_block ctx) func_body
|
||||
| SLocalDecl _ ->
|
||||
assert false (* We don't need to declare variables in Python *)
|
||||
@ -458,7 +465,7 @@ let rec format_statement
|
||||
(format_block ctx) case_none format_var case_some_var format_var tmp_var
|
||||
(format_block ctx) case_some
|
||||
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
|
||||
let cons_map = EnumName.Map.find e_name ctx.ctx_enums in
|
||||
let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in
|
||||
let cases =
|
||||
List.map2
|
||||
(fun x (cons, _) -> x, cons)
|
||||
@ -472,9 +479,9 @@ let rec format_statement
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ")
|
||||
(fun fmt ({ case_block; payload_var_name; _ }, cons_name) ->
|
||||
Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a"
|
||||
format_var tmp_var format_enum_name e_name format_enum_cons_name
|
||||
cons_name format_var payload_var_name format_var tmp_var
|
||||
(format_block ctx) case_block))
|
||||
format_var tmp_var (format_enum_name ctx) e_name
|
||||
format_enum_cons_name cons_name format_var payload_var_name
|
||||
format_var tmp_var (format_block ctx) case_block))
|
||||
cases
|
||||
| SReturn e1 ->
|
||||
Format.fprintf fmt "@[<hov 4>return %a@]" (format_expression ctx)
|
||||
@ -493,7 +500,7 @@ let rec format_statement
|
||||
(Pos.get_law_info pos)
|
||||
| SSpecialOp _ -> failwith "should not happen"
|
||||
|
||||
and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit =
|
||||
and format_block ctx (fmt : Format.formatter) (b : block) : unit =
|
||||
Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(format_statement ctx) fmt
|
||||
@ -504,7 +511,7 @@ and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit =
|
||||
let format_ctx
|
||||
(type_ordering : Scopelang.Dependency.TVertex.t list)
|
||||
(fmt : Format.formatter)
|
||||
(ctx : decl_ctx) : unit =
|
||||
ctx : unit =
|
||||
let format_struct_decl fmt (struct_name, struct_fields) =
|
||||
let fields = StructField.Map.bindings struct_fields in
|
||||
Format.fprintf fmt
|
||||
@ -522,13 +529,13 @@ let format_ctx
|
||||
\ return not (self == other)@\n\
|
||||
@\n\
|
||||
\ def __str__(self) -> str:@\n\
|
||||
\ @[<hov 4>return \"%a(%a)\".format(%a)@]" format_struct_name
|
||||
\ @[<hov 4>return \"%a(%a)\".format(%a)@]" (format_struct_name ctx)
|
||||
struct_name
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
(fun fmt (struct_field, struct_field_type) ->
|
||||
Format.fprintf fmt "%a: %a" format_struct_field_name struct_field
|
||||
format_typ struct_field_type))
|
||||
(format_typ ctx) struct_field_type))
|
||||
fields
|
||||
(if StructField.Map.is_empty struct_fields then fun fmt _ ->
|
||||
Format.fprintf fmt " pass"
|
||||
@ -538,7 +545,7 @@ let format_ctx
|
||||
(fun fmt (struct_field, _) ->
|
||||
Format.fprintf fmt " self.%a = %a" format_struct_field_name
|
||||
struct_field format_struct_field_name struct_field))
|
||||
fields format_struct_name struct_name
|
||||
fields (format_struct_name ctx) struct_name
|
||||
(if not (StructField.Map.is_empty struct_fields) then
|
||||
Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ")
|
||||
@ -546,7 +553,7 @@ let format_ctx
|
||||
Format.fprintf fmt "self.%a == other.%a" format_struct_field_name
|
||||
struct_field format_struct_field_name struct_field)
|
||||
else fun fmt _ -> Format.fprintf fmt "True")
|
||||
fields format_struct_name struct_name
|
||||
fields (format_struct_name ctx) struct_name
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",")
|
||||
(fun fmt (struct_field, _) ->
|
||||
@ -585,7 +592,7 @@ let format_ctx
|
||||
@\n\
|
||||
\ def __str__(self) -> str:@\n\
|
||||
\ @[<hov 4>return \"{}({})\".format(self.code, self.value)@]"
|
||||
format_enum_name enum_name
|
||||
(format_enum_name ctx) enum_name
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun fmt (i, enum_cons, _enum_cons_type) ->
|
||||
@ -593,8 +600,8 @@ let format_ctx
|
||||
(List.mapi
|
||||
(fun i (x, y) -> i, x, y)
|
||||
(EnumConstructor.Map.bindings enum_cons))
|
||||
format_enum_name enum_name format_enum_name enum_name format_enum_name
|
||||
enum_name
|
||||
(format_enum_name ctx) enum_name (format_enum_name ctx) enum_name
|
||||
(format_enum_name ctx) enum_name
|
||||
in
|
||||
|
||||
let is_in_type_ordering s =
|
||||
@ -611,42 +618,25 @@ let format_ctx
|
||||
(StructName.Map.bindings
|
||||
(StructName.Map.filter
|
||||
(fun s _ -> not (is_in_type_ordering s))
|
||||
ctx.ctx_structs))
|
||||
ctx.decl_ctx.ctx_structs))
|
||||
in
|
||||
List.iter
|
||||
(fun struct_or_enum ->
|
||||
match struct_or_enum with
|
||||
| Scopelang.Dependency.TVertex.Struct s ->
|
||||
if StructName.path s = [] then
|
||||
Format.fprintf fmt "%a@\n@\n" format_struct_decl
|
||||
(s, StructName.Map.find s ctx.ctx_structs)
|
||||
(s, StructName.Map.find s ctx.decl_ctx.ctx_structs)
|
||||
| Scopelang.Dependency.TVertex.Enum e ->
|
||||
if EnumName.path e = [] then
|
||||
Format.fprintf fmt "%a@\n@\n" format_enum_decl
|
||||
(e, EnumName.Map.find e ctx.ctx_enums))
|
||||
(e, EnumName.Map.find e ctx.decl_ctx.ctx_enums))
|
||||
(type_ordering @ scope_structs)
|
||||
|
||||
let format_program
|
||||
(fmt : Format.formatter)
|
||||
(p : Ast.program)
|
||||
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
|
||||
(* We disable the style flag in order to enjoy formatting from the
|
||||
pretty-printers of Dcalc and Lcalc but without the color terminal
|
||||
markers. *)
|
||||
Format.fprintf fmt
|
||||
"@[<v># This file has been generated by the Catala compiler, do not edit!@,\
|
||||
@,\
|
||||
from catala.runtime import *@,\
|
||||
from typing import Any, List, Callable, Tuple@,\
|
||||
from enum import Enum@,\
|
||||
@,\
|
||||
@[<v>%a@]@,\
|
||||
@,\
|
||||
%a@]@?"
|
||||
(format_ctx type_ordering) p.decl_ctx
|
||||
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function
|
||||
let format_code_item ctx fmt = function
|
||||
| SVar { var; expr; typ = _ } ->
|
||||
Format.fprintf fmt "@[<hv 4>%a = (@,%a@,@])@," format_var var
|
||||
(format_expression p.decl_ctx)
|
||||
expr
|
||||
(format_expression ctx) expr
|
||||
| SFunc { var; func }
|
||||
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
|
||||
let { Ast.func_params; Ast.func_body; _ } = func in
|
||||
@ -655,6 +645,31 @@ let format_program
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
(fun fmt (var, typ) ->
|
||||
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
|
||||
format_typ typ))
|
||||
func_params (format_block p.decl_ctx) func_body))
|
||||
p.code_items
|
||||
(format_typ ctx) typ))
|
||||
func_params (format_block ctx) func_body
|
||||
|
||||
let format_program
|
||||
(fmt : Format.formatter)
|
||||
(p : Ast.program)
|
||||
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
|
||||
Format.pp_open_vbox fmt 0;
|
||||
let header =
|
||||
[
|
||||
"# This file has been generated by the Catala compiler, do not edit!";
|
||||
"";
|
||||
"from catala.runtime import *";
|
||||
"from typing import Any, List, Callable, Tuple";
|
||||
"from enum import Enum";
|
||||
"";
|
||||
]
|
||||
in
|
||||
Format.pp_print_list Format.pp_print_string fmt header;
|
||||
ModuleName.Map.iter
|
||||
(fun m v ->
|
||||
Format.fprintf fmt "import %a as %a@," ModuleName.format m format_var v)
|
||||
p.ctx.modules;
|
||||
Format.pp_print_cut fmt ();
|
||||
format_ctx type_ordering fmt p.ctx;
|
||||
Format.pp_print_cut fmt ();
|
||||
Format.pp_print_list (format_code_item p.ctx) fmt p.code_items;
|
||||
Format.pp_print_flush fmt ()
|
||||
|
@ -373,6 +373,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
args
|
||||
| ETuple _ | ETupleAccess _ ->
|
||||
Message.raise_internal_error "Tuple compilation to R unimplemented!"
|
||||
| EExternal _ -> failwith "TODO"
|
||||
|
||||
let rec format_statement
|
||||
(ctx : decl_ctx)
|
||||
@ -562,11 +563,11 @@ let format_program
|
||||
@[<v>%a@]@,\
|
||||
@,\
|
||||
%a@]@?"
|
||||
(format_ctx type_ordering) p.decl_ctx
|
||||
(format_ctx type_ordering) p.ctx.decl_ctx
|
||||
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function
|
||||
| SVar { var; expr; typ = _ } ->
|
||||
Format.fprintf fmt "@[<hv 2>%a <- (@,%a@,@])@," format_var var
|
||||
(format_expression p.decl_ctx)
|
||||
(format_expression p.ctx.decl_ctx)
|
||||
expr
|
||||
| SFunc { var; func }
|
||||
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
|
||||
@ -578,5 +579,7 @@ let format_program
|
||||
(fun fmt (var, typ) ->
|
||||
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
|
||||
format_typ typ))
|
||||
func_params (format_block p.decl_ctx) func_body))
|
||||
func_params
|
||||
(format_block p.ctx.decl_ctx)
|
||||
func_body))
|
||||
p.code_items
|
||||
|
@ -1051,16 +1051,7 @@ let load_runtime_modules prg =
|
||||
obj_file Format.pp_print_text
|
||||
(Dynlink.error_message dl_err)
|
||||
in
|
||||
let modules_list_topo =
|
||||
let rec aux acc (M mtree) =
|
||||
ModuleName.Map.fold
|
||||
(fun mname sub acc ->
|
||||
if List.exists (ModuleName.equal mname) acc then acc
|
||||
else mname :: aux acc sub)
|
||||
mtree acc
|
||||
in
|
||||
List.rev (aux [] prg.decl_ctx.ctx_modules)
|
||||
in
|
||||
let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in
|
||||
if modules_list_topo <> [] then
|
||||
Message.emit_debug "Loading shared modules... %a"
|
||||
(Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format)
|
||||
|
@ -85,3 +85,13 @@ let to_expr p main_scope =
|
||||
let res = Scope.unfold p.decl_ctx p.code_items main_scope in
|
||||
Expr.Box.assert_closed (Expr.Box.lift res);
|
||||
res
|
||||
|
||||
let modules_to_list (mt : module_tree) =
|
||||
let rec aux acc (M mtree) =
|
||||
ModuleName.Map.fold
|
||||
(fun mname sub acc ->
|
||||
if List.exists (ModuleName.equal mname) acc then acc
|
||||
else mname :: aux acc sub)
|
||||
mtree acc
|
||||
in
|
||||
List.rev (aux [] mt)
|
||||
|
@ -52,3 +52,6 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed
|
||||
function. *)
|
||||
|
||||
val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body
|
||||
|
||||
val modules_to_list : module_tree -> ModuleName.t list
|
||||
(** Returns a list of used modules, in topological order *)
|
||||
|
@ -394,7 +394,6 @@ class S4In:
|
||||
return "S4In()".format()
|
||||
|
||||
|
||||
|
||||
glob1 = (decimal_of_string("44.12"))
|
||||
|
||||
def glob3(x:Money):
|
||||
|
Loading…
Reference in New Issue
Block a user