Almost there with structs

This commit is contained in:
Denis Merigoux 2023-12-11 17:08:32 +01:00
parent aca1d0e712
commit 37ab4187bd
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
12 changed files with 146 additions and 47 deletions

View File

@ -398,11 +398,30 @@ module Flags = struct
let keep_special_ops =
value
& flag
& info ["keep_special_ops"]
& info ["keep-special-ops"]
~doc:
"During the Lcalc->Scalc translation, uses special AST nodes for \
higher-order operators rather than nested closures (useful for C)."
let dead_value_assignment =
value
& flag
& info ["dead-value-assignment"]
~doc:
"During the Lcalc->Scalc translation, insert dummy variable \
assignments before raising terminal exception to please gradual \
typing tools that check exhaustivity of variable definitions in \
every code branch."
let no_struct_literals =
value
& flag
& info ["no-struct-literals"]
~doc:
"During the Lcalc->Scalc translation, insert temporary variable \
assignments to hold the result of structure initializations \
(matches the absence of struct literals of C89)."
let closure_conversion =
value
& flag

View File

@ -130,6 +130,8 @@ module Flags : sig
val avoid_exceptions : bool Term.t
val closure_conversion : bool Term.t
val keep_special_ops : bool Term.t
val dead_value_assignment : bool Term.t
val no_struct_literals : bool Term.t
val include_dirs : raw_file list Term.t
val disable_counterexamples : bool Term.t
end

View File

@ -285,7 +285,9 @@ module Passes = struct
~check_invariants
~avoid_exceptions
~closure_conversion
~keep_special_ops :
~keep_special_ops
~dead_value_assignment
~no_struct_literals :
Scalc.Ast.program * Scopelang.Dependency.TVertex.t list =
let prg, type_ordering =
lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.untyped
@ -294,7 +296,10 @@ module Passes = struct
Message.emit_debug "Retyping lambda calculus...";
let prg = Typing.program ~leave_unresolved:true prg in
debug_pass_name "scalc";
Scalc.From_lcalc.translate_program ~keep_special_ops prg, type_ordering
( Scalc.From_lcalc.translate_program
~config:{ keep_special_ops; dead_value_assignment; no_struct_literals }
prg,
type_ordering )
end
module Commands = struct
@ -839,10 +844,13 @@ module Commands = struct
avoid_exceptions
closure_conversion
keep_special_ops
dead_value_assignment
no_struct_literals
ex_scope_opt =
let prg, _ =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~keep_special_ops
~dead_value_assignment ~no_struct_literals
in
let _output_file, with_output = get_output_format options output in
with_output
@ -877,6 +885,8 @@ module Commands = struct
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.keep_special_ops
$ Cli.Flags.dead_value_assignment
$ Cli.Flags.no_struct_literals
$ Cli.Flags.ex_scope_opt)
let python
@ -890,6 +900,7 @@ module Commands = struct
let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~keep_special_ops:false
~dead_value_assignment:true ~no_struct_literals:false
in
let output_file, with_output =
@ -919,6 +930,7 @@ module Commands = struct
let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions:false ~closure_conversion ~keep_special_ops:false
~dead_value_assignment:false ~no_struct_literals:false
in
let output_file, with_output = get_output_format options ~ext:".r" output in
@ -943,6 +955,7 @@ module Commands = struct
let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions:true ~closure_conversion:true ~keep_special_ops:true
~dead_value_assignment:false ~no_struct_literals:true
in
let output_file, with_output = get_output_format options ~ext:".c" output in
Message.emit_debug "Compiling program into C...";

View File

@ -63,6 +63,8 @@ module Passes : sig
avoid_exceptions:bool ->
closure_conversion:bool ->
keep_special_ops:bool ->
dead_value_assignment:bool ->
no_struct_literals:bool ->
Scalc.Ast.program * Scopelang.Dependency.TVertex.t list
end

View File

@ -34,6 +34,7 @@ let run
let prg, type_ordering =
Driver.Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~keep_special_ops:false
~dead_value_assignment:true ~no_struct_literals:false
in
let output_file, with_output = get_output_format options ~ext:".py" output in

View File

@ -69,6 +69,7 @@ type stmt =
| SIfThenElse of { if_expr : expr; then_block : block; else_block : block }
| SSwitch of {
switch_expr : expr;
switch_expr_typ : typ;
enum_name : EnumName.t;
switch_cases : switch_case list;
}

View File

@ -20,13 +20,19 @@ module A = Ast
module L = Lcalc.Ast
module D = Dcalc.Ast
type translation_config = {
keep_special_ops : bool;
dead_value_assignment : bool;
no_struct_literals : bool;
}
type 'm ctxt = {
func_dict : ('m L.expr, A.FuncName.t) Var.Map.t;
decl_ctx : decl_ctx;
var_dict : ('m L.expr, A.VarName.t) Var.Map.t;
inside_definition_of : A.VarName.t option;
context_name : string;
keep_special_ops : bool;
config : translation_config;
}
let unthunk e =
@ -55,7 +61,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
(Var.Map.keys ctxt.var_dict))
in
[], (local_var, Expr.pos expr)
| EStruct { fields; name } ->
| EStruct { fields; name } when not ctxt.config.no_struct_literals ->
let args_stmts, new_args =
StructField.Map.fold
(fun _ arg (args_stmts, new_args) ->
@ -91,7 +97,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
f = EOp { op = Op.HandleDefaultOpt; tys = _ }, _binder_mark;
args = [_exceptions; _just; _cons];
}
when ctxt.keep_special_ops ->
when ctxt.config.keep_special_ops ->
(* This should be translated as a statement *)
raise Not_found
| EApp { f; args } ->
@ -158,7 +164,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
f = EOp { op = Op.HandleDefaultOpt; tys = _ }, _binder_mark;
args = [exceptions; just; cons];
}
when ctxt.keep_special_ops ->
when ctxt.config.keep_special_ops ->
let exceptions =
match Mark.remove exceptions with
| EArray exceptions -> exceptions
@ -304,7 +310,12 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
e1_stmts
@ [
( A.SSwitch
{ switch_expr = new_e1; enum_name = name; switch_cases = new_args },
{
switch_expr = new_e1;
switch_expr_typ = Expr.maybe_ty (Mark.get e1);
enum_name = name;
switch_cases = new_args;
},
Expr.pos block_expr );
]
| EIfThenElse { cond; etrue; efalse } ->
@ -329,8 +340,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
(* Before raising the exception, we still give a dummy definition to the
current variable so that tools like mypy don't complain. *)
(match ctxt.inside_definition_of with
| None -> []
| Some x ->
| Some x when ctxt.config.dead_value_assignment ->
[
( A.SLocalDef
{
@ -338,8 +348,34 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
expr = Ast.EVar Ast.dead_value, Expr.pos block_expr;
},
Expr.pos block_expr );
])
]
| _ -> [])
@ [A.SRaise except, Expr.pos block_expr]
| EStruct { fields; name } when ctxt.config.no_struct_literals ->
let args_stmts, new_args =
StructField.Map.fold
(fun _ arg (args_stmts, new_args) ->
let arg_stmts, new_arg = translate_expr ctxt arg in
arg_stmts @ args_stmts, new_arg :: new_args)
fields ([], [])
in
let new_args = List.rev new_args in
let args_stmts = List.rev args_stmts in
let struct_expr =
A.EStruct { fields = new_args; name }, Expr.pos block_expr
in
let tmp_struct_var_name =
match ctxt.inside_definition_of with
| None ->
failwith "should not happen"
(* [translate_expr] should create this [inside_definition_of]*)
| Some x -> x, Expr.pos block_expr
in
args_stmts
@ [
( A.SLocalDef { name = tmp_struct_var_name; expr = struct_expr },
Expr.pos block_expr );
]
| _ -> (
let e_stmts, new_e = translate_expr ctxt block_expr in
e_stmts
@ -359,7 +395,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
])
let rec translate_scope_body_expr
~(keep_special_ops : bool)
~(config : translation_config)
(scope_name : ScopeName.t)
(decl_ctx : decl_ctx)
(var_dict : ('m L.expr, A.VarName.t) Var.Map.t)
@ -375,7 +411,7 @@ let rec translate_scope_body_expr
var_dict;
inside_definition_of = None;
context_name = Mark.remove (ScopeName.get_info scope_name);
keep_special_ops;
config;
}
e
in
@ -395,7 +431,7 @@ let rec translate_scope_body_expr
var_dict;
inside_definition_of = Some let_var_id;
context_name = Mark.remove (ScopeName.get_info scope_name);
keep_special_ops;
config;
}
scope_let.scope_let_expr
| _ ->
@ -407,7 +443,7 @@ let rec translate_scope_body_expr
var_dict;
inside_definition_of = Some let_var_id;
context_name = Mark.remove (ScopeName.get_info scope_name);
keep_special_ops;
config;
}
scope_let.scope_let_expr
in
@ -426,11 +462,11 @@ let rec translate_scope_body_expr
},
scope_let.scope_let_pos );
])
@ translate_scope_body_expr ~keep_special_ops scope_name decl_ctx
new_var_dict func_dict scope_let_next
@ translate_scope_body_expr ~config scope_name decl_ctx new_var_dict
func_dict scope_let_next
let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program
=
let translate_program ~(config : translation_config) (p : 'm L.program) :
A.program =
let _, _, rev_items =
Scope.fold_left
~f:(fun (func_dict, var_dict, rev_items) code_item var ->
@ -447,8 +483,8 @@ let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program
Var.Map.add scope_input_var scope_input_var_id var_dict
in
let new_scope_body =
translate_scope_body_expr ~keep_special_ops name p.decl_ctx
var_dict_local func_dict scope_body_expr
translate_scope_body_expr ~config name p.decl_ctx var_dict_local
func_dict scope_body_expr
in
let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in
( Var.Map.add var func_id func_dict,
@ -493,7 +529,7 @@ let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program
var_dict args args_id;
inside_definition_of = None;
context_name = Mark.remove (TopdefName.get_info name);
keep_special_ops;
config;
}
in
translate_expr ctxt expr
@ -529,7 +565,7 @@ let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program
var_dict;
inside_definition_of = None;
context_name = Mark.remove (TopdefName.get_info name);
keep_special_ops;
config;
}
in
translate_expr ctxt expr

View File

@ -16,9 +16,23 @@
open Shared_ast
(* When [keep_special_ops] is true, then this translation uses special Scalc AST
nodes for higher-order operators like map, fold, handle_default, etc. This is
useful if the target language after Scalc does not support nested functions
like C. *)
type translation_config = {
keep_special_ops : bool;
(** When [keep_special_ops] is true, then this translation uses special
Scalc AST nodes for higher-order operators like map, fold,
handle_default, etc. This is useful if the target language after Scalc
does not support nested functions like C. *)
dead_value_assignment : bool;
(** When [dead_value_assignment] is true, the translation inserts dummy
assignments of the variable being defined in the current code branch
just before raising a terminal error. This is useful for languages
like Python and their linting tools like mypy. The assignment uses the
polymorphic [Ast.dead_value]. *)
no_struct_literals : bool;
(** When [no_struct_literals] is true, the translation inserts a temporary
variable to hold the initialization of struct literals. This matches
what C89 expects. *)
}
val translate_program :
keep_special_ops:bool -> typed Lcalc.Ast.program -> Ast.program
config:translation_config -> typed Lcalc.Ast.program -> Ast.program

View File

@ -151,7 +151,8 @@ let rec format_statement
Format.fprintf fmt "@[<hov 2>%a %a@]" Print.keyword "assert"
(format_expr decl_ctx ~debug)
(naked_expr, Mark.get stmt)
| SSwitch { switch_expr = e_switch; enum_name = enum; switch_cases = arms } ->
| SSwitch { switch_expr = e_switch; enum_name = enum; switch_cases = arms; _ }
->
let cons = EnumName.Map.find enum decl_ctx.ctx_enums in
Format.fprintf fmt "@[<v 0>%a @[<hov 2>%a@]%a@,@]%a" Print.keyword "switch"
(format_expr decl_ctx ~debug)

View File

@ -154,7 +154,7 @@ let rec format_typ
(* We translate the option type with an overloading to C's [NULL] *)
Format.fprintf fmt
"@[<v 2>struct {@ char some_tag;@ @[<v 2>union {@ void *none;@ %a;@]@,\
} some_value;@]@,\
} payload;@]@,\
} /* option %a */ %t"
(format_typ decl_ctx (fun fmt -> Format.fprintf fmt "some"))
some_typ (Print.typ decl_ctx) some_typ element_name
@ -344,17 +344,16 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
unit =
match Mark.remove e with
| EVar v when v = Ast.dead_value -> Format.fprintf fmt "NULL"
| EVar v -> format_var fmt v
| EFunc f -> format_func_name fmt f
| EStruct { fields = es; name = s } ->
Format.fprintf fmt "new(\"catala_struct_%a\",@ %a)" format_struct_name s
| EStruct { fields = es; _ } ->
(* These should only appear when initializing a variable definition *)
Format.fprintf fmt "{ %a }"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, (struct_field, _)) ->
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
(format_expression ctx) e))
(List.combine es
(StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs)))
(fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e))
es
| EStructFieldAccess { e1; field; _ } ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1
format_struct_field_name field
@ -467,22 +466,28 @@ let rec format_statement
{ case_block = case_none; _ };
{ case_block = case_some; payload_var_name = case_some_var };
];
switch_expr_typ;
}
when EnumName.equal e_name Expr.option_enum ->
(* We translate the option type with an overloading by Python's [None] *)
let tmp_var = VarName.fresh ("perhaps_none_arg", Pos.no_pos) in
Format.fprintf fmt
"%a <- %a@\n\
@[<hov 2>if (is.null(%a)) {@\n\
"%a = %a;@\n\
@[<hov 2>if (%a.some_tag != 0) {@\n\
%a@]@\n\
@[<hov 2>} else {@\n\
%a = %a@\n\
%a = %a.payload.some;@\n\
%a@]@\n\
}"
format_var tmp_var (format_expression ctx) e1 format_var tmp_var
(format_block ctx) case_none format_var case_some_var format_var tmp_var
(format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases } ->
(format_typ ctx (fun fmt -> format_var fmt tmp_var))
switch_expr_typ (format_expression ctx) e1 format_var tmp_var
(format_block ctx) case_none
(format_typ ctx (fun fmt -> format_var fmt case_some_var))
(match Mark.remove switch_expr_typ with
| TOption tau -> tau
| _ -> failwith "should not happen")
format_var tmp_var (format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cases =
List.map2
(fun x (cons, _) -> x, cons)
@ -500,7 +505,7 @@ let rec format_statement
payload_var_name format_var tmp_var (format_block ctx) case_block))
cases
| SReturn e1 ->
Format.fprintf fmt "@[<hov 2>return(%a)@]" (format_expression ctx)
Format.fprintf fmt "@[<hov 2>return %a;@]" (format_expression ctx)
(e1, Mark.get s)
| SAssert e1 ->
let pos = Mark.get s in
@ -531,6 +536,9 @@ let format_program
"@[<v>/* This file has been generated by the Catala compiler, do not edit! \
*/@,\
@,\
#include <stdio.h>@,\
#include <stdlib.h>@,\
@,\
%a@,\
%a@,\
@]"

View File

@ -432,6 +432,7 @@ let rec format_statement
{ case_block = case_none; _ };
{ case_block = case_some; payload_var_name = case_some_var };
];
_;
}
when EnumName.equal e_name Expr.option_enum ->
(* We translate the option type with an overloading by Python's [None] *)
@ -446,7 +447,7 @@ let rec format_statement
format_var tmp_var (format_expression ctx) e1 format_var tmp_var
(format_block ctx) case_none format_var case_some_var format_var tmp_var
(format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases } ->
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cons_map = EnumName.Map.find e_name ctx.ctx_enums in
let cases =
List.map2

View File

@ -414,6 +414,7 @@ let rec format_statement
{ case_block = case_none; _ };
{ case_block = case_some; payload_var_name = case_some_var };
];
_;
}
when EnumName.equal e_name Expr.option_enum ->
(* We translate the option type with an overloading by Python's [None] *)
@ -429,7 +430,7 @@ let rec format_statement
format_var tmp_var (format_expression ctx) e1 format_var tmp_var
(format_block ctx) case_none format_var case_some_var format_var tmp_var
(format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases } ->
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cases =
List.map2
(fun x (cons, _) -> x, cons)