Print algebraic data types as R classes

This commit is contained in:
Denis Merigoux 2023-08-04 18:07:49 +02:00
parent fd89562c8b
commit 84d37d8720
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3

View File

@ -152,23 +152,28 @@ let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
let format_typ = format_typ in
match Mark.remove typ with
| TLit TUnit -> Format.fprintf fmt "catala_unit"
| TLit TMoney -> Format.fprintf fmt "ctala_money"
| TLit TInt -> Format.fprintf fmt "catala_integer"
| TLit TRat -> Format.fprintf fmt "catala_decimal"
| TLit TDate -> Format.fprintf fmt "catala_date"
| TLit TDuration -> Format.fprintf fmt "catala_duration"
| TLit TBool -> Format.fprintf fmt "logical"
| TTuple _ts -> Format.fprintf fmt "list"
| TStruct s -> Format.fprintf fmt "catala_class_%a" format_struct_name s
| TLit TUnit -> Format.fprintf fmt "\"catala_unit\""
| TLit TMoney -> Format.fprintf fmt "\"ctala_money\""
| TLit TInt -> Format.fprintf fmt "\"catala_integer\""
| TLit TRat -> Format.fprintf fmt "\"catala_decimal\""
| TLit TDate -> Format.fprintf fmt "\"catala_date\""
| TLit TDuration -> Format.fprintf fmt "\"catala_duration\""
| TLit TBool -> Format.fprintf fmt "\"logical\""
| TTuple ts ->
Format.fprintf fmt "\"list\"@ # tuple(%a)@\n"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@;")
format_typ)
ts
| TStruct s -> Format.fprintf fmt "\"catala_class_%a\"" format_struct_name s
| TOption some_typ ->
(* We loose track of optional value as they're crammed into NULL *)
format_typ fmt some_typ
| TEnum e -> Format.fprintf fmt "catala_enum_%a" format_enum_name e
| TEnum e -> Format.fprintf fmt "\"catala_enum_%a\"" format_enum_name e
| TArrow (_t1, _t2) ->
Message.raise_internal_error "This type should not be printed out in R: %a"
Print.typ_debug typ
| TArray _t1 -> Format.fprintf fmt "vector"
| TArray t1 -> Format.fprintf fmt "\"list\" # array(%a)@\n" format_typ t1
| TAny ->
Message.raise_internal_error "This type should not be printed out in R: %a"
Print.typ_debug typ
@ -460,53 +465,56 @@ let format_ctx
(ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
let fields = StructField.Map.bindings struct_fields in
let non_func_fields =
List.filter
(fun (_, t) -> match Mark.remove t with TArrow _ -> false | _ -> true)
fields
in
let func_fields =
List.filter
(fun (_, t) -> match Mark.remove t with TArrow _ -> true | _ -> false)
fields
in
Format.fprintf fmt
"%a <- setClass(@\n \"%a\",@\n representation = list@[<hov 2>(%a)@]@\n)"
"@[<hov 2>catala_struct_%a <- setRefClass(@,\
\"catala_struct_%a\",@;\
fields = list@[<hov 2>(%a)@],@,\
methods = list@[<hov 2>(%a)@]@,\
)@]"
format_struct_name struct_name format_struct_name struct_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@;")
(fun fmt (struct_field, typ) ->
Format.fprintf fmt "%a = \"%a\"" format_struct_field_name
struct_field format_typ typ))
fields
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
format_typ typ))
non_func_fields
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@;")
(fun fmt (struct_field, typ) ->
Format.fprintf fmt
"%a = # %a@\nfunction () @[<hov 2>{@;stop(\"uninitialized\")@;}@]"
format_struct_field_name struct_field Print.typ_debug typ))
func_fields
in
let format_enum_decl fmt (enum_name, enum_cons) =
if EnumConstructor.Map.is_empty enum_cons then
failwith "no constructors in the enum"
else
Format.fprintf fmt
"@[<hov 4>class %a_Code(Enum):@\n\
%a@]@\n\
@\n\
class %a:@\n\
\ def __init__(self, code: %a_Code, value: Any) -> None:@\n\
\ self.code = code@\n\
\ self.value = value@\n\
@\n\
@\n\
\ def __eq__(self, other: object) -> bool:@\n\
\ if isinstance(other, %a):@\n\
\ return self.code == other.code and self.value == \
other.value@\n\
\ else:@\n\
\ return False@\n\
@\n\
@\n\
\ def __ne__(self, other: object) -> bool:@\n\
\ return not (self == other)@\n\
@\n\
\ def __str__(self) -> str:@\n\
\ @[<hov 4>return \"{}({})\".format(self.code, self.value)@]"
format_enum_name enum_name
"# Enum cases: %a@\n\
@[<hov 2>catala_enum_%a <- setRefClass(@,\
\"catala_enum_%a\",@;\
fields = list@[<hov 2>(code =@;\
\"character\",@;\
value =@;\
\"ANY\")@])@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (i, enum_cons, _enum_cons_type) ->
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
(List.mapi
(fun i (x, y) -> i, x, y)
(EnumConstructor.Map.bindings enum_cons))
format_enum_name enum_name format_enum_name enum_name format_enum_name
enum_name
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "\"%a\" (%a)" format_enum_cons_name enum_cons
format_typ enum_cons_type))
(EnumConstructor.Map.bindings enum_cons)
format_enum_name enum_name format_enum_name enum_name
in
let is_in_type_ordering s =
@ -547,6 +555,7 @@ let format_program
"@[<v># This file has been generated by the Catala compiler, do not edit!@,\
@,\
source(\"runtimes/r/runtime.R\")@,\
@,\
@[<v>%a@]@,\
@,\
%a@]@?"