Merge pull request #315 from AltGr/ast-factorisation

Factorise ASTs (between dcalc and lcalc)
This commit is contained in:
Denis Merigoux 2022-08-16 11:42:51 +02:00 committed by GitHub
commit efa7cec4c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1205 additions and 1138 deletions

View File

@ -1,7 +1,7 @@
# Reformatting commits to be skipped when running 'git blame'
# Use `git config --global blame.ignoreRevsFile .git-blame-ignore-revs` to use it
# Add new reformatting commits at the top
99b6fc33b508c879f669172005b6c359d7d4f596
ba620fca280338139e015e316894a7cf49c450d5
7485c7f2ce726f59f1ec66ddfe1d3f7d640201d8

View File

@ -15,240 +15,23 @@
License for the specific language governing permissions and limitations under
the License. *)
[@@@ocaml.warning "-7-34"]
open Utils
module Runtime = Runtime_ocaml.Runtime
include Astgen
include Astgen_utils
module ScopeName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
type lit = dcalc glit
module StructName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
module EnumName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
type marked_typ = typ Marked.pos
and typ =
| TLit of typ_lit
| TTuple of marked_typ list * StructName.t option
| TEnum of marked_typ list * EnumName.t
| TArrow of marked_typ * marked_typ
| TArray of marked_typ
| TAny
type date = Runtime.date
type duration = Runtime.duration
type integer = Runtime.integer
type decimal = Runtime.decimal
type money = Runtime.money
type lit =
| LBool of bool
| LEmptyError
| LInt of integer
| LRat of decimal
| LMoney of money
| LUnit
| LDate of date
| LDuration of duration
type op_kind = KInt | KRat | KMoney | KDate | KDuration
type ternop = Fold
type binop =
| And
| Or
| Xor
| Add of op_kind
| Sub of op_kind
| Mult of op_kind
| Div of op_kind
| Lt of op_kind
| Lte of op_kind
| Gt of op_kind
| Gte of op_kind
| Eq
| Neq
| Map
| Concat
| Filter
type log_entry = VarDef of typ | BeginCall | EndCall | PosRecordIfTrueBool
type unop =
| Not
| Minus of op_kind
| Log of log_entry * Utils.Uid.MarkedString.info list
| Length
| IntToRat
| MoneyToRat
| RatToMoney
| GetDay
| GetMonth
| GetYear
| FirstDayOfMonth
| LastDayOfMonth
| RoundMoney
| RoundDecimal
type operator = Ternop of ternop | Binop of binop | Unop of unop
(** Some structures used for type inference *)
module Infer = struct
module Any =
Utils.Uid.Make
(struct
type info = unit
let format_info fmt () = Format.fprintf fmt "any"
end)
()
type unionfind_typ = typ Marked.pos UnionFind.elem
(** We do not reuse {!type: Dcalc.Ast.typ} because we have to include a new
[TAny] variant. Indeed, error terms can have any type and this has to be
captured by the type sytem. *)
and typ =
| TLit of typ_lit
| TArrow of unionfind_typ * unionfind_typ
| TTuple of unionfind_typ list * StructName.t option
| TEnum of unionfind_typ list * EnumName.t
| TArray of unionfind_typ
| TAny of Any.t
let rec typ_to_ast (ty : unionfind_typ) : marked_typ =
let ty, pos = UnionFind.get (UnionFind.find ty) in
match ty with
| TLit l -> TLit l, pos
| TTuple (ts, s) -> TTuple (List.map typ_to_ast ts, s), pos
| TEnum (ts, e) -> TEnum (List.map typ_to_ast ts, e), pos
| TArrow (t1, t2) -> TArrow (typ_to_ast t1, typ_to_ast t2), pos
| TAny _ -> TAny, pos
| TArray t1 -> TArray (typ_to_ast t1), pos
let rec ast_to_typ (ty : marked_typ) : unionfind_typ =
let ty' =
match Marked.unmark ty with
| TLit l -> TLit l
| TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
| TTuple (ts, s) -> TTuple (List.map (fun t -> ast_to_typ t) ts, s)
| TEnum (ts, e) -> TEnum (List.map (fun t -> ast_to_typ t) ts, e)
| TArray t -> TArray (ast_to_typ t)
| TAny -> TAny (Any.fresh ())
in
UnionFind.make (Marked.same_mark_as ty' ty)
end
type untyped = { pos : Pos.t } [@@ocaml.unboxed]
type typed = { pos : Pos.t; ty : marked_typ }
type inferring = { pos : Pos.t; uf : Infer.unionfind_typ }
(** The generic type of AST markings. Using a GADT allows functions to be
polymorphic in the marking, but still do transformations on types when
appropriate *)
type _ mark =
| Untyped : untyped -> untyped mark
| Typed : typed -> typed mark
| Inferring : inferring -> inferring mark
type ('a, 'm) marked = ('a, 'm mark) Marked.t
type 'm marked_expr = ('m expr, 'm) marked
and 'm expr =
| EVar of 'm expr Bindlib.var
| ETuple of 'm marked_expr list * StructName.t option
| ETupleAccess of
'm marked_expr * int * StructName.t option * typ Marked.pos list
| EInj of 'm marked_expr * int * EnumName.t * typ Marked.pos list
| EMatch of 'm marked_expr * 'm marked_expr list * EnumName.t
| EArray of 'm marked_expr list
| ELit of lit
| EAbs of
(('m expr, 'm marked_expr) Bindlib.mbinder[@opaque]) * typ Marked.pos list
| EApp of 'm marked_expr * 'm marked_expr list
| EAssert of 'm marked_expr
| EOp of operator
| EDefault of 'm marked_expr list * 'm marked_expr * 'm marked_expr
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
| ErrorOnEmpty of 'm marked_expr
type typed_expr = typed marked_expr
type struct_ctx = (StructFieldName.t * typ Marked.pos) list StructMap.t
type enum_ctx = (EnumConstructor.t * typ Marked.pos) list EnumMap.t
type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx }
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder
type scope_let_kind =
| DestructuringInputStruct
| ScopeVarDefinition
| SubScopeVarDefinition
| CallingSubScope
| DestructuringSubScopeResults
| Assertion
type ('expr, 'm) scope_let = {
scope_let_kind : scope_let_kind;
scope_let_typ : typ Marked.pos;
scope_let_expr : ('expr, 'm) marked;
scope_let_next : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
scope_let_pos : Pos.t;
}
and ('expr, 'm) scope_body_expr =
| Result of ('expr, 'm) marked
| ScopeLet of ('expr, 'm) scope_let
type ('expr, 'm) scope_body = {
scope_body_input_struct : StructName.t;
scope_body_output_struct : StructName.t;
scope_body_expr : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
}
type ('expr, 'm) scope_def = {
scope_name : ScopeName.t;
scope_body : ('expr, 'm) scope_body;
scope_next : ('expr, ('expr, 'm) scopes) Bindlib.binder;
}
and ('expr, 'm) scopes = Nil | ScopeDef of ('expr, 'm) scope_def
type ('expr, 'm) program_generic = {
decl_ctx : decl_ctx;
scopes : ('expr, 'm) scopes;
}
type 'm expr = (dcalc, 'm mark) gexpr
and 'm marked_expr = (dcalc, 'm mark) marked_gexpr
type 'm program = ('m expr, 'm) program_generic
let no_mark (type m) : m mark -> m mark = function
| Untyped _ -> Untyped { pos = Pos.no_pos }
| Typed _ -> Typed { pos = Pos.no_pos; ty = Marked.mark Pos.no_pos TAny }
| Inferring _ ->
Inferring
{
pos = Pos.no_pos;
uf = UnionFind.make Infer.(TAny (Any.fresh ()), Pos.no_pos);
}
let mark_pos (type m) (m : m mark) : Pos.t =
match m with
| Untyped { pos } | Typed { pos; _ } | Inferring { pos; _ } -> pos
match m with Untyped { pos } | Typed { pos; _ } -> pos
let pos (type m) (x : ('a, m) marked) : Pos.t = mark_pos (Marked.get_mark x)
let ty (_, m) : marked_typ = match m with Typed { ty; _ } -> ty
@ -257,78 +40,11 @@ let with_ty (type m) (ty : marked_typ) (x : ('a, m) marked) : ('a, typed) marked
=
Marked.mark
(match Marked.get_mark x with
| Untyped { pos } | Inferring { pos; _ } -> Typed { pos; ty }
| Untyped { pos } -> Typed { pos; ty }
| Typed m -> Typed { m with ty })
(Marked.unmark x)
let evar v mark = Bindlib.box_apply (Marked.mark mark) (Bindlib.box_var v)
let etuple args s mark =
Bindlib.box_apply (fun args -> ETuple (args, s), mark) (Bindlib.box_list args)
let etupleaccess e1 i s typs mark =
Bindlib.box_apply (fun e1 -> ETupleAccess (e1, i, s, typs), mark) e1
let einj e1 i e_name typs mark =
Bindlib.box_apply (fun e1 -> EInj (e1, i, e_name, typs), mark) e1
let ematch arg arms e_name mark =
Bindlib.box_apply2
(fun arg arms -> EMatch (arg, arms, e_name), mark)
arg (Bindlib.box_list arms)
let earray args mark =
Bindlib.box_apply (fun args -> EArray args, mark) (Bindlib.box_list args)
let elit l mark = Bindlib.box (ELit l, mark)
let eabs binder typs mark =
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) binder
let eapp e1 args mark =
Bindlib.box_apply2
(fun e1 args -> EApp (e1, args), mark)
e1 (Bindlib.box_list args)
let eassert e1 mark = Bindlib.box_apply (fun e1 -> EAssert e1, mark) e1
let eop op mark = Bindlib.box (EOp op, mark)
let edefault excepts just cons mark =
Bindlib.box_apply3
(fun excepts just cons -> EDefault (excepts, just, cons), mark)
(Bindlib.box_list excepts) just cons
let eifthenelse e1 e2 e3 mark =
Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), mark) e1 e2 e3
let eerroronempty e1 mark =
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1
let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
let map_expr ctx ~f e =
let m = Marked.get_mark e in
match Marked.unmark e with
| EVar v -> evar (translate_var v) m
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m
| EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind binder in
eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
| ETupleAccess (e1, n, s_name, typs) ->
etupleaccess ((f ctx) e1) n s_name typs m
| EInj (e1, i, e_name, typs) -> einj ((f ctx) e1) i e_name typs m
| EMatch (arg, arms, e_name) ->
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
| EArray args -> earray (List.map (f ctx) args) m
| ELit l -> elit l m
| EAssert e1 -> eassert ((f ctx) e1) m
| EOp op -> Bindlib.box (EOp op, m)
| EDefault (excepts, just, cons) ->
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
| EIfThenElse (e1, e2, e3) ->
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
let map_expr ctx ~f e = Astgen_utils.map_gexpr ctx ~f e
let rec map_expr_top_down ~f e =
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
@ -347,79 +63,7 @@ let box_expr : ('m expr, 'm) box_expr_sig =
let rec id_t () e = map_expr () ~f:id_t e in
id_t () e
let rec fold_left_scope_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result _ -> init
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
fold_left_scope_lets ~f ~init:(f init scope_let var) next
let rec fold_right_scope_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result result -> init result
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
let next_result = fold_right_scope_lets ~f ~init next in
f scope_let var next_result
let map_exprs_in_scope_lets ~f ~varf scope_body_expr =
fold_right_scope_lets
~f:(fun scope_let var_next acc ->
Bindlib.box_apply2
(fun scope_let_next scope_let_expr ->
ScopeLet { scope_let with scope_let_next; scope_let_expr })
(Bindlib.bind_var (varf var_next) acc)
(f scope_let.scope_let_expr))
~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res))
scope_body_expr
let rec fold_left_scope_defs ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var, next = Bindlib.unbind scope_def.scope_next in
fold_left_scope_defs ~f ~init:(f init scope_def var) next
let rec fold_right_scope_defs ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var_next, next = Bindlib.unbind scope_def.scope_next in
let result_next = fold_right_scope_defs ~f ~init next in
f scope_def var_next result_next
let map_scope_defs ~f scopes =
fold_right_scope_defs
~f:(fun scope_def var_next acc ->
let new_scope_def = f scope_def in
let new_next = Bindlib.bind_var var_next acc in
Bindlib.box_apply2
(fun new_scope_def new_next ->
ScopeDef { new_scope_def with scope_next = new_next })
new_scope_def new_next)
~init:(Bindlib.box Nil) scopes
let map_exprs_in_scopes ~f ~varf scopes =
fold_right_scope_defs
~f:(fun scope_def var_next acc ->
let scope_input_var, scope_lets =
Bindlib.unbind scope_def.scope_body.scope_body_expr
in
let new_scope_body_expr = map_exprs_in_scope_lets ~f ~varf scope_lets in
let new_scope_body_expr =
Bindlib.bind_var (varf scope_input_var) new_scope_body_expr
in
let new_next = Bindlib.bind_var (varf var_next) acc in
Bindlib.box_apply2
(fun scope_body_expr scope_next ->
ScopeDef
{
scope_def with
scope_body = { scope_def.scope_body with scope_body_expr };
scope_next;
})
new_scope_body_expr new_next)
~init:(Bindlib.box Nil) scopes
open Astgen_utils
let untype_program prg =
{
@ -428,35 +72,17 @@ let untype_program prg =
Bindlib.unbox
(map_exprs_in_scopes
~f:(fun e -> untype_expr e)
~varf:translate_var prg.scopes);
~varf:Var.translate prg.scopes);
}
type 'm var = 'm expr Bindlib.var
type 'm vars = 'm expr Bindlib.mvar
type 'm var = 'm expr Var.t
type 'm vars = 'm expr Var.vars
let new_var s = Bindlib.new_var (fun x -> EVar x) s
module Var = struct
type t = V : 'a expr Bindlib.var -> t
(* We use this trivial GADT to make the 'm parameter disappear under an
existential. It's fine for a use as keys only. (bindlib defines [any_var]
similarly but it's not exported) todo: add [@@ocaml.unboxed] once it's
possible through abstract types *)
let t v = V v
let get (V v) = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
let compare (V x) (V y) = Bindlib.compare_vars x y
let eq (V x) (V y) = Bindlib.eq_vars x y
end
module VarSet = Set.Make (Var)
module VarMap = Map.Make (Var)
let rec free_vars_expr (e : 'm marked_expr) : VarSet.t =
let rec free_vars_expr (e : 'm marked_expr) : 'm expr Var.Set.t =
match Marked.unmark e with
| EVar v -> VarSet.singleton (Var.t v)
| EVar v -> Var.Set.singleton v
| ETuple (es, _) | EArray es ->
es |> List.map free_vars_expr |> List.fold_left VarSet.union VarSet.empty
es |> List.map free_vars_expr |> List.fold_left Var.Set.union Var.Set.empty
| ETupleAccess (e1, _, _, _)
| EAssert e1
| ErrorOnEmpty e1
@ -465,43 +91,43 @@ let rec free_vars_expr (e : 'm marked_expr) : VarSet.t =
| EApp (e1, es) | EMatch (e1, es, _) ->
e1 :: es
|> List.map free_vars_expr
|> List.fold_left VarSet.union VarSet.empty
|> List.fold_left Var.Set.union Var.Set.empty
| EDefault (es, ejust, econs) ->
ejust :: econs :: es
|> List.map free_vars_expr
|> List.fold_left VarSet.union VarSet.empty
| EOp _ | ELit _ -> VarSet.empty
|> List.fold_left Var.Set.union Var.Set.empty
| EOp _ | ELit _ -> Var.Set.empty
| EIfThenElse (e1, e2, e3) ->
[e1; e2; e3]
|> List.map free_vars_expr
|> List.fold_left VarSet.union VarSet.empty
|> List.fold_left Var.Set.union Var.Set.empty
| EAbs (binder, _) ->
let vs, body = Bindlib.unmbind binder in
Array.fold_right VarSet.remove (Array.map Var.t vs) (free_vars_expr body)
Array.fold_right Var.Set.remove vs (free_vars_expr body)
let rec free_vars_scope_body_expr (scope_lets : ('m expr, 'm) scope_body_expr) :
VarSet.t =
'm expr Var.Set.t =
match scope_lets with
| Result e -> free_vars_expr e
| ScopeLet { scope_let_expr = e; scope_let_next = next; _ } ->
let v, body = Bindlib.unbind next in
VarSet.union (free_vars_expr e)
(VarSet.remove (Var.t v) (free_vars_scope_body_expr body))
Var.Set.union (free_vars_expr e)
(Var.Set.remove v (free_vars_scope_body_expr body))
let free_vars_scope_body (scope_body : ('m expr, 'm) scope_body) : VarSet.t =
let free_vars_scope_body (scope_body : ('m expr, 'm) scope_body) :
'm expr Var.Set.t =
let { scope_body_expr = binder; _ } = scope_body in
let v, body = Bindlib.unbind binder in
VarSet.remove (Var.t v) (free_vars_scope_body_expr body)
Var.Set.remove v (free_vars_scope_body_expr body)
let rec free_vars_scopes (scopes : ('m expr, 'm) scopes) : VarSet.t =
let rec free_vars_scopes (scopes : ('m expr, 'm) scopes) : 'm expr Var.Set.t =
match scopes with
| Nil -> VarSet.empty
| Nil -> Var.Set.empty
| ScopeDef { scope_body = body; scope_next = next; _ } ->
let v, next = Bindlib.unbind next in
VarSet.union
(VarSet.remove (Var.t v) (free_vars_scopes next))
Var.Set.union
(Var.Set.remove v (free_vars_scopes next))
(free_vars_scope_body body)
(* type vars = expr Bindlib.mvar *)
let make_var ((x, mark) : ('m expr Bindlib.var, 'm) marked) :
'm marked_expr Bindlib.box =
@ -542,11 +168,6 @@ let map_mark
match m with
| Untyped { pos } -> Untyped { pos = pos_f pos }
| Typed { pos; ty } -> Typed { pos = pos_f pos; ty = ty_f ty }
| Inferring { pos; uf } ->
Inferring
{ pos = pos_f pos; uf = Infer.ast_to_typ (ty_f (Infer.typ_to_ast uf)) }
let resolve_inferring { uf; pos } = { ty = Infer.typ_to_ast uf; pos }
let map_mark2
(type m)
@ -557,13 +178,6 @@ let map_mark2
match m1, m2 with
| Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos }
| Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 }
| Inferring m1, Inferring m2 ->
Inferring
{
pos = pos_f m1.pos m2.pos;
uf =
Infer.ast_to_typ (ty_f (resolve_inferring m1) (resolve_inferring m2));
}
let fold_marks
(type m)
@ -580,17 +194,9 @@ let fold_marks
pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms);
ty = ty_f (List.map (function Typed m -> m) ms);
}
| Inferring _ :: _ ->
Inferring
{
pos = pos_f (List.map (function Inferring { pos; _ } -> pos) ms);
uf =
Infer.ast_to_typ
(ty_f (List.map (function Inferring m -> resolve_inferring m) ms));
}
let empty_thunked_term mark : 'm marked_expr =
let silent = new_var "_" in
let silent = Var.make "_" in
let pos = mark_pos mark in
Bindlib.unbox
(make_abs [| silent |]

View File

@ -18,224 +18,32 @@
(** Abstract syntax tree of the default calculus intermediate representation *)
open Utils
module Runtime = Runtime_ocaml.Runtime
module ScopeName : Uid.Id with type info = Uid.MarkedString.info
module StructName : Uid.Id with type info = Uid.MarkedString.info
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info
module StructMap : Map.S with type key = StructName.t
module EnumName : Uid.Id with type info = Uid.MarkedString.info
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info
module EnumMap : Map.S with type key = EnumName.t
include module type of Astgen
include module type of Astgen_utils
(** Abstract syntax tree for the default calculus *)
type lit = dcalc glit
(** {1 Abstract syntax tree} *)
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
type marked_typ = typ Marked.pos
and typ =
| TLit of typ_lit
| TTuple of marked_typ list * StructName.t option
| TEnum of marked_typ list * EnumName.t
| TArrow of marked_typ * marked_typ
| TArray of marked_typ
| TAny
type date = Runtime.date
type duration = Runtime.duration
type lit =
| LBool of bool
| LEmptyError
| LInt of Runtime.integer
| LRat of Runtime.decimal
| LMoney of Runtime.money
| LUnit
| LDate of date
| LDuration of duration
type op_kind =
| KInt
| KRat
| KMoney
| KDate
| KDuration (** All ops don't have a KDate and KDuration. *)
type ternop = Fold
type binop =
| And
| Or
| Xor
| Add of op_kind
| Sub of op_kind
| Mult of op_kind
| Div of op_kind
| Lt of op_kind
| Lte of op_kind
| Gt of op_kind
| Gte of op_kind
| Eq
| Neq
| Map
| Concat
| Filter
type log_entry =
| VarDef of typ
(** During code generation, we need to know the type of the variable being
logged for embedding *)
| BeginCall
| EndCall
| PosRecordIfTrueBool
type unop =
| Not
| Minus of op_kind
| Log of log_entry * Utils.Uid.MarkedString.info list
| Length
| IntToRat
| MoneyToRat
| RatToMoney
| GetDay
| GetMonth
| GetYear
| FirstDayOfMonth
| LastDayOfMonth
| RoundMoney
| RoundDecimal
type operator = Ternop of ternop | Binop of binop | Unop of unop
(** Contains some structures used for type inference *)
module Infer : sig
module Any : Utils.Uid.Id with type info = unit
type unionfind_typ = typ Marked.pos UnionFind.elem
(** We do not reuse {!type: typ} because we have to include a new [TAny]
variant. Indeed, error terms can have any type and this has to be captured
by the type sytem. *)
and typ =
| TLit of typ_lit
| TArrow of unionfind_typ * unionfind_typ
| TTuple of unionfind_typ list * StructName.t option
| TEnum of unionfind_typ list * EnumName.t
| TArray of unionfind_typ
| TAny of Any.t
val typ_to_ast : unionfind_typ -> marked_typ
val ast_to_typ : marked_typ -> unionfind_typ
end
type untyped = { pos : Pos.t } [@@unboxed]
type typed = { pos : Pos.t; ty : marked_typ }
type inferring = { pos : Pos.t; uf : Infer.unionfind_typ }
(** The generic type of AST markings. Using a GADT allows functions to be
polymorphic in the marking, but still do transformations on types when
appropriate *)
type _ mark =
| Untyped : untyped -> untyped mark
| Typed : typed -> typed mark
| Inferring : inferring -> inferring mark
type ('a, 'm) marked = ('a, 'm mark) Marked.t
type 'm marked_expr = ('m expr, 'm) marked
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
library, based on higher-order abstract syntax*)
and 'm expr =
| EVar of 'm expr Bindlib.var
| ETuple of 'm marked_expr list * StructName.t option
(** The [MarkedString.info] is the former struct field name*)
| ETupleAccess of 'm marked_expr * int * StructName.t option * marked_typ list
(** The [MarkedString.info] is the former struct field name *)
| EInj of 'm marked_expr * int * EnumName.t * marked_typ list
(** The [MarkedString.info] is the former enum case name *)
| EMatch of 'm marked_expr * 'm marked_expr list * EnumName.t
(** The [MarkedString.info] is the former enum case name *)
| EArray of 'm marked_expr list
| ELit of lit
| EAbs of
(('m expr, 'm marked_expr) Bindlib.mbinder[@opaque]) * marked_typ list
| EApp of 'm marked_expr * 'm marked_expr list
| EAssert of 'm marked_expr
| EOp of operator
| EDefault of 'm marked_expr list * 'm marked_expr * 'm marked_expr
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
| ErrorOnEmpty of 'm marked_expr
(** {3 Expression annotations ([Marked.t])} *)
type typed_expr = typed marked_expr
type struct_ctx = (StructFieldName.t * marked_typ) list StructMap.t
type enum_ctx = (EnumConstructor.t * marked_typ) list EnumMap.t
type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx }
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder
(** This kind annotation signals that the let-binding respects a structural
invariant. These invariants concern the shape of the expression in the
let-binding, and are documented below. *)
type scope_let_kind =
| DestructuringInputStruct (** [let x = input.field]*)
| ScopeVarDefinition (** [let x = error_on_empty e]*)
| SubScopeVarDefinition
(** [let s.x = fun _ -> e] or [let s.x = error_on_empty e] for input-only
subscope variables. *)
| CallingSubScope (** [let result = s ({ x = s.x; y = s.x; ...}) ]*)
| DestructuringSubScopeResults (** [let s.x = result.x ]**)
| Assertion (** [let _ = assert e]*)
type ('expr, 'm) scope_let = {
scope_let_kind : scope_let_kind;
scope_let_typ : marked_typ;
scope_let_expr : ('expr, 'm) marked;
scope_let_next : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
scope_let_pos : Pos.t;
}
(** This type is parametrized by the expression type so it can be reused in
later intermediate representations. *)
(** A scope let-binding has all the information necessary to make a proper
let-binding expression, plus an annotation for the kind of the let-binding
that comes from the compilation of a {!module: Scopelang.Ast} statement. *)
and ('expr, 'm) scope_body_expr =
| Result of ('expr, 'm) marked
| ScopeLet of ('expr, 'm) scope_let
type ('expr, 'm) scope_body = {
scope_body_input_struct : StructName.t;
scope_body_output_struct : StructName.t;
scope_body_expr : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
}
(** Instead of being a single expression, we give a little more ad-hoc structure
to the scope body by decomposing it in an ordered list of let-bindings, and
a result expression that uses the let-binded variables. The first binder is
the argument of type [scope_body_input_struct]. *)
type ('expr, 'm) scope_def = {
scope_name : ScopeName.t;
scope_body : ('expr, 'm) scope_body;
scope_next : ('expr, ('expr, 'm) scopes) Bindlib.binder;
}
(** Finally, we do the same transformation for the whole program for the kinded
lets. This permit us to use bindlib variables for scopes names. *)
and ('expr, 'm) scopes = Nil | ScopeDef of ('expr, 'm) scope_def
type ('expr, 'm) program_generic = {
decl_ctx : decl_ctx;
scopes : ('expr, 'm) scopes;
}
type 'm expr = (dcalc, 'm mark) gexpr
and 'm marked_expr = (dcalc, 'm mark) marked_gexpr
type 'm program = ('m expr, 'm) program_generic
(** {1 Helpers} *)
(** {2 Variables} *)
type 'm var = 'm expr Var.t
type 'm vars = 'm expr Var.vars
val free_vars_expr : 'm marked_expr -> 'm expr Var.Set.t
val free_vars_scope_body_expr :
('m expr, 'm) scope_body_expr -> 'm expr Var.Set.t
val free_vars_scope_body : ('m expr, 'm) scope_body -> 'm expr Var.Set.t
val free_vars_scopes : ('m expr, 'm) scopes -> 'm expr Var.Set.t
val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box
(** {2 Manipulation of marks} *)
val no_mark : 'm mark -> 'm mark
@ -379,101 +187,6 @@ val map_expr_top_down :
val map_expr_marks :
f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box
val fold_left_scope_lets :
f:('a -> ('expr, 'm) scope_let -> 'expr Bindlib.var -> 'a) ->
init:'a ->
('expr, 'm) scope_body_expr ->
'a
(** Usage:
[fold_left_scope_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined. *)
val fold_right_scope_lets :
f:(('expr1, 'm1) scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:(('expr1, 'm1) marked -> 'a) ->
('expr1, 'm1) scope_body_expr ->
'a
(** Usage:
[fold_right_scope_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined (which are before in the program order). *)
val map_exprs_in_scope_lets :
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
('expr1, 'm1) scope_body_expr ->
('expr2, 'm2) scope_body_expr Bindlib.box
val fold_left_scope_defs :
f:('a -> ('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a) ->
init:'a ->
('expr1, 'm1) scopes ->
'a
(** Usage:
[fold_left_scope_defs ~f:(fun acc scope_def scope_var -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined. *)
val fold_right_scope_defs :
f:(('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:'a ->
('expr1, 'm1) scopes ->
'a
(** Usage:
[fold_right_scope_defs ~f:(fun scope_def scope_var acc -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined (which are before in the program order). *)
val map_scope_defs :
f:(('expr, 'm) scope_def -> ('expr, 'm) scope_def Bindlib.box) ->
('expr, 'm) scopes ->
('expr, 'm) scopes Bindlib.box
val map_exprs_in_scopes :
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
('expr1, 'm1) scopes ->
('expr2, 'm2) scopes Bindlib.box
(** This is the main map visitor for all the expressions inside all the scopes
of the program. *)
(** {2 Variables} *)
type 'm var = 'm expr Bindlib.var
val new_var : string -> 'm var
val translate_var : 'm1 var -> 'm2 var
(** used to convert between e.g. [untyped expr var] into a [typed expr var] *)
module Var : sig
type t
val t : 'm expr Bindlib.var -> t
(** Hides the marking type parameter annotation behind an existential type so
that variables can be stored in non-polymorphic sets and maps *)
val get : t -> 'm expr Bindlib.var
(** Be careful with this, it breaks the type abstraction by casting the
existential type annotation. See [!Bindlib.copy_var] for more detail. *)
val compare : t -> t -> int
val eq : t -> t -> bool
end
module VarMap : Map.S with type key = Var.t
module VarSet : Set.S with type elt = Var.t
val free_vars_expr : 'm marked_expr -> VarSet.t
val free_vars_scope_body_expr : ('m expr, 'm) scope_body_expr -> VarSet.t
val free_vars_scope_body : ('m expr, 'm) scope_body -> VarSet.t
val free_vars_scopes : ('m expr, 'm) scopes -> VarSet.t
(* type vars = expr Bindlib.mvar *)
val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box
(** {2 Boxed term constructors} *)
type ('e, 'm) make_abs_sig =

View File

@ -18,7 +18,7 @@ open Utils
open Ast
type partial_evaluation_ctx = {
var_values : typed marked_expr Ast.VarMap.t;
var_values : (typed expr, typed marked_expr) Var.Map.t;
decl_ctx : decl_ctx;
}
@ -184,7 +184,7 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm marked_expr) :
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, pos) (rec_helper e1)
let optimize_expr (decl_ctx : decl_ctx) (e : 'm marked_expr) =
partial_evaluation { var_values = VarMap.empty; decl_ctx } e
partial_evaluation { var_values = Var.Map.empty; decl_ctx } e
let rec scope_lets_map
(t : 'a -> 'm marked_expr -> 'm marked_expr Bindlib.box)
@ -250,6 +250,6 @@ let program_map
let optimize_program (p : 'm program) : untyped program =
Bindlib.unbox
(program_map partial_evaluation
{ var_values = VarMap.empty; decl_ctx = p.decl_ctx }
{ var_values = Var.Map.empty; decl_ctx = p.decl_ctx }
p)
|> untype_program

View File

@ -18,8 +18,51 @@
inference using the classical W algorithm with union-find unification. *)
open Utils
module A = Ast
open A.Infer
module A = Astgen
module Any =
Utils.Uid.Make
(struct
type info = unit
let format_info fmt () = Format.fprintf fmt "any"
end)
()
type unionfind_typ = typ Marked.pos UnionFind.elem
(** We do not reuse {!type: Dcalc.Ast.typ} because we have to include a new
[TAny] variant. Indeed, error terms can have any type and this has to be
captured by the type sytem. *)
and typ =
| TLit of A.typ_lit
| TArrow of unionfind_typ * unionfind_typ
| TTuple of unionfind_typ list * A.StructName.t option
| TEnum of unionfind_typ list * A.EnumName.t
| TArray of unionfind_typ
| TAny of Any.t
let rec typ_to_ast (ty : unionfind_typ) : A.marked_typ =
let ty, pos = UnionFind.get (UnionFind.find ty) in
match ty with
| TLit l -> TLit l, pos
| TTuple (ts, s) -> TTuple (List.map typ_to_ast ts, s), pos
| TEnum (ts, e) -> TEnum (List.map typ_to_ast ts, e), pos
| TArrow (t1, t2) -> TArrow (typ_to_ast t1, typ_to_ast t2), pos
| TAny _ -> TAny, pos
| TArray t1 -> TArray (typ_to_ast t1), pos
let rec ast_to_typ (ty : A.marked_typ) : unionfind_typ =
let ty' =
match Marked.unmark ty with
| TLit l -> TLit l
| TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
| TTuple (ts, s) -> TTuple (List.map (fun t -> ast_to_typ t) ts, s)
| TEnum (ts, e) -> TEnum (List.map (fun t -> ast_to_typ t) ts, e)
| TArray t -> TArray (ast_to_typ t)
| TAny -> TAny (Any.fresh ())
in
UnionFind.make (Marked.same_mark_as ty' ty)
(** {1 Types and unification} *)
@ -57,14 +100,16 @@ let rec format_typ
exception
Type_error of
A.untyped A.marked_expr
A.any_marked_expr
* typ Marked.pos UnionFind.elem
* typ Marked.pos UnionFind.elem
type mark = { pos : Pos.t; uf : unionfind_typ }
(** Raises an error if unification cannot be performed *)
let rec unify
(ctx : Ast.decl_ctx)
(e : 'm A.marked_expr) (* used for error context *)
(e : ('a, 'm A.mark) Ast.marked_gexpr) (* used for error context *)
(t1 : typ Marked.pos UnionFind.elem)
(t2 : typ Marked.pos UnionFind.elem) : unit =
let unify = unify ctx in
@ -72,9 +117,7 @@ let rec unify
t2; *)
let t1_repr = UnionFind.get (UnionFind.find t1) in
let t2_repr = UnionFind.get (UnionFind.find t2) in
let raise_type_error () =
raise (Type_error (Bindlib.unbox (A.untype_expr e), t1, t2))
in
let raise_type_error () = raise (Type_error (A.AnyExpr e, t1, t2)) in
let repr =
match Marked.unmark t1_repr, Marked.unmark t2_repr with
| TLit tl1, TLit tl2 when tl1 = tl2 -> None
@ -108,6 +151,11 @@ let rec unify
let handle_type_error ctx e t1 t2 =
(* TODO: if we get weird error messages, then it means that we should use the
persistent version of the union-find data structure. *)
let pos =
match e with
| A.AnyExpr e -> (
match Marked.get_mark e with Untyped { pos } | Typed { pos; _ } -> pos)
in
let t1_repr = UnionFind.get (UnionFind.find t1) in
let t2_repr = UnionFind.get (UnionFind.find t2) in
let t1_pos = Marked.get_mark t1_repr in
@ -132,7 +180,7 @@ let handle_type_error ctx e t1 t2 =
( Some
(Format.asprintf
"Error coming from typechecking the following expression:"),
A.pos e );
pos );
Some (Format.asprintf "Type %a coming from expression:" t1_s ()), t1_pos;
Some (Format.asprintf "Type %a coming from expression:" t2_s ()), t2_pos;
]
@ -213,11 +261,10 @@ let op_type (op : A.operator Marked.pos) : typ Marked.pos UnionFind.elem =
(** {1 Double-directed typing} *)
type env = typ Marked.pos UnionFind.elem A.VarMap.t
type 'e env = ('e, typ Marked.pos UnionFind.elem) Var.Map.t
let translate_var v = Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
let add_pos e ty = Marked.mark (A.pos e) ty
let ty (_, A.Inferring { A.uf; _ }) = uf
let add_pos e ty = Marked.mark (Ast.pos e) ty
let ty (_, { uf; _ }) = uf
let ( let+ ) x f = Bindlib.box_apply f x
let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2
@ -244,24 +291,24 @@ let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
(** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up
(ctx : Ast.decl_ctx)
(env : env)
(e : 'm A.marked_expr) : A.inferring A.marked_expr Bindlib.box =
(env : 'm Ast.expr env)
(e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box =
(* Cli.debug_format "Looking for type of %a" (Print.format_expr ~debug:true
ctx) e; *)
let pos_e = A.pos e in
let mark (e : A.inferring A.expr) uf =
Marked.mark (A.Inferring { A.uf; pos = pos_e }) e
let pos_e = Ast.pos e in
let mark (e : (A.dcalc, mark) A.gexpr) uf =
Marked.mark { uf; pos = pos_e } e
in
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in
match Marked.unmark e with
| A.EVar v -> begin
match A.VarMap.find_opt (A.Var.t v) env with
match Var.Map.find_opt v env with
| Some t ->
let+ v' = Bindlib.box_var (translate_var v) in
let+ v' = Bindlib.box_var (Var.translate v) in
mark v' t
| None ->
Errors.raise_spanned_error (A.pos e)
Errors.raise_spanned_error (Ast.pos e)
"Variable %s not found in the current context." (Bindlib.name_of v)
end
| A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool)
@ -285,7 +332,7 @@ let rec typecheck_expr_bottom_up
match List.nth_opt utyps n with
| Some t' -> mark (ETupleAccess (e1, n, s, typs)) t'
| None ->
Errors.raise_spanned_error (A.pos e1)
Errors.raise_spanned_error (Marked.get_mark e1).pos
"Expression should have a tuple type with at least %d elements but \
only has %d"
n (List.length typs)
@ -296,7 +343,7 @@ let rec typecheck_expr_bottom_up
match List.nth_opt ts' n with
| Some ts_n -> ts_n
| None ->
Errors.raise_spanned_error (A.pos e)
Errors.raise_spanned_error (Ast.pos e)
"Expression should have a sum type with at least %d cases but only \
has %d"
n (List.length ts')
@ -321,18 +368,16 @@ let rec typecheck_expr_bottom_up
mark (EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, taus) ->
if Bindlib.mbinder_arity binder <> List.length taus then
Errors.raise_spanned_error (A.pos e)
Errors.raise_spanned_error (Ast.pos e)
"function has %d variables but was supplied %d types"
(Bindlib.mbinder_arity binder)
(List.length taus)
else
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map translate_var xs in
let xstaus = List.mapi (fun i tau -> xs'.(i), ast_to_typ tau) taus in
let xs' = Array.map Var.translate xs in
let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in
let env =
List.fold_left
(fun env (x, tau) -> A.VarMap.add (A.Var.t x) tau env)
env xstaus
List.fold_left (fun env (x, tau) -> Var.Map.add x tau env) env xstaus
in
let body' = typecheck_expr_bottom_up ctx env body in
let t_func =
@ -402,27 +447,26 @@ let rec typecheck_expr_bottom_up
(** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down
(ctx : Ast.decl_ctx)
(env : env)
(env : 'm Ast.expr env)
(tau : typ Marked.pos UnionFind.elem)
(e : 'm A.marked_expr) : A.inferring A.marked_expr Bindlib.box =
(e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box =
(* Cli.debug_format "Propagating type %a for expr %a" (format_typ ctx) tau
(Print.format_expr ctx) e; *)
let pos_e = A.pos e in
let mark e = Marked.mark (A.Inferring { uf = tau; pos = pos_e }) e in
let unify_and_mark (e : A.inferring A.expr) tau' =
let e = Marked.mark (A.Inferring { uf = tau'; pos = pos_e }) e in
unify ctx (Bindlib.unbox (A.untype_expr e)) tau tau';
e
let pos_e = Ast.pos e in
let mark e = Marked.mark { uf = tau; pos = pos_e } e in
let unify_and_mark (e' : (A.dcalc, mark) A.gexpr) tau' =
unify ctx e tau tau';
Marked.mark { uf = tau; pos = pos_e } e'
in
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
match Marked.unmark e with
| A.EVar v -> begin
match A.VarMap.find_opt (A.Var.t v) env with
match Var.Map.find_opt v env with
| Some tau' ->
let+ v' = Bindlib.box_var (translate_var v) in
let+ v' = Bindlib.box_var (Var.translate v) in
unify_and_mark v' tau'
| None ->
Errors.raise_spanned_error (A.pos e)
Errors.raise_spanned_error pos_e
"Variable %s not found in the current context" (Bindlib.name_of v)
end
| A.ELit (LBool _) as e1 ->
@ -465,7 +509,7 @@ and typecheck_expr_top_down
match List.nth_opt ts' n with
| Some ts_n -> ts_n
| None ->
Errors.raise_spanned_error (A.pos e)
Errors.raise_spanned_error (Ast.pos e)
"Expression should have a sum type with at least %d cases but only \
has %d"
n (List.length ts)
@ -496,19 +540,19 @@ and typecheck_expr_top_down
unify_and_mark (EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, t_args) ->
if Bindlib.mbinder_arity binder <> List.length t_args then
Errors.raise_spanned_error (A.pos e)
Errors.raise_spanned_error (Ast.pos e)
"function has %d variables but was supplied %d types"
(Bindlib.mbinder_arity binder)
(List.length t_args)
else
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map translate_var xs in
let xs' = Array.map Var.translate xs in
let xstaus =
List.map2 (fun x t_arg -> x, ast_to_typ t_arg) (Array.to_list xs) t_args
in
let env =
List.fold_left
(fun env (x, t_arg) -> A.VarMap.add (A.Var.t x) t_arg env)
(fun env (x, t_arg) -> Var.Map.add x t_arg env)
env xstaus
in
let body' = typecheck_expr_bottom_up ctx env body in
@ -577,30 +621,28 @@ let wrap ctx f e =
(** {1 API} *)
let get_ty_mark (A.Inferring { uf; pos }) =
A.Typed { ty = A.Infer.typ_to_ast uf; pos }
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
(* Infer the type of an expression *)
let infer_types (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) :
let infer_types (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) :
Ast.typed Ast.marked_expr Bindlib.box =
A.map_expr_marks ~f:get_ty_mark
Astgen_utils.map_gexpr_marks ~f:get_ty_mark
@@ Bindlib.unbox
@@ wrap ctx (typecheck_expr_bottom_up ctx A.VarMap.empty) e
@@ wrap ctx (typecheck_expr_bottom_up ctx Var.Map.empty) e
let infer_type (type m) ctx (e : m A.marked_expr) =
let infer_type (type m) ctx (e : m Ast.marked_expr) =
match Marked.get_mark e with
| A.Typed { ty; _ } -> ty
| A.Inferring { uf; _ } -> typ_to_ast uf
| A.Untyped _ -> A.ty (Bindlib.unbox (infer_types ctx e))
| A.Untyped _ -> Ast.ty (Bindlib.unbox (infer_types ctx e))
(** Typechecks an expression given an expected type *)
let check_type
(ctx : Ast.decl_ctx)
(e : 'm A.marked_expr)
(e : 'm Ast.marked_expr)
(tau : A.typ Marked.pos) =
(* todo: consider using the already inferred type if ['m] = [typed] *)
ignore
@@ wrap ctx (typecheck_expr_top_down ctx A.VarMap.empty (ast_to_typ tau)) e
@@ wrap ctx (typecheck_expr_top_down ctx Var.Map.empty (ast_to_typ tau)) e
let infer_types_program prg =
let ctx = prg.A.decl_ctx in
@ -633,32 +675,34 @@ let infer_types_program prg =
| A.Result e ->
let e' = typecheck_expr_bottom_up ctx env e in
Bindlib.box_apply
(fun e ->
unify ctx e (ty e) ty_out;
A.Result e)
(fun e1 ->
unify ctx e (ty e1) ty_out;
let e1 = Astgen_utils.map_gexpr_marks ~f:get_ty_mark e1 in
A.Result (Bindlib.unbox e1))
e'
| A.ScopeLet
{
scope_let_kind;
scope_let_typ;
scope_let_expr = e;
scope_let_expr = e0;
scope_let_next;
scope_let_pos;
} ->
let ty_e = ast_to_typ scope_let_typ in
let e = typecheck_expr_bottom_up ctx env e in
let e = typecheck_expr_bottom_up ctx env e0 in
let var, next = Bindlib.unbind scope_let_next in
let env = A.VarMap.add (A.Var.t var) ty_e env in
let env = Var.Map.add var ty_e env in
let next = process_scope_body_expr env next in
let scope_let_next = Bindlib.bind_var (translate_var var) next in
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
Bindlib.box_apply2
(fun scope_let_expr scope_let_next ->
unify ctx scope_let_expr (ty scope_let_expr) ty_e;
(fun e scope_let_next ->
unify ctx e0 (ty e) ty_e;
let e = Astgen_utils.map_gexpr_marks ~f:get_ty_mark e in
A.ScopeLet
{
scope_let_kind;
scope_let_typ;
scope_let_expr;
scope_let_expr = Bindlib.unbox e;
scope_let_next;
scope_let_pos;
})
@ -666,26 +710,15 @@ let infer_types_program prg =
in
let scope_body_expr =
let var, e = Bindlib.unbind body in
let env = A.VarMap.add (A.Var.t var) ty_in env in
let env = Var.Map.add var ty_in env in
let e' = process_scope_body_expr env e in
let e' =
Bindlib.box_apply
(fun e ->
Bindlib.unbox
@@ A.map_exprs_in_scope_lets ~varf:translate_var
~f:
(A.map_expr_top_down ~f:(fun e ->
Marked.(mark (get_ty_mark (get_mark e)) (unmark e))))
e)
e'
in
Bindlib.bind_var (translate_var var) e'
Bindlib.bind_var (Var.translate var) e'
in
let scope_next =
let scope_var, next = Bindlib.unbind scope_next in
let env = A.VarMap.add (A.Var.t scope_var) ty_scope env in
let env = Var.Map.add scope_var ty_scope env in
let next' = process_scopes env next in
Bindlib.bind_var (translate_var scope_var) next'
Bindlib.bind_var (Var.translate scope_var) next'
in
Bindlib.box_apply2
(fun scope_body_expr scope_next ->
@ -702,6 +735,6 @@ let infer_types_program prg =
})
scope_body_expr scope_next
in
let scopes = wrap ctx (process_scopes A.VarMap.empty) prg.scopes in
let scopes = wrap ctx (process_scopes Var.Map.empty) prg.scopes in
Bindlib.box_apply (fun scopes -> { A.decl_ctx = ctx; scopes }) scopes
|> Bindlib.unbox

View File

@ -15,45 +15,17 @@
the License. *)
open Utils
module Runtime = Runtime_ocaml.Runtime
include Astgen
module D = Dcalc.Ast
type lit =
| LBool of bool
| LInt of Runtime.integer
| LRat of Runtime.decimal
| LMoney of Runtime.money
| LUnit
| LDate of Runtime.date
| LDuration of Runtime.duration
type lit = lcalc glit
type except = ConflictError | EmptyError | NoValueProvided | Crash
type 'm mark = 'm D.mark
type 'm marked_expr = ('m expr, 'm) D.marked
and 'm expr =
| EVar of 'm expr Bindlib.var
| ETuple of 'm marked_expr list * D.StructName.t option
(** The [MarkedString.info] is the former struct field name*)
| ETupleAccess of
'm marked_expr * int * D.StructName.t option * D.typ Marked.pos list
(** The [MarkedString.info] is the former struct field name *)
| EInj of 'm marked_expr * int * D.EnumName.t * D.typ Marked.pos list
(** The [MarkedString.info] is the former enum case name *)
| EMatch of 'm marked_expr * 'm marked_expr list * D.EnumName.t
(** The [MarkedString.info] is the former enum case name *)
| EArray of 'm marked_expr list
| ELit of lit
| EAbs of ('m expr, 'm marked_expr) Bindlib.mbinder * D.typ Marked.pos list
| EApp of 'm marked_expr * 'm marked_expr list
| EAssert of 'm marked_expr
| EOp of D.operator
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
| ERaise of except
| ECatch of 'm marked_expr * except * 'm marked_expr
type 'm expr = (lcalc, 'm mark) gexpr
and 'm marked_expr = (lcalc, 'm mark) marked_gexpr
type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic
type 'm var = 'm expr Var.t
type 'm vars = 'm expr Var.vars
(* <copy-paste from dcalc/ast.ml> *)
@ -92,23 +64,6 @@ let eop op mark = Bindlib.box (EOp op, mark)
let eifthenelse e1 e2 e3 pos =
Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), pos) e1 e2 e3
type 'm var = 'm expr Bindlib.var
type 'm vars = 'm expr Bindlib.mvar
let new_var s = Bindlib.new_var (fun x -> EVar x) s
module Var = struct
type t = V : 'a var -> t
(* See Dcalc.Ast.var *)
let t v = V v
let get (V v) = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
let compare (V x) (V y) = Bindlib.compare_vars x y
end
module VarSet = Set.Make (Var)
module VarMap = Map.Make (Var)
(* </copy-paste> *)
let eraise e1 pos = Bindlib.box (ERaise e1, pos)
@ -116,32 +71,7 @@ let eraise e1 pos = Bindlib.box (ERaise e1, pos)
let ecatch e1 exn e2 pos =
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2
let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
let map_expr ctx ~f e =
let m = Marked.get_mark e in
match Marked.unmark e with
| EVar v -> evar (translate_var v) (Marked.get_mark e)
| EApp (e1, args) ->
eapp (f ctx e1) (List.map (f ctx) args) (Marked.get_mark e)
| EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind binder in
eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m
| ETuple (args, s) -> etuple (List.map (f ctx) args) s (Marked.get_mark e)
| ETupleAccess (e1, n, s_name, typs) ->
etupleaccess ((f ctx) e1) n s_name typs (Marked.get_mark e)
| EInj (e1, i, e_name, typs) ->
einj ((f ctx) e1) i e_name typs (Marked.get_mark e)
| EMatch (arg, arms, e_name) ->
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name (Marked.get_mark e)
| EArray args -> earray (List.map (f ctx) args) (Marked.get_mark e)
| ELit l -> elit l (Marked.get_mark e)
| EAssert e1 -> eassert ((f ctx) e1) (Marked.get_mark e)
| EOp op -> Bindlib.box (EOp op, Marked.get_mark e)
| ERaise exn -> eraise exn (Marked.get_mark e)
| EIfThenElse (e1, e2, e3) ->
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Marked.get_mark e)
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) (Marked.get_mark e)
let map_expr ctx ~f e = Astgen_utils.map_gexpr ctx ~f e
let rec map_expr_top_down ~f e =
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
@ -159,7 +89,7 @@ let untype_program prg =
Bindlib.unbox
(D.map_exprs_in_scopes
~f:(fun e -> untype_expr e)
~varf:translate_var prg.D.scopes);
~varf:Var.translate prg.D.scopes);
}
(** See [Bindlib.box_term] documentation for why we are doing that. *)
@ -254,13 +184,13 @@ let make_matchopt_with_abs_arms arg e_none e_some =
e_some, permitting it to be used inside the expression. There is no
requirements on the form of both e_some and e_none. *)
let make_matchopt m v tau arg e_none e_some =
let x = new_var "_" in
let x = Var.make "_" in
make_matchopt_with_abs_arms arg
(make_abs (Array.of_list [x]) e_none [D.TLit D.TUnit, D.mark_pos m] m)
(make_abs (Array.of_list [v]) e_some [tau] m)
let handle_default = Var.t (new_var "handle_default")
let handle_default_opt = Var.t (new_var "handle_default_opt")
let handle_default = Var.make "handle_default"
let handle_default_opt = Var.make "handle_default_opt"
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder

View File

@ -15,79 +15,23 @@
the License. *)
open Utils
module Runtime = Runtime_ocaml.Runtime
include module type of Astgen
(** Abstract syntax tree for the lambda calculus *)
(** {1 Abstract syntax tree} *)
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
library, based on higher-order abstract syntax*)
type lit = lcalc glit
type lit =
| LBool of bool
| LInt of Runtime.integer
| LRat of Runtime.decimal
| LMoney of Runtime.money
| LUnit
| LDate of Runtime.date
| LDuration of Runtime.duration
type except = ConflictError | EmptyError | NoValueProvided | Crash
type 'm mark = 'm Dcalc.Ast.mark
type 'm marked_expr = ('m expr, 'm) Dcalc.Ast.marked
and 'm expr =
| EVar of 'm expr Bindlib.var
| ETuple of 'm marked_expr list * Dcalc.Ast.StructName.t option
(** The [MarkedString.info] is the former struct field name*)
| ETupleAccess of
'm marked_expr
* int
* Dcalc.Ast.StructName.t option
* Dcalc.Ast.typ Marked.pos list
(** The [MarkedString.info] is the former struct field name *)
| EInj of
'm marked_expr
* int
* Dcalc.Ast.EnumName.t
* Dcalc.Ast.typ Marked.pos list
(** The [MarkedString.info] is the former enum case name *)
| EMatch of 'm marked_expr * 'm marked_expr list * Dcalc.Ast.EnumName.t
(** The [MarkedString.info] is the former enum case name *)
| EArray of 'm marked_expr list
| ELit of lit
| EAbs of
('m expr, 'm marked_expr) Bindlib.mbinder * Dcalc.Ast.typ Marked.pos list
| EApp of 'm marked_expr * 'm marked_expr list
| EAssert of 'm marked_expr
| EOp of Dcalc.Ast.operator
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
| ERaise of except
| ECatch of 'm marked_expr * except * 'm marked_expr
type 'm expr = (lcalc, 'm mark) gexpr
and 'm marked_expr = (lcalc, 'm mark) marked_gexpr
type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic
(** {1 Variable helpers} *)
type 'm var = 'm expr Bindlib.var
type 'm vars = 'm expr Bindlib.mvar
module Var : sig
type t
val t : 'm expr Bindlib.var -> t
val get : t -> 'm expr Bindlib.var
val compare : t -> t -> int
end
module VarMap : Map.S with type key = Var.t
module VarSet : Set.S with type elt = Var.t
val new_var : string -> 'm var
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder
type 'm var = 'm expr Var.t
type 'm vars = 'm expr Var.vars
(** {2 Program traversal} *)
@ -246,5 +190,5 @@ val box_expr : 'm marked_expr -> 'm marked_expr Bindlib.box
(** {1 Special symbols} *)
val handle_default : Var.t
val handle_default_opt : Var.t
val handle_default : untyped var
val handle_default_opt : untyped var

View File

@ -21,12 +21,12 @@ module D = Dcalc.Ast
(** TODO: This version is not yet debugged and ought to be specialized when
Lcalc has more structure. *)
type ctx = { name_context : string; globally_bound_vars : VarSet.t }
type 'm ctx = { name_context : string; globally_bound_vars : 'm expr Var.Set.t }
(** Returns the expression with closed closures and the set of free variables
inside this new expression. Implementation guided by
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=9. *)
let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
m marked_expr Bindlib.box =
let module MVarSet = Set.Make (struct
type t = m var
@ -39,7 +39,7 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
( Bindlib.box_apply
(fun new_v -> new_v, Marked.get_mark e)
(Bindlib.box_var v),
if VarSet.mem (Var.t v) ctx.globally_bound_vars then MVarSet.empty
if Var.Set.mem v ctx.globally_bound_vars then MVarSet.empty
else MVarSet.singleton v )
| ETuple (args, s) ->
let new_args, free_vars =
@ -138,9 +138,9 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
in
let extra_vars_list = MVarSet.elements extra_vars in
(* x1, ..., xn *)
let code_var = new_var ctx.name_context in
let code_var = Var.make ctx.name_context in
(* code *)
let inner_c_var = new_var "env" in
let inner_c_var = Var.make "env" in
let any_ty = Dcalc.Ast.TAny, binder_pos in
let new_closure_body =
make_multiple_let_in
@ -200,8 +200,7 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
(fun new_e2 -> EApp ((EOp op, pos_op), new_e2), Marked.get_mark e)
(Bindlib.box_list new_args),
free_vars )
| EApp ((EVar v, v_pos), args)
when VarSet.mem (Var.t v) ctx.globally_bound_vars ->
| EApp ((EVar v, v_pos), args) when Var.Set.mem v ctx.globally_bound_vars ->
(* This corresponds to a scope call, which we don't want to transform*)
let new_args, free_vars =
List.fold_right
@ -217,8 +216,8 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
free_vars )
| EApp (e1, args) ->
let new_e1, free_vars = aux e1 in
let env_var = new_var "env" in
let code_var = new_var "code" in
let env_var = Var.make "env" in
let code_var = Var.make "code" in
let new_args, free_vars =
List.fold_right
(fun arg (new_args, free_vars) ->
@ -286,7 +285,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
let scope_input_var, scope_body_expr =
Bindlib.unbind scope.scope_body.scope_body_expr
in
let global_vars = VarSet.add (Var.t scope_var) global_vars in
let global_vars = Var.Set.add scope_var global_vars in
let ctx =
{
name_context =
@ -320,7 +319,10 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
new_scope_body_expr
(Bindlib.bind_var scope_var next))),
global_vars ))
~init:(Fun.id, VarSet.of_list [handle_default; handle_default_opt])
~init:
( Fun.id,
Var.Set.of_list
(List.map Var.translate [handle_default; handle_default_opt]) )
p.scopes
in
Bindlib.box_apply

View File

@ -18,7 +18,7 @@ open Utils
module D = Dcalc.Ast
module A = Ast
type 'm ctx = 'm A.var D.VarMap.t
type 'm ctx = ('m D.expr, 'm A.expr Var.t) Var.Map.t
(** This environment contains a mapping between the variables in Dcalc and their
correspondance in Lcalc. *)
@ -35,7 +35,7 @@ let translate_lit (l : D.lit) : 'm A.expr =
let thunk_expr (e : 'm A.marked_expr Bindlib.box) (mark : 'm A.mark) :
'm A.marked_expr Bindlib.box =
let dummy_var = A.new_var "_" in
let dummy_var = Var.make "_" in
A.make_abs [| dummy_var |] e [D.TAny, D.mark_pos mark] mark
let rec translate_default
@ -51,7 +51,7 @@ let rec translate_default
in
let exceptions =
A.make_app
(A.make_var (A.Var.get A.handle_default, mark_default))
(A.make_var (Var.translate A.handle_default, mark_default))
[
A.earray exceptions mark_default;
thunk_expr (translate_expr ctx just) mark_default;
@ -64,7 +64,7 @@ let rec translate_default
and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) :
'm A.marked_expr Bindlib.box =
match Marked.unmark e with
| D.EVar v -> A.make_var (D.VarMap.find (D.Var.t v) ctx, Marked.get_mark e)
| D.EVar v -> A.make_var (Var.Map.find v ctx, Marked.get_mark e)
| D.ETuple (args, s) ->
A.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e)
| D.ETupleAccess (e1, i, s, ts) ->
@ -96,8 +96,8 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) :
let ctx, lc_vars =
Array.fold_right
(fun var (ctx, lc_vars) ->
let lc_var = A.new_var (Bindlib.name_of var) in
D.VarMap.add (D.Var.t var) lc_var ctx, lc_var :: lc_vars)
let lc_var = Var.make (Bindlib.name_of var) in
Var.Map.add var lc_var ctx, lc_var :: lc_vars)
vars (ctx, [])
in
let lc_vars = Array.of_list lc_vars in
@ -126,11 +126,9 @@ let rec translate_scope_lets
let old_scope_let_var, scope_let_next =
Bindlib.unbind scope_let.scope_let_next
in
let new_scope_let_var = A.new_var (Bindlib.name_of old_scope_let_var) in
let new_scope_let_var = Var.make (Bindlib.name_of old_scope_let_var) in
let new_scope_let_expr = translate_expr ctx scope_let.scope_let_expr in
let new_ctx =
D.VarMap.add (D.Var.t old_scope_let_var) new_scope_let_var ctx
in
let new_ctx = Var.Map.add old_scope_let_var new_scope_let_var ctx in
let new_scope_next = translate_scope_lets decl_ctx new_ctx scope_let_next in
let new_scope_next = Bindlib.bind_var new_scope_let_var new_scope_next in
Bindlib.box_apply2
@ -154,15 +152,13 @@ let rec translate_scopes
| ScopeDef scope_def ->
let old_scope_var, scope_next = Bindlib.unbind scope_def.scope_next in
let new_scope_var =
A.new_var (Marked.unmark (D.ScopeName.get_info scope_def.scope_name))
Var.make (Marked.unmark (D.ScopeName.get_info scope_def.scope_name))
in
let old_scope_input_var, scope_body_expr =
Bindlib.unbind scope_def.scope_body.scope_body_expr
in
let new_scope_input_var = A.new_var (Bindlib.name_of old_scope_input_var) in
let new_ctx =
D.VarMap.add (D.Var.t old_scope_input_var) new_scope_input_var ctx
in
let new_scope_input_var = Var.make (Bindlib.name_of old_scope_input_var) in
let new_ctx = Var.Map.add old_scope_input_var new_scope_input_var ctx in
let new_scope_body_expr =
translate_scope_lets decl_ctx new_ctx scope_body_expr
in
@ -181,7 +177,7 @@ let rec translate_scopes
})
new_scope_body_expr
in
let new_ctx = D.VarMap.add (D.Var.t old_scope_var) new_scope_var new_ctx in
let new_ctx = Var.Map.add old_scope_var new_scope_var new_ctx in
let scope_next =
Bindlib.bind_var new_scope_var
(translate_scopes decl_ctx new_ctx scope_next)
@ -199,6 +195,6 @@ let rec translate_scopes
let translate_program (prgm : 'm D.program) : 'm A.program =
{
scopes =
Bindlib.unbox (translate_scopes prgm.decl_ctx D.VarMap.empty prgm.scopes);
Bindlib.unbox (translate_scopes prgm.decl_ctx Var.Map.empty prgm.scopes);
decl_ctx = prgm.decl_ctx;
}

View File

@ -40,7 +40,7 @@ module A = Ast
hoisted and later handled by the [translate_expr] function. Every other
cases is found in the translate_and_hoist function. *)
type 'm hoists = 'm D.marked_expr A.VarMap.t
type 'm hoists = ('m A.expr, 'm D.marked_expr) Var.Map.t
(** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *)
type 'm info = {
@ -60,14 +60,13 @@ let pp_info (fmt : Format.formatter) (info : 'm info) =
type 'm ctx = {
decl_ctx : D.decl_ctx;
vars : 'm info D.VarMap.t;
vars : ('m D.expr, 'm info) Var.Map.t;
(** information context about variables in the current scope *)
}
let _pp_ctx (fmt : Format.formatter) (ctx : 'm ctx) =
let pp_binding (fmt : Format.formatter) ((v, info) : D.Var.t * 'm info) =
Format.fprintf fmt "%a: %a" Dcalc.Print.format_var (D.Var.get v) pp_info
info
let pp_binding (fmt : Format.formatter) ((v, info) : 'm D.var * 'm info) =
Format.fprintf fmt "%a: %a" Dcalc.Print.format_var v pp_info info
in
let pp_bindings =
@ -76,14 +75,14 @@ let _pp_ctx (fmt : Format.formatter) (ctx : 'm ctx) =
pp_binding
in
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx.vars)
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (Var.Map.bindings ctx.vars)
(** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a
slightly better way. *)
let find ?(info : string = "none") (n : 'm D.var) (ctx : 'm ctx) : 'm info =
(* let _ = Format.asprintf "Searching for variable %a inside context %a"
Dcalc.Print.format_var n pp_ctx ctx |> Cli.debug_print in *)
try D.VarMap.find (D.Var.t n) ctx.vars
try Var.Map.find n ctx.vars
with Not_found ->
Errors.raise_spanned_error Pos.no_pos
"Internal Error: Variable %a was not found in the current environment. \
@ -96,7 +95,7 @@ let find ?(info : string = "none") (n : 'm D.var) (ctx : 'm ctx) : 'm info =
debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *)
let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx)
: 'm ctx =
let new_var = A.new_var (Bindlib.name_of var) in
let new_var = Var.make (Bindlib.name_of var) in
let expr = A.make_var (new_var, mark) in
(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var
@ -104,7 +103,7 @@ let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx)
{
ctx with
vars =
D.VarMap.update (D.Var.t var)
Var.Map.update var
(fun _ -> Some { expr; var = new_var; is_pure })
ctx.vars;
}
@ -147,16 +146,16 @@ let translate_lit (l : D.lit) (pos : Pos.t) : A.lit =
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
Raises an internal error if there is two identicals keys in differnts parts. *)
let disjoint_union_maps (pos : Pos.t) (cs : 'a A.VarMap.t list) : 'a A.VarMap.t
=
let disjoint_union_maps (pos : Pos.t) (cs : ('e, 'a) Var.Map.t list) :
('e, 'a) Var.Map.t =
let disjoint_union =
A.VarMap.union (fun _ _ _ ->
Var.Map.union (fun _ _ _ ->
Errors.raise_spanned_error pos
"Internal Error: Two supposed to be disjoints maps have one shared \
key.")
in
List.fold_left disjoint_union A.VarMap.empty cs
List.fold_left disjoint_union Var.Map.empty cs
(** [e' = translate_and_hoist ctx e ] Translate the Dcalc expression e into an
expression in Lcalc, given we translate each hoists correctly. It ensures
@ -176,34 +175,34 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
assumption can change in the future, and this case is here for this
reason. *)
if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = A.new_var (Bindlib.name_of v) in
let v' = Var.make (Bindlib.name_of v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Dcalc.Print.format_var v
Print.format_var v'; *)
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') e
else (find ~info:"should never happend" v ctx).expr, A.VarMap.empty
A.make_var (v', pos), Var.Map.singleton v' e
else (find ~info:"should never happend" v ctx).expr, Var.Map.empty
| D.EApp ((D.EVar v, p), [(D.ELit D.LUnit, _)]) ->
if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = A.new_var (Bindlib.name_of v) in
let v' = Var.make (Bindlib.name_of v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Dcalc.Print.format_var v
Print.format_var v'; *)
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') (D.EVar v, p)
A.make_var (v', pos), Var.Map.singleton v' (D.EVar v, p)
else
Errors.raise_spanned_error (D.pos e)
"Internal error: an pure variable was found in an unpure environment."
| D.EDefault (_exceptions, _just, _cons) ->
let v' = A.new_var "default_term" in
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') e
let v' = Var.make "default_term" in
A.make_var (v', pos), Var.Map.singleton v' e
| D.ELit D.LEmptyError ->
let v' = A.new_var "empty_litteral" in
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') e
let v' = Var.make "empty_litteral" in
A.make_var (v', pos), Var.Map.singleton v' e
(* This one is a very special case. It transform an unpure expression
environement to a pure expression. *)
| ErrorOnEmpty arg ->
(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)
let silent_var = A.new_var "_" in
let x = A.new_var "non_empty_argument" in
let silent_var = Var.make "_" in
let x = Var.make "non_empty_argument" in
let arg' = translate_expr ctx arg in
@ -213,9 +212,9 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
[D.TAny, D.pos e]
pos)
(A.make_abs [| x |] (A.make_var (x, pos)) [D.TAny, D.pos e] pos),
A.VarMap.empty )
Var.Map.empty )
(* pure terms *)
| D.ELit l -> A.elit (translate_lit l (D.pos e)) pos, A.VarMap.empty
| D.ELit l -> A.elit (translate_lit l (D.pos e)) pos, Var.Map.empty
| D.EIfThenElse (e1, e2, e3) ->
let e1', h1 = translate_and_hoist ctx e1 in
let e2', h2 = translate_and_hoist ctx e2 in
@ -293,12 +292,12 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in
A.earray es' pos, disjoint_union_maps (D.pos e) hoists
| EOp op -> Bindlib.box (A.EOp op, pos), A.VarMap.empty
| EOp op -> Bindlib.box (A.EOp op, pos), Var.Map.empty
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
: 'm A.marked_expr Bindlib.box =
let e', hoists = translate_and_hoist ctx e in
let hoists = A.VarMap.bindings hoists in
let hoists = Var.Map.bindings hoists in
let _pos = Marked.get_mark e in
@ -321,7 +320,7 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
let cons' = translate_expr ctx cons in
(* calls handle_option. *)
A.make_app
(A.make_var (A.Var.get A.handle_default_opt, mark_hoist))
(A.make_var (Var.translate A.handle_default_opt, mark_hoist))
[
Bindlib.box_apply
(fun excep' -> A.EArray excep', mark_hoist)
@ -336,8 +335,8 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
(* [ match arg with | None -> raise NoValueProvided | Some v -> assert
{{ v }} ] *)
let silent_var = A.new_var "_" in
let x = A.new_var "assertion_argument" in
let silent_var = Var.make "_" in
let x = Var.make "assertion_argument" in
A.make_matchopt_with_abs_arms arg'
(A.make_abs [| silent_var |]
@ -360,7 +359,7 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
] *)
(* Cli.debug_print @@ Format.asprintf "build matchopt using %a"
Print.format_var v; *)
A.make_matchopt mark_hoist (A.Var.get v)
A.make_matchopt mark_hoist v
(D.TAny, D.mark_pos mark_hoist)
c' (A.make_none mark_hoist) acc)
@ -582,7 +581,7 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
let scopes =
Bindlib.unbox
(translate_scopes { decl_ctx; vars = D.VarMap.empty } prgm.scopes)
(translate_scopes { decl_ctx; vars = Var.Map.empty } prgm.scopes)
in
{ scopes; decl_ctx }

View File

@ -416,8 +416,8 @@ let rec format_expr
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
format_with_parens arg1
| EApp ((EVar x, pos), args)
when Ast.Var.compare (Ast.Var.t x) Ast.handle_default = 0
|| Ast.Var.compare (Ast.Var.t x) Ast.handle_default_opt = 0 ->
when Var.compare x (Var.translate Ast.handle_default) = 0
|| Var.compare x (Var.translate Ast.handle_default_opt) = 0 ->
Format.fprintf fmt
"@[<hov 2>%a@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \
start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a@]"

View File

@ -19,23 +19,23 @@ module A = Ast
module L = Lcalc.Ast
module D = Dcalc.Ast
type ctxt = {
func_dict : A.TopLevelName.t L.VarMap.t;
type 'm ctxt = {
func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t;
decl_ctx : D.decl_ctx;
var_dict : A.LocalName.t L.VarMap.t;
var_dict : ('m L.expr, A.LocalName.t) Var.Map.t;
inside_definition_of : A.LocalName.t option;
context_name : string;
}
(* Expressions can spill out side effect, hence this function also returns a
list of statements to be prepended before the expression is evaluated *)
let rec translate_expr (ctxt : ctxt) (expr : 'm L.marked_expr) :
let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
A.block * A.expr Marked.pos =
match Marked.unmark expr with
| L.EVar v ->
let local_var =
try A.EVar (L.VarMap.find (L.Var.t v) ctxt.var_dict)
with Not_found -> A.EFunc (L.VarMap.find (L.Var.t v) ctxt.func_dict)
try A.EVar (Var.Map.find v ctxt.var_dict)
with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict)
in
[], (local_var, D.pos expr)
| L.ETuple (args, Some s_name) ->
@ -115,8 +115,8 @@ let rec translate_expr (ctxt : ctxt) (expr : 'm L.marked_expr) :
:: tmp_stmts,
(A.EVar tmp_var, D.pos expr) )
and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
=
and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
A.block =
match Marked.unmark block_expr with
| L.EAssert e ->
(* Assertions are always encapsulated in a unit-typed let binding *)
@ -133,7 +133,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
var_dict =
List.fold_left
(fun var_dict (x, _) ->
L.VarMap.add (L.Var.t x)
Var.Map.add x
(A.LocalName.fresh (Bindlib.name_of x, binder_pos))
var_dict)
ctxt.var_dict vars_tau;
@ -142,15 +142,14 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
let local_decls =
List.map
(fun (x, tau) ->
( A.SLocalDecl
((L.VarMap.find (L.Var.t x) ctxt.var_dict, binder_pos), tau),
( A.SLocalDecl ((Var.Map.find x ctxt.var_dict, binder_pos), tau),
binder_pos ))
vars_tau
in
let vars_args =
List.map2
(fun (x, tau) arg ->
(L.VarMap.find (L.Var.t x) ctxt.var_dict, binder_pos), tau, arg)
(Var.Map.find x ctxt.var_dict, binder_pos), tau, arg)
vars_tau args
in
let def_blocks =
@ -185,7 +184,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
var_dict =
List.fold_left
(fun var_dict (x, _) ->
L.VarMap.add (L.Var.t x)
Var.Map.add x
(A.LocalName.fresh (Bindlib.name_of x, binder_pos))
var_dict)
ctxt.var_dict vars_tau;
@ -200,7 +199,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
func_params =
List.map
(fun (var, tau) ->
(L.VarMap.find (L.Var.t var) ctxt.var_dict, binder_pos), tau)
(Var.Map.find var ctxt.var_dict, binder_pos), tau)
vars_tau;
func_body = new_body;
} ),
@ -220,10 +219,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
A.LocalName.fresh (Bindlib.name_of var, D.pos arg)
in
let ctxt =
{
ctxt with
var_dict = L.VarMap.add (L.Var.t var) scalc_var ctxt.var_dict;
}
{ ctxt with var_dict = Var.Map.add var scalc_var ctxt.var_dict }
in
let new_arg = translate_statements ctxt body in
(new_arg, scalc_var) :: new_args
@ -275,8 +271,8 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
let rec translate_scope_body_expr
(scope_name : D.ScopeName.t)
(decl_ctx : D.decl_ctx)
(var_dict : A.LocalName.t L.VarMap.t)
(func_dict : A.TopLevelName.t L.VarMap.t)
(var_dict : ('m L.expr, A.LocalName.t) Var.Map.t)
(func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t)
(scope_expr : ('m L.expr, 'm) D.scope_body_expr) : A.block =
match scope_expr with
| Result e ->
@ -297,7 +293,7 @@ let rec translate_scope_body_expr
let let_var_id =
A.LocalName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos)
in
let new_var_dict = L.VarMap.add (L.Var.t let_var) let_var_id var_dict in
let new_var_dict = Var.Map.add let_var let_var_id var_dict in
(match scope_let.scope_let_kind with
| D.Assertion ->
translate_statements
@ -349,7 +345,7 @@ let translate_program (p : 'm L.program) : A.program =
A.LocalName.fresh (Bindlib.name_of scope_input_var, input_pos)
in
let var_dict =
L.VarMap.singleton (L.Var.t scope_input_var) scope_input_var_id
Var.Map.singleton scope_input_var scope_input_var_id
in
let new_scope_body =
translate_scope_body_expr scope_def.D.scope_name p.decl_ctx
@ -358,9 +354,7 @@ let translate_program (p : 'm L.program) : A.program =
let func_id =
A.TopLevelName.fresh (Bindlib.name_of scope_var, Pos.no_pos)
in
let func_dict =
L.VarMap.add (L.Var.t scope_var) func_id func_dict
in
let func_dict = Var.Map.add scope_var func_id func_dict in
( func_dict,
{
Ast.scope_body_name = scope_def.D.scope_name;
@ -387,8 +381,8 @@ let translate_program (p : 'm L.program) : A.program =
:: new_scopes ))
~init:
( (if !Cli.avoid_exceptions_flag then
L.VarMap.singleton L.handle_default_opt A.handle_default_opt
else L.VarMap.singleton L.handle_default A.handle_default),
Var.Map.singleton L.handle_default_opt A.handle_default_opt
else Var.Map.singleton L.handle_default A.handle_default),
[] )
p.D.scopes
in

View File

@ -332,9 +332,7 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
Bindlib.box_apply Marked.unmark new_e
| EAbs (binder, typ) ->
let xs, body = Bindlib.unmbind binder in
let new_xs =
Array.map (fun x -> Dcalc.Ast.new_var (Bindlib.name_of x)) xs
in
let new_xs = Array.map (fun x -> Var.make (Bindlib.name_of x)) xs in
let both_xs = Array.map2 (fun x new_x -> x, new_x) xs new_xs in
let body =
translate_expr
@ -415,7 +413,7 @@ let translate_rule
match rule with
| Definition ((ScopeVar a, var_def_pos), tau, a_io, e) ->
let a_name = Ast.ScopeVar.get_info (Marked.unmark a) in
let a_var = Dcalc.Ast.new_var (Marked.unmark a_name) in
let a_var = Var.make (Marked.unmark a_name) in
let tau = translate_typ ctx tau in
let new_e = translate_expr ctx e in
let a_expr = Dcalc.Ast.make_var (a_var, pos_mark var_def_pos) in
@ -469,14 +467,14 @@ let translate_rule
^ Marked.unmark (Ast.ScopeVar.get_info (Marked.unmark subs_var)))
(Ast.SubScopeName.get_info (Marked.unmark subs_index))
in
let a_var = Dcalc.Ast.new_var (Marked.unmark a_name) in
let a_var = Var.make (Marked.unmark a_name) in
let tau = translate_typ ctx tau in
let new_e =
tag_with_log_entry (translate_expr ctx e)
(Dcalc.Ast.VarDef (Marked.unmark tau))
[sigma_name, pos_sigma; a_name]
in
let silent_var = Dcalc.Ast.new_var "_" in
let silent_var = Var.make "_" in
let thunked_or_nonempty_new_e =
match Marked.unmark a_io.io_input with
| NoInput -> failwith "should not happen"
@ -582,7 +580,7 @@ let translate_rule
List.map
(fun (subvar : scope_var_ctx) ->
let sub_dcalc_var =
Dcalc.Ast.new_var
Var.make
(Marked.unmark (Ast.SubScopeName.get_info subindex)
^ "."
^ Marked.unmark (Ast.ScopeVar.get_info subvar.scope_var_name))
@ -613,7 +611,7 @@ let translate_rule
Ast.ScopeName.get_info subname;
]
in
let result_tuple_var = Dcalc.Ast.new_var "result" in
let result_tuple_var = Var.make "result" in
let result_tuple_typ =
( Dcalc.Ast.TTuple
( List.map
@ -698,7 +696,7 @@ let translate_rule
new_e;
Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion;
})
(Bindlib.bind_var (Dcalc.Ast.new_var "_") next)
(Bindlib.bind_var (Var.make "_") next)
new_e),
ctx )
@ -753,7 +751,7 @@ let translate_scope_decl
(sigma : Ast.scope_decl) :
(Dcalc.Ast.untyped Dcalc.Ast.expr, Dcalc.Ast.untyped) Dcalc.Ast.scope_body
Bindlib.box
* Dcalc.Ast.struct_ctx =
* Astgen.struct_ctx =
let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in
let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in
let scope_variables = scope_sig.scope_sig_local_vars in
@ -765,9 +763,7 @@ let translate_scope_decl
match Marked.unmark scope_var.scope_var_io.io_input with
| OnlyInput ->
let scope_var_name = Ast.ScopeVar.get_info scope_var.scope_var_name in
let scope_var_dcalc =
Dcalc.Ast.new_var (Marked.unmark scope_var_name)
in
let scope_var_dcalc = Var.make (Marked.unmark scope_var_name) in
{
ctx with
scope_vars =
@ -916,7 +912,7 @@ let translate_program (prgm : Ast.program) :
Ast.ScopeMap.mapi
(fun scope_name scope ->
let scope_dvar =
Dcalc.Ast.new_var
Var.make
(Marked.unmark (Ast.ScopeName.get_info scope.Ast.scope_decl_name))
in
let scope_return_struct_name =
@ -926,8 +922,7 @@ let translate_program (prgm : Ast.program) :
(Ast.ScopeName.get_info scope_name))
in
let scope_input_var =
Dcalc.Ast.new_var
(Marked.unmark (Ast.ScopeName.get_info scope_name) ^ "_in")
Var.make (Marked.unmark (Ast.ScopeName.get_info scope_name) ^ "_in")
in
let scope_input_struct_name =
Ast.StructName.fresh

293
compiler/utils/astgen.ml Normal file
View File

@ -0,0 +1,293 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
module Runtime = Runtime_ocaml.Runtime
module ScopeName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
module EnumName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
(** Abstract syntax tree for the default calculus *)
(** {1 Abstract syntax tree} *)
(** {2 Types} *)
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
type marked_typ = typ Marked.pos
and typ =
| TLit of typ_lit
| TTuple of marked_typ list * StructName.t option
| TEnum of marked_typ list * EnumName.t
| TArrow of marked_typ * marked_typ
| TArray of marked_typ
| TAny
(** {2 Constants and operators} *)
type date = Runtime.date
type duration = Runtime.duration
type op_kind =
| KInt
| KRat
| KMoney
| KDate
| KDuration (** All ops don't have a KDate and KDuration. *)
type ternop = Fold
type binop =
| And
| Or
| Xor
| Add of op_kind
| Sub of op_kind
| Mult of op_kind
| Div of op_kind
| Lt of op_kind
| Lte of op_kind
| Gt of op_kind
| Gte of op_kind
| Eq
| Neq
| Map
| Concat
| Filter
type log_entry =
| VarDef of typ
(** During code generation, we need to know the type of the variable being
logged for embedding *)
| BeginCall
| EndCall
| PosRecordIfTrueBool
type unop =
| Not
| Minus of op_kind
| Log of log_entry * Uid.MarkedString.info list
| Length
| IntToRat
| MoneyToRat
| RatToMoney
| GetDay
| GetMonth
| GetYear
| FirstDayOfMonth
| LastDayOfMonth
| RoundMoney
| RoundDecimal
type operator = Ternop of ternop | Binop of binop | Unop of unop
type except = ConflictError | EmptyError | NoValueProvided | Crash
(** {2 Generic expressions} *)
(** Define a common base type for the expressions in most passes of the compiler *)
type desugared = [ `Desugared ]
type scopelang = [ `Scopelang ]
type dcalc = [ `Dcalc ]
type lcalc = [ `Lcalc ]
type scalc = [ `Scalc ]
type any = [ desugared | scopelang | dcalc | lcalc | scalc ]
(** Literals are the same throughout compilation except for the [LEmptyError]
case which is eliminated midway through. *)
type 'a glit =
| LBool : bool -> 'a glit
| LEmptyError : [< desugared | scopelang | dcalc ] glit
| LInt : Runtime.integer -> 'a glit
| LRat : Runtime.decimal -> 'a glit
| LMoney : Runtime.money -> 'a glit
| LUnit : 'a glit
| LDate : date -> 'a glit
| LDuration : duration -> 'a glit
type ('a, 't) marked_gexpr = (('a, 't) gexpr, 't) Marked.t
(** General expressions: groups all expression cases of the different ASTs, and
uses a GADT to eliminate irrelevant cases for each one. The ['t] annotations
are also totally unconstrained at this point. The dcalc exprs, for example,
are then defined with [type expr = dcalc gexpr] plus the annotations. *)
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
library, based on higher-order abstract syntax *)
and ('a, 't) gexpr =
(* Constructors common to all ASTs *)
| ELit : 'a glit -> ('a, 't) gexpr
| EApp : ('a, 't) marked_gexpr * ('a, 't) marked_gexpr list -> ('a, 't) gexpr
| EOp : operator -> ('a, 't) gexpr
| EArray : ('a, 't) marked_gexpr list -> ('a, 't) gexpr
(* All but statement calculus *)
| EVar :
('a, 't) gexpr Bindlib.var
-> (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr
| EAbs :
(('a, 't) gexpr, ('a, 't) marked_gexpr) Bindlib.mbinder
* typ Marked.pos list
-> (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr
| EIfThenElse :
('a, 't) marked_gexpr * ('a, 't) marked_gexpr * ('a, 't) marked_gexpr
-> (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr
(* (* Early stages *) | ELocation: location -> ([< desugared | scopelang ] as
'a, 't) gexpr | EStruct: StructName.t * ('a, 't) marked_gexpr
StructFieldMap.t -> ([< desugared | scopelang ] as 'a, 't) gexpr |
EStructAccess: ('a, 't) marked_gexpr * StructFieldName.t * StructName.t ->
([< desugared | scopelang ] as 'a, 't) gexpr | EEnumInj: ('a, 't)
marked_gexpr * EnumConstructor.t * EnumName.t -> ([< desugared | scopelang
] as 'a, 't) gexpr | EMatchS: ('a, 't) marked_gexpr * EnumName.t * ('a, 't)
marked_gexpr EnumConstructorMap.t -> ([< desugared | scopelang ] as 'a, 't)
gexpr *)
(* Lambda-like *)
| ETuple :
('a, 't) marked_gexpr list * StructName.t option
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
| ETupleAccess :
('a, 't) marked_gexpr * int * StructName.t option * typ Marked.pos list
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
| EInj :
('a, 't) marked_gexpr * int * EnumName.t * typ Marked.pos list
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
| EMatch :
('a, 't) marked_gexpr * ('a, 't) marked_gexpr list * EnumName.t
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
| EAssert : ('a, 't) marked_gexpr -> (([< dcalc | lcalc ] as 'a), 't) gexpr
(* Default terms *)
| EDefault :
('a, 't) marked_gexpr list * ('a, 't) marked_gexpr * ('a, 't) marked_gexpr
-> (([< desugared | scopelang | dcalc ] as 'a), 't) gexpr
| ErrorOnEmpty :
('a, 't) marked_gexpr
-> (([< desugared | scopelang | dcalc ] as 'a), 't) gexpr
(* Lambda calculus with exceptions *)
| ERaise : except -> ((lcalc as 'a), 't) gexpr
| ECatch :
('a, 't) marked_gexpr * except * ('a, 't) marked_gexpr
-> ((lcalc as 'a), 't) gexpr
(* (\* Statement calculus *\)
* | ESVar: LocalName.t -> (scalc as 'a, 't) gexpr
* | ESStruct: ('a, 't) marked_gexpr list * StructName.t -> (scalc as 'a, 't) gexpr
* | ESStructFieldAccess: ('a, 't) marked_gexpr * StructFieldName.t * StructName.t -> (scalc as 'a, 't) gexpr
* | ESInj: ('a, 't) marked_gexpr * EnumConstructor.t * EnumName.t -> (scalc as 'a, 't) gexpr
* | ESFunc: TopLevelName.t -> (scalc as 'a, 't) gexpr *)
(** {2 Markings} *)
type untyped = { pos : Pos.t } [@@ocaml.unboxed]
type typed = { pos : Pos.t; ty : marked_typ }
(* type inferring = { pos : Pos.t; uf : Infer.unionfind_typ } *)
(** The generic type of AST markings. Using a GADT allows functions to be
polymorphic in the marking, but still do transformations on types when
appropriate. Expected to fill the ['t] parameter of [gexpr] and
[marked_gexpr] *)
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
(* | Inferring : inferring -> inferring mark *)
type ('a, 'm) marked = ('a, 'm mark) Marked.t
(** Useful for errors and printing, for example *)
type any_marked_expr =
| AnyExpr : ([< any ], 'm mark) marked_gexpr -> any_marked_expr
(** {2 Higher-level program structure} *)
(** Constructs scopes and programs on top of expressions. We may use the [gexpr]
type above at some point, but at the moment this is polymorphic in the types
of the expressions. Their markings are constrained to belong to the [mark]
GADT defined above. *)
(** This kind annotation signals that the let-binding respects a structural
invariant. These invariants concern the shape of the expression in the
let-binding, and are documented below. *)
type scope_let_kind =
| DestructuringInputStruct (** [let x = input.field]*)
| ScopeVarDefinition (** [let x = error_on_empty e]*)
| SubScopeVarDefinition
(** [let s.x = fun _ -> e] or [let s.x = error_on_empty e] for input-only
subscope variables. *)
| CallingSubScope (** [let result = s ({ x = s.x; y = s.x; ...}) ]*)
| DestructuringSubScopeResults (** [let s.x = result.x ]**)
| Assertion (** [let _ = assert e]*)
type ('expr, 'm) scope_let = {
scope_let_kind : scope_let_kind;
scope_let_typ : marked_typ;
scope_let_expr : ('expr, 'm) marked;
scope_let_next : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
scope_let_pos : Pos.t;
}
(** This type is parametrized by the expression type so it can be reused in
later intermediate representations. *)
(** A scope let-binding has all the information necessary to make a proper
let-binding expression, plus an annotation for the kind of the let-binding
that comes from the compilation of a {!module: Scopelang.Ast} statement. *)
and ('expr, 'm) scope_body_expr =
| Result of ('expr, 'm) marked
| ScopeLet of ('expr, 'm) scope_let
type ('expr, 'm) scope_body = {
scope_body_input_struct : StructName.t;
scope_body_output_struct : StructName.t;
scope_body_expr : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
}
(** Instead of being a single expression, we give a little more ad-hoc structure
to the scope body by decomposing it in an ordered list of let-bindings, and
a result expression that uses the let-binded variables. The first binder is
the argument of type [scope_body_input_struct]. *)
type ('expr, 'm) scope_def = {
scope_name : ScopeName.t;
scope_body : ('expr, 'm) scope_body;
scope_next : ('expr, ('expr, 'm) scopes) Bindlib.binder;
}
(** Finally, we do the same transformation for the whole program for the kinded
lets. This permit us to use bindlib variables for scopes names. *)
and ('expr, 'm) scopes = Nil | ScopeDef of ('expr, 'm) scope_def
type struct_ctx = (StructFieldName.t * marked_typ) list StructMap.t
type decl_ctx = {
ctx_enums : (EnumConstructor.t * marked_typ) list EnumMap.t;
ctx_structs : struct_ctx;
}
type ('expr, 'm) program_generic = {
decl_ctx : decl_ctx;
scopes : ('expr, 'm) scopes;
}

View File

@ -0,0 +1,180 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Astgen
(** Functions handling the types in [Astgen] *)
let evar v mark = Bindlib.box_apply (Marked.mark mark) (Bindlib.box_var v)
let etuple args s mark =
Bindlib.box_apply (fun args -> ETuple (args, s), mark) (Bindlib.box_list args)
let etupleaccess e1 i s typs mark =
Bindlib.box_apply (fun e1 -> ETupleAccess (e1, i, s, typs), mark) e1
let einj e1 i e_name typs mark =
Bindlib.box_apply (fun e1 -> EInj (e1, i, e_name, typs), mark) e1
let ematch arg arms e_name mark =
Bindlib.box_apply2
(fun arg arms -> EMatch (arg, arms, e_name), mark)
arg (Bindlib.box_list arms)
let earray args mark =
Bindlib.box_apply (fun args -> EArray args, mark) (Bindlib.box_list args)
let elit l mark = Bindlib.box (ELit l, mark)
let eabs binder typs mark =
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) binder
let eapp e1 args mark =
Bindlib.box_apply2
(fun e1 args -> EApp (e1, args), mark)
e1 (Bindlib.box_list args)
let eassert e1 mark = Bindlib.box_apply (fun e1 -> EAssert e1, mark) e1
let eop op mark = Bindlib.box (EOp op, mark)
let edefault excepts just cons mark =
Bindlib.box_apply3
(fun excepts just cons -> EDefault (excepts, just, cons), mark)
(Bindlib.box_list excepts) just cons
let eifthenelse e1 e2 e3 mark =
Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), mark) e1 e2 e3
let eerroronempty e1 mark =
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1
let eraise e1 pos = Bindlib.box (ERaise e1, pos)
let ecatch e1 exn e2 pos =
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2
let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
let map_gexpr
(type a)
(ctx : 'ctx)
~(f : 'ctx -> (a, 'm1) marked_gexpr -> (a, 'm2) marked_gexpr Bindlib.box)
(e : ((a, 'm1) gexpr, 'm2) Marked.t) : (a, 'm2) marked_gexpr Bindlib.box =
let m = Marked.get_mark e in
match Marked.unmark e with
| ELit l -> elit l m
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m
| EOp op -> Bindlib.box (EOp op, m)
| EArray args -> earray (List.map (f ctx) args) m
| EVar v -> evar (translate_var v) m
| EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind binder in
eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m
| EIfThenElse (e1, e2, e3) ->
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
| ETupleAccess (e1, n, s_name, typs) ->
etupleaccess ((f ctx) e1) n s_name typs m
| EInj (e1, i, e_name, typs) -> einj ((f ctx) e1) i e_name typs m
| EMatch (arg, arms, e_name) ->
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
| EAssert e1 -> eassert ((f ctx) e1) m
| EDefault (excepts, just, cons) ->
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) (Marked.get_mark e)
| ERaise exn -> eraise exn (Marked.get_mark e)
let rec map_gexpr_top_down ~f e =
map_gexpr () ~f:(fun () -> map_gexpr_top_down ~f) (f e)
let map_gexpr_marks ~f e =
map_gexpr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
let rec fold_left_scope_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result _ -> init
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
fold_left_scope_lets ~f ~init:(f init scope_let var) next
let rec fold_right_scope_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result result -> init result
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
let next_result = fold_right_scope_lets ~f ~init next in
f scope_let var next_result
let map_exprs_in_scope_lets ~f ~varf scope_body_expr =
fold_right_scope_lets
~f:(fun scope_let var_next acc ->
Bindlib.box_apply2
(fun scope_let_next scope_let_expr ->
ScopeLet { scope_let with scope_let_next; scope_let_expr })
(Bindlib.bind_var (varf var_next) acc)
(f scope_let.scope_let_expr))
~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res))
scope_body_expr
let rec fold_left_scope_defs ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var, next = Bindlib.unbind scope_def.scope_next in
fold_left_scope_defs ~f ~init:(f init scope_def var) next
let rec fold_right_scope_defs ~f ~init scopes =
match scopes with
| Nil -> init
| ScopeDef scope_def ->
let var_next, next = Bindlib.unbind scope_def.scope_next in
let result_next = fold_right_scope_defs ~f ~init next in
f scope_def var_next result_next
let map_scope_defs ~f scopes =
fold_right_scope_defs
~f:(fun scope_def var_next acc ->
let new_scope_def = f scope_def in
let new_next = Bindlib.bind_var var_next acc in
Bindlib.box_apply2
(fun new_scope_def new_next ->
ScopeDef { new_scope_def with scope_next = new_next })
new_scope_def new_next)
~init:(Bindlib.box Nil) scopes
let map_exprs_in_scopes ~f ~varf scopes =
fold_right_scope_defs
~f:(fun scope_def var_next acc ->
let scope_input_var, scope_lets =
Bindlib.unbind scope_def.scope_body.scope_body_expr
in
let new_scope_body_expr = map_exprs_in_scope_lets ~f ~varf scope_lets in
let new_scope_body_expr =
Bindlib.bind_var (varf scope_input_var) new_scope_body_expr
in
let new_next = Bindlib.bind_var (varf var_next) acc in
Bindlib.box_apply2
(fun scope_body_expr scope_next ->
ScopeDef
{
scope_def with
scope_body = { scope_def.scope_body with scope_body_expr };
scope_next;
})
new_scope_body_expr new_next)
~init:(Bindlib.box Nil) scopes

View File

@ -0,0 +1,183 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Functions handling the types in [Astgen] *)
open Astgen
(** {2 Boxed constructors} *)
val evar :
(([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr Bindlib.var ->
't ->
('a, 't) marked_gexpr Bindlib.box
val etuple :
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box list ->
StructName.t option ->
't ->
('a, 't) marked_gexpr Bindlib.box
val etupleaccess :
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
int ->
StructName.t option ->
marked_typ list ->
't ->
('a, 't) marked_gexpr Bindlib.box
val einj :
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
int ->
EnumName.t ->
marked_typ list ->
't ->
('a, 't) marked_gexpr Bindlib.box
val ematch :
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
('a, 't) marked_gexpr Bindlib.box list ->
EnumName.t ->
't ->
('a, 't) marked_gexpr Bindlib.box
val earray :
('a, 't) marked_gexpr Bindlib.box list ->
't ->
('a, 't) marked_gexpr Bindlib.box
val elit : 'a glit -> 't -> ('a, 't) marked_gexpr Bindlib.box
val eabs :
( (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr,
('a, 't) marked_gexpr )
Bindlib.mbinder
Bindlib.box ->
marked_typ list ->
't ->
('a, 't) marked_gexpr Bindlib.box
val eapp :
('a, 't) marked_gexpr Bindlib.box ->
('a, 't) marked_gexpr Bindlib.box list ->
't ->
('a, 't) marked_gexpr Bindlib.box
val eassert :
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
't ->
('a, 't) marked_gexpr Bindlib.box
val eop : operator -> 't -> ('a, 't) marked_gexpr Bindlib.box
val edefault :
(([< desugared | scopelang | dcalc ] as 'a), 't) marked_gexpr Bindlib.box list ->
('a, 't) marked_gexpr Bindlib.box ->
('a, 't) marked_gexpr Bindlib.box ->
't ->
('a, 't) marked_gexpr Bindlib.box
val eifthenelse :
(([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) marked_gexpr
Bindlib.box ->
('a, 't) marked_gexpr Bindlib.box ->
('a, 't) marked_gexpr Bindlib.box ->
't ->
('a, 't) marked_gexpr Bindlib.box
val eerroronempty :
(([< desugared | scopelang | dcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
't ->
('a, 't) marked_gexpr Bindlib.box
(** ---------- *)
val map_gexpr :
'ctx ->
f:('ctx -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box) ->
(('a, 't1) gexpr, 't2) Marked.t ->
('a, 't2) marked_gexpr Bindlib.box
val map_gexpr_top_down :
f:(('a, 't1) marked_gexpr -> (('a, 't1) gexpr, 't2) Marked.t) ->
('a, 't1) marked_gexpr ->
('a, 't2) marked_gexpr Bindlib.box
(** Recursively applies [f] to the nodes of the expression tree. The type
returned by [f] is hybrid since the mark at top-level has been rewritten,
but not yet the marks in the subtrees. *)
val map_gexpr_marks :
f:('t1 -> 't2) -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box
val fold_left_scope_lets :
f:('a -> ('expr, 'm) scope_let -> 'expr Bindlib.var -> 'a) ->
init:'a ->
('expr, 'm) scope_body_expr ->
'a
(** Usage:
[fold_left_scope_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined. *)
val fold_right_scope_lets :
f:(('expr1, 'm1) scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:(('expr1, 'm1) marked -> 'a) ->
('expr1, 'm1) scope_body_expr ->
'a
(** Usage:
[fold_right_scope_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined (which are before in the program order). *)
val map_exprs_in_scope_lets :
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
('expr1, 'm1) scope_body_expr ->
('expr2, 'm2) scope_body_expr Bindlib.box
val fold_left_scope_defs :
f:('a -> ('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a) ->
init:'a ->
('expr1, 'm1) scopes ->
'a
(** Usage:
[fold_left_scope_defs ~f:(fun acc scope_def scope_var -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined. *)
val fold_right_scope_defs :
f:(('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) ->
init:'a ->
('expr1, 'm1) scopes ->
'a
(** Usage:
[fold_right_scope_defs ~f:(fun scope_def scope_var acc -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined (which are before in the program order). *)
val map_scope_defs :
f:(('expr, 'm) scope_def -> ('expr, 'm) scope_def Bindlib.box) ->
('expr, 'm) scopes ->
('expr, 'm) scopes Bindlib.box
val map_exprs_in_scopes :
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
('expr1, 'm1) scopes ->
('expr2, 'm2) scopes Bindlib.box
(** This is the main map visitor for all the expressions inside all the scopes
of the program. *)

View File

@ -1,7 +1,7 @@
(library
(name utils)
(public_name catala.utils)
(libraries cmdliner ubase ANSITerminal re))
(libraries cmdliner ubase ANSITerminal re bindlib catala.runtime_ocaml))
(documentation
(package catala)

View File

@ -21,6 +21,7 @@ type 'a pos = ('a, Pos.t) t
let mark m e : ('a, 'm) t = e, m
let unmark ((x, _) : ('a, 'm) t) : 'a = x
let get_mark ((_, x) : ('a, 'm) t) : 'm = x
let map_mark (f : 'm1 -> 'm2) ((a, m) : ('a, 'm1) t) : ('a, 'm2) t = a, f m
let map_under_mark (f : 'a -> 'b) ((x, y) : ('a, 'm) t) : ('b, 'c) t = f x, y
let same_mark_as (x : 'a) ((_, y) : ('b, 'm) t) : ('a, 'm) t = x, y

View File

@ -27,6 +27,7 @@ type 'a pos = ('a, Pos.t) t
val mark : 'm -> 'a -> ('a, 'm) t
val unmark : ('a, 'm) t -> 'a
val get_mark : ('a, 'm) t -> 'm
val map_mark : ('m1 -> 'm2) -> ('a, 'm1) t -> ('a, 'm2) t
val map_under_mark : ('a -> 'b) -> ('a, 'm) t -> ('b, 'm) t
val same_mark_as : 'a -> ('b, 'm) t -> ('a, 'm) t
val unmark_option : ('a, 'm) t option -> 'a option

108
compiler/utils/var.ml Normal file
View File

@ -0,0 +1,108 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Astgen
(** {1 Variables and their collections} *)
(** This module provides types and helpers for Bindlib variables on the
[Astgen.gexpr] type *)
(* The subtypes of the generic AST that hold vars *)
type 'e expr = 'e
constraint 'e = ([< desugared | scopelang | dcalc | lcalc ], 't) gexpr
type 'e var = 'e expr Bindlib.var
type 'e t = 'e var
type 'e vars = 'e expr Bindlib.mvar
type 'e binder = (('a, 't) gexpr, ('a, 't) marked_gexpr) Bindlib.binder
constraint 'e = ('a, 't) gexpr
let make (name : string) : 'e var = Bindlib.new_var (fun x -> EVar x) name
let compare = Bindlib.compare_vars
let eq = Bindlib.eq_vars
let translate (v : 'e1 var) : 'e2 var =
Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
(* The purpose of this module is just to lift a type parameter outside of
[Set.S] and [Map.S], so that we can have ['e Var.Set.t] for sets of variables
bound to the ['e = ('a, 't) gexpr] expression type. This is made possible by
the fact that [Bindlib.compare_vars] is polymorphic in that parameter; we
first hide that parameter inside an existential, then re-add a phantom type
outside of the set to ensure consistency. Extracting the elements is then
done with [Bindlib.copy_var] but technically it's not much different from an
[Obj] conversion.
If anyone has a better solution, besides a copy-paste of Set.Make / Map.Make
code... *)
module Generic = struct
(* Existentially quantify the type parameters to allow application of
Set.Make *)
type t = Var : 'e var -> t
(* Note: adding [[@@ocaml.unboxed]] would be OK and make our wrappers live at
the type-level without affecting the actual data representation. But
[Bindlib.var] being abstract, we can't convince OCaml it's ok at the moment
and have to hold it *)
let t v = Var v
let get (Var v) = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
let compare (Var x) (Var y) = Bindlib.compare_vars x y
let eq (Var x) (Var y) = Bindlib.eq_vars x y
end
(* Wrapper around Set.Make to re-add type parameters (avoid inconsistent
sets) *)
module Set = struct
open Generic
open Set.Make (Generic)
type nonrec 'e t = t constraint 'e = 'e expr
let empty = empty
let singleton x = singleton (t x)
let add x s = add (t x) s
let remove x s = remove (t x) s
let union s1 s2 = union s1 s2
let mem x s = mem (t x) s
let of_list l = of_list (List.map t l)
let elements s = elements s |> List.map get
(* Add more as needed *)
end
(* Wrapper around Map.Make to re-add type parameters (avoid inconsistent
maps) *)
module Map = struct
open Generic
open Map.Make (Generic)
type nonrec ('e, 'x) t = 'x t constraint 'e = 'e expr
let empty = empty
let singleton v x = singleton (t v) x
let add v x m = add (t v) x m
let update v f m = update (t v) f m
let find v m = find (t v) m
let find_opt v m = find_opt (t v) m
let bindings m = bindings m |> List.map (fun (v, x) -> get v, x)
let mem x m = mem (t x) m
let union f m1 m2 = union (fun v x1 x2 -> f (get v) x1 x2) m1 m2
let fold f m acc = fold (fun v x acc -> f (get v) x acc) m acc
(* Add more as needed *)
end

73
compiler/utils/var.mli Normal file
View File

@ -0,0 +1,73 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Astgen
(** {1 Variables and their collections} *)
(** This module provides types and helpers for Bindlib variables on the
[Astgen.gexpr] type *)
type 'e expr = 'e
constraint 'e = ([< desugared | scopelang | dcalc | lcalc ], 't) gexpr
(** Subtype of Astgen.gexpr where variables are handled *)
type 'e var = 'e expr Bindlib.var
type 'e t = 'e var
type 'e vars = 'e expr Bindlib.mvar
val make : string -> 'e t
val compare : 'e t -> 'e t -> int
val eq : 'e t -> 'e t -> bool
val translate : 'e1 t -> 'e2 t
(** Needed when converting from one AST type to another. See the note of caution
on [Bindlib.copy_var]. *)
(** Wrapper over [Set.S] but with a type variable for the AST type parameters.
Extend as needed *)
module Set : sig
type 'e t constraint 'e = 'e expr
val empty : 'e t
val singleton : 'e var -> 'e t
val add : 'e var -> 'e t -> 'e t
val remove : 'e var -> 'e t -> 'e t
val union : 'e t -> 'e t -> 'e t
val mem : 'e var -> 'e t -> bool
val of_list : 'e var list -> 'e t
val elements : 'e t -> 'e var list
end
(** Wrapper over [Map.S] but with a type variable for the AST type parameters.
Extend as needed *)
module Map : sig
type ('e, 'x) t constraint 'e = 'e expr
val empty : ('e, 'x) t
val singleton : 'e var -> 'x -> ('e, 'x) t
val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t
val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t
val find : 'e var -> ('e, 'x) t -> 'x
val find_opt : 'e var -> ('e, 'x) t -> 'x option
val bindings : ('e, 'x) t -> ('e var * 'x) list
val mem : 'e var -> ('e, 'x) t -> bool
val union :
('e var -> 'x -> 'x -> 'x option) -> ('e, 'x) t -> ('e, 'x) t -> ('e, 'x) t
val fold : ('e var -> 'x -> 'acc -> 'acc) -> ('e, 'x) t -> 'acc -> 'acc
end

View File

@ -21,27 +21,28 @@ open Ast
(** {1 Helpers and type definitions}*)
type vc_return = typed marked_expr * typ Marked.pos VarMap.t
type vc_return = typed marked_expr * (typed expr, typ Marked.pos) Var.Map.t
(** The return type of VC generators is the VC expression plus the types of any
locally free variable inside that expression. *)
type ctx = {
current_scope_name : ScopeName.t;
decl : decl_ctx;
input_vars : Var.t list;
scope_variables_typs : typ Marked.pos VarMap.t;
input_vars : typed var list;
scope_variables_typs : (typed expr, typ Marked.pos) Var.Map.t;
}
let conjunction (args : vc_return list) (mark : typed mark) : vc_return =
let acc, list =
match args with
| hd :: tl -> hd, tl
| [] -> ((ELit (LBool true), mark), VarMap.empty), []
| [] -> ((ELit (LBool true), mark), Var.Map.empty), []
in
List.fold_left
(fun (acc, acc_ty) (arg, arg_ty) ->
( (EApp ((EOp (Binop And), mark), [arg; acc]), mark),
VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty ))
Var.Map.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty
))
acc list
let negation ((arg, arg_ty) : vc_return) (mark : typed mark) : vc_return =
@ -51,12 +52,13 @@ let disjunction (args : vc_return list) (mark : typed mark) : vc_return =
let acc, list =
match args with
| hd :: tl -> hd, tl
| [] -> ((ELit (LBool false), mark), VarMap.empty), []
| [] -> ((ELit (LBool false), mark), Var.Map.empty), []
in
List.fold_left
(fun ((acc, acc_ty) : vc_return) (arg, arg_ty) ->
( (EApp ((EOp (Binop Or), mark), [arg; acc]), mark),
VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty ))
Var.Map.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty
))
acc list
(** [half_product \[a1,...,an\] \[b1,...,bm\] returns \[(a1,b1),...(a1,bn),...(an,b1),...(an,bm)\]] *)
@ -80,7 +82,7 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : typed marked_expr)
(ELit (LBool true), _),
cons ),
_ )
when List.exists (fun x' -> Var.eq (Var.t x) x') ctx.input_vars ->
when List.exists (fun x' -> Var.eq x x') ctx.input_vars ->
(* scope variables*)
cons
| EAbs (binder, [(TLit TUnit, _)]) ->
@ -130,7 +132,7 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : typed marked_expr) :
in
( vc_body_expr,
List.fold_left
(fun acc (var, ty) -> VarMap.add (Var.t var) ty acc)
(fun acc (var, ty) -> Var.Map.add var ty acc)
vc_body_ty
(List.map2 (fun x y -> x, y) (Array.to_list vars) typs) )
| EApp (f, args) ->
@ -147,18 +149,18 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : typed marked_expr) :
[
e1_vc, vc_typ1;
( (EIfThenElse (e1, e2_vc, e3_vc), Marked.get_mark e),
VarMap.union
Var.Map.union
(fun _ _ _ -> failwith "should not happen")
vc_typ2 vc_typ3 );
]
(Marked.get_mark e)
| ELit LEmptyError ->
Marked.same_mark_as (ELit (LBool false)) e, VarMap.empty
Marked.same_mark_as (ELit (LBool false)) e, Var.Map.empty
| EVar _
(* Per default calculus semantics, you cannot call a function with an argument
that evaluates to the empty error. Thus, all variable evaluate to non-empty-error terms. *)
| ELit _ | EOp _ ->
Marked.same_mark_as (ELit (LBool true)) e, VarMap.empty
Marked.same_mark_as (ELit (LBool true)) e, Var.Map.empty
| EDefault (exceptions, just, cons) ->
(* <e1 ... en | ejust :- econs > never returns empty if and only if:
- first we look if e1 .. en ejust can return empty;
@ -223,7 +225,7 @@ let rec generate_vs_must_not_return_confict (ctx : ctx) (e : typed marked_expr)
in
( vc_body_expr,
List.fold_left
(fun acc (var, ty) -> VarMap.add (Var.t var) ty acc)
(fun acc (var, ty) -> Var.Map.add var ty acc)
vc_body_ty
(List.map2 (fun x y -> x, y) (Array.to_list vars) typs) )
| EApp (f, args) ->
@ -238,13 +240,13 @@ let rec generate_vs_must_not_return_confict (ctx : ctx) (e : typed marked_expr)
[
e1_vc, vc_typ1;
( (EIfThenElse (e1, e2_vc, e3_vc), Marked.get_mark e),
VarMap.union
Var.Map.union
(fun _ _ _ -> failwith "should not happen")
vc_typ2 vc_typ3 );
]
(Marked.get_mark e)
| EVar _ | ELit _ | EOp _ ->
Marked.same_mark_as (ELit (LBool true)) e, VarMap.empty
Marked.same_mark_as (ELit (LBool true)) e, Var.Map.empty
| EDefault (exceptions, just, cons) ->
(* <e1 ... en | ejust :- econs > never returns conflict if and only if:
- neither e1 nor ... nor en nor ejust nor econs return conflict
@ -284,8 +286,8 @@ type verification_condition = {
(* should have type bool *)
vc_kind : verification_condition_kind;
vc_scope : ScopeName.t;
vc_variable : Var.t Marked.pos;
vc_free_vars_typ : typ Marked.pos VarMap.t;
vc_variable : typed var Marked.pos;
vc_free_vars_typ : (typed expr, typ Marked.pos) Var.Map.t;
}
let rec generate_verification_conditions_scope_body_expr
@ -301,7 +303,7 @@ let rec generate_verification_conditions_scope_body_expr
let new_ctx, vc_list =
match scope_let.scope_let_kind with
| DestructuringInputStruct ->
{ ctx with input_vars = Var.t scope_let_var :: ctx.input_vars }, []
{ ctx with input_vars = scope_let_var :: ctx.input_vars }, []
| ScopeVarDefinition | SubScopeVarDefinition ->
(* For scope variables, we should check both that they never evaluate to
emptyError nor conflictError. But for subscope variable definitions,
@ -324,11 +326,11 @@ let rec generate_verification_conditions_scope_body_expr
vc_guard = Marked.same_mark_as (Marked.unmark vc_confl) e;
vc_kind = NoOverlappingExceptions;
vc_free_vars_typ =
VarMap.union
Var.Map.union
(fun _ _ -> failwith "should not happen")
ctx.scope_variables_typs vc_confl_typs;
vc_scope = ctx.current_scope_name;
vc_variable = Var.t scope_let_var, scope_let.scope_let_pos;
vc_variable = scope_let_var, scope_let.scope_let_pos;
};
]
in
@ -347,11 +349,11 @@ let rec generate_verification_conditions_scope_body_expr
vc_guard = Marked.same_mark_as (Marked.unmark vc_empty) e;
vc_kind = NoEmptyError;
vc_free_vars_typ =
VarMap.union
Var.Map.union
(fun _ _ -> failwith "should not happen")
ctx.scope_variables_typs vc_empty_typs;
vc_scope = ctx.current_scope_name;
vc_variable = Var.t scope_let_var, scope_let.scope_let_pos;
vc_variable = scope_let_var, scope_let.scope_let_pos;
}
:: vc_list
| _ -> vc_list
@ -364,7 +366,7 @@ let rec generate_verification_conditions_scope_body_expr
{
new_ctx with
scope_variables_typs =
VarMap.add (Var.t scope_let_var) scope_let.scope_let_typ
Var.Map.add scope_let_var scope_let.scope_let_typ
new_ctx.scope_variables_typs;
}
scope_let_next
@ -396,7 +398,7 @@ let rec generate_verification_conditions_scopes
decl = decl_ctx;
input_vars = [];
scope_variables_typs =
VarMap.empty
Var.Map.empty
(* We don't need to add the typ of the scope input var here
because it will never appear in an expression for which we
generate a verification conditions (the big struct is
@ -423,7 +425,7 @@ let generate_verification_conditions
let to_str vc =
Format.asprintf "%s.%s"
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable)))
(Bindlib.name_of (Marked.unmark vc.vc_variable))
in
String.compare (to_str vc1) (to_str vc2))
vcs

View File

@ -17,6 +17,8 @@
(** Generates verification conditions from scope definitions *)
open Utils
type verification_condition_kind =
| NoEmptyError
(** This verification condition checks whether a definition never returns
@ -30,8 +32,9 @@ type verification_condition = {
(** This expression should have type [bool]*)
vc_kind : verification_condition_kind;
vc_scope : Dcalc.Ast.ScopeName.t;
vc_variable : Dcalc.Ast.Var.t Utils.Marked.pos;
vc_free_vars_typ : Dcalc.Ast.typ Utils.Marked.pos Dcalc.Ast.VarMap.t;
vc_variable : Astgen.typed Dcalc.Ast.var Marked.pos;
vc_free_vars_typ :
(Astgen.typed Dcalc.Ast.expr, Dcalc.Ast.typ Marked.pos) Var.Map.t;
(** Types of the locally free variables in [vc_guard]. The types of other
free variables linked to scope variables can be obtained with
[Dcalc.Ast.variable_types]. *)

View File

@ -23,7 +23,8 @@ module type Backend = sig
type backend_context
val make_context : decl_ctx -> typ Marked.pos VarMap.t -> backend_context
val make_context :
decl_ctx -> (typed expr, typ Marked.pos) Var.Map.t -> backend_context
type vc_encoding
@ -37,7 +38,9 @@ module type Backend = sig
val is_model_empty : model -> bool
val translate_expr :
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
backend_context ->
Astgen.typed Dcalc.Ast.marked_expr ->
backend_context * vc_encoding
end
module type BackendIO = sig
@ -45,12 +48,15 @@ module type BackendIO = sig
type backend_context
val make_context : decl_ctx -> typ Marked.pos VarMap.t -> backend_context
val make_context :
decl_ctx -> (Astgen.typed expr, typ Marked.pos) Var.Map.t -> backend_context
type vc_encoding
val translate_expr :
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
backend_context ->
Astgen.typed Dcalc.Ast.marked_expr ->
backend_context * vc_encoding
type model
@ -95,12 +101,12 @@ module MakeBackendIO (B : Backend) = struct
Format.asprintf "%s This variable never returns an empty error"
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
| Conditions.NoOverlappingExceptions ->
Format.asprintf "%s No two exceptions to ever overlap for this variable"
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
let print_negative_result
(vc : Conditions.verification_condition)
@ -112,14 +118,14 @@ module MakeBackendIO (B : Backend) = struct
Format.asprintf "%s This variable might return an empty error:\n%s"
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
(Pos.retrieve_loc_text (Marked.get_mark vc.vc_variable))
| Conditions.NoOverlappingExceptions ->
Format.asprintf
"%s At least two exceptions overlap for this variable:\n%s"
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
(Pos.retrieve_loc_text (Marked.get_mark vc.vc_variable))
in
let counterexample : string option =
@ -178,6 +184,6 @@ module MakeBackendIO (B : Backend) = struct
Cli.error_print "%s The translation to Z3 failed:\n%s"
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
msg
end

View File

@ -17,6 +17,8 @@
(** Common code for handling the IO of all proof backends supported *)
open Utils
module type Backend = sig
val init_backend : unit -> unit
@ -24,7 +26,7 @@ module type Backend = sig
val make_context :
Dcalc.Ast.decl_ctx ->
Dcalc.Ast.typ Utils.Marked.pos Dcalc.Ast.VarMap.t ->
(Astgen.typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t ->
backend_context
type vc_encoding
@ -39,7 +41,9 @@ module type Backend = sig
val is_model_empty : model -> bool
val translate_expr :
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
backend_context ->
Astgen.typed Dcalc.Ast.marked_expr ->
backend_context * vc_encoding
end
module type BackendIO = sig
@ -49,13 +53,15 @@ module type BackendIO = sig
val make_context :
Dcalc.Ast.decl_ctx ->
Dcalc.Ast.typ Utils.Marked.pos Dcalc.Ast.VarMap.t ->
(Astgen.typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t ->
backend_context
type vc_encoding
val translate_expr :
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
backend_context ->
Astgen.typed Dcalc.Ast.marked_expr ->
backend_context * vc_encoding
type model

View File

@ -26,20 +26,20 @@ type context = {
ctx_decl : decl_ctx;
(* The declaration context from the Catala program, containing information to
precisely pretty print Catala expressions *)
ctx_var : typ Marked.pos VarMap.t;
ctx_var : (typed expr, typ Marked.pos) Var.Map.t;
(* A map from Catala variables to their types, needed to create Z3 expressions
of the right sort *)
ctx_funcdecl : FuncDecl.func_decl VarMap.t;
ctx_funcdecl : (typed expr, FuncDecl.func_decl) Var.Map.t;
(* A map from Catala function names (represented as variables) to Z3 function
declarations, used to only define once functions in Z3 queries *)
ctx_z3vars : Var.t StringMap.t;
ctx_z3vars : typed var StringMap.t;
(* A map from strings, corresponding to Z3 symbol names, to the Catala
variable they represent. Used when to pretty-print Z3 models when a
counterexample is generated *)
ctx_z3datatypes : Sort.sort EnumMap.t;
(* A map from Catala enumeration names to the corresponding Z3 sort, from
which we can retrieve constructors and accessors *)
ctx_z3matchsubsts : Expr.expr VarMap.t;
ctx_z3matchsubsts : (typed expr, Expr.expr) Var.Map.t;
(* A map from Catala temporary variables, generated when translating a match,
to the corresponding enum accessor call as a Z3 expression *)
ctx_z3structs : Sort.sort StructMap.t;
@ -64,13 +64,13 @@ type context = {
(** [add_funcdecl] adds the mapping between the Catala variable [v] and the Z3
function declaration [fd] to the context **)
let add_funcdecl (v : Var.t) (fd : FuncDecl.func_decl) (ctx : context) : context
=
{ ctx with ctx_funcdecl = VarMap.add v fd ctx.ctx_funcdecl }
let add_funcdecl (v : typed var) (fd : FuncDecl.func_decl) (ctx : context) :
context =
{ ctx with ctx_funcdecl = Var.Map.add v fd ctx.ctx_funcdecl }
(** [add_z3var] adds the mapping between [name] and the Catala variable [v] to
the context **)
let add_z3var (name : string) (v : Var.t) (ctx : context) : context =
let add_z3var (name : string) (v : typed var) (ctx : context) : context =
{ ctx with ctx_z3vars = StringMap.add name v ctx.ctx_z3vars }
(** [add_z3enum] adds the mapping between the Catala enumeration [enum] and the
@ -81,8 +81,8 @@ let add_z3enum (enum : EnumName.t) (sort : Sort.sort) (ctx : context) : context
(** [add_z3var] adds the mapping between temporary variable [v] and the Z3
expression [e] representing an accessor application to the context **)
let add_z3matchsubst (v : Var.t) (e : Expr.expr) (ctx : context) : context =
{ ctx with ctx_z3matchsubsts = VarMap.add v e ctx.ctx_z3matchsubsts }
let add_z3matchsubst (v : typed var) (e : Expr.expr) (ctx : context) : context =
{ ctx with ctx_z3matchsubsts = Var.Map.add v e ctx.ctx_z3matchsubsts }
(** [add_z3struct] adds the mapping between the Catala struct [s] and the
corresponding Z3 datatype [sort] to the context **)
@ -223,9 +223,8 @@ let print_model (ctx : context) (model : Model.model) : string =
let v = StringMap.find symbol_name ctx.ctx_z3vars in
Format.fprintf fmt "%s %s : %s"
(Cli.with_style [ANSITerminal.blue] "%s" "-->")
(Cli.with_style [ANSITerminal.yellow] "%s"
(Bindlib.name_of (Var.get v)))
(print_z3model_expr ctx (VarMap.find v ctx.ctx_var) e)
(Cli.with_style [ANSITerminal.yellow] "%s" (Bindlib.name_of v))
(print_z3model_expr ctx (Var.Map.find v ctx.ctx_var) e)
else
(* Declaration d is a function *)
match Model.get_func_interp model d with
@ -239,8 +238,7 @@ let print_model (ctx : context) (model : Model.model) : string =
let v = StringMap.find symbol_name ctx.ctx_z3vars in
Format.fprintf fmt "%s %s : %s"
(Cli.with_style [ANSITerminal.blue] "%s" "-->")
(Cli.with_style [ANSITerminal.yellow] "%s"
(Bindlib.name_of (Var.get v)))
(Cli.with_style [ANSITerminal.yellow] "%s" (Bindlib.name_of v))
(* TODO: Model of a Z3 function should be pretty-printed *)
(Model.FuncInterp.to_string f)))
decls
@ -387,18 +385,18 @@ let translate_lit (ctx : context) (l : lit) : Expr.expr =
corresponding to the variable [v]. If no such function declaration exists
yet, we construct it and add it to the context, thus requiring to return a
new context *)
let find_or_create_funcdecl (ctx : context) (v : Var.t) :
let find_or_create_funcdecl (ctx : context) (v : typed var) :
context * FuncDecl.func_decl =
match VarMap.find_opt v ctx.ctx_funcdecl with
match Var.Map.find_opt v ctx.ctx_funcdecl with
| Some fd -> ctx, fd
| None -> (
(* Retrieves the Catala type of the function [v] *)
let f_ty = VarMap.find v ctx.ctx_var in
let f_ty = Var.Map.find v ctx.ctx_var in
match Marked.unmark f_ty with
| TArrow (t1, t2) ->
let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
let name = unique_name (Var.get v) in
let name = unique_name v in
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in
let ctx = add_funcdecl v fd ctx in
let ctx = add_z3var name v ctx in
@ -631,7 +629,7 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
match Marked.unmark e with
| EAbs (e, _) ->
(* Create a fresh Catala variable to substitue and obtain the body *)
let fresh_v = new_var "arm!tmp" in
let fresh_v = Var.make "arm!tmp" in
let fresh_e = EVar fresh_v in
(* Invariant: Catala enums always have exactly one argument *)
@ -639,7 +637,7 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
let proj = Expr.mk_app ctx.ctx_z3 accessor [head] in
(* The fresh variable should be substituted by a projection into the enum
in the body, we add this to the context *)
let ctx = add_z3matchsubst (Var.t fresh_v) proj ctx in
let ctx = add_z3matchsubst fresh_v proj ctx in
let body = Bindlib.msubst e [| fresh_e |] in
translate_expr ctx body
@ -649,12 +647,12 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
match Marked.unmark vc with
| EVar v -> (
match VarMap.find_opt (Var.t v) ctx.ctx_z3matchsubsts with
match Var.Map.find_opt v ctx.ctx_z3matchsubsts with
| None ->
(* We are in the standard case, where this is a true Catala variable *)
let t = VarMap.find (Var.t v) ctx.ctx_var in
let t = Var.Map.find v ctx.ctx_var in
let name = unique_name v in
let ctx = add_z3var name (Var.t v) ctx in
let ctx = add_z3var name v ctx in
let ctx, ty = translate_typ ctx (Marked.unmark t) in
let z3_var = Expr.mk_const_s ctx.ctx_z3 name ty in
let ctx =
@ -726,7 +724,7 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
match Marked.unmark head with
| EOp op -> translate_op ctx op args
| EVar v ->
let ctx, fd = find_or_create_funcdecl ctx (Var.t v) in
let ctx, fd = find_or_create_funcdecl ctx v in
(* Fold_right to preserve the order of the arguments: The head argument is
appended at the head *)
let ctx, z3_args =
@ -804,7 +802,8 @@ module Backend = struct
let make_context
(decl_ctx : decl_ctx)
(free_vars_typ : typ Marked.pos VarMap.t) : backend_context =
(free_vars_typ : (typed expr, typ Marked.pos) Var.Map.t) : backend_context
=
let cfg =
(if !Cli.disable_counterexamples then [] else ["model", "true"])
@ ["proof", "false"]
@ -815,10 +814,10 @@ module Backend = struct
ctx_z3 = z3_ctx;
ctx_decl = decl_ctx;
ctx_var = free_vars_typ;
ctx_funcdecl = VarMap.empty;
ctx_funcdecl = Var.Map.empty;
ctx_z3vars = StringMap.empty;
ctx_z3datatypes = EnumMap.empty;
ctx_z3matchsubsts = VarMap.empty;
ctx_z3matchsubsts = Var.Map.empty;
ctx_z3structs = StructMap.empty;
ctx_z3unit = z3unit;
ctx_z3constraints = [];

View File

@ -1,7 +1,7 @@
let Foo =
λ (Foo_in_28: Foo_in{}) →
let bar_29 : integer =
λ (Foo_in_27: Foo_in{}) →
let bar_28 : integer =
try
handle_default_0 [] (λ (__30: any) → true) (λ (__31: any) → 0)
handle_default_0 [] (λ (__29: any) → true) (λ (__30: any) → 0)
with EmptyError -> raise NoValueProvided in
Foo_out {"bar_out": bar_29}
Foo_out {"bar_out": bar_28}