Correct compilation of tryCatch

This commit is contained in:
Denis Merigoux 2023-08-04 19:03:10 +02:00
parent 84d37d8720
commit 1df2ebda13
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
10 changed files with 108 additions and 35 deletions

View File

@ -20,6 +20,7 @@
type backend_lang = En | Fr | Pl
type when_enum = Auto | Always | Never
type message_format_enum = Human | GNU
type compilation_method = Expression | Statement
type input_file = FileName of string | Contents of string
(** Associates a {!type: Cli.backend_lang} with its string represtation. *)
@ -30,6 +31,7 @@ let language_code =
fun l -> List.assoc l rl
let message_format_opt = ["human", Human; "gnu", GNU]
let compilation_method_opt = ["expression", Expression; "statement", Statement]
type options = {
mutable input_file : input_file;
@ -317,6 +319,15 @@ module Flags = struct
"Disables the search for counterexamples. Useful when you want a \
deterministic output from the Catala compiler, since provers can \
have some randomness in them."
let scalc_try_with_compilation =
value
& opt (enum compilation_method_opt) Statement
& info
["scalc_try_with_compilation"]
~doc:
"How should try ... with ... constructs be compiled from Lcalc to \
Scalc ? Choice is between $(i,expression) or $(i,statement)."
end
let version = "0.8.0"

View File

@ -24,6 +24,11 @@ type message_format_enum =
| Human
| GNU (** Format of error and warning messages output by the compiler. *)
type compilation_method =
| Expression
| Statement
(** Whether to compile something as an expression or a statement *)
type input_file = FileName of string | Contents of string
val languages : (string * backend_lang) list
@ -99,6 +104,7 @@ module Flags : sig
val closure_conversion : bool Term.t
val link_modules : string list Term.t
val disable_counterexamples : bool Term.t
val scalc_try_with_compilation : compilation_method Term.t
end
(** {2 Command-line application} *)

View File

@ -190,7 +190,8 @@ module Passes = struct
~optimize
~check_invariants
~avoid_exceptions
~closure_conversion :
~closure_conversion
~scalc_try_with_compilation :
Scalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list =
@ -199,7 +200,15 @@ module Passes = struct
~closure_conversion
in
Message.emit_debug "Compiling program into statement calculus...";
Scalc.From_lcalc.translate_program prg, ctx, type_ordering
( Scalc.From_lcalc.translate_program prg
{
try_catch_type =
(match scalc_try_with_compilation with
| Cli.Expression -> Scalc.From_lcalc.Expression
| Cli.Statement -> Scalc.From_lcalc.Statement);
},
ctx,
type_ordering )
end
module Commands = struct
@ -707,10 +716,11 @@ module Commands = struct
check_invariants
avoid_exceptions
closure_conversion
ex_scope_opt =
ex_scope_opt
scalc_try_with_compilation =
let prg, ctx, _ =
Passes.scalc options ~link_modules ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion
~avoid_exceptions ~closure_conversion ~scalc_try_with_compilation
in
let _output_file, with_output = get_output_format options output in
with_output
@ -744,7 +754,8 @@ module Commands = struct
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.ex_scope_opt)
$ Cli.Flags.ex_scope_opt
$ Cli.Flags.scalc_try_with_compilation)
let python
options
@ -757,6 +768,7 @@ module Commands = struct
let prg, _, type_ordering =
Passes.scalc options ~link_modules ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion
~scalc_try_with_compilation:Statement
in
let output_file, with_output =
get_output_format options ~ext:".py" output
@ -792,6 +804,7 @@ module Commands = struct
let prg, _, type_ordering =
Passes.scalc options ~link_modules ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion
~scalc_try_with_compilation:Expression
in
let output_file, with_output = get_output_format options ~ext:".r" output in
Message.emit_debug "Compiling program into R...";

View File

@ -66,6 +66,7 @@ module Passes : sig
check_invariants:bool ->
avoid_exceptions:bool ->
closure_conversion:bool ->
scalc_try_with_compilation:Cli.compilation_method ->
Scalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list

View File

@ -41,6 +41,7 @@ and naked_expr =
| ELit : lit -> naked_expr
| EApp : expr * expr list -> naked_expr
| EOp : operator -> naked_expr
| ETryExcept : expr * except * expr -> naked_expr
type stmt =
| SInnerFuncDef of VarName.t Mark.pos * func

View File

@ -20,7 +20,11 @@ module A = Ast
module L = Lcalc.Ast
module D = Dcalc.Ast
type compilation_type = Expression | Statement
type compilation_options = { try_catch_type : compilation_type }
type 'm ctxt = {
compilation_options : compilation_options;
func_dict : ('m L.expr, A.FuncName.t) Var.Map.t;
decl_ctx : decl_ctx;
var_dict : ('m L.expr, A.VarName.t) Var.Map.t;
@ -88,6 +92,12 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
args_stmts, (A.EArray new_args, Expr.pos expr)
| EOp { op; _ } -> [], (A.EOp (Operator.translate op), Expr.pos expr)
| ELit l -> [], (A.ELit l, Expr.pos expr)
| ECatch { body; exn; handler }
when ctxt.compilation_options.try_catch_type = Expression ->
let try_stmts, new_e_try = translate_expr ctxt body in
let catch_stmts, new_e_catch = translate_expr ctxt handler in
( try_stmts @ catch_stmts,
(A.ETryExcept (new_e_try, exn, new_e_catch), Expr.pos expr) )
| _ ->
let tmp_var =
A.VarName.fresh
@ -233,7 +243,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
let s_e_false = translate_statements ctxt efalse in
cond_stmts
@ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr]
| ECatch { body; exn; handler } ->
| ECatch { body; exn; handler }
when ctxt.compilation_options.try_catch_type = Statement ->
let s_e_try = translate_statements ctxt body in
let s_e_catch = translate_statements ctxt handler in
[A.STryExcept (s_e_try, exn, s_e_catch), Expr.pos block_expr]
@ -269,6 +280,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
])
let rec translate_scope_body_expr
(options : compilation_options)
(scope_name : ScopeName.t)
(decl_ctx : decl_ctx)
(var_dict : ('m L.expr, A.VarName.t) Var.Map.t)
@ -279,6 +291,7 @@ let rec translate_scope_body_expr
let block, new_e =
translate_expr
{
compilation_options = options;
decl_ctx;
func_dict;
var_dict;
@ -298,6 +311,7 @@ let rec translate_scope_body_expr
| Assertion ->
translate_statements
{
compilation_options = options;
decl_ctx;
func_dict;
var_dict;
@ -309,6 +323,7 @@ let rec translate_scope_body_expr
let let_expr_stmts, new_let_expr =
translate_expr
{
compilation_options = options;
decl_ctx;
func_dict;
var_dict;
@ -325,10 +340,11 @@ let rec translate_scope_body_expr
( A.SLocalDef ((let_var_id, scope_let.scope_let_pos), new_let_expr),
scope_let.scope_let_pos );
])
@ translate_scope_body_expr scope_name decl_ctx new_var_dict func_dict
scope_let_next
@ translate_scope_body_expr options scope_name decl_ctx new_var_dict
func_dict scope_let_next
let translate_program (p : 'm L.program) : A.program =
let translate_program (p : 'm L.program) (options : compilation_options) :
A.program =
let _, _, rev_items =
Scope.fold_left
~f:(fun (func_dict, var_dict, rev_items) code_item var ->
@ -345,8 +361,8 @@ let translate_program (p : 'm L.program) : A.program =
Var.Map.add scope_input_var scope_input_var_id var_dict
in
let new_scope_body =
translate_scope_body_expr name p.decl_ctx var_dict_local func_dict
scope_body_expr
translate_scope_body_expr options name p.decl_ctx var_dict_local
func_dict scope_body_expr
in
let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in
( Var.Map.add var func_id func_dict,
@ -381,6 +397,7 @@ let translate_program (p : 'm L.program) : A.program =
let block, expr =
let ctxt =
{
compilation_options = options;
func_dict;
decl_ctx = p.decl_ctx;
var_dict =
@ -410,6 +427,7 @@ let translate_program (p : 'm L.program) : A.program =
let block, expr =
let ctxt =
{
compilation_options = options;
func_dict;
decl_ctx = p.decl_ctx;
var_dict;

View File

@ -16,4 +16,8 @@
open Shared_ast
val translate_program : untyped Lcalc.Ast.program -> Ast.program
type compilation_type = Expression | Statement
type compilation_options = { try_catch_type : compilation_type }
val translate_program :
untyped Lcalc.Ast.program -> compilation_options -> Ast.program

View File

@ -85,6 +85,9 @@ let rec format_expr
format_with_parens)
args
| EOp op -> Print.operator ~debug fmt op
| ETryExcept (e_try, except, e_with) ->
Format.fprintf fmt "@[<v 2>%a(%a,@;%a,@;%a)@]" Print.keyword "tryWithExn"
format_expr e_try Print.except except format_expr e_with
let rec format_statement
(decl_ctx : decl_ctx)

View File

@ -381,6 +381,9 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(format_expression ctx))
args
| EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
| ETryExcept _ ->
Message.raise_internal_error
"Python needs TryExcept to be compiled as statements and not expressions"
let rec format_statement
(ctx : decl_ctx)

View File

@ -229,8 +229,8 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map;
0
in
if v_str = "_" then Format.fprintf fmt "_"
(* special case for the unit pattern *)
if v_str = "_" then Format.fprintf fmt "dummy_var"
(* special case for the unit pattern TODO escape dummy_var *)
else if local_id = 0 then format_name_cleaned fmt v_str
else Format.fprintf fmt "%a_%d" format_name_cleaned v_str local_id
@ -243,23 +243,30 @@ let format_exception (fmt : Format.formatter) (exc : except Mark.pos) : unit =
match Mark.remove exc with
| ConflictError ->
Format.fprintf fmt
"ConflictError(@[<hov 0>SourcePosition(@[<hov 0>filename=\"%s\",@ \
start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \
law_headings=%a)@])@]"
"catala_conflict_error(@[<hov 0>SourcePosition(@[<hov \
0>filename=\"%s\",@ start_line=%d,@ start_column=%d,@ end_line=%d,@ \
end_column=%d,@ law_headings=%a)@])@]"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos)
| EmptyError -> Format.fprintf fmt "EmptyError"
| Crash -> Format.fprintf fmt "Crash"
| EmptyError -> Format.fprintf fmt "catala_empty_error()"
| Crash -> Format.fprintf fmt "catala_crash()"
| NoValueProvided ->
Format.fprintf fmt
"NoValueProvided(@[<hov 0>SourcePosition(@[<hov 0>filename=\"%s\",@ \
start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \
law_headings=%a)@])@]"
"catala_no_value_provided_error(@[<hov 0>SourcePosition(@[<hov \
0>filename=\"%s\",@ start_line=%d,@ start_column=%d,@ end_line=%d,@ \
end_column=%d,@ law_headings=%a)@])@]"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos)
let format_exception_name (fmt : Format.formatter) (exc : except) : unit =
match exc with
| ConflictError -> Format.fprintf fmt "catala_conflict_error"
| EmptyError -> Format.fprintf fmt "catala_empty_error"
| Crash -> Format.fprintf fmt "catala_crash"
| NoValueProvided -> Format.fprintf fmt "catala_no_value_provided_error"
let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
unit =
match Mark.remove e with
@ -373,6 +380,12 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(format_expression ctx))
args
| EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
| ETryExcept (e_try, except, e_catch) ->
Format.fprintf fmt
(* TODO escape dummy__arg*)
"tryCatch@[<hov 2>(%a, %a = function(dummy__arg)) @[<hov 2>{@;%a@;}@],@]"
(format_expression ctx) e_try format_exception_name except
(format_expression ctx) e_catch
let rec format_statement
(ctx : decl_ctx)
@ -380,23 +393,22 @@ let rec format_statement
(s : stmt Mark.pos) : unit =
match Mark.remove s with
| SInnerFuncDef (name, { func_params; func_body }) ->
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_var
Format.fprintf fmt "@[<hov 2>%a <- function(@\n%a) {@\n%a@]@\n}" format_var
(Mark.remove name)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n,@;")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Mark.remove var) format_typ
typ))
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
format_typ typ))
func_params (format_block ctx) func_body
| SLocalDecl _ ->
assert false (* We don't need to declare variables in Python *)
| SLocalDef (v, e) ->
Format.fprintf fmt "@[<hov 4>%a = %a@]" format_var (Mark.remove v)
Format.fprintf fmt "@[<hov 2>%a <- %a@]" format_var (Mark.remove v)
(format_expression ctx) e
| STryExcept (try_b, except, catch_b) ->
Format.fprintf fmt "@[<hov 4>try:@\n%a@]@\n@[<hov 4>except %a:@\n%a@]"
(format_block ctx) try_b format_exception (except, Pos.no_pos)
(format_block ctx) catch_b
| STryExcept (_try_b, _except, _catch_b) ->
Message.raise_internal_error
"R needs TryExcept to be compiled as exceptions and not statements"
| SRaise except ->
Format.fprintf fmt "@[<hov 4>raise %a@]" format_exception
(except, Mark.get s)
@ -562,17 +574,18 @@ let format_program
(format_ctx type_ordering) p.decl_ctx
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function
| SVar { var; expr } ->
Format.fprintf fmt "@[<hv 4>%a = (@,%a@,@])@," format_var var
Format.fprintf fmt "@[<hv 2>%a <- (@,%a@,@])@," format_var var
(format_expression p.decl_ctx)
expr
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
let { Ast.func_params; Ast.func_body } = func in
Format.fprintf fmt "@[<hv 4>def %a(%a):@\n%a@]@," format_func_name var
Format.fprintf fmt "@[<hv 2>%a <- function(@\n%a) {@\n%a@]@\n}@,"
format_func_name var
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n,@;")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
format_typ typ))
func_params (format_block p.decl_ctx) func_body))
p.code_items