C backend: better ensure consistent declarations and mallocs

This commit is contained in:
Louis Gesbert 2024-09-13 10:12:34 +02:00
parent ff18ee0267
commit 22a16c2b8a
5 changed files with 180 additions and 146 deletions

@ -241,19 +241,22 @@ let _format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) : unit
let rec format_expression (ctx : ctx) (global_vars : VarName.Set.t) (fmt : Format.formatter) (e : expr) : unit
let format_expression = format_expression ctx global_vars in
match Mark.remove e with
| EVar v -> VarName.format fmt v
| EVar v ->
if VarName.Set.mem v global_vars then Format.fprintf fmt "%a()" VarName.format v
else VarName.format fmt v
| EFunc f -> FuncName.format fmt f
| EStructFieldAccess { e1; field; _ } ->
Format.fprintf fmt "%a->%a" (format_expression ctx) e1 StructField.format
Format.fprintf fmt "%a->%a" format_expression e1 StructField.format
| EInj { e1; cons; name = enum_name; _ }
when EnumName.equal enum_name Expr.option_enum ->
if EnumConstructor.equal cons Expr.none_constr then
Format.fprintf fmt "CATALA_NONE"
else Format.fprintf fmt "catala_some(%a)" (format_expression ctx) e1
else Format.fprintf fmt "catala_some(%a)" format_expression e1
| EStruct _ | EInj _ | EArray _ ->
Message.error ~internal:true "Unlifted construct found: %a"
(Scalc__Print.format_expr ctx.decl_ctx ?debug:None)
@ -262,10 +265,10 @@ let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) : unit
| ELit l -> Format.fprintf fmt "%a" format_lit (Mark.copy e l)
| EPosLit -> assert false (* Handled only as toplevel definitions *)
| EAppOp { op = (ToClosureEnv | FromClosureEnv), _; args = [arg]; _ } ->
format_expression ctx fmt arg
format_expression fmt arg
| EAppOp { op = ((Map | Filter), _) as op; args = [arg1; arg2]; _ } ->
Format.fprintf fmt "%a(%a,@ %a)" format_op op (format_expression ctx) arg1
(format_expression ctx) arg2
Format.fprintf fmt "%a(%a,@ %a)" format_op op format_expression arg1
format_expression arg2
| EAppOp
op = ((Reduce | Fold), _) as op;
@ -275,42 +278,42 @@ let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) : unit
(* Operators with a polymorphic return type need a cast *)
Format.fprintf fmt "((%a)%a(%a,@ %a,@ %a))"
(format_typ ~const:true ctx.decl_ctx ignore)
aty format_op op (format_expression ctx) fct (format_expression ctx) base
(format_expression ctx) arr
aty format_op op format_expression fct format_expression base
format_expression arr
| EAppOp { op = Add_dat_dur rounding, _; args; _ } ->
Format.fprintf fmt "o_add_dat_dur(%s,@ %a)"
(match rounding with
| RoundUp -> "dc_date_round_up"
| RoundDown -> "dc_date_round_down"
| AbortOnRound -> "dc_date_round_abort")
(Format.pp_print_list (format_expression ctx) ~pp_sep:(fun ppf () ->
(Format.pp_print_list format_expression ~pp_sep:(fun ppf () ->
Format.fprintf ppf ",@ "))
| EApp { f; args } ->
Format.fprintf fmt "@[<hov 2>%a@,(@[<hov 0>%a)@]@]" (format_expression ctx)
Format.fprintf fmt "@[<hov 2>%a@,(@[<hov 0>%a)@]@]" format_expression
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
| EAppOp { op; args; _ } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" format_op op
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
| ETuple _ -> assert false (* Must be a statement *)
| ETupleAccess { e1; index = 0; typ = (TArrow _, _) as typ } ->
(* Closure function *)
Format.fprintf fmt "@[<hov 1>((%a)@,%a->funcp)@]"
(format_typ ~const:true ctx.decl_ctx ignore)
typ (format_expression ctx) e1
typ format_expression e1
| ETupleAccess { e1; index = 1; typ = TClosureEnv, _ } ->
Format.fprintf fmt "%a->env" (format_expression ctx) e1
Format.fprintf fmt "%a->env" format_expression e1
| ETupleAccess { e1; index; typ } ->
Format.fprintf fmt "(%a)%a[%d]"
(format_typ ctx.decl_ctx ignore)
typ (format_expression ctx) e1 index
typ format_expression e1 index
| EExternal _ -> failwith "TODO"
let is_closure_typ = function
@ -319,20 +322,13 @@ let is_closure_typ = function
let rec format_statement
(ctx : ctx)
(global_vars : VarName.Set.t)
(fmt : Format.formatter)
(s : stmt Mark.pos) : unit =
match Mark.remove s with
| SInnerFuncDef _ ->
Message.error ~pos:(Mark.get s) ~internal:true
"This inner functions should have been hoisted in Scalc"
| SLocalDecl { name = v, _; typ = ty } ->
if is_dummy_var v then ()
Format.fprintf fmt "@,@[<hov 2>%a@];"
(format_typ ctx.decl_ctx ~const:true (fun fmt ->
Format.pp_print_space fmt ();
VarName.format fmt v))
| SLocalInit { name = v, _; typ; expr = EStruct { fields; _ }, _ }
when StructField.Map.is_empty fields && not (is_dummy_var v) ->
Format.fprintf fmt "@,@[<hov 2>%a =@ NULL@];"
@ -352,38 +348,14 @@ let rec format_statement
VarName.format v (Pos.get_file pos) (Pos.get_start_line pos)
(Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
| SLocalInit
name = v, _;
expr = ((EArray _ | EStruct _ | EInj _ | ETuple _) as expr), _;
} ->
| SLocalDecl { name = v, _; typ = ty } ->
if is_dummy_var v then ()
let const, pp_size =
match expr with
| EArray _ ->
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_array)"
| EStruct { name; _ } ->
( false,
fun fmt -> Format.fprintf fmt "sizeof(%a)" StructName.format name )
| EInj { name; _ } when EnumName.equal name Expr.option_enum ->
true, fun fmt -> Format.pp_print_string fmt "sizeof(catala_option)"
| EInj { name; _ } ->
false, fun fmt -> Format.fprintf fmt "sizeof(%a)" EnumName.format name
| ETuple [_fct; _env] when is_closure_typ typ ->
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_closure)"
| ETuple elts ->
( true,
fun fmt ->
Format.fprintf fmt "%d * sizeof(void*)" (List.length elts) )
| _ -> assert false
Format.fprintf fmt "@,@[<hov 2>%a =@ catala_malloc(%t)@];"
(format_typ ~const ctx.decl_ctx (fun fmt ->
Format.fprintf fmt "@,@[<hov 2>%a@];"
(format_typ ctx.decl_ctx ~const:true (fun fmt ->
Format.pp_print_space fmt ();
VarName.format fmt v))
typ pp_size
| SLocalDef { name = v, _; expr = EArray elts, _; _ } ->
(* We detect array initializations which have special treatment. *)
let size = List.length elts in
@ -395,20 +367,20 @@ let rec format_statement
(fun i arg ->
Format.fprintf fmt "@,@[<hov 2>%a->elements[%d] =@ %a;@]" VarName.format
v i (format_expression ctx) arg)
v i (format_expression ctx global_vars) arg)
| SLocalDef { name = v, _; expr = EStruct { fields; _ }, _; _ } ->
(fun field expr ->
Format.fprintf fmt "@,@[<hov 2>%a->%a =@ %a;@]" VarName.format v
StructField.format field (format_expression ctx) expr)
StructField.format field (format_expression ctx global_vars) expr)
| SLocalDef { name = v, _; expr = EInj { e1; cons; name; _ }, _; _ }
when not (EnumName.equal name Expr.option_enum) ->
Format.fprintf fmt "@,@[<hov 2>%a->code = %a_%a;@]" VarName.format v
EnumName.format name EnumConstructor.format cons;
Format.fprintf fmt "@,@[<hov 2>%a->payload.%a = %a;@]" VarName.format v
EnumConstructor.format cons (format_expression ctx) e1
EnumConstructor.format cons (format_expression ctx global_vars) e1
| SLocalDef
name = v, _;
@ -417,15 +389,15 @@ let rec format_statement
} ->
(* We detect closure initializations which have special treatment. *)
Format.fprintf fmt "@,@[<hov 2>%a->funcp =@ (void (*)(void))%a;@]"
VarName.format v (format_expression ctx) fct;
VarName.format v (format_expression ctx global_vars) fct;
Format.fprintf fmt "@,@[<hov 2>%a->env =@ %a;@]" VarName.format v
(format_expression ctx) env
(format_expression ctx global_vars) env
| SLocalDef { name = v, _; expr = ETuple elts, _; _ } ->
(* We detect tuple initializations which have special treatment. *)
(fun i arg ->
Format.fprintf fmt "@,@[<hov 2>%a[%d] =@ %a;@]" VarName.format v i
(format_expression ctx) arg)
(format_expression ctx global_vars) arg)
| SLocalInit
@ -441,35 +413,36 @@ let rec format_statement
_ ) ) as e;
} ->
Format.fprintf fmt "@,@[<hov 2>%a = %a /*XXX*/;@]"
Format.fprintf fmt "@,@[<hov 2>%a = %a;@]"
(format_typ ~const:true ctx.decl_ctx (fun fmt ->
Format.pp_print_space fmt ();
VarName.format fmt (Mark.remove v)))
typ (format_expression ctx) e
typ (format_expression ctx global_vars) e
| SLocalInit { name = v; typ; expr = e } ->
(* Handling at the block level guarantees that [e] is supported as initial value *)
Format.fprintf fmt "@,@[<hov 2>%a = %a;@]"
(format_typ ctx.decl_ctx (fun fmt ->
Format.pp_print_space fmt ();
VarName.format fmt (Mark.remove v)))
typ (format_expression ctx) e
typ (format_expression ctx global_vars) e
| SLocalDef { name = v; expr = e; _ } ->
Format.fprintf fmt "@,@[<hov 2>%a = %a;@]" VarName.format (Mark.remove v)
(format_expression ctx) e
(format_expression ctx global_vars) e
| SFatalError { pos_expr; error } ->
Format.fprintf fmt "@,@[<hov 2>catala_error(catala_%s,@ %a);@]"
(String.to_snake_case (Runtime.error_to_string error))
(format_expression ctx) pos_expr
(format_expression ctx global_vars) pos_expr
| SIfThenElse { if_expr = ELit (LBool true), _; then_block; _ } ->
format_block ctx fmt then_block
format_block ctx global_vars fmt then_block
| SIfThenElse { if_expr = ELit (LBool false), _; else_block; _ } ->
format_block ctx fmt else_block
format_block ctx global_vars fmt else_block
| SIfThenElse { if_expr = cond; then_block = b1; else_block = b2 } ->
Format.fprintf fmt
@[<hv 2>@[<hov 2>if (%a == CATALA_TRUE) {@]%a@;\
<1 -2>} else {%a@;\
<1 -2>}@]" (format_expression ctx) cond (format_block ctx) b1
(format_block ctx) b2
<1 -2>}@]" (format_expression ctx global_vars) cond (format_block ctx global_vars) b1
(format_block ctx global_vars) b2
| SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ }
when EnumName.equal e_name Expr.option_enum ->
let cases =
@ -491,7 +464,7 @@ let rec format_statement
Format.fprintf fmt "@,@[<v 2>if (%a->code == catala_option_some) {"
VarName.format switch_var;
let pos = Mark.get s in
format_block ctx fmt
format_block ctx global_vars fmt
(Utils.subst_block some_case.payload_var_name
(* Not a real catala struct, but will print as <var>->payload *)
( EStructFieldAccess
@ -503,7 +476,7 @@ let rec format_statement
pos )
some_case.payload_var_typ pos some_case.case_block);
Format.fprintf fmt "@;<1 -2>} else {";
format_block ctx fmt none_case.case_block;
format_block ctx global_vars fmt none_case.case_block;
Format.fprintf fmt "@;<1 -2>}@]"
| SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ } ->
Format.fprintf fmt "@,@[<v 2>@[<hov 4>switch (%a->code) {@]" VarName.format
@ -523,7 +496,7 @@ let rec format_statement
payload_var_typ VarName.format switch_var
(* EnumName.format e_name *)
EnumConstructor.format cons_name;
Format.fprintf fmt "%a@ break;@;<1 -2>}@]" (format_block ctx) case_block)
Format.fprintf fmt "%a@ break;@;<1 -2>}@]" (format_block ctx global_vars) case_block)
(EnumName.Map.find e_name ctx.decl_ctx.ctx_enums));
@ -531,73 +504,78 @@ let rec format_statement
Format.fprintf fmt "@;<0 -2>}";
Format.pp_close_box fmt ()
| SReturn e1 ->
Format.fprintf fmt "@,@[<hov 2>return %a;@]" (format_expression ctx) e1
Format.fprintf fmt "@,@[<hov 2>return %a;@]" (format_expression ctx global_vars) e1
| SAssert { pos_expr; expr } ->
Format.fprintf fmt
@[<v 2>@[<hov 2>if (%a != CATALA_TRUE) {@]@,\
@[<hov 2>catala_error(catala_assertion_failed,@ %a);@]@;\
<1 -2>}@]" (format_expression ctx) expr (format_expression ctx) pos_expr
<1 -2>}@]" (format_expression ctx global_vars) expr (format_expression ctx global_vars) pos_expr
| _ -> .
and format_block (ctx : ctx) (fmt : Format.formatter) (b : block) : unit =
(* C89 doesn't accept initialisations of constructions from non-constants: -
for known structures needing malloc, provision the malloc here (turn Decl
and format_block (ctx : ctx) (global_vars : VarName.Set.t) (fmt : Format.formatter) (b : block) : unit =
(* C89 doesn't accept initialisations of constructions from non-constants:
- for known structures needing malloc, provision the malloc here (turn Decl
into Init (that will only do the malloc) + def) - for literal constants
keep init - otherwise split Init into decl + def *)
let find_static_def name =
| SLocalDef { name = n; _ }, _ -> Mark.equal VarName.equal n name
| _ -> false)
| Some
( SLocalDef
expr = ((EArray _ | EStruct _ | EInj _ | ETuple _), _) as expr;
_ ) ->
Some expr
| _ -> None
let requires_malloc = function
| (EArray _ | EStruct _ | EInj _ | ETuple _ | ELit _), _ -> true
| _ -> false
let revb =
(fun acc -> function
| (SLocalInit { expr = (ELit _ | EPosLit), _; _ }, _) as st -> st :: acc
| ( SLocalInit
expr = ((EArray _ | EStruct _ | EInj _ | ETuple _), _) as expr;
m ) ->
(* These need malloc and init, split in two since the Init won't
actually set them *)
(SLocalDef { name; typ; expr }, m)
:: (SLocalInit { name; typ; expr }, m)
:: acc
| (SLocalDecl { name; typ }, m) as decl -> (
match find_static_def name with
| Some expr -> (SLocalInit { name; typ; expr }, m) :: acc
| _ -> decl :: acc)
| SLocalInit { name; typ; expr }, m ->
(SLocalDef { name; typ; expr }, m)
:: (SLocalDecl { name; typ }, m)
:: acc
| st -> st :: acc)
[] b
let print_init_malloc fmt v typ =
let const, pp_size =
match Mark.remove typ with
| TArray _ ->
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_array)"
| TStruct name ->
false, fun fmt -> Format.fprintf fmt "sizeof(%a)" StructName.format name
| TEnum name ->
true, fun fmt -> Format.fprintf fmt "sizeof(%a)" EnumName.format name
| TTuple _ when is_closure_typ typ ->
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_closure)"
| TTuple ts ->
true, fun fmt -> Format.fprintf fmt "%d * sizeof(void*)" (List.length ts)
| _ -> assert false
(* Postfix [const] declares that the pointer is const, but not its contents *)
Format.fprintf fmt "@,@[<hov 2>%a =@ catala_malloc(%t)@];"
(format_typ ~const ctx.decl_ctx (fun fmt ->
Format.fprintf fmt " const@ %a" VarName.format v))
typ pp_size
(* C89 requires declarations to be on top of the block *)
let decls, others =
(function (SLocalDecl _ | SLocalInit _), _ -> true | _ -> false)
let rec format_decls defined_vars remaining = function
| (SLocalDecl { name; typ }, _) as decl :: r ->
let requires_malloc =
match typ with
| (TArray _ | TStruct _ | TEnum _ | TTuple _), _ ->
None <>
Utils.find_block (function
| SLocalDef { name = n1; expr; _ }, _ | SLocalInit { name = n1; expr; _ }, _ ->
Mark.equal VarName.equal name n1 && requires_malloc expr
| _ -> false)
| _ -> false
if requires_malloc then print_init_malloc fmt (Mark.remove name) typ
else format_statement ctx global_vars fmt decl;
format_decls defined_vars remaining r
| (SLocalInit { name; typ; expr }, m) as init :: r ->
if requires_malloc expr then
(print_init_malloc fmt (Mark.remove name) typ;
format_decls defined_vars ((SLocalDef { name; typ; expr }, m) :: remaining) r)
else if VarName.Set.subset (Utils.get_vars expr) defined_vars then
(format_statement ctx global_vars fmt init;
format_decls (VarName.Set.add (Mark.remove name) defined_vars) remaining r)
(* The init depends on undefined variables, it can't be moved to the top, so we split it into decl + def *)
(format_statement ctx global_vars fmt (SLocalDecl { name; typ }, m);
format_decls defined_vars ((SLocalDef { name; typ; expr }, m) :: remaining) r)
| stmt :: r -> format_decls defined_vars (stmt :: remaining) r
| [] -> List.rev remaining
List.iter (format_statement ctx fmt) (List.rev decls);
List.iter (format_statement ctx fmt) (List.rev others)
let remaining = format_decls global_vars [] b in
List.iter (format_statement ctx global_vars fmt) remaining
let format_main (fmt : Format.formatter) (p : Ast.program) =
Format.fprintf fmt "@,@[<v 2>int main (int argc, char** argv)@;<0 -2>{";
@ -655,24 +633,37 @@ let format_program
format_ctx type_ordering fmt p.ctx.decl_ctx;
Format.pp_print_cut fmt ();
let ctx = { decl_ctx = p.ctx.decl_ctx } in
(fun fmt code_item ->
match code_item with
| SVar { var; expr; typ } ->
Format.fprintf fmt "@[<hov 2>%a = NULL;@]@,"
(format_typ ~const:true p.ctx.decl_ctx (fun fmt ->
Format.pp_print_space fmt ();
VarName.format fmt var))
(* We hide the value below a macro that performs lazy expansion *)
Format.fprintf fmt "#define %a (%a ? %a : (%a = %a))@," VarName.format
var VarName.format var VarName.format var VarName.format var
(format_expression ctx) expr
let _global_vars =
List.fold_left (fun global_vars code_item ->
match code_item with
| SVar { var; expr; typ } ->
(* Global variables are turned into inline functions without parameters that perform lazy evaluation: {[
inline foo_type foo() {
static foo_type foo = NULL;
return (foo ? foo : foo = foo_init());
Format.fprintf fmt "@[<v 2>@[<hov 4>inline %a() {@]@,"
(format_typ ~const:true p.ctx.decl_ctx (fun fmt ->
Format.pp_print_space fmt ();
VarName.format fmt var))
Format.fprintf fmt "@[<hov 2>static %a = NULL;@]@,"
(format_typ ~const:true p.ctx.decl_ctx (fun fmt ->
Format.pp_print_space fmt ();
VarName.format fmt var))
Format.fprintf fmt "@[<hov 2>return (%a ? %a : (%a = %a));"
VarName.format var VarName.format var VarName.format var
(format_expression ctx global_vars) expr;
Format.fprintf fmt "@;<1 -2>}@]@,@,";
VarName.Set.add var global_vars
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
let { func_params; func_body; func_return_typ } = func in
Format.fprintf fmt
"@,@[<v 2>@[<hov 4>%a@ @[<hv 1>(%a)@]@]@;<1 -2>{%a@]@,}"
"@,@[<v 2>@[<hov 4>%a@ @[<hv 1>(%a)@]@]@;<1 -2>{%a@]@,}@,"
(format_typ ~const:true ctx.decl_ctx (fun fmt ->
Format.pp_print_space fmt ();
FuncName.format fmt var))
@ -686,8 +677,10 @@ let format_program
VarName.format fmt (Mark.remove var)))
fmt typ;
Format.pp_close_box fmt ()))
func_params (format_block ctx) func_body)
fmt p.code_items;
func_params (format_block ctx global_vars) func_body;
VarName.Set.empty p.code_items
Format.pp_print_cut fmt ();
format_main fmt p;
Format.pp_close_box fmt ()

@ -21,6 +21,20 @@ module Runtime = Runtime_ocaml.Runtime
module D = Dcalc.Ast
module L = Lcalc.Ast
let rec get_vars e =
match Mark.remove e with
| EVar v -> VarName.Set.singleton v
| EFunc _ | ELit _ | EPosLit | EExternal _ -> VarName.Set.empty
| EStruct str ->
StructField.Map.fold (fun _ e -> VarName.Set.union (get_vars e)) str.fields VarName.Set.empty
| EStructFieldAccess { e1; _ } | ETupleAccess { e1; _ } | EInj { e1; _ }->
get_vars e1
| ETuple el | EArray el | EAppOp { args = el; _ } ->
List.fold_left (fun acc e -> VarName.Set.union acc (get_vars e)) VarName.Set.empty el
| EApp { f; args } ->
List.fold_left (fun acc e -> VarName.Set.union acc (get_vars e))
(get_vars f) args
let rec subst_expr v e within_expr =
let m = Mark.get within_expr in
match Mark.remove within_expr with
@ -86,3 +100,23 @@ and subst_block v e block =
let subst_block v expr typ pos block =
try subst_block v expr block
with Exit -> (SLocalInit { name = v, pos; typ; expr }, pos) :: block
let rec find_block pred = function
| [] -> None
| stmt :: _ when pred stmt -> Some stmt
| (SIfThenElse { then_block; else_block; _ }, _) :: r ->
(match find_block pred then_block with
| None ->
(match find_block pred else_block with
| None -> find_block pred r
| some -> some)
| some -> some)
| (SSwitch { switch_cases; _ }, _) :: r ->
List.find_map (fun case ->
find_block pred case.case_block)
| None -> find_block pred r
| some -> some)
| _ :: r -> find_block pred r

@ -27,3 +27,8 @@ val subst_block : VarName.t -> expr -> typ -> Pos.t -> block -> block
[var] within the given [block]. If not possible (the variable appears in a
variable-only position), the block is returned with an initialisation of
[var] with [replacement] prepended *)
val find_block : (stmt Mark.pos -> bool) -> block -> stmt Mark.pos option
(** Recurses into branchings, but not function bodies *)
val get_vars : expr -> VarName.Set.t

@ -119,8 +119,10 @@ void catala_free(void* ptr, size_t sz)
/* --- Base types --- */
const int catala_true = 1;
const int catala_false = 0;
const int catala_true_value = 1;
const int * const catala_true = &catala_true_value;
const int catala_false_value = 0;
const int * const catala_false = &catala_false_value;
const int catala_unitval = 0;
/* --- Constructors --- */

@ -79,11 +79,11 @@ typedef struct catala_closure {
const CLOSURE_ENV env;
} catala_closure;
extern const int catala_true;
#define CATALA_TRUE &catala_true
extern const int * const catala_true;
#define CATALA_TRUE catala_true
extern const int catala_false;
#define CATALA_FALSE &catala_false
extern const int * const catala_false;
#define CATALA_FALSE catala_false
extern const int catala_unitval;
#define CATALA_UNITVAL &catala_unitval