diff --git a/compiler/catala_utils/cli.ml b/compiler/catala_utils/cli.ml index 60aa5572..f77dda46 100644 --- a/compiler/catala_utils/cli.ml +++ b/compiler/catala_utils/cli.ml @@ -398,11 +398,30 @@ module Flags = struct let keep_special_ops = value & flag - & info ["keep_special_ops"] + & info ["keep-special-ops"] ~doc: "During the Lcalc->Scalc translation, uses special AST nodes for \ higher-order operators rather than nested closures (useful for C)." + let dead_value_assignment = + value + & flag + & info ["dead-value-assignment"] + ~doc: + "During the Lcalc->Scalc translation, insert dummy variable \ + assignments before raising terminal exception to please gradual \ + typing tools that check exhaustivity of variable definitions in \ + every code branch." + + let no_struct_literals = + value + & flag + & info ["no-struct-literals"] + ~doc: + "During the Lcalc->Scalc translation, insert temporary variable \ + assignments to hold the result of structure initializations \ + (matches the absence of struct literals of C89)." + let closure_conversion = value & flag diff --git a/compiler/catala_utils/cli.mli b/compiler/catala_utils/cli.mli index bbf9027e..783ec88a 100644 --- a/compiler/catala_utils/cli.mli +++ b/compiler/catala_utils/cli.mli @@ -130,6 +130,8 @@ module Flags : sig val avoid_exceptions : bool Term.t val closure_conversion : bool Term.t val keep_special_ops : bool Term.t + val dead_value_assignment : bool Term.t + val no_struct_literals : bool Term.t val include_dirs : raw_file list Term.t val disable_counterexamples : bool Term.t end diff --git a/compiler/driver.ml b/compiler/driver.ml index 5889c336..105f6907 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -285,7 +285,9 @@ module Passes = struct ~check_invariants ~avoid_exceptions ~closure_conversion - ~keep_special_ops : + ~keep_special_ops + ~dead_value_assignment + ~no_struct_literals : Scalc.Ast.program * Scopelang.Dependency.TVertex.t list = let prg, type_ordering = lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.untyped @@ -294,7 +296,10 @@ module Passes = struct Message.emit_debug "Retyping lambda calculus..."; let prg = Typing.program ~leave_unresolved:true prg in debug_pass_name "scalc"; - Scalc.From_lcalc.translate_program ~keep_special_ops prg, type_ordering + ( Scalc.From_lcalc.translate_program + ~config:{ keep_special_ops; dead_value_assignment; no_struct_literals } + prg, + type_ordering ) end module Commands = struct @@ -839,10 +844,13 @@ module Commands = struct avoid_exceptions closure_conversion keep_special_ops + dead_value_assignment + no_struct_literals ex_scope_opt = let prg, _ = Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~keep_special_ops + ~dead_value_assignment ~no_struct_literals in let _output_file, with_output = get_output_format options output in with_output @@ -877,6 +885,8 @@ module Commands = struct $ Cli.Flags.avoid_exceptions $ Cli.Flags.closure_conversion $ Cli.Flags.keep_special_ops + $ Cli.Flags.dead_value_assignment + $ Cli.Flags.no_struct_literals $ Cli.Flags.ex_scope_opt) let python @@ -890,6 +900,7 @@ module Commands = struct let prg, type_ordering = Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~keep_special_ops:false + ~dead_value_assignment:true ~no_struct_literals:false in let output_file, with_output = @@ -919,6 +930,7 @@ module Commands = struct let prg, type_ordering = Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions:false ~closure_conversion ~keep_special_ops:false + ~dead_value_assignment:false ~no_struct_literals:false in let output_file, with_output = get_output_format options ~ext:".r" output in @@ -943,6 +955,7 @@ module Commands = struct let prg, type_ordering = Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions:true ~closure_conversion:true ~keep_special_ops:true + ~dead_value_assignment:false ~no_struct_literals:true in let output_file, with_output = get_output_format options ~ext:".c" output in Message.emit_debug "Compiling program into C..."; diff --git a/compiler/driver.mli b/compiler/driver.mli index 9148075e..fa256891 100644 --- a/compiler/driver.mli +++ b/compiler/driver.mli @@ -63,6 +63,8 @@ module Passes : sig avoid_exceptions:bool -> closure_conversion:bool -> keep_special_ops:bool -> + dead_value_assignment:bool -> + no_struct_literals:bool -> Scalc.Ast.program * Scopelang.Dependency.TVertex.t list end diff --git a/compiler/plugins/python.ml b/compiler/plugins/python.ml index c32c1047..4dfb8f7a 100644 --- a/compiler/plugins/python.ml +++ b/compiler/plugins/python.ml @@ -34,6 +34,7 @@ let run let prg, type_ordering = Driver.Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~keep_special_ops:false + ~dead_value_assignment:true ~no_struct_literals:false in let output_file, with_output = get_output_format options ~ext:".py" output in diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index 75ae6229..2603f13a 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -69,6 +69,7 @@ type stmt = | SIfThenElse of { if_expr : expr; then_block : block; else_block : block } | SSwitch of { switch_expr : expr; + switch_expr_typ : typ; enum_name : EnumName.t; switch_cases : switch_case list; } diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 2ffdeb7a..31d10578 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -20,13 +20,19 @@ module A = Ast module L = Lcalc.Ast module D = Dcalc.Ast +type translation_config = { + keep_special_ops : bool; + dead_value_assignment : bool; + no_struct_literals : bool; +} + type 'm ctxt = { func_dict : ('m L.expr, A.FuncName.t) Var.Map.t; decl_ctx : decl_ctx; var_dict : ('m L.expr, A.VarName.t) Var.Map.t; inside_definition_of : A.VarName.t option; context_name : string; - keep_special_ops : bool; + config : translation_config; } let unthunk e = @@ -55,7 +61,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr = (Var.Map.keys ctxt.var_dict)) in [], (local_var, Expr.pos expr) - | EStruct { fields; name } -> + | EStruct { fields; name } when not ctxt.config.no_struct_literals -> let args_stmts, new_args = StructField.Map.fold (fun _ arg (args_stmts, new_args) -> @@ -91,7 +97,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr = f = EOp { op = Op.HandleDefaultOpt; tys = _ }, _binder_mark; args = [_exceptions; _just; _cons]; } - when ctxt.keep_special_ops -> + when ctxt.config.keep_special_ops -> (* This should be translated as a statement *) raise Not_found | EApp { f; args } -> @@ -158,7 +164,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = f = EOp { op = Op.HandleDefaultOpt; tys = _ }, _binder_mark; args = [exceptions; just; cons]; } - when ctxt.keep_special_ops -> + when ctxt.config.keep_special_ops -> let exceptions = match Mark.remove exceptions with | EArray exceptions -> exceptions @@ -304,7 +310,12 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = e1_stmts @ [ ( A.SSwitch - { switch_expr = new_e1; enum_name = name; switch_cases = new_args }, + { + switch_expr = new_e1; + switch_expr_typ = Expr.maybe_ty (Mark.get e1); + enum_name = name; + switch_cases = new_args; + }, Expr.pos block_expr ); ] | EIfThenElse { cond; etrue; efalse } -> @@ -329,8 +340,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = (* Before raising the exception, we still give a dummy definition to the current variable so that tools like mypy don't complain. *) (match ctxt.inside_definition_of with - | None -> [] - | Some x -> + | Some x when ctxt.config.dead_value_assignment -> [ ( A.SLocalDef { @@ -338,8 +348,34 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = expr = Ast.EVar Ast.dead_value, Expr.pos block_expr; }, Expr.pos block_expr ); - ]) + ] + | _ -> []) @ [A.SRaise except, Expr.pos block_expr] + | EStruct { fields; name } when ctxt.config.no_struct_literals -> + let args_stmts, new_args = + StructField.Map.fold + (fun _ arg (args_stmts, new_args) -> + let arg_stmts, new_arg = translate_expr ctxt arg in + arg_stmts @ args_stmts, new_arg :: new_args) + fields ([], []) + in + let new_args = List.rev new_args in + let args_stmts = List.rev args_stmts in + let struct_expr = + A.EStruct { fields = new_args; name }, Expr.pos block_expr + in + let tmp_struct_var_name = + match ctxt.inside_definition_of with + | None -> + failwith "should not happen" + (* [translate_expr] should create this [inside_definition_of]*) + | Some x -> x, Expr.pos block_expr + in + args_stmts + @ [ + ( A.SLocalDef { name = tmp_struct_var_name; expr = struct_expr }, + Expr.pos block_expr ); + ] | _ -> ( let e_stmts, new_e = translate_expr ctxt block_expr in e_stmts @@ -359,7 +395,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = ]) let rec translate_scope_body_expr - ~(keep_special_ops : bool) + ~(config : translation_config) (scope_name : ScopeName.t) (decl_ctx : decl_ctx) (var_dict : ('m L.expr, A.VarName.t) Var.Map.t) @@ -375,7 +411,7 @@ let rec translate_scope_body_expr var_dict; inside_definition_of = None; context_name = Mark.remove (ScopeName.get_info scope_name); - keep_special_ops; + config; } e in @@ -395,7 +431,7 @@ let rec translate_scope_body_expr var_dict; inside_definition_of = Some let_var_id; context_name = Mark.remove (ScopeName.get_info scope_name); - keep_special_ops; + config; } scope_let.scope_let_expr | _ -> @@ -407,7 +443,7 @@ let rec translate_scope_body_expr var_dict; inside_definition_of = Some let_var_id; context_name = Mark.remove (ScopeName.get_info scope_name); - keep_special_ops; + config; } scope_let.scope_let_expr in @@ -426,11 +462,11 @@ let rec translate_scope_body_expr }, scope_let.scope_let_pos ); ]) - @ translate_scope_body_expr ~keep_special_ops scope_name decl_ctx - new_var_dict func_dict scope_let_next + @ translate_scope_body_expr ~config scope_name decl_ctx new_var_dict + func_dict scope_let_next -let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program - = +let translate_program ~(config : translation_config) (p : 'm L.program) : + A.program = let _, _, rev_items = Scope.fold_left ~f:(fun (func_dict, var_dict, rev_items) code_item var -> @@ -447,8 +483,8 @@ let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program Var.Map.add scope_input_var scope_input_var_id var_dict in let new_scope_body = - translate_scope_body_expr ~keep_special_ops name p.decl_ctx - var_dict_local func_dict scope_body_expr + translate_scope_body_expr ~config name p.decl_ctx var_dict_local + func_dict scope_body_expr in let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in ( Var.Map.add var func_id func_dict, @@ -493,7 +529,7 @@ let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program var_dict args args_id; inside_definition_of = None; context_name = Mark.remove (TopdefName.get_info name); - keep_special_ops; + config; } in translate_expr ctxt expr @@ -529,7 +565,7 @@ let translate_program ~(keep_special_ops : bool) (p : 'm L.program) : A.program var_dict; inside_definition_of = None; context_name = Mark.remove (TopdefName.get_info name); - keep_special_ops; + config; } in translate_expr ctxt expr diff --git a/compiler/scalc/from_lcalc.mli b/compiler/scalc/from_lcalc.mli index 3b639f2e..fdd3b255 100644 --- a/compiler/scalc/from_lcalc.mli +++ b/compiler/scalc/from_lcalc.mli @@ -16,9 +16,23 @@ open Shared_ast -(* When [keep_special_ops] is true, then this translation uses special Scalc AST - nodes for higher-order operators like map, fold, handle_default, etc. This is - useful if the target language after Scalc does not support nested functions - like C. *) +type translation_config = { + keep_special_ops : bool; + (** When [keep_special_ops] is true, then this translation uses special + Scalc AST nodes for higher-order operators like map, fold, + handle_default, etc. This is useful if the target language after Scalc + does not support nested functions like C. *) + dead_value_assignment : bool; + (** When [dead_value_assignment] is true, the translation inserts dummy + assignments of the variable being defined in the current code branch + just before raising a terminal error. This is useful for languages + like Python and their linting tools like mypy. The assignment uses the + polymorphic [Ast.dead_value]. *) + no_struct_literals : bool; + (** When [no_struct_literals] is true, the translation inserts a temporary + variable to hold the initialization of struct literals. This matches + what C89 expects. *) +} + val translate_program : - keep_special_ops:bool -> typed Lcalc.Ast.program -> Ast.program + config:translation_config -> typed Lcalc.Ast.program -> Ast.program diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index aaf1d6d6..ce0ce79d 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -151,7 +151,8 @@ let rec format_statement Format.fprintf fmt "@[%a %a@]" Print.keyword "assert" (format_expr decl_ctx ~debug) (naked_expr, Mark.get stmt) - | SSwitch { switch_expr = e_switch; enum_name = enum; switch_cases = arms } -> + | SSwitch { switch_expr = e_switch; enum_name = enum; switch_cases = arms; _ } + -> let cons = EnumName.Map.find enum decl_ctx.ctx_enums in Format.fprintf fmt "@[%a @[%a@]%a@,@]%a" Print.keyword "switch" (format_expr decl_ctx ~debug) diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index fc78daaa..115660b3 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -154,7 +154,7 @@ let rec format_typ (* We translate the option type with an overloading to C's [NULL] *) Format.fprintf fmt "@[struct {@ char some_tag;@ @[union {@ void *none;@ %a;@]@,\ - } some_value;@]@,\ + } payload;@]@,\ } /* option %a */ %t" (format_typ decl_ctx (fun fmt -> Format.fprintf fmt "some")) some_typ (Print.typ decl_ctx) some_typ element_name @@ -344,17 +344,16 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit = let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : unit = match Mark.remove e with + | EVar v when v = Ast.dead_value -> Format.fprintf fmt "NULL" | EVar v -> format_var fmt v | EFunc f -> format_func_name fmt f - | EStruct { fields = es; name = s } -> - Format.fprintf fmt "new(\"catala_struct_%a\",@ %a)" format_struct_name s + | EStruct { fields = es; _ } -> + (* These should only appear when initializing a variable definition *) + Format.fprintf fmt "{ %a }" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt (e, (struct_field, _)) -> - Format.fprintf fmt "%a = %a" format_struct_field_name struct_field - (format_expression ctx) e)) - (List.combine es - (StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs))) + (fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e)) + es | EStructFieldAccess { e1; field; _ } -> Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field @@ -467,22 +466,28 @@ let rec format_statement { case_block = case_none; _ }; { case_block = case_some; payload_var_name = case_some_var }; ]; + switch_expr_typ; } when EnumName.equal e_name Expr.option_enum -> (* We translate the option type with an overloading by Python's [None] *) let tmp_var = VarName.fresh ("perhaps_none_arg", Pos.no_pos) in Format.fprintf fmt - "%a <- %a@\n\ - @[if (is.null(%a)) {@\n\ + "%a = %a;@\n\ + @[if (%a.some_tag != 0) {@\n\ %a@]@\n\ @[} else {@\n\ - %a = %a@\n\ + %a = %a.payload.some;@\n\ %a@]@\n\ }" - format_var tmp_var (format_expression ctx) e1 format_var tmp_var - (format_block ctx) case_none format_var case_some_var format_var tmp_var - (format_block ctx) case_some - | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases } -> + (format_typ ctx (fun fmt -> format_var fmt tmp_var)) + switch_expr_typ (format_expression ctx) e1 format_var tmp_var + (format_block ctx) case_none + (format_typ ctx (fun fmt -> format_var fmt case_some_var)) + (match Mark.remove switch_expr_typ with + | TOption tau -> tau + | _ -> failwith "should not happen") + format_var tmp_var (format_block ctx) case_some + | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> let cases = List.map2 (fun x (cons, _) -> x, cons) @@ -500,7 +505,7 @@ let rec format_statement payload_var_name format_var tmp_var (format_block ctx) case_block)) cases | SReturn e1 -> - Format.fprintf fmt "@[return(%a)@]" (format_expression ctx) + Format.fprintf fmt "@[return %a;@]" (format_expression ctx) (e1, Mark.get s) | SAssert e1 -> let pos = Mark.get s in @@ -531,6 +536,9 @@ let format_program "@[/* This file has been generated by the Catala compiler, do not edit! \ */@,\ @,\ + #include @,\ + #include @,\ + @,\ %a@,\ %a@,\ @]" diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 8ec516db..d845d85a 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -432,6 +432,7 @@ let rec format_statement { case_block = case_none; _ }; { case_block = case_some; payload_var_name = case_some_var }; ]; + _; } when EnumName.equal e_name Expr.option_enum -> (* We translate the option type with an overloading by Python's [None] *) @@ -446,7 +447,7 @@ let rec format_statement format_var tmp_var (format_expression ctx) e1 format_var tmp_var (format_block ctx) case_none format_var case_some_var format_var tmp_var (format_block ctx) case_some - | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases } -> + | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> let cons_map = EnumName.Map.find e_name ctx.ctx_enums in let cases = List.map2 diff --git a/compiler/scalc/to_r.ml b/compiler/scalc/to_r.ml index 9c9e1a6b..7d75b2a3 100644 --- a/compiler/scalc/to_r.ml +++ b/compiler/scalc/to_r.ml @@ -414,6 +414,7 @@ let rec format_statement { case_block = case_none; _ }; { case_block = case_some; payload_var_name = case_some_var }; ]; + _; } when EnumName.equal e_name Expr.option_enum -> (* We translate the option type with an overloading by Python's [None] *) @@ -429,7 +430,7 @@ let rec format_statement format_var tmp_var (format_expression ctx) e1 format_var tmp_var (format_block ctx) case_none format_var case_some_var format_var tmp_var (format_block ctx) case_some - | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases } -> + | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> let cases = List.map2 (fun x (cons, _) -> x, cons)