Switch to use Python enums as a tag for tagged unions

This commit is contained in:
Denis Merigoux 2021-06-24 21:55:20 +02:00
parent fbf60b89bf
commit 95b34937a6
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
2 changed files with 527 additions and 297 deletions

View File

@ -204,8 +204,8 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
| EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field
| EInj (e, cons, enum_name) ->
Format.fprintf fmt "%a_%a(%a)" format_enum_name enum_name format_enum_cons_name cons
(format_expression ctx) e
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name format_enum_name enum_name
format_enum_cons_name cons (format_expression ctx) e
| EArray es ->
Format.fprintf fmt "[%a]"
(Format.pp_print_list
@ -217,7 +217,7 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) (format_expression ctx) arg1
(format_expression ctx) arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "(%a %a %a)" (format_expression ctx) arg1 format_binop (op, Pos.no_pos)
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop (op, Pos.no_pos)
(format_expression ctx) arg2
| EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ])
when !Cli.trace_flag ->
@ -281,7 +281,7 @@ let rec format_statement (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ")
(fun fmt (case_block, payload_var, cons_name) ->
Format.fprintf fmt "%a is %a_%a:@\n%a = %a.value@\n%a" format_var tmp_var
Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" format_var tmp_var
format_enum_name e_name format_enum_cons_name cons_name format_var payload_var
format_var tmp_var (format_block ctx) case_block))
cases
@ -318,21 +318,24 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
struct_fields
in
let format_enum_decl fmt (enum_name, enum_cons) =
if List.length enum_cons = 0 then
Format.fprintf fmt
"class %a(Unit):@\n\tdef __init__(self, value: Any) -> None:@\n\t\tself.value = value@\n@\n"
format_enum_name enum_name
if List.length enum_cons = 0 then failwith "no constructors in the enum"
else
Format.fprintf fmt
"class %a:@\n\tdef __init__(self, value: Any) -> None:@\n\t\tself.value = value@\n@\n%a"
format_enum_name enum_name
"@[<hov 4>class %a_Code(Enum):@\n\
%a@]@\n\
@\n\
class %a:@\n\
\tdef __init__(self, code: %a_Code, value: Any) -> None:@\n\
\t\tself.code = code@\n\
\t\tself.value = value" format_enum_name enum_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
(fun _fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "class %a_%a(%a):@\n\tpass" format_enum_name enum_name
format_enum_cons_name enum_cons format_enum_name enum_name))
enum_cons
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (i, enum_cons, enum_cons_type) ->
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
(List.mapi (fun i (x, y) -> (i, x, y)) enum_cons)
format_enum_name enum_name format_enum_name enum_name
in
let is_in_type_ordering s =
List.exists
(fun struct_or_enum ->
@ -366,6 +369,7 @@ let format_program (fmt : Format.formatter) (p : Ast.program)
@\n\
from .catala_runtime import *@\n\
from typing import Any, List, Callable, Tuple\n\
from enum import Enum\n\
@\n\
%a@\n\
@\n\

File diff suppressed because it is too large Load Diff