Add externals to scalc, working test with Python backend

This commit is contained in:
Louis Gesbert 2024-02-22 12:14:25 +01:00
parent 589833bca7
commit 97ae62384e
14 changed files with 259 additions and 205 deletions

View File

@ -880,8 +880,8 @@ module Commands = struct
@@ fun fmt ->
match ex_scope_opt with
| Some scope ->
let scope_uid = get_scope_uid prg.decl_ctx scope in
Scalc.Print.format_item ~debug:options.Cli.debug prg.decl_ctx fmt
let scope_uid = get_scope_uid prg.ctx.decl_ctx scope in
Scalc.Print.format_item ~debug:options.Cli.debug prg.ctx.decl_ctx fmt
(List.find
(function
| Scalc.Ast.SScope { scope_body_name; _ } ->
@ -889,7 +889,7 @@ module Commands = struct
| _ -> false)
prg.code_items);
Format.pp_print_newline fmt ()
| None -> Scalc.Print.format_program prg.decl_ctx fmt prg
| None -> Scalc.Print.format_program fmt prg
let scalc_cmd =
Cmd.v

View File

@ -28,13 +28,13 @@ let register info term =
let list () = Hashtbl.to_seq_values backend_plugins |> List.of_seq
let names () = Hashtbl.to_seq_keys backend_plugins |> List.of_seq
let load_failures = Hashtbl.create 17
let print_failures () =
if Hashtbl.length load_failures > 0 then
Message.emit_warning "Some plugins could not be loaded:@,%a"
(Format.pp_print_seq (fun ppf -> Format.fprintf ppf " - %s")) (Hashtbl.to_seq_values load_failures)
(Format.pp_print_seq (fun ppf -> Format.fprintf ppf " - %s"))
(Hashtbl.to_seq_values load_failures)
let load_file f =
try

View File

@ -43,4 +43,5 @@ val load_dir : string -> unit
(** Load all plugins found in the given directory *)
val print_failures : unit -> unit
(** Dynlink errors may be silenced at startup time if not in --debug mode, this prints them as warnings *)
(** Dynlink errors may be silenced at startup time if not in --debug mode, this
prints them as warnings *)

View File

@ -29,7 +29,7 @@ module FuncName =
module VarName =
Uid.Gen
(struct
let style = Ocolor_types.(Fg (C4 hi_green))
let style = Ocolor_types.Default_fg
end)
()
@ -62,6 +62,7 @@ and naked_expr =
| ELit of lit
| EApp of { f : expr; args : expr list }
| EAppOp of { op : operator; args : expr list }
| EExternal of { modname : VarName.t Mark.pos; name : string Mark.pos }
type stmt =
| SInnerFuncDef of { name : VarName.t Mark.pos; func : func }
@ -114,4 +115,10 @@ type code_item =
| SFunc of { var : FuncName.t; func : func }
| SScope of scope_body
type program = { decl_ctx : decl_ctx; code_items : code_item list }
type ctx = { decl_ctx : decl_ctx; modules : VarName.t ModuleName.Map.t }
type program = {
ctx : ctx;
code_items : code_item list;
module_name : ModuleName.t option;
}

View File

@ -32,6 +32,7 @@ type 'm ctxt = {
inside_definition_of : A.VarName.t option;
context_name : string;
config : translation_config;
program_ctx : A.ctx;
}
let unthunk e =
@ -65,7 +66,11 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
(Var.Map.keys ctxt.var_dict))
in
[], (local_var, Expr.pos expr)
| EStruct { fields; name } when not ctxt.config.no_struct_literals ->
| EStruct { fields; name } ->
if ctxt.config.no_struct_literals then
(* In C89, struct literates have to be initialized at variable
definition... *)
raise (NotAnExpr { needs_a_local_decl = false });
let args_stmts, new_args =
StructField.Map.fold
(fun field arg (args_stmts, new_args) ->
@ -76,11 +81,11 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
in
let args_stmts = List.rev args_stmts in
args_stmts, (A.EStruct { fields = new_args; name }, Expr.pos expr)
| EStruct _ when ctxt.config.no_struct_literals ->
(* In C89, struct literates have to be initialized at variable
definition... *)
raise (NotAnExpr { needs_a_local_decl = false })
| EInj { e = e1; cons; name } when not ctxt.config.no_struct_literals ->
| EInj { e = e1; cons; name } ->
if ctxt.config.no_struct_literals then
(* In C89, struct literates have to be initialized at variable
definition... *)
raise (NotAnExpr { needs_a_local_decl = false });
let e1_stmts, new_e1 = translate_expr ctxt e1 in
( e1_stmts,
( A.EInj
@ -91,10 +96,6 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
expr_typ = Expr.maybe_ty (Mark.get expr);
},
Expr.pos expr ) )
| EInj _ when ctxt.config.no_struct_literals ->
(* In C89, struct literates have to be initialized at variable
definition... *)
raise (NotAnExpr { needs_a_local_decl = false })
| ETuple args ->
let args_stmts, new_args =
List.fold_left
@ -212,7 +213,20 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
let new_args = List.rev new_args in
args_stmts, (A.EArray new_args, Expr.pos expr)
| ELit l -> [], (A.ELit l, Expr.pos expr)
| _ -> raise (NotAnExpr { needs_a_local_decl = true })
| EExternal { name } ->
let path, name =
match Mark.remove name with
| External_value name -> TopdefName.(path name, get_info name)
| External_scope name -> ScopeName.(path name, get_info name)
in
let modname =
( ModuleName.Map.find (List.hd (List.rev path)) ctxt.program_ctx.modules,
Expr.pos expr )
in
[], (EExternal { modname; name }, Expr.pos expr)
| ECatch _ | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | ERaise _ ->
raise (NotAnExpr { needs_a_local_decl = true })
| _ -> .
with NotAnExpr { needs_a_local_decl } ->
let tmp_var =
A.VarName.fresh
@ -542,8 +556,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
},
Expr.pos block_expr );
]
| _ -> (
Message.emit_debug "E: %a" Expr.format block_expr;
| ELit _ | EAppOp _ | EArray _ | EVar _ | EStruct _ | EInj _ | ETuple _
| ETupleAccess _ | EStructAccess _ | EExternal _ | EApp _ -> (
let e_stmts, new_e = translate_expr ctxt block_expr in
e_stmts
@
@ -566,27 +580,28 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
}),
Expr.pos block_expr );
])
| _ -> .
let rec translate_scope_body_expr
~(config : translation_config)
(scope_name : ScopeName.t)
(decl_ctx : decl_ctx)
(program_ctx : A.ctx)
(var_dict : ('m L.expr, A.VarName.t) Var.Map.t)
(func_dict : ('m L.expr, A.FuncName.t) Var.Map.t)
(scope_expr : 'm L.expr scope_body_expr) : A.block =
let ctx =
{
func_dict;
var_dict;
inside_definition_of = None;
context_name = Mark.remove (ScopeName.get_info scope_name);
config;
program_ctx;
}
in
match scope_expr with
| Last e ->
let block, new_e =
translate_expr
{
func_dict;
var_dict;
inside_definition_of = None;
context_name = Mark.remove (ScopeName.get_info scope_name);
config;
}
e
in
let block, new_e = translate_expr ctx e in
block @ [A.SReturn (Mark.remove new_e), Mark.get new_e]
| Cons (scope_let, next_bnd) ->
let let_var, scope_let_next = Bindlib.unbind next_bnd in
@ -597,24 +612,12 @@ let rec translate_scope_body_expr
(match scope_let.scope_let_kind with
| Assertion ->
translate_statements
{
func_dict;
var_dict;
inside_definition_of = Some let_var_id;
context_name = Mark.remove (ScopeName.get_info scope_name);
config;
}
{ ctx with inside_definition_of = Some let_var_id }
scope_let.scope_let_expr
| _ ->
let let_expr_stmts, new_let_expr =
translate_expr
{
func_dict;
var_dict;
inside_definition_of = Some let_var_id;
context_name = Mark.remove (ScopeName.get_info scope_name);
config;
}
{ ctx with inside_definition_of = Some let_var_id }
scope_let.scope_let_expr
in
let_expr_stmts
@ -633,11 +636,19 @@ let rec translate_scope_body_expr
},
scope_let.scope_let_pos );
])
@ translate_scope_body_expr ~config scope_name decl_ctx new_var_dict
@ translate_scope_body_expr ~config scope_name program_ctx new_var_dict
func_dict scope_let_next
let translate_program ~(config : translation_config) (p : 'm L.program) :
A.program =
let modules =
List.fold_left
(fun acc m ->
ModuleName.Map.add m (A.VarName.fresh (ModuleName.get_info m)) acc)
ModuleName.Map.empty
(Program.modules_to_list p.decl_ctx.ctx_modules)
in
let ctx = { A.decl_ctx = p.decl_ctx; A.modules } in
let (_, _, rev_items), () =
BoundList.fold_left
~f:(fun (func_dict, var_dict, rev_items) code_item var ->
@ -654,8 +665,8 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
Var.Map.add scope_input_var scope_input_var_id var_dict
in
let new_scope_body =
translate_scope_body_expr ~config name p.decl_ctx var_dict_local
func_dict scope_body_expr
translate_scope_body_expr ~config name ctx var_dict_local func_dict
scope_body_expr
in
let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in
( Var.Map.add var func_id func_dict,
@ -700,6 +711,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
inside_definition_of = None;
context_name = Mark.remove (TopdefName.get_info name);
config;
program_ctx = ctx;
}
in
translate_expr ctxt expr
@ -735,6 +747,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
inside_definition_of = None;
context_name = Mark.remove (TopdefName.get_info name);
config;
program_ctx = ctx;
}
in
translate_expr ctxt expr
@ -778,4 +791,4 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
~init:(Var.Map.empty, Var.Map.empty, [])
p.code_items
in
{ decl_ctx = p.decl_ctx; code_items = List.rev rev_items }
{ ctx; code_items = List.rev rev_items; module_name = p.module_name }

View File

@ -21,10 +21,10 @@ open Ast
let needs_parens (_e : expr) : bool = false
let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit =
Format.fprintf fmt "%a_%s" VarName.format v (string_of_int (VarName.hash v))
Format.fprintf fmt "%a_%d" VarName.format v (VarName.hash v)
let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
Format.fprintf fmt "%a_%s" FuncName.format v (string_of_int (FuncName.hash v))
Format.fprintf fmt "@{<green>%a_%d@}" FuncName.format v (FuncName.hash v)
let rec format_expr
(decl_ctx : decl_ctx)
@ -99,6 +99,9 @@ let rec format_expr
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args
| EExternal { modname; name } ->
Format.fprintf fmt "%a.%s" format_var_name (Mark.remove modname)
(Mark.remove name)
let rec format_statement
(decl_ctx : decl_ctx)
@ -226,15 +229,22 @@ let format_item decl_ctx ?debug ppf def =
Format.pp_close_box ppf ();
Format.pp_print_cut ppf ()
let format_program decl_ctx ?debug ppf prg =
let format_program ?debug ppf prg =
let decl_ctx =
(* TODO: this is redundant with From_dcalc.add_option_type (which is already
applied in avoid_exceptions mode) *)
{
decl_ctx with
prg.ctx.decl_ctx with
ctx_enums =
EnumName.Map.add Expr.option_enum Expr.option_enum_config
decl_ctx.ctx_enums;
prg.ctx.decl_ctx.ctx_enums;
}
in
Format.pp_open_vbox ppf 0;
ModuleName.Map.iter
(fun m var ->
Format.fprintf ppf "%a %a = %a@," Print.keyword "module" format_var_name
var ModuleName.format m)
prg.ctx.modules;
Format.pp_print_list (format_item decl_ctx ?debug) ppf prg.code_items;
Format.pp_close_box ppf ()

View File

@ -21,5 +21,4 @@ val format_item :
Ast.code_item ->
unit
val format_program :
Shared_ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.program -> unit
val format_program : ?debug:bool -> Format.formatter -> Ast.program -> unit

View File

@ -385,6 +385,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
args
| ETuple _ | ETupleAccess _ ->
Message.raise_internal_error "Tuple compilation to R unimplemented!"
| EExternal _ -> failwith "TODO"
let typ_is_array (ctx : decl_ctx) (typ : typ) =
match Mark.remove typ with
@ -604,26 +605,28 @@ let format_program
%a@,\
%a@,\
@]"
(format_ctx type_ordering) p.decl_ctx
(format_ctx type_ordering) p.ctx.decl_ctx
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt code_item ->
match code_item with
| SVar { var; expr; typ } ->
Format.fprintf fmt "@[<v 2>%a = %a;@]"
(format_typ p.decl_ctx (fun fmt -> format_var fmt var))
(format_typ p.ctx.decl_ctx (fun fmt -> format_var fmt var))
typ
(format_expression p.decl_ctx)
(format_expression p.ctx.decl_ctx)
expr
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
let { func_params; func_body; func_return_typ } = func in
Format.fprintf fmt "@[<v 2>%a(%a) {@,%a@]@,}"
(format_typ p.decl_ctx (fun fmt -> format_func_name fmt var))
(format_typ p.ctx.decl_ctx (fun fmt -> format_func_name fmt var))
func_return_typ
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (var, typ) ->
(format_typ p.decl_ctx (fun fmt ->
(format_typ p.ctx.decl_ctx (fun fmt ->
format_var fmt (Mark.remove var)))
fmt typ))
func_params (format_block p.decl_ctx) func_body))
func_params
(format_block p.ctx.decl_ctx)
func_body))
p.code_items

View File

@ -126,68 +126,13 @@ let avoid_keywords (s : string) : string =
then s ^ "_"
else s
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_camel_case
(String.to_ascii (Format.asprintf "%a" StructName.format v))))
module StringMap = String.Map
let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
=
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_ascii (Format.asprintf "%a" StructField.format v)))
module IntMap = Map.Make (struct
include Int
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_camel_case
(String.to_ascii (Format.asprintf "%a" EnumName.format v))))
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
unit =
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_ascii (Format.asprintf "%a" EnumConstructor.format v)))
let typ_needs_parens (e : typ) : bool =
match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
let format_typ = format_typ in
let format_typ_with_parens (fmt : Format.formatter) (t : typ) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t
in
match Mark.remove typ with
| TLit TUnit -> Format.fprintf fmt "Unit"
| TLit TMoney -> Format.fprintf fmt "Money"
| TLit TInt -> Format.fprintf fmt "Integer"
| TLit TRat -> Format.fprintf fmt "Decimal"
| TLit TDate -> Format.fprintf fmt "Date"
| TLit TDuration -> Format.fprintf fmt "Duration"
| TLit TBool -> Format.fprintf fmt "bool"
| TTuple ts ->
Format.fprintf fmt "Tuple[%a]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t))
ts
| TStruct s -> Format.fprintf fmt "%a" format_struct_name s
| TOption some_typ ->
(* We translate the option type with an overloading by Python's [None] *)
Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TDefault t -> format_typ fmt t
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e
| TArrow (t1, t2) ->
Format.fprintf fmt "Callable[[%a], %a]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
format_typ_with_parens)
t1 format_typ_with_parens t2
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "Any"
| TClosureEnv -> failwith "unimplemented!"
let format ppf i = Format.pp_print_int ppf i
end)
let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
s
@ -198,14 +143,6 @@ let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
|> avoid_keywords
|> Format.fprintf fmt "%s"
module StringMap = String.Map
module IntMap = Map.Make (struct
include Int
let format ppf i = Format.pp_print_int ppf i
end)
(** For each `VarName.t` defined by its string and then by its hash, we keep
track of which local integer id we've given it. This is used to keep
variable naming with low indices rather than one global counter for all
@ -244,6 +181,76 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
else if local_id = 0 then format_name_cleaned fmt v_str
else Format.fprintf fmt "%a_%d" format_name_cleaned v_str local_id
let format_path ctx fmt p =
match List.rev p with
| [] -> ()
| m :: _ ->
format_var fmt (ModuleName.Map.find m ctx.modules);
Format.pp_print_char fmt '.'
let format_struct_name ctx (fmt : Format.formatter) (v : StructName.t) : unit =
format_path ctx fmt (StructName.path v);
Format.pp_print_string fmt
(avoid_keywords
(String.to_camel_case
(String.to_ascii (Mark.remove (StructName.get_info v)))))
let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
=
Format.pp_print_string fmt
(avoid_keywords (String.to_ascii (StructField.to_string v)))
let format_enum_name ctx (fmt : Format.formatter) (v : EnumName.t) : unit =
format_path ctx fmt (EnumName.path v);
Format.pp_print_string fmt
(avoid_keywords
(String.to_camel_case
(String.to_ascii (Mark.remove (EnumName.get_info v)))))
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
unit =
Format.pp_print_string fmt
(avoid_keywords (String.to_ascii (EnumConstructor.to_string v)))
let typ_needs_parens (e : typ) : bool =
match Mark.remove e with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ ctx (fmt : Format.formatter) (typ : typ) : unit =
let format_typ = format_typ ctx in
let format_typ_with_parens (fmt : Format.formatter) (t : typ) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t
in
match Mark.remove typ with
| TLit TUnit -> Format.fprintf fmt "Unit"
| TLit TMoney -> Format.fprintf fmt "Money"
| TLit TInt -> Format.fprintf fmt "Integer"
| TLit TRat -> Format.fprintf fmt "Decimal"
| TLit TDate -> Format.fprintf fmt "Date"
| TLit TDuration -> Format.fprintf fmt "Duration"
| TLit TBool -> Format.fprintf fmt "bool"
| TTuple ts ->
Format.fprintf fmt "Tuple[%a]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t))
ts
| TStruct s -> Format.fprintf fmt "%a" (format_struct_name ctx) s
| TOption some_typ ->
(* We translate the option type with an overloading by Python's [None] *)
Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TDefault t -> format_typ fmt t
| TEnum e -> Format.fprintf fmt "%a" (format_enum_name ctx) e
| TArrow (t1, t2) ->
Format.fprintf fmt "Callable[[%a], %a]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
format_typ_with_parens)
t1 format_typ_with_parens t2
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "Any"
| TClosureEnv -> failwith "unimplemented!"
let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
let v_str = Mark.remove (FuncName.get_info v) in
format_name_cleaned fmt v_str
@ -270,13 +277,12 @@ let format_exception (fmt : Format.formatter) (exc : except Mark.pos) : unit =
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos)
let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
unit =
let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit =
match Mark.remove e with
| EVar v -> format_var fmt v
| EFunc f -> format_func_name fmt f
| EStruct { fields = es; name = s } ->
Format.fprintf fmt "%a(%a)" format_struct_name s
Format.fprintf fmt "%a(%a)" (format_struct_name ctx) s
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (struct_field, e) ->
@ -297,8 +303,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(* We translate the option type with an overloading by Python's [None] *)
format_expression ctx fmt e
| EInj { e1 = e; cons; name = enum_name; _ } ->
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name
format_enum_name enum_name format_enum_cons_name cons
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" (format_enum_name ctx) enum_name
(format_enum_name ctx) enum_name format_enum_cons_name cons
(format_expression ctx) e
| EArray es ->
Format.fprintf fmt "[%a]"
@ -402,11 +408,12 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
es
| ETupleAccess { e1; index } ->
Format.fprintf fmt "%a[%d]" (format_expression ctx) e1 index
| EExternal { modname; name } ->
Format.fprintf fmt "%a.%a" format_var (Mark.remove modname)
format_name_cleaned (Mark.remove name)
let rec format_statement
(ctx : decl_ctx)
(fmt : Format.formatter)
(s : stmt Mark.pos) : unit =
let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
=
match Mark.remove s with
| SInnerFuncDef { name; func = { func_params; func_body; _ } } ->
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_var
@ -414,8 +421,8 @@ let rec format_statement
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Mark.remove var) format_typ
typ))
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
(format_typ ctx) typ))
func_params (format_block ctx) func_body
| SLocalDecl _ ->
assert false (* We don't need to declare variables in Python *)
@ -458,7 +465,7 @@ let rec format_statement
(format_block ctx) case_none format_var case_some_var format_var tmp_var
(format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cons_map = EnumName.Map.find e_name ctx.ctx_enums in
let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in
let cases =
List.map2
(fun x (cons, _) -> x, cons)
@ -472,9 +479,9 @@ let rec format_statement
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ")
(fun fmt ({ case_block; payload_var_name; _ }, cons_name) ->
Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a"
format_var tmp_var format_enum_name e_name format_enum_cons_name
cons_name format_var payload_var_name format_var tmp_var
(format_block ctx) case_block))
format_var tmp_var (format_enum_name ctx) e_name
format_enum_cons_name cons_name format_var payload_var_name
format_var tmp_var (format_block ctx) case_block))
cases
| SReturn e1 ->
Format.fprintf fmt "@[<hov 4>return %a@]" (format_expression ctx)
@ -493,7 +500,7 @@ let rec format_statement
(Pos.get_law_info pos)
| SSpecialOp _ -> failwith "should not happen"
and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit =
and format_block ctx (fmt : Format.formatter) (b : block) : unit =
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(format_statement ctx) fmt
@ -504,7 +511,7 @@ and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit =
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter)
(ctx : decl_ctx) : unit =
ctx : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
let fields = StructField.Map.bindings struct_fields in
Format.fprintf fmt
@ -522,13 +529,13 @@ let format_ctx
\ return not (self == other)@\n\
@\n\
\ def __str__(self) -> str:@\n\
\ @[<hov 4>return \"%a(%a)\".format(%a)@]" format_struct_name
\ @[<hov 4>return \"%a(%a)\".format(%a)@]" (format_struct_name ctx)
struct_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "%a: %a" format_struct_field_name struct_field
format_typ struct_field_type))
(format_typ ctx) struct_field_type))
fields
(if StructField.Map.is_empty struct_fields then fun fmt _ ->
Format.fprintf fmt " pass"
@ -538,7 +545,7 @@ let format_ctx
(fun fmt (struct_field, _) ->
Format.fprintf fmt " self.%a = %a" format_struct_field_name
struct_field format_struct_field_name struct_field))
fields format_struct_name struct_name
fields (format_struct_name ctx) struct_name
(if not (StructField.Map.is_empty struct_fields) then
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ")
@ -546,7 +553,7 @@ let format_ctx
Format.fprintf fmt "self.%a == other.%a" format_struct_field_name
struct_field format_struct_field_name struct_field)
else fun fmt _ -> Format.fprintf fmt "True")
fields format_struct_name struct_name
fields (format_struct_name ctx) struct_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",")
(fun fmt (struct_field, _) ->
@ -585,7 +592,7 @@ let format_ctx
@\n\
\ def __str__(self) -> str:@\n\
\ @[<hov 4>return \"{}({})\".format(self.code, self.value)@]"
format_enum_name enum_name
(format_enum_name ctx) enum_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (i, enum_cons, _enum_cons_type) ->
@ -593,8 +600,8 @@ let format_ctx
(List.mapi
(fun i (x, y) -> i, x, y)
(EnumConstructor.Map.bindings enum_cons))
format_enum_name enum_name format_enum_name enum_name format_enum_name
enum_name
(format_enum_name ctx) enum_name (format_enum_name ctx) enum_name
(format_enum_name ctx) enum_name
in
let is_in_type_ordering s =
@ -611,50 +618,58 @@ let format_ctx
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs))
ctx.decl_ctx.ctx_structs))
in
List.iter
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
Format.fprintf fmt "%a@\n@\n" format_struct_decl
(s, StructName.Map.find s ctx.ctx_structs)
if StructName.path s = [] then
Format.fprintf fmt "%a@\n@\n" format_struct_decl
(s, StructName.Map.find s ctx.decl_ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e ->
Format.fprintf fmt "%a@\n@\n" format_enum_decl
(e, EnumName.Map.find e ctx.ctx_enums))
if EnumName.path e = [] then
Format.fprintf fmt "%a@\n@\n" format_enum_decl
(e, EnumName.Map.find e ctx.decl_ctx.ctx_enums))
(type_ordering @ scope_structs)
let format_code_item ctx fmt = function
| SVar { var; expr; typ = _ } ->
Format.fprintf fmt "@[<hv 4>%a = (@,%a@,@])@," format_var var
(format_expression ctx) expr
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
let { Ast.func_params; Ast.func_body; _ } = func in
Format.fprintf fmt "@[<hv 4>def %a(%a):@\n%a@]@," format_func_name var
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
(format_typ ctx) typ))
func_params (format_block ctx) func_body
let format_program
(fmt : Format.formatter)
(p : Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
(* We disable the style flag in order to enjoy formatting from the
pretty-printers of Dcalc and Lcalc but without the color terminal
markers. *)
Format.fprintf fmt
"@[<v># This file has been generated by the Catala compiler, do not edit!@,\
@,\
from catala.runtime import *@,\
from typing import Any, List, Callable, Tuple@,\
from enum import Enum@,\
@,\
@[<v>%a@]@,\
@,\
%a@]@?"
(format_ctx type_ordering) p.decl_ctx
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function
| SVar { var; expr; typ = _ } ->
Format.fprintf fmt "@[<hv 4>%a = (@,%a@,@])@," format_var var
(format_expression p.decl_ctx)
expr
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
let { Ast.func_params; Ast.func_body; _ } = func in
Format.fprintf fmt "@[<hv 4>def %a(%a):@\n%a@]@," format_func_name var
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
format_typ typ))
func_params (format_block p.decl_ctx) func_body))
p.code_items
Format.pp_open_vbox fmt 0;
let header =
[
"# This file has been generated by the Catala compiler, do not edit!";
"";
"from catala.runtime import *";
"from typing import Any, List, Callable, Tuple";
"from enum import Enum";
"";
]
in
Format.pp_print_list Format.pp_print_string fmt header;
ModuleName.Map.iter
(fun m v ->
Format.fprintf fmt "import %a as %a@," ModuleName.format m format_var v)
p.ctx.modules;
Format.pp_print_cut fmt ();
format_ctx type_ordering fmt p.ctx;
Format.pp_print_cut fmt ();
Format.pp_print_list (format_code_item p.ctx) fmt p.code_items;
Format.pp_print_flush fmt ()

View File

@ -373,6 +373,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
args
| ETuple _ | ETupleAccess _ ->
Message.raise_internal_error "Tuple compilation to R unimplemented!"
| EExternal _ -> failwith "TODO"
let rec format_statement
(ctx : decl_ctx)
@ -562,11 +563,11 @@ let format_program
@[<v>%a@]@,\
@,\
%a@]@?"
(format_ctx type_ordering) p.decl_ctx
(format_ctx type_ordering) p.ctx.decl_ctx
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function
| SVar { var; expr; typ = _ } ->
Format.fprintf fmt "@[<hv 2>%a <- (@,%a@,@])@," format_var var
(format_expression p.decl_ctx)
(format_expression p.ctx.decl_ctx)
expr
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
@ -578,5 +579,7 @@ let format_program
(fun fmt (var, typ) ->
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
format_typ typ))
func_params (format_block p.decl_ctx) func_body))
func_params
(format_block p.ctx.decl_ctx)
func_body))
p.code_items

View File

@ -1051,16 +1051,7 @@ let load_runtime_modules prg =
obj_file Format.pp_print_text
(Dynlink.error_message dl_err)
in
let modules_list_topo =
let rec aux acc (M mtree) =
ModuleName.Map.fold
(fun mname sub acc ->
if List.exists (ModuleName.equal mname) acc then acc
else mname :: aux acc sub)
mtree acc
in
List.rev (aux [] prg.decl_ctx.ctx_modules)
in
let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in
if modules_list_topo <> [] then
Message.emit_debug "Loading shared modules... %a"
(Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format)

View File

@ -85,3 +85,13 @@ let to_expr p main_scope =
let res = Scope.unfold p.decl_ctx p.code_items main_scope in
Expr.Box.assert_closed (Expr.Box.lift res);
res
let modules_to_list (mt : module_tree) =
let rec aux acc (M mtree) =
ModuleName.Map.fold
(fun mname sub acc ->
if List.exists (ModuleName.equal mname) acc then acc
else mname :: aux acc sub)
mtree acc
in
List.rev (aux [] mt)

View File

@ -52,3 +52,6 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed
function. *)
val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body
val modules_to_list : module_tree -> ModuleName.t list
(** Returns a list of used modules, in topological order *)

View File

@ -394,7 +394,6 @@ class S4In:
return "S4In()".format()
glob1 = (decimal_of_string("44.12"))
def glob3(x:Money):