Scalc change: switch only on variables

matches can bind, but switches cannot, so we can assume the switch argument
should always be bound to a name ; this allow the intermediate variable to be
better renamed.
This commit is contained in:
Louis Gesbert 2024-08-08 15:06:03 +02:00
parent 14a378a33d
commit e9abbf9bd8
8 changed files with 169 additions and 176 deletions

View File

@ -70,8 +70,8 @@ type stmt =
| SFatalError of Runtime.error
| SIfThenElse of { if_expr : expr; then_block : block; else_block : block }
| SSwitch of {
switch_expr : expr;
switch_expr_typ : typ;
switch_var : VarName.t;
switch_var_typ : typ;
enum_name : EnumName.t;
switch_cases : switch_case list;
}

View File

@ -421,7 +421,23 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R
binder_pos );
], ctxt.ren_ctx
| EMatch { e = e1; cases; name } ->
let typ = Expr.maybe_ty (Mark.get e1) in
let e1_stmts, new_e1, ren_ctx = translate_expr ctxt e1 in
let ctxt = { ctxt with ren_ctx } in
let e1_stmts, switch_var, ctxt =
match new_e1 with
| A.EVar v, _ -> e1_stmts, v, ctxt
| _ ->
let v, ctxt = fresh_var ctxt ctxt.context_name ~pos:(Expr.pos e1) in
RevBlock.append e1_stmts
( A.SLocalInit
{ name = v, Expr.pos e1;
expr = new_e1;
typ },
Expr.pos e1 ),
v,
ctxt
in
let new_cases =
EnumConstructor.Map.fold
(fun _ arg new_args ->
@ -443,20 +459,19 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block * R
| _ -> assert false)
cases []
in
let new_args = List.rev new_cases in
RevBlock.rebuild e1_stmts
~tail:
[
( A.SSwitch
{
switch_expr = new_e1;
switch_expr_typ = Expr.maybe_ty (Mark.get e1);
switch_var;
switch_var_typ = typ;
enum_name = name;
switch_cases = new_args;
switch_cases = List.rev new_cases;
},
Expr.pos block_expr );
],
ren_ctx
ctxt.ren_ctx
| EIfThenElse { cond; etrue; efalse } ->
let cond_stmts, s_cond, ren_ctx = translate_expr ctxt cond in
let s_e_true, _ = translate_statements ctxt etrue in

View File

@ -169,12 +169,12 @@ let rec format_statement
Format.fprintf fmt "@[<hov 2>%a %a@]" Print.keyword "assert"
(format_expr decl_ctx ~debug)
(naked_expr, Mark.get stmt)
| SSwitch { switch_expr = e_switch; enum_name = enum; switch_cases = arms; _ }
| SSwitch { switch_var = v_switch; enum_name = enum; switch_cases = arms; _ }
->
let cons = EnumName.Map.find enum decl_ctx.ctx_enums in
Format.fprintf fmt "@[<v 0>%a @[<hov 2>%a@]%a@,@]%a" Print.keyword "switch"
(format_expr decl_ctx ~debug)
e_switch Print.punctuation ":"
format_var_name v_switch
Print.punctuation ":"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt ((case, _), switch_case_data) ->

View File

@ -389,18 +389,15 @@ let rec format_statement
Format.fprintf fmt
"@[<hv 2>@[<hov 2>if (%a) {@]@,%a@,@;<1 -2>} else {@,%a@,@;<1 -2>}@]"
(format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
| SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ } ->
let cases =
List.map2
(fun x (cons, _) -> x, cons)
cases
(EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums))
in
let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "@[<hov 2>%a %a = %a;@]@," EnumName.format e_name
VarName.format tmp_var (format_expression ctx) e1;
Format.pp_open_vbox fmt 2;
Format.fprintf fmt "@[<hov 4>switch (%a.code) {@]@," VarName.format tmp_var;
Format.fprintf fmt "@[<hov 4>switch (%a.code) {@]@," VarName.format switch_var;
Format.pp_print_list
(fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) ->
Format.fprintf fmt "@[<hv 2>case %a_%a:@ " EnumName.format e_name
@ -408,7 +405,7 @@ let rec format_statement
if not (Type.equal payload_var_typ (TLit TUnit, Pos.no_pos)) then
Format.fprintf fmt "%a = %a.payload.%a;@ "
(format_typ ctx (fun fmt -> VarName.format fmt payload_var_name))
payload_var_typ VarName.format tmp_var EnumConstructor.format
payload_var_typ VarName.format switch_var EnumConstructor.format
cons_name;
Format.fprintf fmt "%a@ break;@]" (format_block ctx) case_block)
fmt cases;

View File

@ -370,7 +370,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
(format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch
{
switch_expr = e1;
switch_var;
enum_name = e_name;
switch_cases =
[
@ -381,14 +381,11 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
}
when EnumName.equal e_name Expr.option_enum ->
(* We translate the option type with an overloading by Python's [None] *)
let tmp_var = VarName.fresh ("perhaps_none_arg", Pos.no_pos) in
Format.fprintf fmt "@[<hv 4>%a = %a@]@," VarName.format tmp_var
(format_expression ctx) e1;
Format.fprintf fmt "@[<v 4>if %a is None:@ %a@]@," VarName.format tmp_var
Format.fprintf fmt "@[<v 4>if %a is None:@ %a@]@," VarName.format switch_var
(format_block ctx) case_none;
Format.fprintf fmt "@[<v 4>else:@ %a = %a@,%a@]" VarName.format
case_some_var VarName.format tmp_var (format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
case_some_var VarName.format switch_var (format_block ctx) case_some
| SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ } ->
let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in
let cases =
List.map2
@ -396,16 +393,14 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
cases
(EnumConstructor.Map.bindings cons_map)
in
let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" VarName.format tmp_var
(format_expression ctx) e1
Format.fprintf fmt "@[<hov 4>if %a@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ")
(fun fmt ({ case_block; payload_var_name; _ }, cons_name) ->
Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a"
VarName.format tmp_var EnumName.format e_name
VarName.format switch_var EnumName.format e_name
EnumConstructor.format cons_name VarName.format payload_var_name
VarName.format tmp_var (format_block ctx) case_block))
VarName.format switch_var (format_block ctx) case_block))
cases
| SReturn e1 ->
Format.fprintf fmt "@[<hov 4>return %a@]" (format_expression ctx)

View File

@ -110,8 +110,8 @@ Baz baz(Baz_in baz_in) {
Array_1 a4;
a4.content_field = catala_malloc(sizeof(Array_1));
a4.content_field[0] = code(env, NULL);
Option_1 match_arg = catala_handle_exceptions(a4);
switch (match_arg.code) {
Option_1 a3 = catala_handle_exceptions(a4);
switch (a3.code) {
case Option_1_None_1:
if (1 /* TRUE */) {
Bar a3;
@ -120,8 +120,8 @@ Baz baz(Baz_in baz_in) {
Array_1 a7;
a7.content_field = catala_malloc(sizeof(Array_1));
Option_1 match_arg = catala_handle_exceptions(a7);
switch (match_arg.code) {
Option_1 a8 = catala_handle_exceptions(a7);
switch (a8.code) {
case Option_1_None_1:
if (1 /* TRUE */) {
Bar a6 = {Bar_No, {No: NULL}};
@ -133,15 +133,15 @@ Baz baz(Baz_in baz_in) {
}
break;
case Option_1_Some_1:
Bar x1 = match_arg.payload.Some_1;
Bar x1 = a8.payload.Some_1;
option_1 a6 = {Option_1_Some_1, {Some_1: x1}};
break;
}
Array_1 a5;
a5.content_field = catala_malloc(sizeof(Array_1));
a5.content_field[0] = a6;
Option_1 match_arg = catala_handle_exceptions(a5);
switch (match_arg.code) {
Option_1 a9 = catala_handle_exceptions(a5);
switch (a9.code) {
case Option_1_None_1:
if (0 /* FALSE */) {
option_1 a4 = {Option_1_None_1, {None_1: NULL}};
@ -152,20 +152,16 @@ Baz baz(Baz_in baz_in) {
}
break;
case Option_1_Some_1:
Bar x1 = match_arg.payload.Some_1;
Bar x1 = a9.payload.Some_1;
option_1 a4 = {Option_1_Some_1, {Some_1: x1}};
break;
}
Option_1 match_arg = a4;
switch (match_arg.code) {
switch (a4.code) {
case Option_1_None_1:
catala_raise_fatal_error (catala_no_value,
"tests/backends/simple.catala_en", 11, 11, 11, 12);
break;
case Option_1_Some_1:
Bar arg = match_arg.payload.Some_1;
a3 = arg;
break;
case Option_1_Some_1: Bar arg = a4.payload.Some_1; a3 = arg; break;
}
option_1 a3 = {Option_1_Some_1, {Some_1: a3}};
@ -175,20 +171,16 @@ Baz baz(Baz_in baz_in) {
}
break;
case Option_1_Some_1:
Bar x1 = match_arg.payload.Some_1;
Bar x1 = a3.payload.Some_1;
option_1 a3 = {Option_1_Some_1, {Some_1: x1}};
break;
}
Option_1 match_arg = a3;
switch (match_arg.code) {
switch (a3.code) {
case Option_1_None_1:
catala_raise_fatal_error (catala_no_value,
"tests/backends/simple.catala_en", 11, 11, 11, 12);
break;
case Option_1_Some_1:
Bar arg = match_arg.payload.Some_1;
a2 = arg;
break;
case Option_1_Some_1: Bar arg = a3.payload.Some_1; a2 = arg; break;
}
Bar a1;
a1 = a2;

View File

@ -91,82 +91,80 @@ class BIn:
def some_name(some_name_in:SomeNameIn):
i = some_name_in.i_in
perhaps_none_arg = handle_exceptions([], [])
if perhaps_none_arg is None:
o4 = handle_exceptions([], [])
if o4 is None:
if True:
o3 = (i + integer_of_string("1"))
else:
o3 = None
else:
x = perhaps_none_arg
x = o4
o3 = x
perhaps_none_arg = handle_exceptions(
o5 = handle_exceptions(
[SourcePosition(
filename="tests/backends/python_name_clash.catala_en",
start_line=10, start_column=23,
end_line=10, end_column=28, law_headings=[])],
start_line=10, start_column=23, end_line=10, end_column=28,
law_headings=[])],
[o3]
)
if perhaps_none_arg is None:
if o5 is None:
if False:
o2 = None
else:
o2 = None
else:
x = perhaps_none_arg
x = o5
o2 = x
perhaps_none_arg = o2
if perhaps_none_arg is None:
if o2 is None:
raise NoValue(SourcePosition(
filename="tests/backends/python_name_clash.catala_en",
start_line=7, start_column=10,
end_line=7, end_column=11, law_headings=[]))
else:
arg = perhaps_none_arg
arg = o2
o1 = arg
o = o1
return SomeName(o = o)
def b(b_in:BIn):
perhaps_none_arg = handle_exceptions([], [])
if perhaps_none_arg is None:
result4 = handle_exceptions([], [])
if result4 is None:
if True:
result3 = integer_of_string("1")
else:
result3 = None
else:
x = perhaps_none_arg
x = result4
result3 = x
perhaps_none_arg = handle_exceptions(
result5 = handle_exceptions(
[SourcePosition(
filename="tests/backends/python_name_clash.catala_en",
start_line=16, start_column=33,
end_line=16, end_column=34, law_headings=[])],
[result3]
)
if perhaps_none_arg is None:
if result5 is None:
if False:
result2 = None
else:
result2 = None
else:
x = perhaps_none_arg
x = result5
result2 = x
perhaps_none_arg = result2
if perhaps_none_arg is None:
if result2 is None:
raise NoValue(SourcePosition(
filename="tests/backends/python_name_clash.catala_en",
start_line=16, start_column=14,
end_line=16, end_column=25, law_headings=[]))
else:
arg = perhaps_none_arg
arg = result2
result1 = arg
result = some_name(SomeNameIn(i_in = result1))
result4 = SomeName(o = result.o)
result6 = SomeName(o = result.o)
if True:
some_name2 = result4
some_name2 = result6
else:
some_name2 = result4
some_name2 = result6
some_name1 = some_name2
return B(some_name = some_name1)
```

View File

@ -128,7 +128,8 @@ let S2 (S2_in: S2_in) =
decl a1 : decimal;
decl a2 : option decimal;
decl a3 : option decimal;
switch handle_exceptions []:
a4 : option decimal = handle_exceptions [];
switch a4:
| ENone _ →
if true:
a3 = ESome glob3 ¤44.00 + 100.
@ -136,7 +137,8 @@ let S2 (S2_in: S2_in) =
a3 = ENone ()
| ESome x →
a3 = ESome x;
switch handle_exceptions [a3]:
a5 : option decimal = handle_exceptions [a3];
switch a5:
| ENone _ →
if false:
a2 = ENone ()
@ -157,7 +159,8 @@ let S3 (S3_in: S3_in) =
decl a1 : decimal;
decl a2 : option decimal;
decl a3 : option decimal;
switch handle_exceptions []:
a4 : option decimal = handle_exceptions [];
switch a4:
| ENone _ →
if true:
a3 = ESome 50. + glob4 ¤44.00 55.
@ -165,7 +168,8 @@ let S3 (S3_in: S3_in) =
a3 = ENone ()
| ESome x →
a3 = ESome x;
switch handle_exceptions [a3]:
a5 : option decimal = handle_exceptions [a3];
switch a5:
| ENone _ →
if false:
a2 = ENone ()
@ -186,7 +190,8 @@ let S4 (S4_in: S4_in) =
decl a1 : decimal;
decl a2 : option decimal;
decl a3 : option decimal;
switch handle_exceptions []:
a4 : option decimal = handle_exceptions [];
switch a4:
| ENone _ →
if true:
a3 = ESome glob5 + 1.
@ -194,7 +199,8 @@ let S4 (S4_in: S4_in) =
a3 = ENone ()
| ESome x →
a3 = ESome x;
switch handle_exceptions [a3]:
a5 : option decimal = handle_exceptions [a3];
switch a5:
| ENone _ →
if false:
a2 = ENone ()
@ -215,7 +221,8 @@ let S (S_in: S_in) =
decl a1 : decimal;
decl a2 : option decimal;
decl a3 : option decimal;
switch handle_exceptions []:
a4 : option decimal = handle_exceptions [];
switch a4:
| ENone _ →
if true:
a3 = ESome glob1 * glob1
@ -223,7 +230,8 @@ let S (S_in: S_in) =
a3 = ENone ()
| ESome x →
a3 = ESome x;
switch handle_exceptions [a3]:
a5 : option decimal = handle_exceptions [a3];
switch a5:
| ENone _ →
if false:
a2 = ENone ()
@ -241,7 +249,8 @@ let S (S_in: S_in) =
decl b1 : A {y: bool; z: decimal};
decl b2 : option A {y: bool; z: decimal};
decl b3 : option A {y: bool; z: decimal};
switch handle_exceptions []:
b4 : option A {y: bool; z: decimal} = handle_exceptions [];
switch b4:
| ENone _ →
if true:
b3 = ESome glob2
@ -249,7 +258,8 @@ let S (S_in: S_in) =
b3 = ENone ()
| ESome x →
b3 = ESome x;
switch handle_exceptions [b3]:
b5 : option A {y: bool; z: decimal} = handle_exceptions [b3];
switch b5:
| ENone _ →
if false:
b2 = ENone ()
@ -446,48 +456,46 @@ glob6 = (
)
def s2(s2_in:S2In):
perhaps_none_arg = handle_exceptions([], [])
if perhaps_none_arg is None:
a4 = handle_exceptions([], [])
if a4 is None:
if True:
a3 = (glob3(money_of_cents_string("4400")) +
decimal_of_string("100."))
else:
a3 = None
else:
x = perhaps_none_arg
x = a4
a3 = x
perhaps_none_arg = handle_exceptions(
a5 = handle_exceptions(
[SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=53, start_column=24,
end_line=53, end_column=43,
start_line=53, start_column=24, end_line=53, end_column=43,
law_headings=["Test toplevel function defs"])],
[a3]
)
if perhaps_none_arg is None:
if a5 is None:
if False:
a2 = None
else:
a2 = None
else:
x = perhaps_none_arg
x = a5
a2 = x
perhaps_none_arg = a2
if perhaps_none_arg is None:
if a2 is None:
raise NoValue(SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=50, start_column=10,
end_line=50, end_column=11,
law_headings=["Test toplevel function defs"]))
else:
arg = perhaps_none_arg
arg = a2
a1 = arg
a = a1
return S2(a = a)
def s3(s3_in:S3In):
perhaps_none_arg = handle_exceptions([], [])
if perhaps_none_arg is None:
a4 = handle_exceptions([], [])
if a4 is None:
if True:
a3 = (decimal_of_string("50.") +
glob4(money_of_cents_string("4400"),
@ -495,151 +503,139 @@ def s3(s3_in:S3In):
else:
a3 = None
else:
x = perhaps_none_arg
x = a4
a3 = x
perhaps_none_arg = handle_exceptions(
a5 = handle_exceptions(
[SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=74, start_column=24,
end_line=74, end_column=47,
law_headings=["Test function def with two args"]
)],
start_line=74, start_column=24, end_line=74, end_column=47,
law_headings=["Test function def with two args"])],
[a3]
)
if perhaps_none_arg is None:
if a5 is None:
if False:
a2 = None
else:
a2 = None
else:
x = perhaps_none_arg
x = a5
a2 = x
perhaps_none_arg = a2
if perhaps_none_arg is None:
if a2 is None:
raise NoValue(SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=71, start_column=10,
end_line=71, end_column=11,
law_headings=["Test function def with two args"]))
else:
arg = perhaps_none_arg
arg = a2
a1 = arg
a = a1
return S3(a = a)
def s4(s4_in:S4In):
perhaps_none_arg = handle_exceptions([], [])
if perhaps_none_arg is None:
a4 = handle_exceptions([], [])
if a4 is None:
if True:
a3 = (glob5 + decimal_of_string("1."))
else:
a3 = None
else:
x = perhaps_none_arg
x = a4
a3 = x
perhaps_none_arg = handle_exceptions(
a5 = handle_exceptions(
[SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=98, start_column=24,
end_line=98, end_column=34,
law_headings=["Test inline defs in toplevel defs"]
)],
start_line=98, start_column=24, end_line=98, end_column=34,
law_headings=["Test inline defs in toplevel defs"])],
[a3]
)
if perhaps_none_arg is None:
if a5 is None:
if False:
a2 = None
else:
a2 = None
else:
x = perhaps_none_arg
x = a5
a2 = x
perhaps_none_arg = a2
if perhaps_none_arg is None:
if a2 is None:
raise NoValue(SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=95, start_column=10,
end_line=95, end_column=11,
law_headings=["Test inline defs in toplevel defs"]))
else:
arg = perhaps_none_arg
arg = a2
a1 = arg
a = a1
return S4(a = a)
def s5(s_in:SIn):
perhaps_none_arg = handle_exceptions([], [])
if perhaps_none_arg is None:
a4 = handle_exceptions([], [])
if a4 is None:
if True:
a3 = (glob1 * glob1)
else:
a3 = None
else:
x = perhaps_none_arg
x = a4
a3 = x
perhaps_none_arg = handle_exceptions(
a5 = handle_exceptions(
[SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=18, start_column=24,
end_line=18, end_column=37,
law_headings=["Test basic toplevel values defs"]
)],
start_line=18, start_column=24, end_line=18, end_column=37,
law_headings=["Test basic toplevel values defs"])],
[a3]
)
if perhaps_none_arg is None:
if a5 is None:
if False:
a2 = None
else:
a2 = None
else:
x = perhaps_none_arg
x = a5
a2 = x
perhaps_none_arg = a2
if perhaps_none_arg is None:
if a2 is None:
raise NoValue(SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=7, start_column=10,
end_line=7, end_column=11,
law_headings=["Test basic toplevel values defs"]))
else:
arg = perhaps_none_arg
arg = a2
a1 = arg
a = a1
perhaps_none_arg = handle_exceptions([], [])
if perhaps_none_arg is None:
b4 = handle_exceptions([], [])
if b4 is None:
if True:
b3 = glob6
else:
b3 = None
else:
x = perhaps_none_arg
x = b4
b3 = x
perhaps_none_arg = handle_exceptions(
b5 = handle_exceptions(
[SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=19, start_column=24,
end_line=19, end_column=29,
law_headings=["Test basic toplevel values defs"]
)],
start_line=19, start_column=24, end_line=19, end_column=29,
law_headings=["Test basic toplevel values defs"])],
[b3]
)
if perhaps_none_arg is None:
if b5 is None:
if False:
b2 = None
else:
b2 = None
else:
x = perhaps_none_arg
x = b5
b2 = x
perhaps_none_arg = b2
if perhaps_none_arg is None:
if b2 is None:
raise NoValue(SourcePosition(
filename="tests/name_resolution/good/toplevel_defs.catala_en",
start_line=8, start_column=10,
end_line=8, end_column=11,
law_headings=["Test basic toplevel values defs"]))
else:
arg = perhaps_none_arg
arg = b2
b1 = arg
b = b1
return S(a = a, b = b)