Typing across closure conversion (#627)

This commit is contained in:
Denis Merigoux 2024-06-03 09:39:17 +02:00 committed by GitHub
commit a0c982a6c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 230 additions and 249 deletions

View File

@ -24,7 +24,31 @@ type 'm ctx = {
globally_bound_vars : ('m expr, typ) Var.Map.t; globally_bound_vars : ('m expr, typ) Var.Map.t;
} }
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys (** Function types will be transformed in this way throughout, including in
[decl_ctx] *)
let rec translate_type t =
let pos = Mark.get t in
match Mark.remove t with
| TArrow (t1, t2) ->
( TTuple
[
( TArrow
( (TClosureEnv, Pos.no_pos) :: List.map translate_type t1,
translate_type t2 ),
Pos.no_pos );
TClosureEnv, Pos.no_pos;
],
pos )
| TDefault t' -> TDefault (translate_type t'), pos
| TOption t' -> TOption (translate_type t'), pos
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
| TArray ts -> TArray (translate_type ts), pos
| TTuple ts -> TTuple (List.map translate_type ts), pos
let translate_mark e = Mark.map_mark (Expr.map_ty translate_type) e
let join_vars : ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t =
fun m1 m2 -> Var.Map.union (fun _ a _ -> Some a) m1 m2
(** {1 Transforming closures}*) (** {1 Transforming closures}*)
@ -33,19 +57,20 @@ let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10 http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10
(environment-passing closure conversion). *) (environment-passing closure conversion). *)
let rec transform_closures_expr : let rec transform_closures_expr :
type m. m ctx -> m expr -> m expr Var.Set.t * m expr boxed = type m. m ctx -> m expr -> (m expr, m mark) Var.Map.t * m expr boxed =
fun ctx e -> fun ctx e ->
let e = translate_mark e in
let m = Mark.get e in let m = Mark.get e in
match Mark.remove e with match Mark.remove e with
| EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _
| ELit _ | EExternal _ | EAssert _ | EFatalError _ | EIfThenElse _ | ELit _ | EExternal _ | EAssert _ | EFatalError _ | EIfThenElse _
| ERaiseEmpty | ECatchEmpty _ -> | ERaiseEmpty | ECatchEmpty _ ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
~f:(transform_closures_expr ctx) ~f:(transform_closures_expr ctx)
e e
| EVar v -> ( | EVar v -> (
match Var.Map.find_opt v ctx.globally_bound_vars with match Var.Map.find_opt v ctx.globally_bound_vars with
| None -> Var.Set.singleton v, (Bindlib.box_var v, m) | None -> Var.Map.singleton v m, (Bindlib.box_var v, m)
| Some (TArrow (targs, tret), _) -> | Some (TArrow (targs, tret), _) ->
(* Here we eta-expand the argument to make sure function pointers are (* Here we eta-expand the argument to make sure function pointers are
correctly casted as closures *) correctly casted as closures *)
@ -69,13 +94,13 @@ let rec transform_closures_expr :
{ {
ctx with ctx with
globally_bound_vars = globally_bound_vars =
Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars; Var.Map.add v (Expr.maybe_ty m) ctx.globally_bound_vars;
} }
in in
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e) Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
in in
Bindlib.unbox boxed Bindlib.unbox boxed
| Some _ -> Var.Set.empty, (Bindlib.box_var v, m)) | Some _ -> Var.Map.empty, (Bindlib.box_var v, m))
| EMatch { e; cases; name } -> | EMatch { e; cases; name } ->
let free_vars, new_e = (transform_closures_expr ctx) e in let free_vars, new_e = (transform_closures_expr ctx) e in
(* We do not close the clotures inside the arms of the match expression, (* We do not close the clotures inside the arms of the match expression,
@ -89,13 +114,11 @@ let rec transform_closures_expr :
let new_free_vars, new_body = (transform_closures_expr ctx) body in let new_free_vars, new_body = (transform_closures_expr ctx) body in
let new_free_vars = let new_free_vars =
Array.fold_left Array.fold_left
(fun acc v -> Var.Set.remove v acc) (fun acc v -> Var.Map.remove v acc)
new_free_vars vars new_free_vars vars
in in
let new_binder = Expr.bind vars new_body in let new_binder = Expr.bind vars new_body in
( Var.Set.union free_vars ( join_vars free_vars new_free_vars,
(Var.Set.diff new_free_vars
(Var.Set.of_list (Array.to_list vars))),
EnumConstructor.Map.add cons EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Mark.get e1)) (Expr.eabs new_binder tys (Mark.get e1))
new_cases ) new_cases )
@ -109,54 +132,58 @@ let rec transform_closures_expr :
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let free_vars, new_body = (transform_closures_expr ctx) body in let free_vars, new_body = (transform_closures_expr ctx) body in
let free_vars = let free_vars =
Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars Array.fold_left (fun acc v -> Var.Map.remove v acc) free_vars vars
in in
let new_binder = Expr.bind vars new_body in let new_binder = Expr.bind vars new_body in
let free_vars, new_args = let free_vars, new_args =
List.fold_right List.fold_right
(fun arg (free_vars, new_args) -> (fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args) join_vars free_vars new_free_vars, new_arg :: new_args)
args (free_vars, []) args (free_vars, [])
in in
( free_vars, ( free_vars,
Expr.eapp Expr.eapp
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos) ~f:(Expr.eabs new_binder (List.map translate_type tys) e1_pos)
~args:new_args ~tys m ) ~args:new_args ~tys m )
| EAbs { binder; tys } -> | EAbs { binder; tys } ->
(* λ x.t *) (* λ x.t *)
let binder_mark = Expr.with_ty m (TAny, Expr.mark_pos m) in let binder_pos = Expr.mark_pos m in
let binder_pos = Expr.mark_pos binder_mark in let mark_ty ty = Expr.with_ty m ty in
(* Converting the closure. *) (* Converting the closure. *)
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
(* t *) (* t *)
let body_vars, new_body = (transform_closures_expr ctx) body in let body_vars, new_body = (transform_closures_expr ctx) body in
(* [[t]] *) (* [[t]] *)
let extra_vars = let extra_vars =
Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars)) Array.fold_left (fun m v -> Var.Map.remove v m) body_vars vars
in
let extra_vars_list = Var.Map.bindings extra_vars in
let extra_vars_types =
List.map (fun (_, m) -> Expr.maybe_ty m) extra_vars_list
in in
let extra_vars_list = Var.Set.elements extra_vars in
(* x1, ..., xn *) (* x1, ..., xn *)
let code_var = Var.make ctx.name_context in let code_var = Var.make ctx.name_context in
(* code *) (* code *)
let closure_env_arg_var = Var.make "env" in let closure_env_arg_var = Var.make "env" in
let closure_env_var = Var.make "env" in let closure_env_var = Var.make "env" in
let any_ty = TAny, binder_pos in let env_ty = TTuple extra_vars_types, binder_pos in
(* let env = from_closure_env env in let arg0 = env.0 in ... *) (* let env = from_closure_env env in let arg0 = env.0 in ... *)
let new_closure_body = let new_closure_body =
Expr.make_let_in closure_env_var any_ty Expr.make_let_in closure_env_var env_ty
(Expr.eappop (Expr.eappop
~op:(Operator.FromClosureEnv, binder_pos) ~op:(Operator.FromClosureEnv, binder_pos)
~tys:[TClosureEnv, binder_pos] ~tys:[TClosureEnv, binder_pos]
~args:[Expr.evar closure_env_arg_var binder_mark] ~args:
binder_mark) [Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, binder_pos))]
(mark_ty env_ty))
(Expr.make_multiple_let_in (Expr.make_multiple_let_in
(Array.of_list extra_vars_list) (Array.of_list (List.map fst extra_vars_list))
(List.map (fun _ -> any_ty) extra_vars_list) extra_vars_types
(List.mapi (List.mapi
(fun i _ -> (fun i _ ->
Expr.make_tupleaccess Expr.make_tupleaccess
(Expr.evar closure_env_var binder_mark) (Expr.evar closure_env_var (mark_ty env_ty))
i i
(List.length extra_vars_list) (List.length extra_vars_list)
binder_pos) binder_pos)
@ -167,33 +194,39 @@ let rec transform_closures_expr :
(* fun env arg0 ... -> new_closure_body *) (* fun env arg0 ... -> new_closure_body *)
let new_closure = let new_closure =
Expr.make_abs Expr.make_abs
(Array.concat [Array.make 1 closure_env_arg_var; vars]) (Array.append [| closure_env_arg_var |] vars)
new_closure_body new_closure_body
((TClosureEnv, binder_pos) :: tys) ((TClosureEnv, binder_pos) :: tys)
(Expr.pos e) (Expr.pos e)
in in
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
( extra_vars, ( extra_vars,
Expr.make_let_in code_var Expr.make_let_in code_var new_closure_ty new_closure
(TAny, Expr.pos e)
new_closure
(Expr.make_tuple (Expr.make_tuple
((Bindlib.box_var code_var, binder_mark) ((Bindlib.box_var code_var, mark_ty new_closure_ty)
:: [ :: [
Expr.eappop Expr.eappop
~op:(Operator.ToClosureEnv, binder_pos) ~op:(Operator.ToClosureEnv, binder_pos)
~tys:[TAny, Expr.pos e] ~tys:
[
( (if extra_vars_list = [] then TLit TUnit
else TTuple extra_vars_types),
binder_pos );
]
~args: ~args:
[ [
(if extra_vars_list = [] then Expr.elit LUnit binder_mark (if extra_vars_list = [] then
Expr.elit LUnit (mark_ty (TLit TUnit, binder_pos))
else else
Expr.etuple Expr.etuple
(List.map (List.map
(fun extra_var -> (fun (extra_var, m) ->
Bindlib.box_var extra_var, binder_mark) ( Bindlib.box_var extra_var,
Expr.with_pos binder_pos m ))
extra_vars_list) extra_vars_list)
m); (mark_ty (TTuple extra_vars_types, binder_pos)));
] ]
(Mark.get e); (mark_ty (TClosureEnv, binder_pos));
]) ])
m) m)
(Expr.pos e) ) (Expr.pos e) )
@ -219,16 +252,16 @@ let rec transform_closures_expr :
let new_arg = let new_arg =
Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg) Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg)
in in
Var.Set.union free_vars new_free_vars, new_arg :: new_args join_vars free_vars new_free_vars, new_arg :: new_args
| _ -> | _ ->
let new_free_vars, new_arg = transform_closures_expr ctx arg in let new_free_vars, new_arg = transform_closures_expr ctx arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args) join_vars free_vars new_free_vars, new_arg :: new_args)
args (Var.Set.empty, []) args (Var.Map.empty, [])
in in
free_vars, Expr.eappop ~op ~tys ~args:new_args (Mark.get e) free_vars, Expr.eappop ~op ~tys ~args:new_args (Mark.get e)
| EAppOp _ -> | EAppOp _ ->
(* This corresponds to an operator call, which we don't want to transform *) (* This corresponds to an operator call, which we don't want to transform *)
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
~f:(transform_closures_expr ctx) ~f:(transform_closures_expr ctx)
e e
| EApp { f = EVar v, f_m; args; tys } | EApp { f = EVar v, f_m; args; tys }
@ -239,12 +272,16 @@ let rec transform_closures_expr :
List.fold_right List.fold_right
(fun arg (free_vars, new_args) -> (fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args) join_vars free_vars new_free_vars, new_arg :: new_args)
args (Var.Set.empty, []) args (Var.Map.empty, [])
in in
free_vars, Expr.eapp ~f:(Expr.evar v f_m) ~args:new_args ~tys m free_vars, Expr.eapp ~f:(Expr.evar v f_m) ~args:new_args ~tys m
| EApp { f = e1; args; tys } -> | EApp { f = e1; args; tys } ->
let free_vars, new_e1 = (transform_closures_expr ctx) e1 in let free_vars, new_e1 = (transform_closures_expr ctx) e1 in
let tys = List.map translate_type tys in
let pos = Expr.mark_pos m in
let env_arg_ty = TClosureEnv, Expr.pos new_e1 in
let fun_ty = TArrow (env_arg_ty :: tys, Expr.maybe_ty m), pos in
let code_env_var = Var.make "code_and_env" in let code_env_var = Var.make "code_and_env" in
let code_env_expr = let code_env_expr =
let pos = Expr.pos e1 in let pos = Expr.pos e1 in
@ -252,8 +289,7 @@ let rec transform_closures_expr :
(Expr.with_ty (Mark.get e1) (Expr.with_ty (Mark.get e1)
( TTuple ( TTuple
[ [
( TArrow ((TClosureEnv, pos) :: tys, (TAny, Expr.pos e)), TArrow ((TClosureEnv, pos) :: tys, Expr.maybe_ty m), Expr.pos e;
Expr.pos e );
TClosureEnv, pos; TClosureEnv, pos;
], ],
pos )) pos ))
@ -264,24 +300,23 @@ let rec transform_closures_expr :
List.fold_right List.fold_right
(fun arg (free_vars, new_args) -> (fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args) join_vars free_vars new_free_vars, new_arg :: new_args)
args (free_vars, []) args (free_vars, [])
in in
let call_expr = let call_expr =
let m1 = Mark.get e1 in let m1 = Mark.get new_e1 in
let pos = Expr.mark_pos m in
let env_arg_ty = TClosureEnv, Expr.pos e1 in
let fun_ty = TArrow (env_arg_ty :: tys, (TAny, Expr.pos e)), Expr.pos e in
Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty] Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty]
[ [
Expr.make_tupleaccess code_env_expr 0 2 pos; Expr.make_tupleaccess code_env_expr 0 2 pos;
Expr.make_tupleaccess code_env_expr 1 2 pos; Expr.make_tupleaccess code_env_expr 1 2 pos;
] ]
(Expr.eapp (Expr.make_app
~f:(Bindlib.box_var code_var, m1) (Bindlib.box_var code_var, Expr.with_ty m1 fun_ty)
~args:((Bindlib.box_var env_var, m1) :: new_args) ((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args)
~tys:(env_arg_ty :: tys) m) (env_arg_ty
(Expr.pos e) :: (* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *) tys)
pos)
pos
in in
( free_vars, ( free_vars,
Expr.make_let_in code_env_var Expr.make_let_in code_env_var
@ -393,33 +428,15 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
capture footprint. See capture footprint. See
[tests/tests_func/good/scope_call_func_struct_closure.catala_en]. *) [tests/tests_func/good/scope_call_func_struct_closure.catala_en]. *)
let new_decl_ctx = let new_decl_ctx =
let rec replace_fun_typs t =
match Mark.remove t with
| TArrow (t1, t2) ->
( TTuple
[
( TArrow
( (TClosureEnv, Pos.no_pos) :: List.map replace_fun_typs t1,
replace_fun_typs t2 ),
Pos.no_pos );
TClosureEnv, Pos.no_pos;
],
Mark.get t )
| TDefault t' -> TDefault (replace_fun_typs t'), Mark.get t
| TOption t' -> TOption (replace_fun_typs t'), Mark.get t
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
| TArray ts -> TArray (replace_fun_typs ts), Mark.get t
| TTuple ts -> TTuple (List.map replace_fun_typs ts), Mark.get t
in
{ {
p.decl_ctx with p.decl_ctx with
ctx_structs = ctx_structs =
StructName.Map.map StructName.Map.map
(StructField.Map.map replace_fun_typs) (StructField.Map.map translate_type)
p.decl_ctx.ctx_structs; p.decl_ctx.ctx_structs;
ctx_enums = ctx_enums =
EnumName.Map.map EnumName.Map.map
(EnumConstructor.Map.map replace_fun_typs) (EnumConstructor.Map.map translate_type)
p.decl_ctx.ctx_enums; p.decl_ctx.ctx_enums;
(* Toplevel definitions may not contain scope calls or take functions as (* Toplevel definitions may not contain scope calls or take functions as
arguments at the moment, which ensures that their interfaces aren't arguments at the moment, which ensures that their interfaces aren't
@ -489,9 +506,7 @@ let rec hoist_closures_expr :
args (collected_closures, []) args (collected_closures, [])
in in
( collected_closures, ( collected_closures,
Expr.eapp Expr.eapp ~f:(Expr.eabs new_binder tys e1_pos) ~args:new_args ~tys m )
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos)
~args:new_args ~tys m )
| EAppOp | EAppOp
{ {
op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op; op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op;
@ -525,20 +540,16 @@ let rec hoist_closures_expr :
in in
collected_closures, Expr.eappop ~op ~args:new_args ~tys (Mark.get e) collected_closures, Expr.eappop ~op ~args:new_args ~tys (Mark.get e)
| EAbs { tys; _ } -> | EAbs { tys; _ } ->
(* this is the closure we want to hoist*) (* this is the closure we want to hoist *)
let closure_var = Var.make ("closure_" ^ name_context) in let closure_var = Var.make ("closure_" ^ name_context) in
(* TODO: This will end up as a toplevel name. However for now we assume (* TODO: This will end up as a toplevel name. However for now we assume
toplevel names are unique, but this breaks this assertions and can lead toplevel names are unique, but this breaks this assertions and can lead
to name wrangling in the backends. We need to have a better system for to name wrangling in the backends. We need to have a better system for
name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but
also OCaml, Python, etc. *) also OCaml, Python, etc. *)
( [ let pos = Expr.mark_pos m in
{ let ty = Expr.maybe_ty ~typ:(TArrow (tys, (TAny, pos))) m in
name = closure_var; ( [{ name = closure_var; ty; closure = Expr.rebox e }],
ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m;
closure = Expr.rebox e;
};
],
Expr.make_var closure_var m ) Expr.make_var closure_var m )
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EFatalError _ | EAppOp _ | EIfThenElse _ | EArray _ | ELit _ | EAssert _ | EFatalError _ | EAppOp _ | EIfThenElse _
@ -660,9 +671,9 @@ let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
(** {1 Closure conversion}*) (** {1 Closure conversion}*)
let closure_conversion (p : 'm program) : untyped program = let closure_conversion (p : 'm program) : 'm program =
let new_p = transform_closures_program p in let new_p = transform_closures_program p in
let new_p = hoist_closures_program (Bindlib.unbox new_p) in let new_p = hoist_closures_program (Bindlib.unbox new_p) in
(* FIXME: either fix the types of the marks, or remove the types annotations (* FIXME: either fix the types of the marks, or remove the types annotations
during the main processing (rather than requiring a new traversal) *) during the main processing (rather than requiring a new traversal) *)
Program.untype (Bindlib.unbox new_p) Bindlib.unbox new_p

View File

@ -21,4 +21,4 @@
After closure conversion, closure hoisting is perform and all closures end After closure conversion, closure hoisting is perform and all closures end
up as toplevel definitions. *) up as toplevel definitions. *)
val closure_conversion : 'm Ast.program -> Shared_ast.untyped Ast.program val closure_conversion : 'm Ast.program -> 'm Ast.program

View File

@ -350,6 +350,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
failwith failwith
"should not happen, array initialization is caught at the statement level" "should not happen, array initialization is caught at the statement level"
| ELit l -> Format.fprintf fmt "%a" format_lit (Mark.copy e l) | ELit l -> Format.fprintf fmt "%a" format_lit (Mark.copy e l)
| EAppOp { op = (ToClosureEnv | FromClosureEnv), _; args = [arg] } ->
format_expression ctx fmt arg
| EAppOp { op = ((Map | Filter), _) as op; args = [arg1; arg2] } -> | EAppOp { op = ((Map | Filter), _) as op; args = [arg1; arg2] } ->
Format.fprintf fmt "%a(%a,@ %a)" format_op op (format_expression ctx) arg1 Format.fprintf fmt "%a(%a,@ %a)" format_op op (format_expression ctx) arg1
(format_expression ctx) arg2 (format_expression ctx) arg2
@ -441,19 +443,14 @@ let rec format_statement
| SFatalError err -> | SFatalError err ->
let pos = Mark.get s in let pos = Mark.get s in
Format.fprintf fmt Format.fprintf fmt
"catala_fatal_error_raised.code = catala_%s;@,\ "@[<hov 2>catala_raise_fatal_error (catala_%s,@ \"%s\",@ %d, %d, %d, \
catala_fatal_error_raised.position.filename = \"%s\";@,\ %d);@]"
catala_fatal_error_raised.position.start_line = %d;@,\
catala_fatal_error_raised.position.start_column = %d;@,\
catala_fatal_error_raised.position.end_line = %d;@,\
catala_fatal_error_raised.position.end_column = %d;@,\
longjmp(catala_fatal_error_jump_buffer, 0);"
(String.to_snake_case (Runtime.error_to_string err)) (String.to_snake_case (Runtime.error_to_string err))
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
| SIfThenElse { if_expr = cond; then_block = b1; else_block = b2 } -> | SIfThenElse { if_expr = cond; then_block = b1; else_block = b2 } ->
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>if (%a) {@\n%a@]@\n@[<hov 2>} else {@\n%a@]@\n}" "@[<hv 2>@[<hov 2>if (%a) {@]@,%a@,@;<1 -2>} else {@,%a@,@;<1 -2>}@]"
(format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } -> | SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cases = let cases =
@ -463,34 +460,33 @@ let rec format_statement
(EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums)) (EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums))
in in
let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "@[<hov 2>%a %a = %a;@]@\n@[<hov 2>if %a@]@\n}" Format.fprintf fmt "@[<hov 2>%a %a = %a;@]@," format_enum_name e_name
format_enum_name e_name format_var tmp_var (format_expression ctx) e1 format_var tmp_var (format_expression ctx) e1;
(Format.pp_print_list Format.pp_open_vbox fmt 2;
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 2>} else if ") Format.fprintf fmt "@[<hov 4>switch (%a.code) {@]@," format_var tmp_var;
(fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) -> Format.pp_print_list
Format.fprintf fmt "(%a.code == %a_%a) {@\n%a = %a.payload.%a;@\n%a" (fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) ->
format_var tmp_var format_enum_name e_name format_enum_cons_name Format.fprintf fmt "@[<hv 2>case %a_%a:@ " format_enum_name e_name
cons_name format_enum_cons_name cons_name;
(format_typ ctx (fun fmt -> format_var fmt payload_var_name)) if not (Type.equal payload_var_typ (TLit TUnit, Pos.no_pos)) then
payload_var_typ format_var tmp_var format_enum_cons_name cons_name Format.fprintf fmt "%a = %a.payload.%a;@ "
(format_block ctx) case_block)) (format_typ ctx (fun fmt -> format_var fmt payload_var_name))
cases payload_var_typ format_var tmp_var format_enum_cons_name cons_name;
Format.fprintf fmt "%a@ break;@]" (format_block ctx) case_block)
fmt cases;
(* Do we want to add 'default' case with a failure ? *)
Format.fprintf fmt "@;<0 -2>}";
Format.pp_close_box fmt ()
| SReturn e1 -> | SReturn e1 ->
Format.fprintf fmt "@[<hov 2>return %a;@]" (format_expression ctx) Format.fprintf fmt "@[<hov 2>return %a;@]" (format_expression ctx)
(e1, Mark.get s) (e1, Mark.get s)
| SAssert e1 -> | SAssert e1 ->
let pos = Mark.get s in let pos = Mark.get s in
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>if (!(%a)) {@\n\ "@[<v 2>@[<hov 2>if (!(%a)) {@]@,\
catala_fatal_error_raised.code = catala_assertion_failure;@,\ @[<hov 2>catala_raise_fatal_error (catala_assertion_failed,@ \"%s\",@ \
catala_fatal_error_raised.position.filename = \"%s\";@,\ %d, %d, %d, %d);@]@;\
catala_fatal_error_raised.position.start_line = %d;@,\ <1 -2>}@]" (format_expression ctx)
catala_fatal_error_raised.position.start_column = %d;@,\
catala_fatal_error_raised.position.end_line = %d;@,\
catala_fatal_error_raised.position.end_column = %d;@,\
longjmp(catala_fatal_error_jump_buffer, 0);@,\
}"
(format_expression ctx)
(e1, Mark.get s) (e1, Mark.get s)
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
@ -548,14 +544,9 @@ let rec format_statement
exceptions; exceptions;
Format.fprintf fmt Format.fprintf fmt
"@[<v 2>if (%a) {@,\ "@[<v 2>if (%a) {@,\
catala_fatal_error_raised.code = catala_conflict;@,\ @[<hov 2>catala_raise_fatal_error(catala_conflict,@ \"%s\",@ %d, %d, \
catala_fatal_error_raised.position.filename = \"%s\";@,\ %d, %d);@]@;\
catala_fatal_error_raised.position.start_line = %d;@,\ <1 -2>}@]@,"
catala_fatal_error_raised.position.start_column = %d;@,\
catala_fatal_error_raised.position.end_line = %d;@,\
catala_fatal_error_raised.position.end_column = %d;@,\
longjmp(catala_fatal_error_jump_buffer, 0);@]@,\
}@,"
format_var exception_conflict (Pos.get_file pos) format_var exception_conflict (Pos.get_file pos)
(Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos); (Pos.get_end_line pos) (Pos.get_end_column pos);

View File

@ -100,6 +100,7 @@ module Map = struct
let empty = empty let empty = empty
let singleton v x = singleton (t v) x let singleton v x = singleton (t v) x
let add v x m = add (t v) x m let add v x m = add (t v) x m
let remove v m = remove (t v) m
let update v f m = update (t v) f m let update v f m = update (t v) f m
let find v m = find (t v) m let find v m = find (t v) m
let find_opt v m = find_opt (t v) m let find_opt v m = find_opt (t v) m

View File

@ -64,6 +64,7 @@ module Map : sig
val empty : ('e, 'x) t val empty : ('e, 'x) t
val singleton : 'e var -> 'x -> ('e, 'x) t val singleton : 'e var -> 'x -> ('e, 'x) t
val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t
val remove : 'e var -> ('e, 'x) t -> ('e, 'x) t
val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t
val find : 'e var -> ('e, 'x) t -> 'x val find : 'e var -> ('e, 'x) t -> 'x
val find_opt : 'e var -> ('e, 'x) t -> 'x option val find_opt : 'e var -> ('e, 'x) t -> 'x option

4
dune
View File

@ -1,6 +1,6 @@
(dirs runtimes compiler build_system) (dirs runtimes compiler build_system tests)
(data_only_dirs tests syntax_highlighting) (data_only_dirs syntax_highlighting)
(vendored_dirs catala-examples.tmp french-law.tmp) (vendored_dirs catala-examples.tmp french-law.tmp)

View File

@ -33,6 +33,22 @@ catala_fatal_error catala_fatal_error_raised;
jmp_buf catala_fatal_error_jump_buffer; jmp_buf catala_fatal_error_jump_buffer;
void catala_raise_fatal_error(catala_fatal_error_code code,
char *filename,
unsigned int start_line,
unsigned int start_column,
unsigned int end_line,
unsigned int end_column)
{
catala_fatal_error_raised.code = code;
catala_fatal_error_raised.position.filename = filename;
catala_fatal_error_raised.position.start_line = start_line;
catala_fatal_error_raised.position.start_column = start_column;
catala_fatal_error_raised.position.end_line = end_line;
catala_fatal_error_raised.position.end_column = end_column;
longjmp(catala_fatal_error_jump_buffer, 0);
}
typedef struct pointer_list pointer_list; typedef struct pointer_list pointer_list;
struct pointer_list struct pointer_list
{ {

View File

@ -1,7 +1,7 @@
(documentation (documentation
(package catala)) (package catala))
(dirs jsoo ocaml python r rescript) (dirs jsoo ocaml python r rescript c)
; Installation is done as source under catala lib directory ; Installation is done as source under catala lib directory
; For dev version this makes it easy to install the proper runtime with just ; For dev version this makes it easy to install the proper runtime with just

1
tests/.ocamlformat Normal file
View File

@ -0,0 +1 @@
disable

View File

@ -52,25 +52,31 @@ int main()
{ {
char *error_kind; char *error_kind;
switch (catala_fatal_error_raised.code) switch (catala_fatal_error_raised.code)
{ {
case catala_no_value_provided: case catala_assertion_failed:
error_kind = "No value provided"; error_kind = "an assertion doesn't hold";
break; break;
case catala_conflict: case catala_no_value:
error_kind = "Conflict between exceptions"; error_kind = "no applicable rule to define this variable in this situation";
break; break;
case catala_crash: case catala_conflict:
error_kind = "Crash"; error_kind = "conflict between multiple valid consequences for assigning the same variable";
break; break;
case catala_empty: case catala_division_by_zero:
error_kind = "Empty error not caught"; error_kind = "a value is being used as denominator in a division and it computed to zero";
break; break;
case catala_assertion_failure: case catala_not_same_length:
error_kind = "Asssertion failure"; error_kind = "traversing multiple lists of different lengths";
break; break;
case catala_malloc_error: case catala_uncomparable_durations:
error_kind = "ambiguous comparison between durations in different units (e.g. months vs. days)";
break;
case catala_indivisible_durations:
error_kind = "dividing durations that are not in days";
break;
case catala_malloc_error:
error_kind = "Malloc error"; error_kind = "Malloc error";
} }
printf("\033[1;31m[ERROR]\033[0m %s in file %s:%d.%d-%d.%d\n", printf("\033[1;31m[ERROR]\033[0m %s in file %s:%d.%d-%d.%d\n",
error_kind, error_kind,
catala_fatal_error_raised.position.filename, catala_fatal_error_raised.position.filename,

View File

@ -120,13 +120,8 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
if (exception_conflict) { if (exception_conflict) {
catala_fatal_error_raised.code = catala_conflict; catala_raise_fatal_error(catala_conflict,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 11, 11, 11, 12);
catala_fatal_error_raised.position.start_line = 11;
catala_fatal_error_raised.position.start_column = 11;
catala_fatal_error_raised.position.end_line = 11;
catala_fatal_error_raised.position.end_column = 12;
longjmp(catala_fatal_error_jump_buffer, 0);
} }
if (exception_acc.code == option_1_enum_some_1_cons) { if (exception_acc.code == option_1_enum_some_1_cons) {
temp_a_1 = exception_acc; temp_a_1 = exception_acc;
@ -157,13 +152,8 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
if (exception_conflict_1) { if (exception_conflict_1) {
catala_fatal_error_raised.code = catala_conflict; catala_raise_fatal_error(catala_conflict,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 11, 11, 11, 12);
catala_fatal_error_raised.position.start_line = 11;
catala_fatal_error_raised.position.start_column = 11;
catala_fatal_error_raised.position.end_line = 11;
catala_fatal_error_raised.position.end_column = 12;
longjmp(catala_fatal_error_jump_buffer, 0);
} }
if (exception_acc_1.code == option_1_enum_some_1_cons) { if (exception_acc_1.code == option_1_enum_some_1_cons) {
temp_a_3 = exception_acc_1; temp_a_3 = exception_acc_1;
@ -178,18 +168,15 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
option_1_enum match_arg = temp_a_3; option_1_enum match_arg = temp_a_3;
if (match_arg.code == option_1_enum_none_1_cons) { switch (match_arg.code) {
void* /* unit */ dummy_var = match_arg.payload.none_1_cons; case option_1_enum_none_1_cons:
catala_fatal_error_raised.code = catala_no_value; catala_raise_fatal_error (catala_no_value,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 11, 11, 11, 12);
catala_fatal_error_raised.position.start_line = 11; break;
catala_fatal_error_raised.position.start_column = 11; case option_1_enum_some_1_cons:
catala_fatal_error_raised.position.end_line = 11; bar_enum arg = match_arg.payload.some_1_cons;
catala_fatal_error_raised.position.end_column = 12; temp_a_2 = arg;
longjmp(catala_fatal_error_jump_buffer, 0); break;
} else if (match_arg.code == option_1_enum_some_1_cons) {
bar_enum arg = match_arg.payload.some_1_cons;
temp_a_2 = arg;
} }
option_1_enum temp_a_8 = {option_1_enum_some_1_cons, option_1_enum temp_a_8 = {option_1_enum_some_1_cons,
{some_1_cons: temp_a_2}}; {some_1_cons: temp_a_2}};
@ -200,18 +187,15 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
option_1_enum match_arg_1 = temp_a_1; option_1_enum match_arg_1 = temp_a_1;
if (match_arg_1.code == option_1_enum_none_1_cons) { switch (match_arg_1.code) {
void* /* unit */ dummy_var = match_arg_1.payload.none_1_cons; case option_1_enum_none_1_cons:
catala_fatal_error_raised.code = catala_no_value; catala_raise_fatal_error (catala_no_value,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 11, 11, 11, 12);
catala_fatal_error_raised.position.start_line = 11; break;
catala_fatal_error_raised.position.start_column = 11; case option_1_enum_some_1_cons:
catala_fatal_error_raised.position.end_line = 11; bar_enum arg_1 = match_arg_1.payload.some_1_cons;
catala_fatal_error_raised.position.end_column = 12; temp_a = arg_1;
longjmp(catala_fatal_error_jump_buffer, 0); break;
} else if (match_arg_1.code == option_1_enum_some_1_cons) {
bar_enum arg_1 = match_arg_1.payload.some_1_cons;
temp_a = arg_1;
} }
bar_enum a_1; bar_enum a_1;
a_1 = temp_a; a_1 = temp_a;
@ -221,12 +205,12 @@ baz_struct baz_func(baz_in_struct baz_in) {
option_2_enum temp_b_3; option_2_enum temp_b_3;
char /* bool */ temp_b_4; char /* bool */ temp_b_4;
bar_enum match_arg_2 = a_1; bar_enum match_arg_2 = a_1;
if (match_arg_2.code == bar_enum_no_cons) { switch (match_arg_2.code) {
void* /* unit */ dummy_var = match_arg_2.payload.no_cons; case bar_enum_no_cons: temp_b_4 = 1 /* TRUE */; break;
temp_b_4 = 1 /* TRUE */; case bar_enum_yes_cons:
} else if (match_arg_2.code == bar_enum_yes_cons) { foo_struct dummy_var = match_arg_2.payload.yes_cons;
foo_struct dummy_var = match_arg_2.payload.yes_cons; temp_b_4 = 0 /* FALSE */;
temp_b_4 = 0 /* FALSE */; break;
} }
if (temp_b_4) { if (temp_b_4) {
option_2_enum temp_b_5 = {option_2_enum_some_2_cons, {some_2_cons: 42.}}; option_2_enum temp_b_5 = {option_2_enum_some_2_cons, {some_2_cons: 42.}};
@ -248,13 +232,8 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
if (exception_conflict_2) { if (exception_conflict_2) {
catala_fatal_error_raised.code = catala_conflict; catala_raise_fatal_error(catala_conflict,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 12, 10, 12, 11);
catala_fatal_error_raised.position.start_line = 12;
catala_fatal_error_raised.position.start_column = 10;
catala_fatal_error_raised.position.end_line = 12;
catala_fatal_error_raised.position.end_column = 11;
longjmp(catala_fatal_error_jump_buffer, 0);
} }
if (exception_acc_2.code == option_2_enum_some_2_cons) { if (exception_acc_2.code == option_2_enum_some_2_cons) {
temp_b_2 = exception_acc_2; temp_b_2 = exception_acc_2;
@ -281,13 +260,8 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
if (exception_conflict_3) { if (exception_conflict_3) {
catala_fatal_error_raised.code = catala_conflict; catala_raise_fatal_error(catala_conflict,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 12, 10, 12, 11);
catala_fatal_error_raised.position.start_line = 12;
catala_fatal_error_raised.position.start_column = 10;
catala_fatal_error_raised.position.end_line = 12;
catala_fatal_error_raised.position.end_column = 11;
longjmp(catala_fatal_error_jump_buffer, 0);
} }
if (exception_acc_3.code == option_2_enum_some_2_cons) { if (exception_acc_3.code == option_2_enum_some_2_cons) {
temp_b_1 = exception_acc_3; temp_b_1 = exception_acc_3;
@ -298,18 +272,14 @@ baz_struct baz_func(baz_in_struct baz_in) {
if (1 /* TRUE */) { if (1 /* TRUE */) {
double temp_b_9; double temp_b_9;
bar_enum match_arg_3 = a_1; bar_enum match_arg_3 = a_1;
if (match_arg_3.code == bar_enum_no_cons) { switch (match_arg_3.code) {
void* /* unit */ dummy_var = match_arg_3.payload.no_cons; case bar_enum_no_cons: temp_b_9 = 0.; break;
temp_b_9 = 0.; case bar_enum_yes_cons:
} else if (match_arg_3.code == bar_enum_yes_cons) { foo_struct foo = match_arg_3.payload.yes_cons;
foo_struct foo = match_arg_3.payload.yes_cons; double temp_b_10;
double temp_b_10; if (foo.x_field) {temp_b_10 = 1.; } else {temp_b_10 = 0.; }
if (foo.x_field) { temp_b_9 = (foo.y_field + temp_b_10);
temp_b_10 = 1.; break;
} else {
temp_b_10 = 0.;
}
temp_b_9 = (foo.y_field + temp_b_10);
} }
option_2_enum temp_b_11 = {option_2_enum_some_2_cons, option_2_enum temp_b_11 = {option_2_enum_some_2_cons,
{some_2_cons: temp_b_9}}; {some_2_cons: temp_b_9}};
@ -331,13 +301,8 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
if (exception_conflict_4) { if (exception_conflict_4) {
catala_fatal_error_raised.code = catala_conflict; catala_raise_fatal_error(catala_conflict,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 12, 10, 12, 11);
catala_fatal_error_raised.position.start_line = 12;
catala_fatal_error_raised.position.start_column = 10;
catala_fatal_error_raised.position.end_line = 12;
catala_fatal_error_raised.position.end_column = 11;
longjmp(catala_fatal_error_jump_buffer, 0);
} }
if (exception_acc_4.code == option_2_enum_some_2_cons) { if (exception_acc_4.code == option_2_enum_some_2_cons) {
temp_b_7 = exception_acc_4; temp_b_7 = exception_acc_4;
@ -358,18 +323,15 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
option_2_enum match_arg_4 = temp_b_1; option_2_enum match_arg_4 = temp_b_1;
if (match_arg_4.code == option_2_enum_none_2_cons) { switch (match_arg_4.code) {
void* /* unit */ dummy_var = match_arg_4.payload.none_2_cons; case option_2_enum_none_2_cons:
catala_fatal_error_raised.code = catala_no_value; catala_raise_fatal_error (catala_no_value,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 12, 10, 12, 11);
catala_fatal_error_raised.position.start_line = 12; break;
catala_fatal_error_raised.position.start_column = 10; case option_2_enum_some_2_cons:
catala_fatal_error_raised.position.end_line = 12; double arg_2 = match_arg_4.payload.some_2_cons;
catala_fatal_error_raised.position.end_column = 11; temp_b = arg_2;
longjmp(catala_fatal_error_jump_buffer, 0); break;
} else if (match_arg_4.code == option_2_enum_some_2_cons) {
double arg_2 = match_arg_4.payload.some_2_cons;
temp_b = arg_2;
} }
double b; double b;
b = temp_b; b = temp_b;
@ -401,13 +363,8 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
if (exception_conflict_5) { if (exception_conflict_5) {
catala_fatal_error_raised.code = catala_conflict; catala_raise_fatal_error(catala_conflict,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 13, 10, 13, 11);
catala_fatal_error_raised.position.start_line = 13;
catala_fatal_error_raised.position.start_column = 10;
catala_fatal_error_raised.position.end_line = 13;
catala_fatal_error_raised.position.end_column = 11;
longjmp(catala_fatal_error_jump_buffer, 0);
} }
if (exception_acc_5.code == option_3_enum_some_3_cons) { if (exception_acc_5.code == option_3_enum_some_3_cons) {
temp_c_1 = exception_acc_5; temp_c_1 = exception_acc_5;
@ -422,18 +379,15 @@ baz_struct baz_func(baz_in_struct baz_in) {
} }
} }
option_3_enum match_arg_5 = temp_c_1; option_3_enum match_arg_5 = temp_c_1;
if (match_arg_5.code == option_3_enum_none_3_cons) { switch (match_arg_5.code) {
void* /* unit */ dummy_var = match_arg_5.payload.none_3_cons; case option_3_enum_none_3_cons:
catala_fatal_error_raised.code = catala_no_value; catala_raise_fatal_error (catala_no_value,
catala_fatal_error_raised.position.filename = "tests/backends/simple.catala_en"; "tests/backends/simple.catala_en", 13, 10, 13, 11);
catala_fatal_error_raised.position.start_line = 13; break;
catala_fatal_error_raised.position.start_column = 10; case option_3_enum_some_3_cons:
catala_fatal_error_raised.position.end_line = 13; array_3_struct arg_3 = match_arg_5.payload.some_3_cons;
catala_fatal_error_raised.position.end_column = 11; temp_c = arg_3;
longjmp(catala_fatal_error_jump_buffer, 0); break;
} else if (match_arg_5.code == option_3_enum_some_3_cons) {
array_3_struct arg_3 = match_arg_5.payload.some_3_cons;
temp_c = arg_3;
} }
array_3_struct c; array_3_struct c;
c = temp_c; c = temp_c;