mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Correct compilation of tryCatch
This commit is contained in:
parent
84d37d8720
commit
1df2ebda13
@ -20,6 +20,7 @@
|
||||
type backend_lang = En | Fr | Pl
|
||||
type when_enum = Auto | Always | Never
|
||||
type message_format_enum = Human | GNU
|
||||
type compilation_method = Expression | Statement
|
||||
type input_file = FileName of string | Contents of string
|
||||
|
||||
(** Associates a {!type: Cli.backend_lang} with its string represtation. *)
|
||||
@ -30,6 +31,7 @@ let language_code =
|
||||
fun l -> List.assoc l rl
|
||||
|
||||
let message_format_opt = ["human", Human; "gnu", GNU]
|
||||
let compilation_method_opt = ["expression", Expression; "statement", Statement]
|
||||
|
||||
type options = {
|
||||
mutable input_file : input_file;
|
||||
@ -317,6 +319,15 @@ module Flags = struct
|
||||
"Disables the search for counterexamples. Useful when you want a \
|
||||
deterministic output from the Catala compiler, since provers can \
|
||||
have some randomness in them."
|
||||
|
||||
let scalc_try_with_compilation =
|
||||
value
|
||||
& opt (enum compilation_method_opt) Statement
|
||||
& info
|
||||
["scalc_try_with_compilation"]
|
||||
~doc:
|
||||
"How should try ... with ... constructs be compiled from Lcalc to \
|
||||
Scalc ? Choice is between $(i,expression) or $(i,statement)."
|
||||
end
|
||||
|
||||
let version = "0.8.0"
|
||||
|
@ -24,6 +24,11 @@ type message_format_enum =
|
||||
| Human
|
||||
| GNU (** Format of error and warning messages output by the compiler. *)
|
||||
|
||||
type compilation_method =
|
||||
| Expression
|
||||
| Statement
|
||||
(** Whether to compile something as an expression or a statement *)
|
||||
|
||||
type input_file = FileName of string | Contents of string
|
||||
|
||||
val languages : (string * backend_lang) list
|
||||
@ -99,6 +104,7 @@ module Flags : sig
|
||||
val closure_conversion : bool Term.t
|
||||
val link_modules : string list Term.t
|
||||
val disable_counterexamples : bool Term.t
|
||||
val scalc_try_with_compilation : compilation_method Term.t
|
||||
end
|
||||
|
||||
(** {2 Command-line application} *)
|
||||
|
@ -190,7 +190,8 @@ module Passes = struct
|
||||
~optimize
|
||||
~check_invariants
|
||||
~avoid_exceptions
|
||||
~closure_conversion :
|
||||
~closure_conversion
|
||||
~scalc_try_with_compilation :
|
||||
Scalc.Ast.program
|
||||
* Desugared.Name_resolution.context
|
||||
* Scopelang.Dependency.TVertex.t list =
|
||||
@ -199,7 +200,15 @@ module Passes = struct
|
||||
~closure_conversion
|
||||
in
|
||||
Message.emit_debug "Compiling program into statement calculus...";
|
||||
Scalc.From_lcalc.translate_program prg, ctx, type_ordering
|
||||
( Scalc.From_lcalc.translate_program prg
|
||||
{
|
||||
try_catch_type =
|
||||
(match scalc_try_with_compilation with
|
||||
| Cli.Expression -> Scalc.From_lcalc.Expression
|
||||
| Cli.Statement -> Scalc.From_lcalc.Statement);
|
||||
},
|
||||
ctx,
|
||||
type_ordering )
|
||||
end
|
||||
|
||||
module Commands = struct
|
||||
@ -707,10 +716,11 @@ module Commands = struct
|
||||
check_invariants
|
||||
avoid_exceptions
|
||||
closure_conversion
|
||||
ex_scope_opt =
|
||||
ex_scope_opt
|
||||
scalc_try_with_compilation =
|
||||
let prg, ctx, _ =
|
||||
Passes.scalc options ~link_modules ~optimize ~check_invariants
|
||||
~avoid_exceptions ~closure_conversion
|
||||
~avoid_exceptions ~closure_conversion ~scalc_try_with_compilation
|
||||
in
|
||||
let _output_file, with_output = get_output_format options output in
|
||||
with_output
|
||||
@ -744,7 +754,8 @@ module Commands = struct
|
||||
$ Cli.Flags.check_invariants
|
||||
$ Cli.Flags.avoid_exceptions
|
||||
$ Cli.Flags.closure_conversion
|
||||
$ Cli.Flags.ex_scope_opt)
|
||||
$ Cli.Flags.ex_scope_opt
|
||||
$ Cli.Flags.scalc_try_with_compilation)
|
||||
|
||||
let python
|
||||
options
|
||||
@ -757,6 +768,7 @@ module Commands = struct
|
||||
let prg, _, type_ordering =
|
||||
Passes.scalc options ~link_modules ~optimize ~check_invariants
|
||||
~avoid_exceptions ~closure_conversion
|
||||
~scalc_try_with_compilation:Statement
|
||||
in
|
||||
let output_file, with_output =
|
||||
get_output_format options ~ext:".py" output
|
||||
@ -792,6 +804,7 @@ module Commands = struct
|
||||
let prg, _, type_ordering =
|
||||
Passes.scalc options ~link_modules ~optimize ~check_invariants
|
||||
~avoid_exceptions ~closure_conversion
|
||||
~scalc_try_with_compilation:Expression
|
||||
in
|
||||
let output_file, with_output = get_output_format options ~ext:".r" output in
|
||||
Message.emit_debug "Compiling program into R...";
|
||||
|
@ -66,6 +66,7 @@ module Passes : sig
|
||||
check_invariants:bool ->
|
||||
avoid_exceptions:bool ->
|
||||
closure_conversion:bool ->
|
||||
scalc_try_with_compilation:Cli.compilation_method ->
|
||||
Scalc.Ast.program
|
||||
* Desugared.Name_resolution.context
|
||||
* Scopelang.Dependency.TVertex.t list
|
||||
|
@ -41,6 +41,7 @@ and naked_expr =
|
||||
| ELit : lit -> naked_expr
|
||||
| EApp : expr * expr list -> naked_expr
|
||||
| EOp : operator -> naked_expr
|
||||
| ETryExcept : expr * except * expr -> naked_expr
|
||||
|
||||
type stmt =
|
||||
| SInnerFuncDef of VarName.t Mark.pos * func
|
||||
|
@ -20,7 +20,11 @@ module A = Ast
|
||||
module L = Lcalc.Ast
|
||||
module D = Dcalc.Ast
|
||||
|
||||
type compilation_type = Expression | Statement
|
||||
type compilation_options = { try_catch_type : compilation_type }
|
||||
|
||||
type 'm ctxt = {
|
||||
compilation_options : compilation_options;
|
||||
func_dict : ('m L.expr, A.FuncName.t) Var.Map.t;
|
||||
decl_ctx : decl_ctx;
|
||||
var_dict : ('m L.expr, A.VarName.t) Var.Map.t;
|
||||
@ -88,6 +92,12 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
|
||||
args_stmts, (A.EArray new_args, Expr.pos expr)
|
||||
| EOp { op; _ } -> [], (A.EOp (Operator.translate op), Expr.pos expr)
|
||||
| ELit l -> [], (A.ELit l, Expr.pos expr)
|
||||
| ECatch { body; exn; handler }
|
||||
when ctxt.compilation_options.try_catch_type = Expression ->
|
||||
let try_stmts, new_e_try = translate_expr ctxt body in
|
||||
let catch_stmts, new_e_catch = translate_expr ctxt handler in
|
||||
( try_stmts @ catch_stmts,
|
||||
(A.ETryExcept (new_e_try, exn, new_e_catch), Expr.pos expr) )
|
||||
| _ ->
|
||||
let tmp_var =
|
||||
A.VarName.fresh
|
||||
@ -233,7 +243,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
let s_e_false = translate_statements ctxt efalse in
|
||||
cond_stmts
|
||||
@ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr]
|
||||
| ECatch { body; exn; handler } ->
|
||||
| ECatch { body; exn; handler }
|
||||
when ctxt.compilation_options.try_catch_type = Statement ->
|
||||
let s_e_try = translate_statements ctxt body in
|
||||
let s_e_catch = translate_statements ctxt handler in
|
||||
[A.STryExcept (s_e_try, exn, s_e_catch), Expr.pos block_expr]
|
||||
@ -269,6 +280,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
])
|
||||
|
||||
let rec translate_scope_body_expr
|
||||
(options : compilation_options)
|
||||
(scope_name : ScopeName.t)
|
||||
(decl_ctx : decl_ctx)
|
||||
(var_dict : ('m L.expr, A.VarName.t) Var.Map.t)
|
||||
@ -279,6 +291,7 @@ let rec translate_scope_body_expr
|
||||
let block, new_e =
|
||||
translate_expr
|
||||
{
|
||||
compilation_options = options;
|
||||
decl_ctx;
|
||||
func_dict;
|
||||
var_dict;
|
||||
@ -298,6 +311,7 @@ let rec translate_scope_body_expr
|
||||
| Assertion ->
|
||||
translate_statements
|
||||
{
|
||||
compilation_options = options;
|
||||
decl_ctx;
|
||||
func_dict;
|
||||
var_dict;
|
||||
@ -309,6 +323,7 @@ let rec translate_scope_body_expr
|
||||
let let_expr_stmts, new_let_expr =
|
||||
translate_expr
|
||||
{
|
||||
compilation_options = options;
|
||||
decl_ctx;
|
||||
func_dict;
|
||||
var_dict;
|
||||
@ -325,10 +340,11 @@ let rec translate_scope_body_expr
|
||||
( A.SLocalDef ((let_var_id, scope_let.scope_let_pos), new_let_expr),
|
||||
scope_let.scope_let_pos );
|
||||
])
|
||||
@ translate_scope_body_expr scope_name decl_ctx new_var_dict func_dict
|
||||
scope_let_next
|
||||
@ translate_scope_body_expr options scope_name decl_ctx new_var_dict
|
||||
func_dict scope_let_next
|
||||
|
||||
let translate_program (p : 'm L.program) : A.program =
|
||||
let translate_program (p : 'm L.program) (options : compilation_options) :
|
||||
A.program =
|
||||
let _, _, rev_items =
|
||||
Scope.fold_left
|
||||
~f:(fun (func_dict, var_dict, rev_items) code_item var ->
|
||||
@ -345,8 +361,8 @@ let translate_program (p : 'm L.program) : A.program =
|
||||
Var.Map.add scope_input_var scope_input_var_id var_dict
|
||||
in
|
||||
let new_scope_body =
|
||||
translate_scope_body_expr name p.decl_ctx var_dict_local func_dict
|
||||
scope_body_expr
|
||||
translate_scope_body_expr options name p.decl_ctx var_dict_local
|
||||
func_dict scope_body_expr
|
||||
in
|
||||
let func_id = A.FuncName.fresh (Bindlib.name_of var, Pos.no_pos) in
|
||||
( Var.Map.add var func_id func_dict,
|
||||
@ -381,6 +397,7 @@ let translate_program (p : 'm L.program) : A.program =
|
||||
let block, expr =
|
||||
let ctxt =
|
||||
{
|
||||
compilation_options = options;
|
||||
func_dict;
|
||||
decl_ctx = p.decl_ctx;
|
||||
var_dict =
|
||||
@ -410,6 +427,7 @@ let translate_program (p : 'm L.program) : A.program =
|
||||
let block, expr =
|
||||
let ctxt =
|
||||
{
|
||||
compilation_options = options;
|
||||
func_dict;
|
||||
decl_ctx = p.decl_ctx;
|
||||
var_dict;
|
||||
|
@ -16,4 +16,8 @@
|
||||
|
||||
open Shared_ast
|
||||
|
||||
val translate_program : untyped Lcalc.Ast.program -> Ast.program
|
||||
type compilation_type = Expression | Statement
|
||||
type compilation_options = { try_catch_type : compilation_type }
|
||||
|
||||
val translate_program :
|
||||
untyped Lcalc.Ast.program -> compilation_options -> Ast.program
|
||||
|
@ -85,6 +85,9 @@ let rec format_expr
|
||||
format_with_parens)
|
||||
args
|
||||
| EOp op -> Print.operator ~debug fmt op
|
||||
| ETryExcept (e_try, except, e_with) ->
|
||||
Format.fprintf fmt "@[<v 2>%a(%a,@;%a,@;%a)@]" Print.keyword "tryWithExn"
|
||||
format_expr e_try Print.except except format_expr e_with
|
||||
|
||||
let rec format_statement
|
||||
(decl_ctx : decl_ctx)
|
||||
|
@ -381,6 +381,9 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
(format_expression ctx))
|
||||
args
|
||||
| EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
|
||||
| ETryExcept _ ->
|
||||
Message.raise_internal_error
|
||||
"Python needs TryExcept to be compiled as statements and not expressions"
|
||||
|
||||
let rec format_statement
|
||||
(ctx : decl_ctx)
|
||||
|
@ -229,8 +229,8 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
|
||||
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map;
|
||||
0
|
||||
in
|
||||
if v_str = "_" then Format.fprintf fmt "_"
|
||||
(* special case for the unit pattern *)
|
||||
if v_str = "_" then Format.fprintf fmt "dummy_var"
|
||||
(* special case for the unit pattern TODO escape dummy_var *)
|
||||
else if local_id = 0 then format_name_cleaned fmt v_str
|
||||
else Format.fprintf fmt "%a_%d" format_name_cleaned v_str local_id
|
||||
|
||||
@ -243,23 +243,30 @@ let format_exception (fmt : Format.formatter) (exc : except Mark.pos) : unit =
|
||||
match Mark.remove exc with
|
||||
| ConflictError ->
|
||||
Format.fprintf fmt
|
||||
"ConflictError(@[<hov 0>SourcePosition(@[<hov 0>filename=\"%s\",@ \
|
||||
start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \
|
||||
law_headings=%a)@])@]"
|
||||
"catala_conflict_error(@[<hov 0>SourcePosition(@[<hov \
|
||||
0>filename=\"%s\",@ start_line=%d,@ start_column=%d,@ end_line=%d,@ \
|
||||
end_column=%d,@ law_headings=%a)@])@]"
|
||||
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
|
||||
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
|
||||
(Pos.get_law_info pos)
|
||||
| EmptyError -> Format.fprintf fmt "EmptyError"
|
||||
| Crash -> Format.fprintf fmt "Crash"
|
||||
| EmptyError -> Format.fprintf fmt "catala_empty_error()"
|
||||
| Crash -> Format.fprintf fmt "catala_crash()"
|
||||
| NoValueProvided ->
|
||||
Format.fprintf fmt
|
||||
"NoValueProvided(@[<hov 0>SourcePosition(@[<hov 0>filename=\"%s\",@ \
|
||||
start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \
|
||||
law_headings=%a)@])@]"
|
||||
"catala_no_value_provided_error(@[<hov 0>SourcePosition(@[<hov \
|
||||
0>filename=\"%s\",@ start_line=%d,@ start_column=%d,@ end_line=%d,@ \
|
||||
end_column=%d,@ law_headings=%a)@])@]"
|
||||
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
|
||||
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
|
||||
(Pos.get_law_info pos)
|
||||
|
||||
let format_exception_name (fmt : Format.formatter) (exc : except) : unit =
|
||||
match exc with
|
||||
| ConflictError -> Format.fprintf fmt "catala_conflict_error"
|
||||
| EmptyError -> Format.fprintf fmt "catala_empty_error"
|
||||
| Crash -> Format.fprintf fmt "catala_crash"
|
||||
| NoValueProvided -> Format.fprintf fmt "catala_no_value_provided_error"
|
||||
|
||||
let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
unit =
|
||||
match Mark.remove e with
|
||||
@ -373,6 +380,12 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
(format_expression ctx))
|
||||
args
|
||||
| EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
|
||||
| ETryExcept (e_try, except, e_catch) ->
|
||||
Format.fprintf fmt
|
||||
(* TODO escape dummy__arg*)
|
||||
"tryCatch@[<hov 2>(%a, %a = function(dummy__arg)) @[<hov 2>{@;%a@;}@],@]"
|
||||
(format_expression ctx) e_try format_exception_name except
|
||||
(format_expression ctx) e_catch
|
||||
|
||||
let rec format_statement
|
||||
(ctx : decl_ctx)
|
||||
@ -380,23 +393,22 @@ let rec format_statement
|
||||
(s : stmt Mark.pos) : unit =
|
||||
match Mark.remove s with
|
||||
| SInnerFuncDef (name, { func_params; func_body }) ->
|
||||
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_var
|
||||
Format.fprintf fmt "@[<hov 2>%a <- function(@\n%a) {@\n%a@]@\n}" format_var
|
||||
(Mark.remove name)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n,@;")
|
||||
(fun fmt (var, typ) ->
|
||||
Format.fprintf fmt "%a:%a" format_var (Mark.remove var) format_typ
|
||||
typ))
|
||||
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
|
||||
format_typ typ))
|
||||
func_params (format_block ctx) func_body
|
||||
| SLocalDecl _ ->
|
||||
assert false (* We don't need to declare variables in Python *)
|
||||
| SLocalDef (v, e) ->
|
||||
Format.fprintf fmt "@[<hov 4>%a = %a@]" format_var (Mark.remove v)
|
||||
Format.fprintf fmt "@[<hov 2>%a <- %a@]" format_var (Mark.remove v)
|
||||
(format_expression ctx) e
|
||||
| STryExcept (try_b, except, catch_b) ->
|
||||
Format.fprintf fmt "@[<hov 4>try:@\n%a@]@\n@[<hov 4>except %a:@\n%a@]"
|
||||
(format_block ctx) try_b format_exception (except, Pos.no_pos)
|
||||
(format_block ctx) catch_b
|
||||
| STryExcept (_try_b, _except, _catch_b) ->
|
||||
Message.raise_internal_error
|
||||
"R needs TryExcept to be compiled as exceptions and not statements"
|
||||
| SRaise except ->
|
||||
Format.fprintf fmt "@[<hov 4>raise %a@]" format_exception
|
||||
(except, Mark.get s)
|
||||
@ -562,17 +574,18 @@ let format_program
|
||||
(format_ctx type_ordering) p.decl_ctx
|
||||
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function
|
||||
| SVar { var; expr } ->
|
||||
Format.fprintf fmt "@[<hv 4>%a = (@,%a@,@])@," format_var var
|
||||
Format.fprintf fmt "@[<hv 2>%a <- (@,%a@,@])@," format_var var
|
||||
(format_expression p.decl_ctx)
|
||||
expr
|
||||
| SFunc { var; func }
|
||||
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
|
||||
let { Ast.func_params; Ast.func_body } = func in
|
||||
Format.fprintf fmt "@[<hv 4>def %a(%a):@\n%a@]@," format_func_name var
|
||||
Format.fprintf fmt "@[<hv 2>%a <- function(@\n%a) {@\n%a@]@\n}@,"
|
||||
format_func_name var
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n,@;")
|
||||
(fun fmt (var, typ) ->
|
||||
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
|
||||
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
|
||||
format_typ typ))
|
||||
func_params (format_block p.decl_ctx) func_body))
|
||||
p.code_items
|
||||
|
Loading…
Reference in New Issue
Block a user