mirror of
https://github.com/CatalaLang/catala.git
synced 2024-10-07 09:17:31 +03:00
C backend: better ensure consistent declarations and mallocs
This commit is contained in:
parent
ff18ee0267
commit
22a16c2b8a
@ -241,19 +241,22 @@ let _format_string_list (fmt : Format.formatter) (uids : string list) : unit =
|
||||
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
|
||||
uids
|
||||
|
||||
let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) : unit
|
||||
let rec format_expression (ctx : ctx) (global_vars : VarName.Set.t) (fmt : Format.formatter) (e : expr) : unit
|
||||
=
|
||||
let format_expression = format_expression ctx global_vars in
|
||||
match Mark.remove e with
|
||||
| EVar v -> VarName.format fmt v
|
||||
| EVar v ->
|
||||
if VarName.Set.mem v global_vars then Format.fprintf fmt "%a()" VarName.format v
|
||||
else VarName.format fmt v
|
||||
| EFunc f -> FuncName.format fmt f
|
||||
| EStructFieldAccess { e1; field; _ } ->
|
||||
Format.fprintf fmt "%a->%a" (format_expression ctx) e1 StructField.format
|
||||
Format.fprintf fmt "%a->%a" format_expression e1 StructField.format
|
||||
field
|
||||
| EInj { e1; cons; name = enum_name; _ }
|
||||
when EnumName.equal enum_name Expr.option_enum ->
|
||||
if EnumConstructor.equal cons Expr.none_constr then
|
||||
Format.fprintf fmt "CATALA_NONE"
|
||||
else Format.fprintf fmt "catala_some(%a)" (format_expression ctx) e1
|
||||
else Format.fprintf fmt "catala_some(%a)" format_expression e1
|
||||
| EStruct _ | EInj _ | EArray _ ->
|
||||
Message.error ~internal:true "Unlifted construct found: %a"
|
||||
(Scalc__Print.format_expr ctx.decl_ctx ?debug:None)
|
||||
@ -262,10 +265,10 @@ let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) : unit
|
||||
| ELit l -> Format.fprintf fmt "%a" format_lit (Mark.copy e l)
|
||||
| EPosLit -> assert false (* Handled only as toplevel definitions *)
|
||||
| EAppOp { op = (ToClosureEnv | FromClosureEnv), _; args = [arg]; _ } ->
|
||||
format_expression ctx fmt arg
|
||||
format_expression fmt arg
|
||||
| EAppOp { op = ((Map | Filter), _) as op; args = [arg1; arg2]; _ } ->
|
||||
Format.fprintf fmt "%a(%a,@ %a)" format_op op (format_expression ctx) arg1
|
||||
(format_expression ctx) arg2
|
||||
Format.fprintf fmt "%a(%a,@ %a)" format_op op format_expression arg1
|
||||
format_expression arg2
|
||||
| EAppOp
|
||||
{
|
||||
op = ((Reduce | Fold), _) as op;
|
||||
@ -275,42 +278,42 @@ let rec format_expression (ctx : ctx) (fmt : Format.formatter) (e : expr) : unit
|
||||
(* Operators with a polymorphic return type need a cast *)
|
||||
Format.fprintf fmt "((%a)%a(%a,@ %a,@ %a))"
|
||||
(format_typ ~const:true ctx.decl_ctx ignore)
|
||||
aty format_op op (format_expression ctx) fct (format_expression ctx) base
|
||||
(format_expression ctx) arr
|
||||
aty format_op op format_expression fct format_expression base
|
||||
format_expression arr
|
||||
| EAppOp { op = Add_dat_dur rounding, _; args; _ } ->
|
||||
Format.fprintf fmt "o_add_dat_dur(%s,@ %a)"
|
||||
(match rounding with
|
||||
| RoundUp -> "dc_date_round_up"
|
||||
| RoundDown -> "dc_date_round_down"
|
||||
| AbortOnRound -> "dc_date_round_abort")
|
||||
(Format.pp_print_list (format_expression ctx) ~pp_sep:(fun ppf () ->
|
||||
(Format.pp_print_list format_expression ~pp_sep:(fun ppf () ->
|
||||
Format.fprintf ppf ",@ "))
|
||||
args
|
||||
| EApp { f; args } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@,(@[<hov 0>%a)@]@]" (format_expression ctx)
|
||||
Format.fprintf fmt "@[<hov 2>%a@,(@[<hov 0>%a)@]@]" format_expression
|
||||
f
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(format_expression ctx))
|
||||
format_expression)
|
||||
args
|
||||
| EAppOp { op; args; _ } ->
|
||||
Format.fprintf fmt "%a(@[<hov 0>%a)@]" format_op op
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(format_expression ctx))
|
||||
format_expression)
|
||||
args
|
||||
| ETuple _ -> assert false (* Must be a statement *)
|
||||
| ETupleAccess { e1; index = 0; typ = (TArrow _, _) as typ } ->
|
||||
(* Closure function *)
|
||||
Format.fprintf fmt "@[<hov 1>((%a)@,%a->funcp)@]"
|
||||
(format_typ ~const:true ctx.decl_ctx ignore)
|
||||
typ (format_expression ctx) e1
|
||||
typ format_expression e1
|
||||
| ETupleAccess { e1; index = 1; typ = TClosureEnv, _ } ->
|
||||
Format.fprintf fmt "%a->env" (format_expression ctx) e1
|
||||
Format.fprintf fmt "%a->env" format_expression e1
|
||||
| ETupleAccess { e1; index; typ } ->
|
||||
Format.fprintf fmt "(%a)%a[%d]"
|
||||
(format_typ ctx.decl_ctx ignore)
|
||||
typ (format_expression ctx) e1 index
|
||||
typ format_expression e1 index
|
||||
| EExternal _ -> failwith "TODO"
|
||||
|
||||
let is_closure_typ = function
|
||||
@ -319,20 +322,13 @@ let is_closure_typ = function
|
||||
|
||||
let rec format_statement
|
||||
(ctx : ctx)
|
||||
(global_vars : VarName.Set.t)
|
||||
(fmt : Format.formatter)
|
||||
(s : stmt Mark.pos) : unit =
|
||||
match Mark.remove s with
|
||||
| SInnerFuncDef _ ->
|
||||
Message.error ~pos:(Mark.get s) ~internal:true
|
||||
"This inner functions should have been hoisted in Scalc"
|
||||
| SLocalDecl { name = v, _; typ = ty } ->
|
||||
if is_dummy_var v then ()
|
||||
else
|
||||
Format.fprintf fmt "@,@[<hov 2>%a@];"
|
||||
(format_typ ctx.decl_ctx ~const:true (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
VarName.format fmt v))
|
||||
ty
|
||||
| SLocalInit { name = v, _; typ; expr = EStruct { fields; _ }, _ }
|
||||
when StructField.Map.is_empty fields && not (is_dummy_var v) ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a =@ NULL@];"
|
||||
@ -352,38 +348,14 @@ let rec format_statement
|
||||
%d}};@]"
|
||||
VarName.format v (Pos.get_file pos) (Pos.get_start_line pos)
|
||||
(Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
|
||||
| SLocalInit
|
||||
{
|
||||
name = v, _;
|
||||
typ;
|
||||
expr = ((EArray _ | EStruct _ | EInj _ | ETuple _) as expr), _;
|
||||
} ->
|
||||
| SLocalDecl { name = v, _; typ = ty } ->
|
||||
if is_dummy_var v then ()
|
||||
else
|
||||
let const, pp_size =
|
||||
match expr with
|
||||
| EArray _ ->
|
||||
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_array)"
|
||||
| EStruct { name; _ } ->
|
||||
( false,
|
||||
fun fmt -> Format.fprintf fmt "sizeof(%a)" StructName.format name )
|
||||
| EInj { name; _ } when EnumName.equal name Expr.option_enum ->
|
||||
true, fun fmt -> Format.pp_print_string fmt "sizeof(catala_option)"
|
||||
| EInj { name; _ } ->
|
||||
false, fun fmt -> Format.fprintf fmt "sizeof(%a)" EnumName.format name
|
||||
| ETuple [_fct; _env] when is_closure_typ typ ->
|
||||
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_closure)"
|
||||
| ETuple elts ->
|
||||
( true,
|
||||
fun fmt ->
|
||||
Format.fprintf fmt "%d * sizeof(void*)" (List.length elts) )
|
||||
| _ -> assert false
|
||||
in
|
||||
Format.fprintf fmt "@,@[<hov 2>%a =@ catala_malloc(%t)@];"
|
||||
(format_typ ~const ctx.decl_ctx (fun fmt ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a@];"
|
||||
(format_typ ctx.decl_ctx ~const:true (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
VarName.format fmt v))
|
||||
typ pp_size
|
||||
ty
|
||||
| SLocalDef { name = v, _; expr = EArray elts, _; _ } ->
|
||||
(* We detect array initializations which have special treatment. *)
|
||||
let size = List.length elts in
|
||||
@ -395,20 +367,20 @@ let rec format_statement
|
||||
List.iteri
|
||||
(fun i arg ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a->elements[%d] =@ %a;@]" VarName.format
|
||||
v i (format_expression ctx) arg)
|
||||
v i (format_expression ctx global_vars) arg)
|
||||
elts
|
||||
| SLocalDef { name = v, _; expr = EStruct { fields; _ }, _; _ } ->
|
||||
StructField.Map.iter
|
||||
(fun field expr ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a->%a =@ %a;@]" VarName.format v
|
||||
StructField.format field (format_expression ctx) expr)
|
||||
StructField.format field (format_expression ctx global_vars) expr)
|
||||
fields
|
||||
| SLocalDef { name = v, _; expr = EInj { e1; cons; name; _ }, _; _ }
|
||||
when not (EnumName.equal name Expr.option_enum) ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a->code = %a_%a;@]" VarName.format v
|
||||
EnumName.format name EnumConstructor.format cons;
|
||||
Format.fprintf fmt "@,@[<hov 2>%a->payload.%a = %a;@]" VarName.format v
|
||||
EnumConstructor.format cons (format_expression ctx) e1
|
||||
EnumConstructor.format cons (format_expression ctx global_vars) e1
|
||||
| SLocalDef
|
||||
{
|
||||
name = v, _;
|
||||
@ -417,15 +389,15 @@ let rec format_statement
|
||||
} ->
|
||||
(* We detect closure initializations which have special treatment. *)
|
||||
Format.fprintf fmt "@,@[<hov 2>%a->funcp =@ (void (*)(void))%a;@]"
|
||||
VarName.format v (format_expression ctx) fct;
|
||||
VarName.format v (format_expression ctx global_vars) fct;
|
||||
Format.fprintf fmt "@,@[<hov 2>%a->env =@ %a;@]" VarName.format v
|
||||
(format_expression ctx) env
|
||||
(format_expression ctx global_vars) env
|
||||
| SLocalDef { name = v, _; expr = ETuple elts, _; _ } ->
|
||||
(* We detect tuple initializations which have special treatment. *)
|
||||
List.iteri
|
||||
(fun i arg ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a[%d] =@ %a;@]" VarName.format v i
|
||||
(format_expression ctx) arg)
|
||||
(format_expression ctx global_vars) arg)
|
||||
elts
|
||||
| SLocalInit
|
||||
{
|
||||
@ -441,35 +413,36 @@ let rec format_statement
|
||||
},
|
||||
_ ) ) as e;
|
||||
} ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a = %a /*XXX*/;@]"
|
||||
Format.fprintf fmt "@,@[<hov 2>%a = %a;@]"
|
||||
(format_typ ~const:true ctx.decl_ctx (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
VarName.format fmt (Mark.remove v)))
|
||||
typ (format_expression ctx) e
|
||||
typ (format_expression ctx global_vars) e
|
||||
| SLocalInit { name = v; typ; expr = e } ->
|
||||
(* Handling at the block level guarantees that [e] is supported as initial value *)
|
||||
Format.fprintf fmt "@,@[<hov 2>%a = %a;@]"
|
||||
(format_typ ctx.decl_ctx (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
VarName.format fmt (Mark.remove v)))
|
||||
typ (format_expression ctx) e
|
||||
typ (format_expression ctx global_vars) e
|
||||
| SLocalDef { name = v; expr = e; _ } ->
|
||||
Format.fprintf fmt "@,@[<hov 2>%a = %a;@]" VarName.format (Mark.remove v)
|
||||
(format_expression ctx) e
|
||||
(format_expression ctx global_vars) e
|
||||
| SFatalError { pos_expr; error } ->
|
||||
Format.fprintf fmt "@,@[<hov 2>catala_error(catala_%s,@ %a);@]"
|
||||
(String.to_snake_case (Runtime.error_to_string error))
|
||||
(format_expression ctx) pos_expr
|
||||
(format_expression ctx global_vars) pos_expr
|
||||
| SIfThenElse { if_expr = ELit (LBool true), _; then_block; _ } ->
|
||||
format_block ctx fmt then_block
|
||||
format_block ctx global_vars fmt then_block
|
||||
| SIfThenElse { if_expr = ELit (LBool false), _; else_block; _ } ->
|
||||
format_block ctx fmt else_block
|
||||
format_block ctx global_vars fmt else_block
|
||||
| SIfThenElse { if_expr = cond; then_block = b1; else_block = b2 } ->
|
||||
Format.fprintf fmt
|
||||
"@,\
|
||||
@[<hv 2>@[<hov 2>if (%a == CATALA_TRUE) {@]%a@;\
|
||||
<1 -2>} else {%a@;\
|
||||
<1 -2>}@]" (format_expression ctx) cond (format_block ctx) b1
|
||||
(format_block ctx) b2
|
||||
<1 -2>}@]" (format_expression ctx global_vars) cond (format_block ctx global_vars) b1
|
||||
(format_block ctx global_vars) b2
|
||||
| SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ }
|
||||
when EnumName.equal e_name Expr.option_enum ->
|
||||
let cases =
|
||||
@ -491,7 +464,7 @@ let rec format_statement
|
||||
Format.fprintf fmt "@,@[<v 2>if (%a->code == catala_option_some) {"
|
||||
VarName.format switch_var;
|
||||
let pos = Mark.get s in
|
||||
format_block ctx fmt
|
||||
format_block ctx global_vars fmt
|
||||
(Utils.subst_block some_case.payload_var_name
|
||||
(* Not a real catala struct, but will print as <var>->payload *)
|
||||
( EStructFieldAccess
|
||||
@ -503,7 +476,7 @@ let rec format_statement
|
||||
pos )
|
||||
some_case.payload_var_typ pos some_case.case_block);
|
||||
Format.fprintf fmt "@;<1 -2>} else {";
|
||||
format_block ctx fmt none_case.case_block;
|
||||
format_block ctx global_vars fmt none_case.case_block;
|
||||
Format.fprintf fmt "@;<1 -2>}@]"
|
||||
| SSwitch { switch_var; enum_name = e_name; switch_cases = cases; _ } ->
|
||||
Format.fprintf fmt "@,@[<v 2>@[<hov 4>switch (%a->code) {@]" VarName.format
|
||||
@ -523,7 +496,7 @@ let rec format_statement
|
||||
payload_var_typ VarName.format switch_var
|
||||
(* EnumName.format e_name *)
|
||||
EnumConstructor.format cons_name;
|
||||
Format.fprintf fmt "%a@ break;@;<1 -2>}@]" (format_block ctx) case_block)
|
||||
Format.fprintf fmt "%a@ break;@;<1 -2>}@]" (format_block ctx global_vars) case_block)
|
||||
cases
|
||||
(EnumConstructor.Map.bindings
|
||||
(EnumName.Map.find e_name ctx.decl_ctx.ctx_enums));
|
||||
@ -531,73 +504,78 @@ let rec format_statement
|
||||
Format.fprintf fmt "@;<0 -2>}";
|
||||
Format.pp_close_box fmt ()
|
||||
| SReturn e1 ->
|
||||
Format.fprintf fmt "@,@[<hov 2>return %a;@]" (format_expression ctx) e1
|
||||
Format.fprintf fmt "@,@[<hov 2>return %a;@]" (format_expression ctx global_vars) e1
|
||||
| SAssert { pos_expr; expr } ->
|
||||
Format.fprintf fmt
|
||||
"@,\
|
||||
@[<v 2>@[<hov 2>if (%a != CATALA_TRUE) {@]@,\
|
||||
@[<hov 2>catala_error(catala_assertion_failed,@ %a);@]@;\
|
||||
<1 -2>}@]" (format_expression ctx) expr (format_expression ctx) pos_expr
|
||||
<1 -2>}@]" (format_expression ctx global_vars) expr (format_expression ctx global_vars) pos_expr
|
||||
| _ -> .
|
||||
|
||||
and format_block (ctx : ctx) (fmt : Format.formatter) (b : block) : unit =
|
||||
(* C89 doesn't accept initialisations of constructions from non-constants: -
|
||||
for known structures needing malloc, provision the malloc here (turn Decl
|
||||
and format_block (ctx : ctx) (global_vars : VarName.Set.t) (fmt : Format.formatter) (b : block) : unit =
|
||||
(* C89 doesn't accept initialisations of constructions from non-constants:
|
||||
- for known structures needing malloc, provision the malloc here (turn Decl
|
||||
into Init (that will only do the malloc) + def) - for literal constants
|
||||
keep init - otherwise split Init into decl + def *)
|
||||
let find_static_def name =
|
||||
match
|
||||
List.find_opt
|
||||
(function
|
||||
| SLocalDef { name = n; _ }, _ -> Mark.equal VarName.equal n name
|
||||
| _ -> false)
|
||||
b
|
||||
with
|
||||
| Some
|
||||
( SLocalDef
|
||||
{
|
||||
expr = ((EArray _ | EStruct _ | EInj _ | ETuple _), _) as expr;
|
||||
_;
|
||||
},
|
||||
_ ) ->
|
||||
Some expr
|
||||
| _ -> None
|
||||
let requires_malloc = function
|
||||
| (EArray _ | EStruct _ | EInj _ | ETuple _ | ELit _), _ -> true
|
||||
| _ -> false
|
||||
in
|
||||
let revb =
|
||||
List.fold_left
|
||||
(fun acc -> function
|
||||
| (SLocalInit { expr = (ELit _ | EPosLit), _; _ }, _) as st -> st :: acc
|
||||
| ( SLocalInit
|
||||
{
|
||||
name;
|
||||
typ;
|
||||
expr = ((EArray _ | EStruct _ | EInj _ | ETuple _), _) as expr;
|
||||
},
|
||||
m ) ->
|
||||
(* These need malloc and init, split in two since the Init won't
|
||||
actually set them *)
|
||||
(SLocalDef { name; typ; expr }, m)
|
||||
:: (SLocalInit { name; typ; expr }, m)
|
||||
:: acc
|
||||
| (SLocalDecl { name; typ }, m) as decl -> (
|
||||
match find_static_def name with
|
||||
| Some expr -> (SLocalInit { name; typ; expr }, m) :: acc
|
||||
| _ -> decl :: acc)
|
||||
| SLocalInit { name; typ; expr }, m ->
|
||||
(SLocalDef { name; typ; expr }, m)
|
||||
:: (SLocalDecl { name; typ }, m)
|
||||
:: acc
|
||||
| st -> st :: acc)
|
||||
[] b
|
||||
let print_init_malloc fmt v typ =
|
||||
let const, pp_size =
|
||||
match Mark.remove typ with
|
||||
| TArray _ ->
|
||||
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_array)"
|
||||
| TStruct name ->
|
||||
false, fun fmt -> Format.fprintf fmt "sizeof(%a)" StructName.format name
|
||||
| TEnum name ->
|
||||
true, fun fmt -> Format.fprintf fmt "sizeof(%a)" EnumName.format name
|
||||
| TTuple _ when is_closure_typ typ ->
|
||||
false, fun fmt -> Format.pp_print_string fmt "sizeof(catala_closure)"
|
||||
| TTuple ts ->
|
||||
true, fun fmt -> Format.fprintf fmt "%d * sizeof(void*)" (List.length ts)
|
||||
| _ -> assert false
|
||||
in
|
||||
(* Postfix [const] declares that the pointer is const, but not its contents *)
|
||||
Format.fprintf fmt "@,@[<hov 2>%a =@ catala_malloc(%t)@];"
|
||||
(format_typ ~const ctx.decl_ctx (fun fmt ->
|
||||
Format.fprintf fmt " const@ %a" VarName.format v))
|
||||
typ pp_size
|
||||
in
|
||||
(* C89 requires declarations to be on top of the block *)
|
||||
let decls, others =
|
||||
List.partition
|
||||
(function (SLocalDecl _ | SLocalInit _), _ -> true | _ -> false)
|
||||
revb
|
||||
let rec format_decls defined_vars remaining = function
|
||||
| (SLocalDecl { name; typ }, _) as decl :: r ->
|
||||
let requires_malloc =
|
||||
match typ with
|
||||
| (TArray _ | TStruct _ | TEnum _ | TTuple _), _ ->
|
||||
None <>
|
||||
Utils.find_block (function
|
||||
| SLocalDef { name = n1; expr; _ }, _ | SLocalInit { name = n1; expr; _ }, _ ->
|
||||
Mark.equal VarName.equal name n1 && requires_malloc expr
|
||||
| _ -> false)
|
||||
r
|
||||
| _ -> false
|
||||
in
|
||||
if requires_malloc then print_init_malloc fmt (Mark.remove name) typ
|
||||
else format_statement ctx global_vars fmt decl;
|
||||
format_decls defined_vars remaining r
|
||||
| (SLocalInit { name; typ; expr }, m) as init :: r ->
|
||||
if requires_malloc expr then
|
||||
(print_init_malloc fmt (Mark.remove name) typ;
|
||||
format_decls defined_vars ((SLocalDef { name; typ; expr }, m) :: remaining) r)
|
||||
else if VarName.Set.subset (Utils.get_vars expr) defined_vars then
|
||||
(format_statement ctx global_vars fmt init;
|
||||
format_decls (VarName.Set.add (Mark.remove name) defined_vars) remaining r)
|
||||
else
|
||||
(* The init depends on undefined variables, it can't be moved to the top, so we split it into decl + def *)
|
||||
(format_statement ctx global_vars fmt (SLocalDecl { name; typ }, m);
|
||||
format_decls defined_vars ((SLocalDef { name; typ; expr }, m) :: remaining) r)
|
||||
| stmt :: r -> format_decls defined_vars (stmt :: remaining) r
|
||||
| [] -> List.rev remaining
|
||||
in
|
||||
List.iter (format_statement ctx fmt) (List.rev decls);
|
||||
List.iter (format_statement ctx fmt) (List.rev others)
|
||||
let remaining = format_decls global_vars [] b in
|
||||
List.iter (format_statement ctx global_vars fmt) remaining
|
||||
|
||||
let format_main (fmt : Format.formatter) (p : Ast.program) =
|
||||
Format.fprintf fmt "@,@[<v 2>int main (int argc, char** argv)@;<0 -2>{";
|
||||
@ -655,24 +633,37 @@ let format_program
|
||||
format_ctx type_ordering fmt p.ctx.decl_ctx;
|
||||
Format.pp_print_cut fmt ();
|
||||
let ctx = { decl_ctx = p.ctx.decl_ctx } in
|
||||
Format.pp_print_list
|
||||
(fun fmt code_item ->
|
||||
match code_item with
|
||||
| SVar { var; expr; typ } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a = NULL;@]@,"
|
||||
(format_typ ~const:true p.ctx.decl_ctx (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
VarName.format fmt var))
|
||||
typ;
|
||||
(* We hide the value below a macro that performs lazy expansion *)
|
||||
Format.fprintf fmt "#define %a (%a ? %a : (%a = %a))@," VarName.format
|
||||
var VarName.format var VarName.format var VarName.format var
|
||||
(format_expression ctx) expr
|
||||
let _global_vars =
|
||||
List.fold_left (fun global_vars code_item ->
|
||||
match code_item with
|
||||
| SVar { var; expr; typ } ->
|
||||
(* Global variables are turned into inline functions without parameters that perform lazy evaluation: {[
|
||||
inline foo_type foo() {
|
||||
static foo_type foo = NULL;
|
||||
return (foo ? foo : foo = foo_init());
|
||||
}
|
||||
]}
|
||||
*)
|
||||
Format.fprintf fmt "@[<v 2>@[<hov 4>inline %a() {@]@,"
|
||||
(format_typ ~const:true p.ctx.decl_ctx (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
VarName.format fmt var))
|
||||
typ;
|
||||
Format.fprintf fmt "@[<hov 2>static %a = NULL;@]@,"
|
||||
(format_typ ~const:true p.ctx.decl_ctx (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
VarName.format fmt var))
|
||||
typ;
|
||||
Format.fprintf fmt "@[<hov 2>return (%a ? %a : (%a = %a));"
|
||||
VarName.format var VarName.format var VarName.format var
|
||||
(format_expression ctx global_vars) expr;
|
||||
Format.fprintf fmt "@;<1 -2>}@]@,@,";
|
||||
VarName.Set.add var global_vars
|
||||
| SFunc { var; func }
|
||||
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
|
||||
let { func_params; func_body; func_return_typ } = func in
|
||||
Format.fprintf fmt
|
||||
"@,@[<v 2>@[<hov 4>%a@ @[<hv 1>(%a)@]@]@;<1 -2>{%a@]@,}"
|
||||
"@,@[<v 2>@[<hov 4>%a@ @[<hv 1>(%a)@]@]@;<1 -2>{%a@]@,}@,"
|
||||
(format_typ ~const:true ctx.decl_ctx (fun fmt ->
|
||||
Format.pp_print_space fmt ();
|
||||
FuncName.format fmt var))
|
||||
@ -686,8 +677,10 @@ let format_program
|
||||
VarName.format fmt (Mark.remove var)))
|
||||
fmt typ;
|
||||
Format.pp_close_box fmt ()))
|
||||
func_params (format_block ctx) func_body)
|
||||
fmt p.code_items;
|
||||
func_params (format_block ctx global_vars) func_body;
|
||||
global_vars)
|
||||
VarName.Set.empty p.code_items
|
||||
in
|
||||
Format.pp_print_cut fmt ();
|
||||
format_main fmt p;
|
||||
Format.pp_close_box fmt ()
|
||||
|
@ -21,6 +21,20 @@ module Runtime = Runtime_ocaml.Runtime
|
||||
module D = Dcalc.Ast
|
||||
module L = Lcalc.Ast
|
||||
|
||||
let rec get_vars e =
|
||||
match Mark.remove e with
|
||||
| EVar v -> VarName.Set.singleton v
|
||||
| EFunc _ | ELit _ | EPosLit | EExternal _ -> VarName.Set.empty
|
||||
| EStruct str ->
|
||||
StructField.Map.fold (fun _ e -> VarName.Set.union (get_vars e)) str.fields VarName.Set.empty
|
||||
| EStructFieldAccess { e1; _ } | ETupleAccess { e1; _ } | EInj { e1; _ }->
|
||||
get_vars e1
|
||||
| ETuple el | EArray el | EAppOp { args = el; _ } ->
|
||||
List.fold_left (fun acc e -> VarName.Set.union acc (get_vars e)) VarName.Set.empty el
|
||||
| EApp { f; args } ->
|
||||
List.fold_left (fun acc e -> VarName.Set.union acc (get_vars e))
|
||||
(get_vars f) args
|
||||
|
||||
let rec subst_expr v e within_expr =
|
||||
let m = Mark.get within_expr in
|
||||
match Mark.remove within_expr with
|
||||
@ -86,3 +100,23 @@ and subst_block v e block =
|
||||
let subst_block v expr typ pos block =
|
||||
try subst_block v expr block
|
||||
with Exit -> (SLocalInit { name = v, pos; typ; expr }, pos) :: block
|
||||
|
||||
let rec find_block pred = function
|
||||
| [] -> None
|
||||
| stmt :: _ when pred stmt -> Some stmt
|
||||
| (SIfThenElse { then_block; else_block; _ }, _) :: r ->
|
||||
(match find_block pred then_block with
|
||||
| None ->
|
||||
(match find_block pred else_block with
|
||||
| None -> find_block pred r
|
||||
| some -> some)
|
||||
| some -> some)
|
||||
| (SSwitch { switch_cases; _ }, _) :: r ->
|
||||
(match
|
||||
List.find_map (fun case ->
|
||||
find_block pred case.case_block)
|
||||
switch_cases
|
||||
with
|
||||
| None -> find_block pred r
|
||||
| some -> some)
|
||||
| _ :: r -> find_block pred r
|
||||
|
@ -27,3 +27,8 @@ val subst_block : VarName.t -> expr -> typ -> Pos.t -> block -> block
|
||||
[var] within the given [block]. If not possible (the variable appears in a
|
||||
variable-only position), the block is returned with an initialisation of
|
||||
[var] with [replacement] prepended *)
|
||||
|
||||
val find_block : (stmt Mark.pos -> bool) -> block -> stmt Mark.pos option
|
||||
(** Recurses into branchings, but not function bodies *)
|
||||
|
||||
val get_vars : expr -> VarName.Set.t
|
||||
|
@ -119,8 +119,10 @@ void catala_free(void* ptr, size_t sz)
|
||||
|
||||
/* --- Base types --- */
|
||||
|
||||
const int catala_true = 1;
|
||||
const int catala_false = 0;
|
||||
const int catala_true_value = 1;
|
||||
const int * const catala_true = &catala_true_value;
|
||||
const int catala_false_value = 0;
|
||||
const int * const catala_false = &catala_false_value;
|
||||
const int catala_unitval = 0;
|
||||
|
||||
/* --- Constructors --- */
|
||||
|
@ -79,11 +79,11 @@ typedef struct catala_closure {
|
||||
const CLOSURE_ENV env;
|
||||
} catala_closure;
|
||||
|
||||
extern const int catala_true;
|
||||
#define CATALA_TRUE &catala_true
|
||||
extern const int * const catala_true;
|
||||
#define CATALA_TRUE catala_true
|
||||
|
||||
extern const int catala_false;
|
||||
#define CATALA_FALSE &catala_false
|
||||
extern const int * const catala_false;
|
||||
#define CATALA_FALSE catala_false
|
||||
|
||||
extern const int catala_unitval;
|
||||
#define CATALA_UNITVAL &catala_unitval
|
||||
|
Loading…
Reference in New Issue
Block a user