open Catala_utils
open Definitions
let rec fold_left_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result _ -> init
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
fold_left_lets ~f ~init:(f init scope_let var) next
let rec fold_right_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result result -> init result
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
let next_result = fold_right_lets ~f ~init next in
f scope_let var next_result
let map_exprs_in_lets :
f:('expr1 -> 'expr2 boxed) ->
varf:('expr1 Var.t -> 'expr2 Var.t) ->
'expr1 scope_body_expr ->
'expr2 scope_body_expr =
fun ~f ~varf scope_body_expr ->
~f:(fun scope_let var_next acc ->
(fun scope_let_next scope_let_expr ->
ScopeLet { scope_let with scope_let_next; scope_let_expr })
(Bindlib.bind_var (varf var_next) acc)
(Expr.Box.lift (f scope_let.scope_let_expr)))
~init:(fun res ->
Bindlib.box_apply (fun res -> Result res) (Expr.Box.lift (f res)))
let rec fold_left ~f ~init = function
| Nil -> init
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
fold_left ~f ~init:(f init item var) next
let rec fold_right ~f ~init = function
| Nil -> init
| Cons (item, next_bind) ->
let var_next, next = Bindlib.unbind next_bind in
let result_next = fold_right ~f ~init next in
f item var_next result_next
let rec map ~f ~varf = function
| Nil -> Nil
| Cons (item, next_bind) ->
let item = f item in
let next_bind =
let var, next = Bindlib.unbind next_bind in
Bindlib.bind_var (varf var) (map ~f ~varf next)
(fun item next_bind -> Cons (item, next_bind))
item next_bind
let rec map_ctx ~f ~varf ctx = function
| Nil -> Nil
| Cons (item, next_bind) ->
let ctx, item = f ctx item in
let next_bind =
let var, next = Bindlib.unbind next_bind in
Bindlib.bind_var (varf var) (map_ctx ~f ~varf ctx next)
(fun item next_bind -> Cons (item, next_bind))
item next_bind
let rec fold_map ~f ~varf ctx = function
| Nil -> ctx, Nil
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
let ctx, item = f ctx var item in
let ctx, next = fold_map ~f ~varf ctx next in
let next_bind = Bindlib.bind_var (varf var) next in
( ctx,
(fun item next_bind -> Cons (item, next_bind))
item next_bind )
let map_exprs ~f ~varf scopes =
let f = function
| ScopeDef (name, body) ->
let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in
let new_body_expr = map_exprs_in_lets ~f ~varf scope_lets in
let new_body_expr =
Bindlib.bind_var (varf scope_input_var) new_body_expr
(fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr }))
| Topdef (name, typ, expr) ->
(fun e -> Topdef (name, typ, e))
(Expr.Box.lift (f expr))
map ~f ~varf scopes
(* TODO: compute the expected body expr arrow type manually instead of [TAny]
for double-checking types ? *)
let rec get_body_expr_mark = function
| ScopeLet sl ->
let _, e = Bindlib.unbind sl.scope_let_next in
get_body_expr_mark e
| Result e ->
let m = Marked.get_mark e in
Expr.with_ty m (Marked.mark (Expr.mark_pos m) TAny)
let get_body_mark scope_body =
let _, e = Bindlib.unbind scope_body.scope_body_expr in
get_body_expr_mark e
let rec unfold_body_expr (ctx : decl_ctx) (scope_let : 'e scope_body_expr) =
match scope_let with
| Result e -> Expr.rebox e
| ScopeLet
scope_let_kind = _;
} ->
let var, next = Bindlib.unbind scope_let_next in
Expr.make_let_in var scope_let_typ
(Expr.rebox scope_let_expr)
(unfold_body_expr ctx next)
let build_typ_from_sig
(_ctx : decl_ctx)
(scope_input_struct_name : StructName.t)
(scope_return_struct_name : StructName.t)
(pos : Pos.t) : typ =
let input_typ = Marked.mark pos (TStruct scope_input_struct_name) in
let result_typ = Marked.mark pos (TStruct scope_return_struct_name) in
Marked.mark pos (TArrow ([input_typ], result_typ))
type 'e scope_name_or_var = ScopeName of ScopeName.t | ScopeVar of 'e Var.t
let to_expr (ctx : decl_ctx) (body : 'e scope_body) (mark_scope : 'm mark) :
'e boxed =
let var, body_expr = Bindlib.unbind body.scope_body_expr in
let body_expr = unfold_body_expr ctx body_expr in
Expr.make_abs [| var |] body_expr
[TStruct body.scope_body_input_struct, Expr.mark_pos mark_scope]
(Expr.mark_pos mark_scope)
let format
?(debug : bool = false)
(ctx : decl_ctx)
(fmt : Format.formatter)
((n, s) : ScopeName.t * 'm scope_body) =
Format.fprintf fmt "@[<hov 2>%a %a =@ %a@]" Print.keyword "let"
ScopeName.format_t n (Expr.format ctx ~debug)
(to_expr ctx s
(fun _ -> Marked.get_mark (ScopeName.get_info n))
(fun ty -> ty)
(get_body_mark s))))
let rec unfold
(ctx : decl_ctx)
(s : 'e code_item_list)
(mark : 'm mark)
(main_scope : 'expr scope_name_or_var) : 'e boxed =
match s with
| Nil -> (
match main_scope with
| ScopeVar v -> Expr.make_var v mark
| ScopeName _ -> failwith "should not happen")
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
let typ, expr, pos, is_main =
match item with
| ScopeDef (name, body) ->
let pos = Marked.get_mark (ScopeName.get_info name) in
let body_mark = get_body_mark body in
let is_main =
match main_scope with
| ScopeName n -> ScopeName.equal n name
| ScopeVar _ -> false
let typ =
build_typ_from_sig ctx body.scope_body_input_struct
body.scope_body_output_struct pos
let expr = to_expr ctx body body_mark in
typ, expr, pos, is_main
| Topdef (name, typ, expr) ->
let pos = Marked.get_mark (TopdefName.get_info name) in
typ, Expr.rebox expr, pos, false
let main_scope = if is_main then ScopeVar var else main_scope in
let next = unfold ctx next mark main_scope in
Expr.make_let_in var typ expr next pos
let rec free_vars_body_expr scope_lets =
match scope_lets with
| Result e -> Expr.free_vars e
| ScopeLet { scope_let_expr = e; scope_let_next = next; _ } ->
let v, body = Bindlib.unbind next in
Var.Set.union (Expr.free_vars e)
(Var.Set.remove v (free_vars_body_expr body))
let free_vars_item = function
| ScopeDef (_, { scope_body_expr; _ }) ->
let v, body = Bindlib.unbind scope_body_expr in
Var.Set.remove v (free_vars_body_expr body)
| Topdef (_, _, expr) -> Expr.free_vars expr
let rec free_vars scopes =
match scopes with
| Nil -> Var.Set.empty
| Cons (item, next_bind) ->
let v, next = Bindlib.unbind next_bind in
Var.Set.union (Var.Set.remove v (free_vars next)) (free_vars_item item)