Switch to use Python enums as a tag for tagged unions

This commit is contained in:
Denis Merigoux 2021-06-24 21:55:20 +02:00
parent fbf60b89bf
commit 95b34937a6
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
2 changed files with 527 additions and 297 deletions

View File

@ -204,8 +204,8 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
| EStructFieldAccess (e1, field, _) -> | EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field
| EInj (e, cons, enum_name) -> | EInj (e, cons, enum_name) ->
Format.fprintf fmt "%a_%a(%a)" format_enum_name enum_name format_enum_cons_name cons Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name format_enum_name enum_name
(format_expression ctx) e format_enum_cons_name cons (format_expression ctx) e
| EArray es -> | EArray es ->
Format.fprintf fmt "[%a]" Format.fprintf fmt "[%a]"
(Format.pp_print_list (Format.pp_print_list
@ -217,7 +217,7 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) (format_expression ctx) arg1 Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) (format_expression ctx) arg1
(format_expression ctx) arg2 (format_expression ctx) arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "(%a %a %a)" (format_expression ctx) arg1 format_binop (op, Pos.no_pos) Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop (op, Pos.no_pos)
(format_expression ctx) arg2 (format_expression ctx) arg2
| EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ]) | EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ])
when !Cli.trace_flag -> when !Cli.trace_flag ->
@ -281,7 +281,7 @@ let rec format_statement (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ")
(fun fmt (case_block, payload_var, cons_name) -> (fun fmt (case_block, payload_var, cons_name) ->
Format.fprintf fmt "%a is %a_%a:@\n%a = %a.value@\n%a" format_var tmp_var Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" format_var tmp_var
format_enum_name e_name format_enum_cons_name cons_name format_var payload_var format_enum_name e_name format_enum_cons_name cons_name format_var payload_var
format_var tmp_var (format_block ctx) case_block)) format_var tmp_var (format_block ctx) case_block))
cases cases
@ -318,21 +318,24 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
struct_fields struct_fields
in in
let format_enum_decl fmt (enum_name, enum_cons) = let format_enum_decl fmt (enum_name, enum_cons) =
if List.length enum_cons = 0 then if List.length enum_cons = 0 then failwith "no constructors in the enum"
Format.fprintf fmt
"class %a(Unit):@\n\tdef __init__(self, value: Any) -> None:@\n\t\tself.value = value@\n@\n"
format_enum_name enum_name
else else
Format.fprintf fmt Format.fprintf fmt
"class %a:@\n\tdef __init__(self, value: Any) -> None:@\n\t\tself.value = value@\n@\n%a" "@[<hov 4>class %a_Code(Enum):@\n\
format_enum_name enum_name %a@]@\n\
@\n\
class %a:@\n\
\tdef __init__(self, code: %a_Code, value: Any) -> None:@\n\
\t\tself.code = code@\n\
\t\tself.value = value" format_enum_name enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (enum_cons, enum_cons_type) -> (fun _fmt (i, enum_cons, enum_cons_type) ->
Format.fprintf fmt "class %a_%a(%a):@\n\tpass" format_enum_name enum_name Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
format_enum_cons_name enum_cons format_enum_name enum_name)) (List.mapi (fun i (x, y) -> (i, x, y)) enum_cons)
enum_cons format_enum_name enum_name format_enum_name enum_name
in in
let is_in_type_ordering s = let is_in_type_ordering s =
List.exists List.exists
(fun struct_or_enum -> (fun struct_or_enum ->
@ -366,6 +369,7 @@ let format_program (fmt : Format.formatter) (p : Ast.program)
@\n\ @\n\
from .catala_runtime import *@\n\ from .catala_runtime import *@\n\
from typing import Any, List, Callable, Tuple\n\ from typing import Any, List, Callable, Tuple\n\
from enum import Enum\n\
@\n\ @\n\
%a@\n\ %a@\n\
@\n\ @\n\

File diff suppressed because it is too large Load Diff