Tweaks in scalc

This makes the C runtime interface for date calculations work properly (although
we have placeholders instead of `dates_calc` at the moment)

Includes a few changes:
- use exprs instead of naked_exprs in statements (it's just easier to
  manipulate: passing a naked_expr when an expr is expected is annoying, while
  the opposite is trivial)
- add position to the `Add_dat_dur` operator, which can fail if no specific
  rounding mode is set (in the OCaml and C backends)
- inline closure calls a bit more in `closure_conversion`, for readability
This commit is contained in:
Louis Gesbert 2024-07-31 18:04:53 +02:00
parent 84651a33f2
commit 3f6d8bf358
12 changed files with 126 additions and 76 deletions

View File

@ -308,7 +308,7 @@ let rec transform_closures_expr :
let tys = List.map translate_type tys in
let pos = Expr.mark_pos m in
let env_arg_ty = TClosureEnv, Expr.pos new_e1 in
let fun_ty = TArrow (env_arg_ty :: tys, Expr.maybe_ty m), pos in
(* let fun_ty = TArrow (env_arg_ty :: tys, Expr.maybe_ty m), pos in *)
let code_env_var = Var.make "code_and_env" in
let code_env_expr =
let pos = Expr.pos e1 in
@ -321,8 +321,8 @@ let rec transform_closures_expr :
],
pos ))
in
let env_var = Var.make "env" in
let code_var = Var.make "code" in
(* let env_var = Var.make "env" in
* let code_var = Var.make "code" in *)
let free_vars, new_args =
List.fold_right
(fun arg (free_vars, new_args) ->
@ -331,19 +331,42 @@ let rec transform_closures_expr :
args (free_vars, [])
in
let call_expr =
let m1 = Mark.get new_e1 in
Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty]
[
Expr.make_tupleaccess code_env_expr 0 2 pos;
Expr.make_tupleaccess code_env_expr 1 2 pos;
]
(Expr.make_app
(Bindlib.box_var code_var, Expr.with_ty m1 fun_ty)
((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args)
(env_arg_ty
:: (* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *) tys)
pos)
pos
(Expr.make_app
(Expr.make_tupleaccess code_env_expr 0 2 pos)
(Expr.make_tupleaccess code_env_expr 1 2 pos :: new_args)
(env_arg_ty
:: (* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *) tys)
pos)
(* let m1 = Mark.get new_e1 in
* if Var.Map.is_empty free_vars then
* Expr.make_let_in [| code_var; env_var |] [fun_ty; env_arg_ty]
* [
* Expr.make_tupleaccess code_env_expr 0 2 pos;
* Expr.make_tupleaccess code_env_expr 1 2 pos;
* ]
* (Expr.make_app
* (Bindlib.box_var code_var, Expr.with_ty m1 fun_ty)
* ((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args)
* (env_arg_ty
* :: (\* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *\) tys)
* pos)
* pos
* else
* Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty]
* [
* Expr.make_tupleaccess code_env_expr 0 2 pos;
* Expr.make_tupleaccess code_env_expr 1 2 pos;
* ]
* (Expr.make_app
* (Bindlib.box_var code_var, Expr.with_ty m1 fun_ty)
* ((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args)
* (env_arg_ty
* :: (\* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *\) tys)
* pos)
* pos *)
in
free_vars, Expr.make_let_in code_env_var (TAny, pos) new_e1 call_expr pos
| _ -> .

View File

@ -423,7 +423,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
Format.fprintf fmt "@[<hov 2>%s@ %t%a@]" (Operator.name op)
(fun ppf ->
match op with
| Map2 | Lt_dur_dur | Lte_dur_dur | Gt_dur_dur | Gte_dur_dur
| Map2 | Add_dat_dur _ | Lt_dur_dur | Lte_dur_dur | Gt_dur_dur | Gte_dur_dur
| Eq_dur_dur ->
Format.fprintf ppf "%a@ " format_pos pos
| Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur ->

View File

@ -73,8 +73,8 @@ type stmt =
enum_name : EnumName.t;
switch_cases : switch_case list;
}
| SReturn of naked_expr
| SAssert of naked_expr
| SReturn of expr
| SAssert of expr
| SSpecialOp of special_operator
and special_operator =

View File

@ -295,7 +295,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
(* Assertions are always encapsulated in a unit-typed let binding *)
let e_stmts, new_e = translate_expr ctxt e in
RevBlock.rebuild
~tail:[A.SAssert (Mark.remove new_e), Expr.pos block_expr]
~tail:[A.SAssert new_e, Expr.pos block_expr]
e_stmts
| EFatalError err -> [SFatalError err, Expr.pos block_expr]
| EApp { f = EAbs { binder; tys }, binder_mark; args; _ } ->
@ -599,7 +599,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
let tail =
[
( (match ctxt.inside_definition_of with
| None -> A.SReturn (Mark.remove new_e)
| None -> A.SReturn new_e
| Some x ->
A.SLocalDef
{
@ -633,7 +633,7 @@ let rec translate_scope_body_expr
match scope_expr with
| Last e ->
let block, new_e = translate_expr ctx e in
RevBlock.rebuild block ~tail:[A.SReturn (Mark.remove new_e), Mark.get new_e]
RevBlock.rebuild block ~tail:[A.SReturn new_e, Mark.get new_e]
| Cons (scope_let, next_bnd) -> (
let let_var, scope_let_next = Bindlib.unbind next_bnd in
let let_var_id =
@ -758,7 +758,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
in
let body_block =
RevBlock.rebuild block
~tail:[A.SReturn (Mark.remove expr), Mark.get expr]
~tail:[A.SReturn expr, Mark.get expr]
in
( Var.Map.add var func_id func_dict,
var_dict,
@ -819,7 +819,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
A.func_params = [];
A.func_body =
RevBlock.rebuild block
~tail:[A.SReturn (Mark.remove expr), Mark.get expr];
~tail:[A.SReturn expr, Mark.get expr];
A.func_return_typ = topdef_ty;
};
}

View File

@ -152,11 +152,11 @@ let rec format_statement
| SReturn ret ->
Format.fprintf fmt "@[<hov 2>%a %a@]" Print.keyword "return"
(format_expr decl_ctx ~debug)
(ret, Mark.get stmt)
| SAssert naked_expr ->
ret
| SAssert expr ->
Format.fprintf fmt "@[<hov 2>%a %a@]" Print.keyword "assert"
(format_expr decl_ctx ~debug)
(naked_expr, Mark.get stmt)
expr
| SSwitch { switch_expr = e_switch; enum_name = enum; switch_cases = arms; _ }
->
let cons = EnumName.Map.find enum decl_ctx.ctx_enums in

View File

@ -285,6 +285,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
match Mark.remove op with
| Log (_entry, _infos) -> assert false
| FromClosureEnv | ToClosureEnv -> assert false
| Add_dat_dur _ -> assert false (* needs specific printing *)
| op -> Format.fprintf fmt "@{<blue;bold>%s@}" (Operator.name op)
let _format_string_list (fmt : Format.formatter) (uids : string list) : unit =
@ -297,7 +298,7 @@ let _format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids
(* TODO: move this to a shared place *)
(* TODO: move these to a shared place *)
let shallow_fold_expr f e acc =
let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in
match Mark.remove e with
@ -318,6 +319,24 @@ let shallow_fold_expr f e acc =
let rec fold_expr f e acc =
shallow_fold_expr (fold_expr f) e (f e acc)
(* Folds through direct expr childs, not subblocks *)
let fold_expr_stmt f st acc = match Mark.remove st with
| SInnerFuncDef _ | SLocalDecl _ | SFatalError _ -> acc
| SLocalInit { expr; _ } | SLocalDef { expr; _ } | SIfThenElse { if_expr = expr; _ }
| SSwitch { switch_expr = expr ; _ }
| SReturn expr | SAssert expr -> fold_expr f expr acc
| SSpecialOp _ -> .
let fold_expr_block f b acc =
List.fold_left (fun acc st -> fold_expr_stmt f st acc) acc b
(* These operators, since they can raise, have an added first argument giving the position of the error if it happens, so they need special treatment *)
let op_can_raise op =
let open Op in
match Mark.remove op with
| HandleExceptions | Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur | Add_dat_dur _ | Gte_dur_dur | Gt_dur_dur | Lte_dur_dur | Lt_dur_dur -> true
| _ -> false
let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) :
unit =
match Mark.remove e with
@ -344,17 +363,6 @@ let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) :
Format.fprintf fmt "%a(%a,@ %a)"
format_op op (format_expression ctx) arg1
(format_expression ctx) arg2
| EAppOp {
op = (HandleExceptions | Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur), pos as op;
args;
_ } ->
(* Operators that can raise, and take a position as argument *)
Format.fprintf fmt "%a(%a,@ %a)"
format_op op
format_var (Pos.Map.find pos ctx.lifted_pos)
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EAppOp {
op = (Reduce | Fold), _ as op;
args = [fct; base; arr];
@ -367,21 +375,35 @@ let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) :
(format_expression ctx) fct
(format_expression ctx) base
(format_expression ctx) arr
| EAppOp { op; args = [arg1; arg2]; _ } ->
Format.fprintf fmt "%a(%a,@ %a)"
format_op op
| EAppOp { op = Add_dat_dur rounding, pos; args = [arg1; arg2]; _ } ->
Format.fprintf fmt "o_add_dat_dur(%s,@ %a,@ %a,@ %a)"
(match rounding with
| RoundUp -> "catala_date_round_up"
| RoundDown -> "catala_date_round_down"
| AbortOnRound -> "catala_date_round_abort")
format_var (Pos.Map.find pos ctx.lifted_pos)
(format_expression ctx) arg1
(format_expression ctx) arg2
| EAppOp { op; args = [arg1]; _ } ->
Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1
(* | EAppOp { op; args = [arg1; arg2]; _ } ->
* Format.fprintf fmt "%a(%a,@ %a)"
* format_op op
* (format_expression ctx) arg1
* (format_expression ctx) arg2
* | EAppOp { op; args = [arg1]; _ } ->
* Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1 *)
| EApp { f; args } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" (format_expression ctx) f
Format.fprintf fmt "@[<hov 2>%a@,(@[<hov 0>%a)@]@]" (format_expression ctx) f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EAppOp { op; args; _ } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" format_op op
| EAppOp { op = _, pos as op; args; _ } ->
Format.fprintf fmt "%a(@[<hov 0>%t%a)@]"
format_op op
(fun ppf ->
if op_can_raise op then
Format.fprintf ppf "%a,@ "
format_var (Pos.Map.find pos ctx.lifted_pos))
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
@ -389,7 +411,7 @@ let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) :
| ETuple _ -> assert false (* Must be a statement *)
| ETupleAccess {e1; index=0; typ=TArrow _, _ as typ} ->
(* Closure function *)
Format.fprintf fmt "(%a)%a->funcp"
Format.fprintf fmt "@[<hov 1>((%a)@,%a->funcp)@]"
(format_typ ~const:true ctx.decl_ctx ignore) typ
(format_expression ctx) e1
| ETupleAccess {e1; index=1; typ=TClosureEnv, _} ->
@ -592,14 +614,13 @@ let rec format_statement
| SSwitch _ -> assert false
(* switches should have been rewritten to only match on variables *)
| SReturn e1 ->
Format.fprintf fmt "@,@[<hov 2>return %a;@]" (format_expression ctx)
(e1, Mark.get s)
Format.fprintf fmt "@,@[<hov 2>return %a;@]" (format_expression ctx) e1
| SAssert e1 ->
Format.fprintf fmt
"@,@[<v 2>@[<hov 2>if (%a != CATALA_TRUE) {@]\
@,@[<hov 2>catala_error(catala_assertion_failed,@ %a);@]\
@;<1 -2>}@]" (format_expression ctx)
(e1, Mark.get s)
e1
format_var (Pos.Map.find (Mark.get s) ctx.lifted_pos)
| _ -> .
(* | SSpecialOp (OHandleDefaultOpt { exceptions; just; cons; return_typ }) ->
@ -683,19 +704,13 @@ let rec format_statement
and format_block (ctx : ctx) (fmt : Format.formatter) (b : block) : unit =
let new_pos =
List.fold_left (fun pmap -> function
| (SAssert _ | SFatalError _), pos ->
let v = VarName.fresh ("pos",pos) in
Pos.Map.add pos v pmap
| (SLocalInit { expr; _} | SLocalDef { expr; _ }), _ ->
fold_expr (fun e pmap -> match e with
| EAppOp { op = (HandleExceptions | Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur), pos; _ }, _ ->
let v = VarName.fresh ("pos",pos) in
Pos.Map.add pos v pmap
| _ -> pmap)
expr pmap
| _ -> pmap)
Pos.Map.empty b
fold_expr_block
(fun e pmap -> match e with
| EAppOp { op = _, pos as op; _ }, _ when op_can_raise op ->
let v = VarName.fresh ("pos", pos) in
Pos.Map.add pos v pmap
| _ -> pmap)
b Pos.Map.empty
in
let new_pos =
Pos.Map.merge (fun _ v1 v2 -> match v1 with None -> v2 | _ -> None)
@ -794,7 +809,7 @@ let format_main
Format.fprintf fmt "@,printf(\"Executing scope %a...\\n\");"
ScopeName.format name;
Format.fprintf fmt "@,%a (NULL);" format_func_name var;
Format.fprintf fmt "@,printf(\"Scope %a executed successfully.\\n\");"
Format.fprintf fmt "@,printf(\"\\x1b[32m[RESULT]\\x1b[m Scope %a executed successfully.\\n\");"
ScopeName.format name)
scopes_with_no_input;
Format.fprintf fmt "@,return 0;@;<1 -2>}@]"
@ -849,4 +864,5 @@ let format_program
p.code_items;
Format.pp_print_cut fmt ();
format_main fmt p;
Format.pp_close_box fmt ()
Format.pp_close_box fmt ();
Format.pp_print_newline fmt ()

View File

@ -447,7 +447,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
cases
| SReturn e1 ->
Format.fprintf fmt "@[<hov 4>return %a@]" (format_expression ctx)
(e1, Mark.get s)
e1
| SAssert e1 ->
let pos = Mark.get s in
Format.fprintf fmt
@ -456,7 +456,7 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \
law_headings=@[<hv>%a@])@])@]@]"
(format_expression ctx)
(e1, Mark.get s)
e1
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos)

View File

@ -337,7 +337,7 @@ let evaluate_operator
| Add_mon_mon, [(ELit (LMoney x), _); (ELit (LMoney y), _)] ->
ELit (LMoney (o_add_mon_mon x y))
| Add_dat_dur r, [(ELit (LDate x), _); (ELit (LDuration y), _)] ->
ELit (LDate (o_add_dat_dur r x y))
ELit (LDate (o_add_dat_dur r (rpos()) x y))
| Add_dur_dur, [(ELit (LDuration x), _); (ELit (LDuration y), _)] ->
ELit (LDuration (o_add_dur_dur x y))
| Sub_int_int, [(ELit (LInt x), _); (ELit (LInt y), _)] ->

View File

@ -123,8 +123,9 @@ let date_of_year (year : int) = Runtime.date_of_numbers year 1 1
(** Returns the date (as a string) corresponding to nb days after the base day,
defined here as Jan 1, 1900 **)
let nb_days_to_date (nb : int) : string =
let dummy_pos = { Runtime.filename = ""; start_line = 0; start_column = 0; end_line = 0; end_column = 0; law_headings = [] } in
Runtime.date_to_string
(Runtime.Oper.o_add_dat_dur AbortOnRound base_day
(Runtime.Oper.o_add_dat_dur AbortOnRound dummy_pos base_day
(Runtime.duration_of_numbers 0 0 nb))
(** [print_z3model_expr] pretty-prints the value [e] given by a Z3 model

View File

@ -444,13 +444,14 @@ typedef enum catala_date_rounding
} catala_date_rounding;
CATALA_DATE o_add_dat_dur (catala_date_rounding mode,
CATALA_DATE x1,
CATALA_DURATION x2)
const catala_code_position* pos,
CATALA_DATE x1,
CATALA_DURATION x2)
{
/* TODO */
return catala_new_date(x1->year + x2->years,
x1->month + x2->months,
x1->day + x2->days);
x1->month + x2->months,
x1->day + x2->days);
}
CATALA_DURATION o_add_dur_dur (CATALA_DURATION x1, CATALA_DURATION x2)

View File

@ -52,6 +52,7 @@ type error =
| DivisionByZero
| NotSameLength
| UncomparableDurations
| AmbiguousDateRounding
| IndivisibleDurations
let error_to_string = function
@ -61,6 +62,7 @@ let error_to_string = function
| DivisionByZero -> "DivisionByZero"
| NotSameLength -> "NotSameLength"
| UncomparableDurations -> "UncomparableDurations"
| AmbiguousDateRounding -> "AmbiguousDateRounding"
| IndivisibleDurations -> "IndivisibleDurations"
let error_message = function
@ -75,6 +77,8 @@ let error_message = function
| UncomparableDurations ->
"ambiguous comparison between durations in different units (e.g. months \
vs. days)"
| AmbiguousDateRounding ->
"ambiguous date computation, and rounding mode was not specified"
| IndivisibleDurations -> "dividing durations that are not in days"
exception Error of error * source_position list
@ -791,7 +795,10 @@ module Oper = struct
let o_add_int_int i1 i2 = Z.add i1 i2
let o_add_rat_rat i1 i2 = Q.add i1 i2
let o_add_mon_mon m1 m2 = Z.add m1 m2
let o_add_dat_dur r da du = Dates_calc.Dates.add_dates ~round:r da du
let o_add_dat_dur r pos da du =
try Dates_calc.Dates.add_dates ~round:r da du
with Dates_calc.Dates.AmbiguousComputation ->
error AmbiguousDateRounding [pos]
let o_add_dur_dur = Dates_calc.Dates.add_periods
let o_sub_int_int i1 i2 = Z.sub i1 i2
let o_sub_rat_rat i1 i2 = Q.sub i1 i2

View File

@ -77,6 +77,8 @@ type error =
| NotSameLength (** Traversing multiple lists of different lengths *)
| UncomparableDurations
(** Comparing durations in different units (e.g. months vs. days) *)
| AmbiguousDateRounding
(** ambiguous date computation, and rounding mode was not specified *)
| IndivisibleDurations (** Dividing durations that are not in days *)
val error_to_string : error -> string
@ -376,7 +378,7 @@ module Oper : sig
val o_add_int_int : integer -> integer -> integer
val o_add_rat_rat : decimal -> decimal -> decimal
val o_add_mon_mon : money -> money -> money
val o_add_dat_dur : date_rounding -> date -> duration -> date
val o_add_dat_dur : date_rounding -> source_position -> date -> duration -> date
val o_add_dur_dur : duration -> duration -> duration
val o_sub_int_int : integer -> integer -> integer
val o_sub_rat_rat : decimal -> decimal -> decimal