mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-12 21:48:25 +03:00
Merge pull request #315 from AltGr/ast-factorisation
Factorise ASTs (between dcalc and lcalc)
This commit is contained in:
commit
efa7cec4c1
@ -1,7 +1,7 @@
|
||||
# Reformatting commits to be skipped when running 'git blame'
|
||||
# Use `git config --global blame.ignoreRevsFile .git-blame-ignore-revs` to use it
|
||||
# Add new reformatting commits at the top
|
||||
99b6fc33b508c879f669172005b6c359d7d4f596
|
||||
ba620fca280338139e015e316894a7cf49c450d5
|
||||
|
||||
7485c7f2ce726f59f1ec66ddfe1d3f7d640201d8
|
||||
|
||||
|
@ -15,240 +15,23 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
[@@@ocaml.warning "-7-34"]
|
||||
|
||||
open Utils
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
include Astgen
|
||||
include Astgen_utils
|
||||
|
||||
module ScopeName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
type lit = dcalc glit
|
||||
|
||||
module StructName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
|
||||
|
||||
module EnumName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
|
||||
|
||||
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
|
||||
|
||||
type marked_typ = typ Marked.pos
|
||||
|
||||
and typ =
|
||||
| TLit of typ_lit
|
||||
| TTuple of marked_typ list * StructName.t option
|
||||
| TEnum of marked_typ list * EnumName.t
|
||||
| TArrow of marked_typ * marked_typ
|
||||
| TArray of marked_typ
|
||||
| TAny
|
||||
|
||||
type date = Runtime.date
|
||||
type duration = Runtime.duration
|
||||
type integer = Runtime.integer
|
||||
type decimal = Runtime.decimal
|
||||
type money = Runtime.money
|
||||
|
||||
type lit =
|
||||
| LBool of bool
|
||||
| LEmptyError
|
||||
| LInt of integer
|
||||
| LRat of decimal
|
||||
| LMoney of money
|
||||
| LUnit
|
||||
| LDate of date
|
||||
| LDuration of duration
|
||||
|
||||
type op_kind = KInt | KRat | KMoney | KDate | KDuration
|
||||
type ternop = Fold
|
||||
|
||||
type binop =
|
||||
| And
|
||||
| Or
|
||||
| Xor
|
||||
| Add of op_kind
|
||||
| Sub of op_kind
|
||||
| Mult of op_kind
|
||||
| Div of op_kind
|
||||
| Lt of op_kind
|
||||
| Lte of op_kind
|
||||
| Gt of op_kind
|
||||
| Gte of op_kind
|
||||
| Eq
|
||||
| Neq
|
||||
| Map
|
||||
| Concat
|
||||
| Filter
|
||||
|
||||
type log_entry = VarDef of typ | BeginCall | EndCall | PosRecordIfTrueBool
|
||||
|
||||
type unop =
|
||||
| Not
|
||||
| Minus of op_kind
|
||||
| Log of log_entry * Utils.Uid.MarkedString.info list
|
||||
| Length
|
||||
| IntToRat
|
||||
| MoneyToRat
|
||||
| RatToMoney
|
||||
| GetDay
|
||||
| GetMonth
|
||||
| GetYear
|
||||
| FirstDayOfMonth
|
||||
| LastDayOfMonth
|
||||
| RoundMoney
|
||||
| RoundDecimal
|
||||
|
||||
type operator = Ternop of ternop | Binop of binop | Unop of unop
|
||||
|
||||
(** Some structures used for type inference *)
|
||||
module Infer = struct
|
||||
module Any =
|
||||
Utils.Uid.Make
|
||||
(struct
|
||||
type info = unit
|
||||
|
||||
let format_info fmt () = Format.fprintf fmt "any"
|
||||
end)
|
||||
()
|
||||
|
||||
type unionfind_typ = typ Marked.pos UnionFind.elem
|
||||
(** We do not reuse {!type: Dcalc.Ast.typ} because we have to include a new
|
||||
[TAny] variant. Indeed, error terms can have any type and this has to be
|
||||
captured by the type sytem. *)
|
||||
|
||||
and typ =
|
||||
| TLit of typ_lit
|
||||
| TArrow of unionfind_typ * unionfind_typ
|
||||
| TTuple of unionfind_typ list * StructName.t option
|
||||
| TEnum of unionfind_typ list * EnumName.t
|
||||
| TArray of unionfind_typ
|
||||
| TAny of Any.t
|
||||
|
||||
let rec typ_to_ast (ty : unionfind_typ) : marked_typ =
|
||||
let ty, pos = UnionFind.get (UnionFind.find ty) in
|
||||
match ty with
|
||||
| TLit l -> TLit l, pos
|
||||
| TTuple (ts, s) -> TTuple (List.map typ_to_ast ts, s), pos
|
||||
| TEnum (ts, e) -> TEnum (List.map typ_to_ast ts, e), pos
|
||||
| TArrow (t1, t2) -> TArrow (typ_to_ast t1, typ_to_ast t2), pos
|
||||
| TAny _ -> TAny, pos
|
||||
| TArray t1 -> TArray (typ_to_ast t1), pos
|
||||
|
||||
let rec ast_to_typ (ty : marked_typ) : unionfind_typ =
|
||||
let ty' =
|
||||
match Marked.unmark ty with
|
||||
| TLit l -> TLit l
|
||||
| TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
|
||||
| TTuple (ts, s) -> TTuple (List.map (fun t -> ast_to_typ t) ts, s)
|
||||
| TEnum (ts, e) -> TEnum (List.map (fun t -> ast_to_typ t) ts, e)
|
||||
| TArray t -> TArray (ast_to_typ t)
|
||||
| TAny -> TAny (Any.fresh ())
|
||||
in
|
||||
UnionFind.make (Marked.same_mark_as ty' ty)
|
||||
end
|
||||
|
||||
type untyped = { pos : Pos.t } [@@ocaml.unboxed]
|
||||
type typed = { pos : Pos.t; ty : marked_typ }
|
||||
type inferring = { pos : Pos.t; uf : Infer.unionfind_typ }
|
||||
|
||||
(** The generic type of AST markings. Using a GADT allows functions to be
|
||||
polymorphic in the marking, but still do transformations on types when
|
||||
appropriate *)
|
||||
type _ mark =
|
||||
| Untyped : untyped -> untyped mark
|
||||
| Typed : typed -> typed mark
|
||||
| Inferring : inferring -> inferring mark
|
||||
|
||||
type ('a, 'm) marked = ('a, 'm mark) Marked.t
|
||||
|
||||
type 'm marked_expr = ('m expr, 'm) marked
|
||||
|
||||
and 'm expr =
|
||||
| EVar of 'm expr Bindlib.var
|
||||
| ETuple of 'm marked_expr list * StructName.t option
|
||||
| ETupleAccess of
|
||||
'm marked_expr * int * StructName.t option * typ Marked.pos list
|
||||
| EInj of 'm marked_expr * int * EnumName.t * typ Marked.pos list
|
||||
| EMatch of 'm marked_expr * 'm marked_expr list * EnumName.t
|
||||
| EArray of 'm marked_expr list
|
||||
| ELit of lit
|
||||
| EAbs of
|
||||
(('m expr, 'm marked_expr) Bindlib.mbinder[@opaque]) * typ Marked.pos list
|
||||
| EApp of 'm marked_expr * 'm marked_expr list
|
||||
| EAssert of 'm marked_expr
|
||||
| EOp of operator
|
||||
| EDefault of 'm marked_expr list * 'm marked_expr * 'm marked_expr
|
||||
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
|
||||
| ErrorOnEmpty of 'm marked_expr
|
||||
|
||||
type typed_expr = typed marked_expr
|
||||
type struct_ctx = (StructFieldName.t * typ Marked.pos) list StructMap.t
|
||||
type enum_ctx = (EnumConstructor.t * typ Marked.pos) list EnumMap.t
|
||||
type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx }
|
||||
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder
|
||||
|
||||
type scope_let_kind =
|
||||
| DestructuringInputStruct
|
||||
| ScopeVarDefinition
|
||||
| SubScopeVarDefinition
|
||||
| CallingSubScope
|
||||
| DestructuringSubScopeResults
|
||||
| Assertion
|
||||
|
||||
type ('expr, 'm) scope_let = {
|
||||
scope_let_kind : scope_let_kind;
|
||||
scope_let_typ : typ Marked.pos;
|
||||
scope_let_expr : ('expr, 'm) marked;
|
||||
scope_let_next : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
|
||||
scope_let_pos : Pos.t;
|
||||
}
|
||||
|
||||
and ('expr, 'm) scope_body_expr =
|
||||
| Result of ('expr, 'm) marked
|
||||
| ScopeLet of ('expr, 'm) scope_let
|
||||
|
||||
type ('expr, 'm) scope_body = {
|
||||
scope_body_input_struct : StructName.t;
|
||||
scope_body_output_struct : StructName.t;
|
||||
scope_body_expr : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
|
||||
}
|
||||
|
||||
type ('expr, 'm) scope_def = {
|
||||
scope_name : ScopeName.t;
|
||||
scope_body : ('expr, 'm) scope_body;
|
||||
scope_next : ('expr, ('expr, 'm) scopes) Bindlib.binder;
|
||||
}
|
||||
|
||||
and ('expr, 'm) scopes = Nil | ScopeDef of ('expr, 'm) scope_def
|
||||
|
||||
type ('expr, 'm) program_generic = {
|
||||
decl_ctx : decl_ctx;
|
||||
scopes : ('expr, 'm) scopes;
|
||||
}
|
||||
type 'm expr = (dcalc, 'm mark) gexpr
|
||||
and 'm marked_expr = (dcalc, 'm mark) marked_gexpr
|
||||
|
||||
type 'm program = ('m expr, 'm) program_generic
|
||||
|
||||
let no_mark (type m) : m mark -> m mark = function
|
||||
| Untyped _ -> Untyped { pos = Pos.no_pos }
|
||||
| Typed _ -> Typed { pos = Pos.no_pos; ty = Marked.mark Pos.no_pos TAny }
|
||||
| Inferring _ ->
|
||||
Inferring
|
||||
{
|
||||
pos = Pos.no_pos;
|
||||
uf = UnionFind.make Infer.(TAny (Any.fresh ()), Pos.no_pos);
|
||||
}
|
||||
|
||||
let mark_pos (type m) (m : m mark) : Pos.t =
|
||||
match m with
|
||||
| Untyped { pos } | Typed { pos; _ } | Inferring { pos; _ } -> pos
|
||||
match m with Untyped { pos } | Typed { pos; _ } -> pos
|
||||
|
||||
let pos (type m) (x : ('a, m) marked) : Pos.t = mark_pos (Marked.get_mark x)
|
||||
let ty (_, m) : marked_typ = match m with Typed { ty; _ } -> ty
|
||||
@ -257,78 +40,11 @@ let with_ty (type m) (ty : marked_typ) (x : ('a, m) marked) : ('a, typed) marked
|
||||
=
|
||||
Marked.mark
|
||||
(match Marked.get_mark x with
|
||||
| Untyped { pos } | Inferring { pos; _ } -> Typed { pos; ty }
|
||||
| Untyped { pos } -> Typed { pos; ty }
|
||||
| Typed m -> Typed { m with ty })
|
||||
(Marked.unmark x)
|
||||
|
||||
let evar v mark = Bindlib.box_apply (Marked.mark mark) (Bindlib.box_var v)
|
||||
|
||||
let etuple args s mark =
|
||||
Bindlib.box_apply (fun args -> ETuple (args, s), mark) (Bindlib.box_list args)
|
||||
|
||||
let etupleaccess e1 i s typs mark =
|
||||
Bindlib.box_apply (fun e1 -> ETupleAccess (e1, i, s, typs), mark) e1
|
||||
|
||||
let einj e1 i e_name typs mark =
|
||||
Bindlib.box_apply (fun e1 -> EInj (e1, i, e_name, typs), mark) e1
|
||||
|
||||
let ematch arg arms e_name mark =
|
||||
Bindlib.box_apply2
|
||||
(fun arg arms -> EMatch (arg, arms, e_name), mark)
|
||||
arg (Bindlib.box_list arms)
|
||||
|
||||
let earray args mark =
|
||||
Bindlib.box_apply (fun args -> EArray args, mark) (Bindlib.box_list args)
|
||||
|
||||
let elit l mark = Bindlib.box (ELit l, mark)
|
||||
|
||||
let eabs binder typs mark =
|
||||
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) binder
|
||||
|
||||
let eapp e1 args mark =
|
||||
Bindlib.box_apply2
|
||||
(fun e1 args -> EApp (e1, args), mark)
|
||||
e1 (Bindlib.box_list args)
|
||||
|
||||
let eassert e1 mark = Bindlib.box_apply (fun e1 -> EAssert e1, mark) e1
|
||||
let eop op mark = Bindlib.box (EOp op, mark)
|
||||
|
||||
let edefault excepts just cons mark =
|
||||
Bindlib.box_apply3
|
||||
(fun excepts just cons -> EDefault (excepts, just, cons), mark)
|
||||
(Bindlib.box_list excepts) just cons
|
||||
|
||||
let eifthenelse e1 e2 e3 mark =
|
||||
Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), mark) e1 e2 e3
|
||||
|
||||
let eerroronempty e1 mark =
|
||||
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1
|
||||
|
||||
let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
|
||||
let map_expr ctx ~f e =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EVar v -> evar (translate_var v) m
|
||||
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m
|
||||
| EAbs (binder, typs) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m
|
||||
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
|
||||
| ETupleAccess (e1, n, s_name, typs) ->
|
||||
etupleaccess ((f ctx) e1) n s_name typs m
|
||||
| EInj (e1, i, e_name, typs) -> einj ((f ctx) e1) i e_name typs m
|
||||
| EMatch (arg, arms, e_name) ->
|
||||
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
|
||||
| EArray args -> earray (List.map (f ctx) args) m
|
||||
| ELit l -> elit l m
|
||||
| EAssert e1 -> eassert ((f ctx) e1) m
|
||||
| EOp op -> Bindlib.box (EOp op, m)
|
||||
| EDefault (excepts, just, cons) ->
|
||||
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
|
||||
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
|
||||
let map_expr ctx ~f e = Astgen_utils.map_gexpr ctx ~f e
|
||||
|
||||
let rec map_expr_top_down ~f e =
|
||||
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
|
||||
@ -347,79 +63,7 @@ let box_expr : ('m expr, 'm) box_expr_sig =
|
||||
let rec id_t () e = map_expr () ~f:id_t e in
|
||||
id_t () e
|
||||
|
||||
let rec fold_left_scope_lets ~f ~init scope_body_expr =
|
||||
match scope_body_expr with
|
||||
| Result _ -> init
|
||||
| ScopeLet scope_let ->
|
||||
let var, next = Bindlib.unbind scope_let.scope_let_next in
|
||||
fold_left_scope_lets ~f ~init:(f init scope_let var) next
|
||||
|
||||
let rec fold_right_scope_lets ~f ~init scope_body_expr =
|
||||
match scope_body_expr with
|
||||
| Result result -> init result
|
||||
| ScopeLet scope_let ->
|
||||
let var, next = Bindlib.unbind scope_let.scope_let_next in
|
||||
let next_result = fold_right_scope_lets ~f ~init next in
|
||||
f scope_let var next_result
|
||||
|
||||
let map_exprs_in_scope_lets ~f ~varf scope_body_expr =
|
||||
fold_right_scope_lets
|
||||
~f:(fun scope_let var_next acc ->
|
||||
Bindlib.box_apply2
|
||||
(fun scope_let_next scope_let_expr ->
|
||||
ScopeLet { scope_let with scope_let_next; scope_let_expr })
|
||||
(Bindlib.bind_var (varf var_next) acc)
|
||||
(f scope_let.scope_let_expr))
|
||||
~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res))
|
||||
scope_body_expr
|
||||
|
||||
let rec fold_left_scope_defs ~f ~init scopes =
|
||||
match scopes with
|
||||
| Nil -> init
|
||||
| ScopeDef scope_def ->
|
||||
let var, next = Bindlib.unbind scope_def.scope_next in
|
||||
fold_left_scope_defs ~f ~init:(f init scope_def var) next
|
||||
|
||||
let rec fold_right_scope_defs ~f ~init scopes =
|
||||
match scopes with
|
||||
| Nil -> init
|
||||
| ScopeDef scope_def ->
|
||||
let var_next, next = Bindlib.unbind scope_def.scope_next in
|
||||
let result_next = fold_right_scope_defs ~f ~init next in
|
||||
f scope_def var_next result_next
|
||||
|
||||
let map_scope_defs ~f scopes =
|
||||
fold_right_scope_defs
|
||||
~f:(fun scope_def var_next acc ->
|
||||
let new_scope_def = f scope_def in
|
||||
let new_next = Bindlib.bind_var var_next acc in
|
||||
Bindlib.box_apply2
|
||||
(fun new_scope_def new_next ->
|
||||
ScopeDef { new_scope_def with scope_next = new_next })
|
||||
new_scope_def new_next)
|
||||
~init:(Bindlib.box Nil) scopes
|
||||
|
||||
let map_exprs_in_scopes ~f ~varf scopes =
|
||||
fold_right_scope_defs
|
||||
~f:(fun scope_def var_next acc ->
|
||||
let scope_input_var, scope_lets =
|
||||
Bindlib.unbind scope_def.scope_body.scope_body_expr
|
||||
in
|
||||
let new_scope_body_expr = map_exprs_in_scope_lets ~f ~varf scope_lets in
|
||||
let new_scope_body_expr =
|
||||
Bindlib.bind_var (varf scope_input_var) new_scope_body_expr
|
||||
in
|
||||
let new_next = Bindlib.bind_var (varf var_next) acc in
|
||||
Bindlib.box_apply2
|
||||
(fun scope_body_expr scope_next ->
|
||||
ScopeDef
|
||||
{
|
||||
scope_def with
|
||||
scope_body = { scope_def.scope_body with scope_body_expr };
|
||||
scope_next;
|
||||
})
|
||||
new_scope_body_expr new_next)
|
||||
~init:(Bindlib.box Nil) scopes
|
||||
open Astgen_utils
|
||||
|
||||
let untype_program prg =
|
||||
{
|
||||
@ -428,35 +72,17 @@ let untype_program prg =
|
||||
Bindlib.unbox
|
||||
(map_exprs_in_scopes
|
||||
~f:(fun e -> untype_expr e)
|
||||
~varf:translate_var prg.scopes);
|
||||
~varf:Var.translate prg.scopes);
|
||||
}
|
||||
|
||||
type 'm var = 'm expr Bindlib.var
|
||||
type 'm vars = 'm expr Bindlib.mvar
|
||||
type 'm var = 'm expr Var.t
|
||||
type 'm vars = 'm expr Var.vars
|
||||
|
||||
let new_var s = Bindlib.new_var (fun x -> EVar x) s
|
||||
|
||||
module Var = struct
|
||||
type t = V : 'a expr Bindlib.var -> t
|
||||
(* We use this trivial GADT to make the 'm parameter disappear under an
|
||||
existential. It's fine for a use as keys only. (bindlib defines [any_var]
|
||||
similarly but it's not exported) todo: add [@@ocaml.unboxed] once it's
|
||||
possible through abstract types *)
|
||||
|
||||
let t v = V v
|
||||
let get (V v) = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
let compare (V x) (V y) = Bindlib.compare_vars x y
|
||||
let eq (V x) (V y) = Bindlib.eq_vars x y
|
||||
end
|
||||
|
||||
module VarSet = Set.Make (Var)
|
||||
module VarMap = Map.Make (Var)
|
||||
|
||||
let rec free_vars_expr (e : 'm marked_expr) : VarSet.t =
|
||||
let rec free_vars_expr (e : 'm marked_expr) : 'm expr Var.Set.t =
|
||||
match Marked.unmark e with
|
||||
| EVar v -> VarSet.singleton (Var.t v)
|
||||
| EVar v -> Var.Set.singleton v
|
||||
| ETuple (es, _) | EArray es ->
|
||||
es |> List.map free_vars_expr |> List.fold_left VarSet.union VarSet.empty
|
||||
es |> List.map free_vars_expr |> List.fold_left Var.Set.union Var.Set.empty
|
||||
| ETupleAccess (e1, _, _, _)
|
||||
| EAssert e1
|
||||
| ErrorOnEmpty e1
|
||||
@ -465,43 +91,43 @@ let rec free_vars_expr (e : 'm marked_expr) : VarSet.t =
|
||||
| EApp (e1, es) | EMatch (e1, es, _) ->
|
||||
e1 :: es
|
||||
|> List.map free_vars_expr
|
||||
|> List.fold_left VarSet.union VarSet.empty
|
||||
|> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EDefault (es, ejust, econs) ->
|
||||
ejust :: econs :: es
|
||||
|> List.map free_vars_expr
|
||||
|> List.fold_left VarSet.union VarSet.empty
|
||||
| EOp _ | ELit _ -> VarSet.empty
|
||||
|> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EOp _ | ELit _ -> Var.Set.empty
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
[e1; e2; e3]
|
||||
|> List.map free_vars_expr
|
||||
|> List.fold_left VarSet.union VarSet.empty
|
||||
|> List.fold_left Var.Set.union Var.Set.empty
|
||||
| EAbs (binder, _) ->
|
||||
let vs, body = Bindlib.unmbind binder in
|
||||
Array.fold_right VarSet.remove (Array.map Var.t vs) (free_vars_expr body)
|
||||
Array.fold_right Var.Set.remove vs (free_vars_expr body)
|
||||
|
||||
let rec free_vars_scope_body_expr (scope_lets : ('m expr, 'm) scope_body_expr) :
|
||||
VarSet.t =
|
||||
'm expr Var.Set.t =
|
||||
match scope_lets with
|
||||
| Result e -> free_vars_expr e
|
||||
| ScopeLet { scope_let_expr = e; scope_let_next = next; _ } ->
|
||||
let v, body = Bindlib.unbind next in
|
||||
VarSet.union (free_vars_expr e)
|
||||
(VarSet.remove (Var.t v) (free_vars_scope_body_expr body))
|
||||
Var.Set.union (free_vars_expr e)
|
||||
(Var.Set.remove v (free_vars_scope_body_expr body))
|
||||
|
||||
let free_vars_scope_body (scope_body : ('m expr, 'm) scope_body) : VarSet.t =
|
||||
let free_vars_scope_body (scope_body : ('m expr, 'm) scope_body) :
|
||||
'm expr Var.Set.t =
|
||||
let { scope_body_expr = binder; _ } = scope_body in
|
||||
let v, body = Bindlib.unbind binder in
|
||||
VarSet.remove (Var.t v) (free_vars_scope_body_expr body)
|
||||
Var.Set.remove v (free_vars_scope_body_expr body)
|
||||
|
||||
let rec free_vars_scopes (scopes : ('m expr, 'm) scopes) : VarSet.t =
|
||||
let rec free_vars_scopes (scopes : ('m expr, 'm) scopes) : 'm expr Var.Set.t =
|
||||
match scopes with
|
||||
| Nil -> VarSet.empty
|
||||
| Nil -> Var.Set.empty
|
||||
| ScopeDef { scope_body = body; scope_next = next; _ } ->
|
||||
let v, next = Bindlib.unbind next in
|
||||
VarSet.union
|
||||
(VarSet.remove (Var.t v) (free_vars_scopes next))
|
||||
Var.Set.union
|
||||
(Var.Set.remove v (free_vars_scopes next))
|
||||
(free_vars_scope_body body)
|
||||
(* type vars = expr Bindlib.mvar *)
|
||||
|
||||
let make_var ((x, mark) : ('m expr Bindlib.var, 'm) marked) :
|
||||
'm marked_expr Bindlib.box =
|
||||
@ -542,11 +168,6 @@ let map_mark
|
||||
match m with
|
||||
| Untyped { pos } -> Untyped { pos = pos_f pos }
|
||||
| Typed { pos; ty } -> Typed { pos = pos_f pos; ty = ty_f ty }
|
||||
| Inferring { pos; uf } ->
|
||||
Inferring
|
||||
{ pos = pos_f pos; uf = Infer.ast_to_typ (ty_f (Infer.typ_to_ast uf)) }
|
||||
|
||||
let resolve_inferring { uf; pos } = { ty = Infer.typ_to_ast uf; pos }
|
||||
|
||||
let map_mark2
|
||||
(type m)
|
||||
@ -557,13 +178,6 @@ let map_mark2
|
||||
match m1, m2 with
|
||||
| Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos }
|
||||
| Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 }
|
||||
| Inferring m1, Inferring m2 ->
|
||||
Inferring
|
||||
{
|
||||
pos = pos_f m1.pos m2.pos;
|
||||
uf =
|
||||
Infer.ast_to_typ (ty_f (resolve_inferring m1) (resolve_inferring m2));
|
||||
}
|
||||
|
||||
let fold_marks
|
||||
(type m)
|
||||
@ -580,17 +194,9 @@ let fold_marks
|
||||
pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms);
|
||||
ty = ty_f (List.map (function Typed m -> m) ms);
|
||||
}
|
||||
| Inferring _ :: _ ->
|
||||
Inferring
|
||||
{
|
||||
pos = pos_f (List.map (function Inferring { pos; _ } -> pos) ms);
|
||||
uf =
|
||||
Infer.ast_to_typ
|
||||
(ty_f (List.map (function Inferring m -> resolve_inferring m) ms));
|
||||
}
|
||||
|
||||
let empty_thunked_term mark : 'm marked_expr =
|
||||
let silent = new_var "_" in
|
||||
let silent = Var.make "_" in
|
||||
let pos = mark_pos mark in
|
||||
Bindlib.unbox
|
||||
(make_abs [| silent |]
|
||||
|
@ -18,224 +18,32 @@
|
||||
(** Abstract syntax tree of the default calculus intermediate representation *)
|
||||
|
||||
open Utils
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
module ScopeName : Uid.Id with type info = Uid.MarkedString.info
|
||||
module StructName : Uid.Id with type info = Uid.MarkedString.info
|
||||
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info
|
||||
module StructMap : Map.S with type key = StructName.t
|
||||
module EnumName : Uid.Id with type info = Uid.MarkedString.info
|
||||
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info
|
||||
module EnumMap : Map.S with type key = EnumName.t
|
||||
include module type of Astgen
|
||||
include module type of Astgen_utils
|
||||
|
||||
(** Abstract syntax tree for the default calculus *)
|
||||
type lit = dcalc glit
|
||||
|
||||
(** {1 Abstract syntax tree} *)
|
||||
|
||||
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
|
||||
|
||||
type marked_typ = typ Marked.pos
|
||||
|
||||
and typ =
|
||||
| TLit of typ_lit
|
||||
| TTuple of marked_typ list * StructName.t option
|
||||
| TEnum of marked_typ list * EnumName.t
|
||||
| TArrow of marked_typ * marked_typ
|
||||
| TArray of marked_typ
|
||||
| TAny
|
||||
|
||||
type date = Runtime.date
|
||||
type duration = Runtime.duration
|
||||
|
||||
type lit =
|
||||
| LBool of bool
|
||||
| LEmptyError
|
||||
| LInt of Runtime.integer
|
||||
| LRat of Runtime.decimal
|
||||
| LMoney of Runtime.money
|
||||
| LUnit
|
||||
| LDate of date
|
||||
| LDuration of duration
|
||||
|
||||
type op_kind =
|
||||
| KInt
|
||||
| KRat
|
||||
| KMoney
|
||||
| KDate
|
||||
| KDuration (** All ops don't have a KDate and KDuration. *)
|
||||
|
||||
type ternop = Fold
|
||||
|
||||
type binop =
|
||||
| And
|
||||
| Or
|
||||
| Xor
|
||||
| Add of op_kind
|
||||
| Sub of op_kind
|
||||
| Mult of op_kind
|
||||
| Div of op_kind
|
||||
| Lt of op_kind
|
||||
| Lte of op_kind
|
||||
| Gt of op_kind
|
||||
| Gte of op_kind
|
||||
| Eq
|
||||
| Neq
|
||||
| Map
|
||||
| Concat
|
||||
| Filter
|
||||
|
||||
type log_entry =
|
||||
| VarDef of typ
|
||||
(** During code generation, we need to know the type of the variable being
|
||||
logged for embedding *)
|
||||
| BeginCall
|
||||
| EndCall
|
||||
| PosRecordIfTrueBool
|
||||
|
||||
type unop =
|
||||
| Not
|
||||
| Minus of op_kind
|
||||
| Log of log_entry * Utils.Uid.MarkedString.info list
|
||||
| Length
|
||||
| IntToRat
|
||||
| MoneyToRat
|
||||
| RatToMoney
|
||||
| GetDay
|
||||
| GetMonth
|
||||
| GetYear
|
||||
| FirstDayOfMonth
|
||||
| LastDayOfMonth
|
||||
| RoundMoney
|
||||
| RoundDecimal
|
||||
|
||||
type operator = Ternop of ternop | Binop of binop | Unop of unop
|
||||
|
||||
(** Contains some structures used for type inference *)
|
||||
module Infer : sig
|
||||
module Any : Utils.Uid.Id with type info = unit
|
||||
|
||||
type unionfind_typ = typ Marked.pos UnionFind.elem
|
||||
(** We do not reuse {!type: typ} because we have to include a new [TAny]
|
||||
variant. Indeed, error terms can have any type and this has to be captured
|
||||
by the type sytem. *)
|
||||
|
||||
and typ =
|
||||
| TLit of typ_lit
|
||||
| TArrow of unionfind_typ * unionfind_typ
|
||||
| TTuple of unionfind_typ list * StructName.t option
|
||||
| TEnum of unionfind_typ list * EnumName.t
|
||||
| TArray of unionfind_typ
|
||||
| TAny of Any.t
|
||||
|
||||
val typ_to_ast : unionfind_typ -> marked_typ
|
||||
val ast_to_typ : marked_typ -> unionfind_typ
|
||||
end
|
||||
|
||||
type untyped = { pos : Pos.t } [@@unboxed]
|
||||
type typed = { pos : Pos.t; ty : marked_typ }
|
||||
type inferring = { pos : Pos.t; uf : Infer.unionfind_typ }
|
||||
|
||||
(** The generic type of AST markings. Using a GADT allows functions to be
|
||||
polymorphic in the marking, but still do transformations on types when
|
||||
appropriate *)
|
||||
type _ mark =
|
||||
| Untyped : untyped -> untyped mark
|
||||
| Typed : typed -> typed mark
|
||||
| Inferring : inferring -> inferring mark
|
||||
|
||||
type ('a, 'm) marked = ('a, 'm mark) Marked.t
|
||||
|
||||
type 'm marked_expr = ('m expr, 'm) marked
|
||||
|
||||
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
|
||||
library, based on higher-order abstract syntax*)
|
||||
and 'm expr =
|
||||
| EVar of 'm expr Bindlib.var
|
||||
| ETuple of 'm marked_expr list * StructName.t option
|
||||
(** The [MarkedString.info] is the former struct field name*)
|
||||
| ETupleAccess of 'm marked_expr * int * StructName.t option * marked_typ list
|
||||
(** The [MarkedString.info] is the former struct field name *)
|
||||
| EInj of 'm marked_expr * int * EnumName.t * marked_typ list
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| EMatch of 'm marked_expr * 'm marked_expr list * EnumName.t
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| EArray of 'm marked_expr list
|
||||
| ELit of lit
|
||||
| EAbs of
|
||||
(('m expr, 'm marked_expr) Bindlib.mbinder[@opaque]) * marked_typ list
|
||||
| EApp of 'm marked_expr * 'm marked_expr list
|
||||
| EAssert of 'm marked_expr
|
||||
| EOp of operator
|
||||
| EDefault of 'm marked_expr list * 'm marked_expr * 'm marked_expr
|
||||
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
|
||||
| ErrorOnEmpty of 'm marked_expr
|
||||
|
||||
(** {3 Expression annotations ([Marked.t])} *)
|
||||
|
||||
type typed_expr = typed marked_expr
|
||||
type struct_ctx = (StructFieldName.t * marked_typ) list StructMap.t
|
||||
type enum_ctx = (EnumConstructor.t * marked_typ) list EnumMap.t
|
||||
type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx }
|
||||
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder
|
||||
|
||||
(** This kind annotation signals that the let-binding respects a structural
|
||||
invariant. These invariants concern the shape of the expression in the
|
||||
let-binding, and are documented below. *)
|
||||
type scope_let_kind =
|
||||
| DestructuringInputStruct (** [let x = input.field]*)
|
||||
| ScopeVarDefinition (** [let x = error_on_empty e]*)
|
||||
| SubScopeVarDefinition
|
||||
(** [let s.x = fun _ -> e] or [let s.x = error_on_empty e] for input-only
|
||||
subscope variables. *)
|
||||
| CallingSubScope (** [let result = s ({ x = s.x; y = s.x; ...}) ]*)
|
||||
| DestructuringSubScopeResults (** [let s.x = result.x ]**)
|
||||
| Assertion (** [let _ = assert e]*)
|
||||
|
||||
type ('expr, 'm) scope_let = {
|
||||
scope_let_kind : scope_let_kind;
|
||||
scope_let_typ : marked_typ;
|
||||
scope_let_expr : ('expr, 'm) marked;
|
||||
scope_let_next : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
|
||||
scope_let_pos : Pos.t;
|
||||
}
|
||||
(** This type is parametrized by the expression type so it can be reused in
|
||||
later intermediate representations. *)
|
||||
|
||||
(** A scope let-binding has all the information necessary to make a proper
|
||||
let-binding expression, plus an annotation for the kind of the let-binding
|
||||
that comes from the compilation of a {!module: Scopelang.Ast} statement. *)
|
||||
and ('expr, 'm) scope_body_expr =
|
||||
| Result of ('expr, 'm) marked
|
||||
| ScopeLet of ('expr, 'm) scope_let
|
||||
|
||||
type ('expr, 'm) scope_body = {
|
||||
scope_body_input_struct : StructName.t;
|
||||
scope_body_output_struct : StructName.t;
|
||||
scope_body_expr : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
|
||||
}
|
||||
(** Instead of being a single expression, we give a little more ad-hoc structure
|
||||
to the scope body by decomposing it in an ordered list of let-bindings, and
|
||||
a result expression that uses the let-binded variables. The first binder is
|
||||
the argument of type [scope_body_input_struct]. *)
|
||||
|
||||
type ('expr, 'm) scope_def = {
|
||||
scope_name : ScopeName.t;
|
||||
scope_body : ('expr, 'm) scope_body;
|
||||
scope_next : ('expr, ('expr, 'm) scopes) Bindlib.binder;
|
||||
}
|
||||
|
||||
(** Finally, we do the same transformation for the whole program for the kinded
|
||||
lets. This permit us to use bindlib variables for scopes names. *)
|
||||
and ('expr, 'm) scopes = Nil | ScopeDef of ('expr, 'm) scope_def
|
||||
|
||||
type ('expr, 'm) program_generic = {
|
||||
decl_ctx : decl_ctx;
|
||||
scopes : ('expr, 'm) scopes;
|
||||
}
|
||||
type 'm expr = (dcalc, 'm mark) gexpr
|
||||
and 'm marked_expr = (dcalc, 'm mark) marked_gexpr
|
||||
|
||||
type 'm program = ('m expr, 'm) program_generic
|
||||
|
||||
(** {1 Helpers} *)
|
||||
|
||||
(** {2 Variables} *)
|
||||
|
||||
type 'm var = 'm expr Var.t
|
||||
type 'm vars = 'm expr Var.vars
|
||||
|
||||
val free_vars_expr : 'm marked_expr -> 'm expr Var.Set.t
|
||||
|
||||
val free_vars_scope_body_expr :
|
||||
('m expr, 'm) scope_body_expr -> 'm expr Var.Set.t
|
||||
|
||||
val free_vars_scope_body : ('m expr, 'm) scope_body -> 'm expr Var.Set.t
|
||||
val free_vars_scopes : ('m expr, 'm) scopes -> 'm expr Var.Set.t
|
||||
val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box
|
||||
|
||||
(** {2 Manipulation of marks} *)
|
||||
|
||||
val no_mark : 'm mark -> 'm mark
|
||||
@ -379,101 +187,6 @@ val map_expr_top_down :
|
||||
val map_expr_marks :
|
||||
f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box
|
||||
|
||||
val fold_left_scope_lets :
|
||||
f:('a -> ('expr, 'm) scope_let -> 'expr Bindlib.var -> 'a) ->
|
||||
init:'a ->
|
||||
('expr, 'm) scope_body_expr ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_left_scope_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets],
|
||||
where [scope_let_var] is the variable bound to the scope let in the next
|
||||
scope lets to be examined. *)
|
||||
|
||||
val fold_right_scope_lets :
|
||||
f:(('expr1, 'm1) scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) ->
|
||||
init:(('expr1, 'm1) marked -> 'a) ->
|
||||
('expr1, 'm1) scope_body_expr ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_right_scope_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets],
|
||||
where [scope_let_var] is the variable bound to the scope let in the next
|
||||
scope lets to be examined (which are before in the program order). *)
|
||||
|
||||
val map_exprs_in_scope_lets :
|
||||
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
|
||||
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
|
||||
('expr1, 'm1) scope_body_expr ->
|
||||
('expr2, 'm2) scope_body_expr Bindlib.box
|
||||
|
||||
val fold_left_scope_defs :
|
||||
f:('a -> ('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a) ->
|
||||
init:'a ->
|
||||
('expr1, 'm1) scopes ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_left_scope_defs ~f:(fun acc scope_def scope_var -> ...) ~init scope_def],
|
||||
where [scope_var] is the variable bound to the scope in the next scopes to
|
||||
be examined. *)
|
||||
|
||||
val fold_right_scope_defs :
|
||||
f:(('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) ->
|
||||
init:'a ->
|
||||
('expr1, 'm1) scopes ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_right_scope_defs ~f:(fun scope_def scope_var acc -> ...) ~init scope_def],
|
||||
where [scope_var] is the variable bound to the scope in the next scopes to
|
||||
be examined (which are before in the program order). *)
|
||||
|
||||
val map_scope_defs :
|
||||
f:(('expr, 'm) scope_def -> ('expr, 'm) scope_def Bindlib.box) ->
|
||||
('expr, 'm) scopes ->
|
||||
('expr, 'm) scopes Bindlib.box
|
||||
|
||||
val map_exprs_in_scopes :
|
||||
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
|
||||
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
|
||||
('expr1, 'm1) scopes ->
|
||||
('expr2, 'm2) scopes Bindlib.box
|
||||
(** This is the main map visitor for all the expressions inside all the scopes
|
||||
of the program. *)
|
||||
|
||||
(** {2 Variables} *)
|
||||
|
||||
type 'm var = 'm expr Bindlib.var
|
||||
|
||||
val new_var : string -> 'm var
|
||||
|
||||
val translate_var : 'm1 var -> 'm2 var
|
||||
(** used to convert between e.g. [untyped expr var] into a [typed expr var] *)
|
||||
|
||||
module Var : sig
|
||||
type t
|
||||
|
||||
val t : 'm expr Bindlib.var -> t
|
||||
(** Hides the marking type parameter annotation behind an existential type so
|
||||
that variables can be stored in non-polymorphic sets and maps *)
|
||||
|
||||
val get : t -> 'm expr Bindlib.var
|
||||
(** Be careful with this, it breaks the type abstraction by casting the
|
||||
existential type annotation. See [!Bindlib.copy_var] for more detail. *)
|
||||
|
||||
val compare : t -> t -> int
|
||||
val eq : t -> t -> bool
|
||||
end
|
||||
|
||||
module VarMap : Map.S with type key = Var.t
|
||||
module VarSet : Set.S with type elt = Var.t
|
||||
|
||||
val free_vars_expr : 'm marked_expr -> VarSet.t
|
||||
val free_vars_scope_body_expr : ('m expr, 'm) scope_body_expr -> VarSet.t
|
||||
val free_vars_scope_body : ('m expr, 'm) scope_body -> VarSet.t
|
||||
val free_vars_scopes : ('m expr, 'm) scopes -> VarSet.t
|
||||
|
||||
(* type vars = expr Bindlib.mvar *)
|
||||
|
||||
val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box
|
||||
|
||||
(** {2 Boxed term constructors} *)
|
||||
|
||||
type ('e, 'm) make_abs_sig =
|
||||
|
@ -18,7 +18,7 @@ open Utils
|
||||
open Ast
|
||||
|
||||
type partial_evaluation_ctx = {
|
||||
var_values : typed marked_expr Ast.VarMap.t;
|
||||
var_values : (typed expr, typed marked_expr) Var.Map.t;
|
||||
decl_ctx : decl_ctx;
|
||||
}
|
||||
|
||||
@ -184,7 +184,7 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm marked_expr) :
|
||||
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, pos) (rec_helper e1)
|
||||
|
||||
let optimize_expr (decl_ctx : decl_ctx) (e : 'm marked_expr) =
|
||||
partial_evaluation { var_values = VarMap.empty; decl_ctx } e
|
||||
partial_evaluation { var_values = Var.Map.empty; decl_ctx } e
|
||||
|
||||
let rec scope_lets_map
|
||||
(t : 'a -> 'm marked_expr -> 'm marked_expr Bindlib.box)
|
||||
@ -250,6 +250,6 @@ let program_map
|
||||
let optimize_program (p : 'm program) : untyped program =
|
||||
Bindlib.unbox
|
||||
(program_map partial_evaluation
|
||||
{ var_values = VarMap.empty; decl_ctx = p.decl_ctx }
|
||||
{ var_values = Var.Map.empty; decl_ctx = p.decl_ctx }
|
||||
p)
|
||||
|> untype_program
|
||||
|
@ -18,8 +18,51 @@
|
||||
inference using the classical W algorithm with union-find unification. *)
|
||||
|
||||
open Utils
|
||||
module A = Ast
|
||||
open A.Infer
|
||||
module A = Astgen
|
||||
|
||||
module Any =
|
||||
Utils.Uid.Make
|
||||
(struct
|
||||
type info = unit
|
||||
|
||||
let format_info fmt () = Format.fprintf fmt "any"
|
||||
end)
|
||||
()
|
||||
|
||||
type unionfind_typ = typ Marked.pos UnionFind.elem
|
||||
(** We do not reuse {!type: Dcalc.Ast.typ} because we have to include a new
|
||||
[TAny] variant. Indeed, error terms can have any type and this has to be
|
||||
captured by the type sytem. *)
|
||||
|
||||
and typ =
|
||||
| TLit of A.typ_lit
|
||||
| TArrow of unionfind_typ * unionfind_typ
|
||||
| TTuple of unionfind_typ list * A.StructName.t option
|
||||
| TEnum of unionfind_typ list * A.EnumName.t
|
||||
| TArray of unionfind_typ
|
||||
| TAny of Any.t
|
||||
|
||||
let rec typ_to_ast (ty : unionfind_typ) : A.marked_typ =
|
||||
let ty, pos = UnionFind.get (UnionFind.find ty) in
|
||||
match ty with
|
||||
| TLit l -> TLit l, pos
|
||||
| TTuple (ts, s) -> TTuple (List.map typ_to_ast ts, s), pos
|
||||
| TEnum (ts, e) -> TEnum (List.map typ_to_ast ts, e), pos
|
||||
| TArrow (t1, t2) -> TArrow (typ_to_ast t1, typ_to_ast t2), pos
|
||||
| TAny _ -> TAny, pos
|
||||
| TArray t1 -> TArray (typ_to_ast t1), pos
|
||||
|
||||
let rec ast_to_typ (ty : A.marked_typ) : unionfind_typ =
|
||||
let ty' =
|
||||
match Marked.unmark ty with
|
||||
| TLit l -> TLit l
|
||||
| TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
|
||||
| TTuple (ts, s) -> TTuple (List.map (fun t -> ast_to_typ t) ts, s)
|
||||
| TEnum (ts, e) -> TEnum (List.map (fun t -> ast_to_typ t) ts, e)
|
||||
| TArray t -> TArray (ast_to_typ t)
|
||||
| TAny -> TAny (Any.fresh ())
|
||||
in
|
||||
UnionFind.make (Marked.same_mark_as ty' ty)
|
||||
|
||||
(** {1 Types and unification} *)
|
||||
|
||||
@ -57,14 +100,16 @@ let rec format_typ
|
||||
|
||||
exception
|
||||
Type_error of
|
||||
A.untyped A.marked_expr
|
||||
A.any_marked_expr
|
||||
* typ Marked.pos UnionFind.elem
|
||||
* typ Marked.pos UnionFind.elem
|
||||
|
||||
type mark = { pos : Pos.t; uf : unionfind_typ }
|
||||
|
||||
(** Raises an error if unification cannot be performed *)
|
||||
let rec unify
|
||||
(ctx : Ast.decl_ctx)
|
||||
(e : 'm A.marked_expr) (* used for error context *)
|
||||
(e : ('a, 'm A.mark) Ast.marked_gexpr) (* used for error context *)
|
||||
(t1 : typ Marked.pos UnionFind.elem)
|
||||
(t2 : typ Marked.pos UnionFind.elem) : unit =
|
||||
let unify = unify ctx in
|
||||
@ -72,9 +117,7 @@ let rec unify
|
||||
t2; *)
|
||||
let t1_repr = UnionFind.get (UnionFind.find t1) in
|
||||
let t2_repr = UnionFind.get (UnionFind.find t2) in
|
||||
let raise_type_error () =
|
||||
raise (Type_error (Bindlib.unbox (A.untype_expr e), t1, t2))
|
||||
in
|
||||
let raise_type_error () = raise (Type_error (A.AnyExpr e, t1, t2)) in
|
||||
let repr =
|
||||
match Marked.unmark t1_repr, Marked.unmark t2_repr with
|
||||
| TLit tl1, TLit tl2 when tl1 = tl2 -> None
|
||||
@ -108,6 +151,11 @@ let rec unify
|
||||
let handle_type_error ctx e t1 t2 =
|
||||
(* TODO: if we get weird error messages, then it means that we should use the
|
||||
persistent version of the union-find data structure. *)
|
||||
let pos =
|
||||
match e with
|
||||
| A.AnyExpr e -> (
|
||||
match Marked.get_mark e with Untyped { pos } | Typed { pos; _ } -> pos)
|
||||
in
|
||||
let t1_repr = UnionFind.get (UnionFind.find t1) in
|
||||
let t2_repr = UnionFind.get (UnionFind.find t2) in
|
||||
let t1_pos = Marked.get_mark t1_repr in
|
||||
@ -132,7 +180,7 @@ let handle_type_error ctx e t1 t2 =
|
||||
( Some
|
||||
(Format.asprintf
|
||||
"Error coming from typechecking the following expression:"),
|
||||
A.pos e );
|
||||
pos );
|
||||
Some (Format.asprintf "Type %a coming from expression:" t1_s ()), t1_pos;
|
||||
Some (Format.asprintf "Type %a coming from expression:" t2_s ()), t2_pos;
|
||||
]
|
||||
@ -213,11 +261,10 @@ let op_type (op : A.operator Marked.pos) : typ Marked.pos UnionFind.elem =
|
||||
|
||||
(** {1 Double-directed typing} *)
|
||||
|
||||
type env = typ Marked.pos UnionFind.elem A.VarMap.t
|
||||
type 'e env = ('e, typ Marked.pos UnionFind.elem) Var.Map.t
|
||||
|
||||
let translate_var v = Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
|
||||
let add_pos e ty = Marked.mark (A.pos e) ty
|
||||
let ty (_, A.Inferring { A.uf; _ }) = uf
|
||||
let add_pos e ty = Marked.mark (Ast.pos e) ty
|
||||
let ty (_, { uf; _ }) = uf
|
||||
let ( let+ ) x f = Bindlib.box_apply f x
|
||||
let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2
|
||||
|
||||
@ -244,24 +291,24 @@ let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
|
||||
(** Infers the most permissive type from an expression *)
|
||||
let rec typecheck_expr_bottom_up
|
||||
(ctx : Ast.decl_ctx)
|
||||
(env : env)
|
||||
(e : 'm A.marked_expr) : A.inferring A.marked_expr Bindlib.box =
|
||||
(env : 'm Ast.expr env)
|
||||
(e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box =
|
||||
(* Cli.debug_format "Looking for type of %a" (Print.format_expr ~debug:true
|
||||
ctx) e; *)
|
||||
let pos_e = A.pos e in
|
||||
let mark (e : A.inferring A.expr) uf =
|
||||
Marked.mark (A.Inferring { A.uf; pos = pos_e }) e
|
||||
let pos_e = Ast.pos e in
|
||||
let mark (e : (A.dcalc, mark) A.gexpr) uf =
|
||||
Marked.mark { uf; pos = pos_e } e
|
||||
in
|
||||
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
||||
let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in
|
||||
match Marked.unmark e with
|
||||
| A.EVar v -> begin
|
||||
match A.VarMap.find_opt (A.Var.t v) env with
|
||||
match Var.Map.find_opt v env with
|
||||
| Some t ->
|
||||
let+ v' = Bindlib.box_var (translate_var v) in
|
||||
let+ v' = Bindlib.box_var (Var.translate v) in
|
||||
mark v' t
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
Errors.raise_spanned_error (Ast.pos e)
|
||||
"Variable %s not found in the current context." (Bindlib.name_of v)
|
||||
end
|
||||
| A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool)
|
||||
@ -285,7 +332,7 @@ let rec typecheck_expr_bottom_up
|
||||
match List.nth_opt utyps n with
|
||||
| Some t' -> mark (ETupleAccess (e1, n, s, typs)) t'
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e1)
|
||||
Errors.raise_spanned_error (Marked.get_mark e1).pos
|
||||
"Expression should have a tuple type with at least %d elements but \
|
||||
only has %d"
|
||||
n (List.length typs)
|
||||
@ -296,7 +343,7 @@ let rec typecheck_expr_bottom_up
|
||||
match List.nth_opt ts' n with
|
||||
| Some ts_n -> ts_n
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
Errors.raise_spanned_error (Ast.pos e)
|
||||
"Expression should have a sum type with at least %d cases but only \
|
||||
has %d"
|
||||
n (List.length ts')
|
||||
@ -321,18 +368,16 @@ let rec typecheck_expr_bottom_up
|
||||
mark (EMatch (e1', es', e_name)) t_ret
|
||||
| A.EAbs (binder, taus) ->
|
||||
if Bindlib.mbinder_arity binder <> List.length taus then
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
Errors.raise_spanned_error (Ast.pos e)
|
||||
"function has %d variables but was supplied %d types"
|
||||
(Bindlib.mbinder_arity binder)
|
||||
(List.length taus)
|
||||
else
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map translate_var xs in
|
||||
let xstaus = List.mapi (fun i tau -> xs'.(i), ast_to_typ tau) taus in
|
||||
let xs' = Array.map Var.translate xs in
|
||||
let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in
|
||||
let env =
|
||||
List.fold_left
|
||||
(fun env (x, tau) -> A.VarMap.add (A.Var.t x) tau env)
|
||||
env xstaus
|
||||
List.fold_left (fun env (x, tau) -> Var.Map.add x tau env) env xstaus
|
||||
in
|
||||
let body' = typecheck_expr_bottom_up ctx env body in
|
||||
let t_func =
|
||||
@ -402,27 +447,26 @@ let rec typecheck_expr_bottom_up
|
||||
(** Checks whether the expression can be typed with the provided type *)
|
||||
and typecheck_expr_top_down
|
||||
(ctx : Ast.decl_ctx)
|
||||
(env : env)
|
||||
(env : 'm Ast.expr env)
|
||||
(tau : typ Marked.pos UnionFind.elem)
|
||||
(e : 'm A.marked_expr) : A.inferring A.marked_expr Bindlib.box =
|
||||
(e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box =
|
||||
(* Cli.debug_format "Propagating type %a for expr %a" (format_typ ctx) tau
|
||||
(Print.format_expr ctx) e; *)
|
||||
let pos_e = A.pos e in
|
||||
let mark e = Marked.mark (A.Inferring { uf = tau; pos = pos_e }) e in
|
||||
let unify_and_mark (e : A.inferring A.expr) tau' =
|
||||
let e = Marked.mark (A.Inferring { uf = tau'; pos = pos_e }) e in
|
||||
unify ctx (Bindlib.unbox (A.untype_expr e)) tau tau';
|
||||
e
|
||||
let pos_e = Ast.pos e in
|
||||
let mark e = Marked.mark { uf = tau; pos = pos_e } e in
|
||||
let unify_and_mark (e' : (A.dcalc, mark) A.gexpr) tau' =
|
||||
unify ctx e tau tau';
|
||||
Marked.mark { uf = tau; pos = pos_e } e'
|
||||
in
|
||||
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
||||
match Marked.unmark e with
|
||||
| A.EVar v -> begin
|
||||
match A.VarMap.find_opt (A.Var.t v) env with
|
||||
match Var.Map.find_opt v env with
|
||||
| Some tau' ->
|
||||
let+ v' = Bindlib.box_var (translate_var v) in
|
||||
let+ v' = Bindlib.box_var (Var.translate v) in
|
||||
unify_and_mark v' tau'
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
Errors.raise_spanned_error pos_e
|
||||
"Variable %s not found in the current context" (Bindlib.name_of v)
|
||||
end
|
||||
| A.ELit (LBool _) as e1 ->
|
||||
@ -465,7 +509,7 @@ and typecheck_expr_top_down
|
||||
match List.nth_opt ts' n with
|
||||
| Some ts_n -> ts_n
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
Errors.raise_spanned_error (Ast.pos e)
|
||||
"Expression should have a sum type with at least %d cases but only \
|
||||
has %d"
|
||||
n (List.length ts)
|
||||
@ -496,19 +540,19 @@ and typecheck_expr_top_down
|
||||
unify_and_mark (EMatch (e1', es', e_name)) t_ret
|
||||
| A.EAbs (binder, t_args) ->
|
||||
if Bindlib.mbinder_arity binder <> List.length t_args then
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
Errors.raise_spanned_error (Ast.pos e)
|
||||
"function has %d variables but was supplied %d types"
|
||||
(Bindlib.mbinder_arity binder)
|
||||
(List.length t_args)
|
||||
else
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map translate_var xs in
|
||||
let xs' = Array.map Var.translate xs in
|
||||
let xstaus =
|
||||
List.map2 (fun x t_arg -> x, ast_to_typ t_arg) (Array.to_list xs) t_args
|
||||
in
|
||||
let env =
|
||||
List.fold_left
|
||||
(fun env (x, t_arg) -> A.VarMap.add (A.Var.t x) t_arg env)
|
||||
(fun env (x, t_arg) -> Var.Map.add x t_arg env)
|
||||
env xstaus
|
||||
in
|
||||
let body' = typecheck_expr_bottom_up ctx env body in
|
||||
@ -577,30 +621,28 @@ let wrap ctx f e =
|
||||
|
||||
(** {1 API} *)
|
||||
|
||||
let get_ty_mark (A.Inferring { uf; pos }) =
|
||||
A.Typed { ty = A.Infer.typ_to_ast uf; pos }
|
||||
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
|
||||
|
||||
(* Infer the type of an expression *)
|
||||
let infer_types (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) :
|
||||
let infer_types (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) :
|
||||
Ast.typed Ast.marked_expr Bindlib.box =
|
||||
A.map_expr_marks ~f:get_ty_mark
|
||||
Astgen_utils.map_gexpr_marks ~f:get_ty_mark
|
||||
@@ Bindlib.unbox
|
||||
@@ wrap ctx (typecheck_expr_bottom_up ctx A.VarMap.empty) e
|
||||
@@ wrap ctx (typecheck_expr_bottom_up ctx Var.Map.empty) e
|
||||
|
||||
let infer_type (type m) ctx (e : m A.marked_expr) =
|
||||
let infer_type (type m) ctx (e : m Ast.marked_expr) =
|
||||
match Marked.get_mark e with
|
||||
| A.Typed { ty; _ } -> ty
|
||||
| A.Inferring { uf; _ } -> typ_to_ast uf
|
||||
| A.Untyped _ -> A.ty (Bindlib.unbox (infer_types ctx e))
|
||||
| A.Untyped _ -> Ast.ty (Bindlib.unbox (infer_types ctx e))
|
||||
|
||||
(** Typechecks an expression given an expected type *)
|
||||
let check_type
|
||||
(ctx : Ast.decl_ctx)
|
||||
(e : 'm A.marked_expr)
|
||||
(e : 'm Ast.marked_expr)
|
||||
(tau : A.typ Marked.pos) =
|
||||
(* todo: consider using the already inferred type if ['m] = [typed] *)
|
||||
ignore
|
||||
@@ wrap ctx (typecheck_expr_top_down ctx A.VarMap.empty (ast_to_typ tau)) e
|
||||
@@ wrap ctx (typecheck_expr_top_down ctx Var.Map.empty (ast_to_typ tau)) e
|
||||
|
||||
let infer_types_program prg =
|
||||
let ctx = prg.A.decl_ctx in
|
||||
@ -633,32 +675,34 @@ let infer_types_program prg =
|
||||
| A.Result e ->
|
||||
let e' = typecheck_expr_bottom_up ctx env e in
|
||||
Bindlib.box_apply
|
||||
(fun e ->
|
||||
unify ctx e (ty e) ty_out;
|
||||
A.Result e)
|
||||
(fun e1 ->
|
||||
unify ctx e (ty e1) ty_out;
|
||||
let e1 = Astgen_utils.map_gexpr_marks ~f:get_ty_mark e1 in
|
||||
A.Result (Bindlib.unbox e1))
|
||||
e'
|
||||
| A.ScopeLet
|
||||
{
|
||||
scope_let_kind;
|
||||
scope_let_typ;
|
||||
scope_let_expr = e;
|
||||
scope_let_expr = e0;
|
||||
scope_let_next;
|
||||
scope_let_pos;
|
||||
} ->
|
||||
let ty_e = ast_to_typ scope_let_typ in
|
||||
let e = typecheck_expr_bottom_up ctx env e in
|
||||
let e = typecheck_expr_bottom_up ctx env e0 in
|
||||
let var, next = Bindlib.unbind scope_let_next in
|
||||
let env = A.VarMap.add (A.Var.t var) ty_e env in
|
||||
let env = Var.Map.add var ty_e env in
|
||||
let next = process_scope_body_expr env next in
|
||||
let scope_let_next = Bindlib.bind_var (translate_var var) next in
|
||||
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
|
||||
Bindlib.box_apply2
|
||||
(fun scope_let_expr scope_let_next ->
|
||||
unify ctx scope_let_expr (ty scope_let_expr) ty_e;
|
||||
(fun e scope_let_next ->
|
||||
unify ctx e0 (ty e) ty_e;
|
||||
let e = Astgen_utils.map_gexpr_marks ~f:get_ty_mark e in
|
||||
A.ScopeLet
|
||||
{
|
||||
scope_let_kind;
|
||||
scope_let_typ;
|
||||
scope_let_expr;
|
||||
scope_let_expr = Bindlib.unbox e;
|
||||
scope_let_next;
|
||||
scope_let_pos;
|
||||
})
|
||||
@ -666,26 +710,15 @@ let infer_types_program prg =
|
||||
in
|
||||
let scope_body_expr =
|
||||
let var, e = Bindlib.unbind body in
|
||||
let env = A.VarMap.add (A.Var.t var) ty_in env in
|
||||
let env = Var.Map.add var ty_in env in
|
||||
let e' = process_scope_body_expr env e in
|
||||
let e' =
|
||||
Bindlib.box_apply
|
||||
(fun e ->
|
||||
Bindlib.unbox
|
||||
@@ A.map_exprs_in_scope_lets ~varf:translate_var
|
||||
~f:
|
||||
(A.map_expr_top_down ~f:(fun e ->
|
||||
Marked.(mark (get_ty_mark (get_mark e)) (unmark e))))
|
||||
e)
|
||||
e'
|
||||
in
|
||||
Bindlib.bind_var (translate_var var) e'
|
||||
Bindlib.bind_var (Var.translate var) e'
|
||||
in
|
||||
let scope_next =
|
||||
let scope_var, next = Bindlib.unbind scope_next in
|
||||
let env = A.VarMap.add (A.Var.t scope_var) ty_scope env in
|
||||
let env = Var.Map.add scope_var ty_scope env in
|
||||
let next' = process_scopes env next in
|
||||
Bindlib.bind_var (translate_var scope_var) next'
|
||||
Bindlib.bind_var (Var.translate scope_var) next'
|
||||
in
|
||||
Bindlib.box_apply2
|
||||
(fun scope_body_expr scope_next ->
|
||||
@ -702,6 +735,6 @@ let infer_types_program prg =
|
||||
})
|
||||
scope_body_expr scope_next
|
||||
in
|
||||
let scopes = wrap ctx (process_scopes A.VarMap.empty) prg.scopes in
|
||||
let scopes = wrap ctx (process_scopes Var.Map.empty) prg.scopes in
|
||||
Bindlib.box_apply (fun scopes -> { A.decl_ctx = ctx; scopes }) scopes
|
||||
|> Bindlib.unbox
|
||||
|
@ -15,45 +15,17 @@
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
include Astgen
|
||||
module D = Dcalc.Ast
|
||||
|
||||
type lit =
|
||||
| LBool of bool
|
||||
| LInt of Runtime.integer
|
||||
| LRat of Runtime.decimal
|
||||
| LMoney of Runtime.money
|
||||
| LUnit
|
||||
| LDate of Runtime.date
|
||||
| LDuration of Runtime.duration
|
||||
type lit = lcalc glit
|
||||
|
||||
type except = ConflictError | EmptyError | NoValueProvided | Crash
|
||||
type 'm mark = 'm D.mark
|
||||
|
||||
type 'm marked_expr = ('m expr, 'm) D.marked
|
||||
|
||||
and 'm expr =
|
||||
| EVar of 'm expr Bindlib.var
|
||||
| ETuple of 'm marked_expr list * D.StructName.t option
|
||||
(** The [MarkedString.info] is the former struct field name*)
|
||||
| ETupleAccess of
|
||||
'm marked_expr * int * D.StructName.t option * D.typ Marked.pos list
|
||||
(** The [MarkedString.info] is the former struct field name *)
|
||||
| EInj of 'm marked_expr * int * D.EnumName.t * D.typ Marked.pos list
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| EMatch of 'm marked_expr * 'm marked_expr list * D.EnumName.t
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| EArray of 'm marked_expr list
|
||||
| ELit of lit
|
||||
| EAbs of ('m expr, 'm marked_expr) Bindlib.mbinder * D.typ Marked.pos list
|
||||
| EApp of 'm marked_expr * 'm marked_expr list
|
||||
| EAssert of 'm marked_expr
|
||||
| EOp of D.operator
|
||||
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
|
||||
| ERaise of except
|
||||
| ECatch of 'm marked_expr * except * 'm marked_expr
|
||||
type 'm expr = (lcalc, 'm mark) gexpr
|
||||
and 'm marked_expr = (lcalc, 'm mark) marked_gexpr
|
||||
|
||||
type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic
|
||||
type 'm var = 'm expr Var.t
|
||||
type 'm vars = 'm expr Var.vars
|
||||
|
||||
(* <copy-paste from dcalc/ast.ml> *)
|
||||
|
||||
@ -92,23 +64,6 @@ let eop op mark = Bindlib.box (EOp op, mark)
|
||||
let eifthenelse e1 e2 e3 pos =
|
||||
Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), pos) e1 e2 e3
|
||||
|
||||
type 'm var = 'm expr Bindlib.var
|
||||
type 'm vars = 'm expr Bindlib.mvar
|
||||
|
||||
let new_var s = Bindlib.new_var (fun x -> EVar x) s
|
||||
|
||||
module Var = struct
|
||||
type t = V : 'a var -> t
|
||||
(* See Dcalc.Ast.var *)
|
||||
|
||||
let t v = V v
|
||||
let get (V v) = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
let compare (V x) (V y) = Bindlib.compare_vars x y
|
||||
end
|
||||
|
||||
module VarSet = Set.Make (Var)
|
||||
module VarMap = Map.Make (Var)
|
||||
|
||||
(* </copy-paste> *)
|
||||
|
||||
let eraise e1 pos = Bindlib.box (ERaise e1, pos)
|
||||
@ -116,32 +71,7 @@ let eraise e1 pos = Bindlib.box (ERaise e1, pos)
|
||||
let ecatch e1 exn e2 pos =
|
||||
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2
|
||||
|
||||
let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
|
||||
let map_expr ctx ~f e =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EVar v -> evar (translate_var v) (Marked.get_mark e)
|
||||
| EApp (e1, args) ->
|
||||
eapp (f ctx e1) (List.map (f ctx) args) (Marked.get_mark e)
|
||||
| EAbs (binder, typs) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m
|
||||
| ETuple (args, s) -> etuple (List.map (f ctx) args) s (Marked.get_mark e)
|
||||
| ETupleAccess (e1, n, s_name, typs) ->
|
||||
etupleaccess ((f ctx) e1) n s_name typs (Marked.get_mark e)
|
||||
| EInj (e1, i, e_name, typs) ->
|
||||
einj ((f ctx) e1) i e_name typs (Marked.get_mark e)
|
||||
| EMatch (arg, arms, e_name) ->
|
||||
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name (Marked.get_mark e)
|
||||
| EArray args -> earray (List.map (f ctx) args) (Marked.get_mark e)
|
||||
| ELit l -> elit l (Marked.get_mark e)
|
||||
| EAssert e1 -> eassert ((f ctx) e1) (Marked.get_mark e)
|
||||
| EOp op -> Bindlib.box (EOp op, Marked.get_mark e)
|
||||
| ERaise exn -> eraise exn (Marked.get_mark e)
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Marked.get_mark e)
|
||||
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) (Marked.get_mark e)
|
||||
let map_expr ctx ~f e = Astgen_utils.map_gexpr ctx ~f e
|
||||
|
||||
let rec map_expr_top_down ~f e =
|
||||
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
|
||||
@ -159,7 +89,7 @@ let untype_program prg =
|
||||
Bindlib.unbox
|
||||
(D.map_exprs_in_scopes
|
||||
~f:(fun e -> untype_expr e)
|
||||
~varf:translate_var prg.D.scopes);
|
||||
~varf:Var.translate prg.D.scopes);
|
||||
}
|
||||
|
||||
(** See [Bindlib.box_term] documentation for why we are doing that. *)
|
||||
@ -254,13 +184,13 @@ let make_matchopt_with_abs_arms arg e_none e_some =
|
||||
e_some, permitting it to be used inside the expression. There is no
|
||||
requirements on the form of both e_some and e_none. *)
|
||||
let make_matchopt m v tau arg e_none e_some =
|
||||
let x = new_var "_" in
|
||||
let x = Var.make "_" in
|
||||
|
||||
make_matchopt_with_abs_arms arg
|
||||
(make_abs (Array.of_list [x]) e_none [D.TLit D.TUnit, D.mark_pos m] m)
|
||||
(make_abs (Array.of_list [v]) e_some [tau] m)
|
||||
|
||||
let handle_default = Var.t (new_var "handle_default")
|
||||
let handle_default_opt = Var.t (new_var "handle_default_opt")
|
||||
let handle_default = Var.make "handle_default"
|
||||
let handle_default_opt = Var.make "handle_default_opt"
|
||||
|
||||
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder
|
||||
|
@ -15,79 +15,23 @@
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
include module type of Astgen
|
||||
|
||||
(** Abstract syntax tree for the lambda calculus *)
|
||||
|
||||
(** {1 Abstract syntax tree} *)
|
||||
|
||||
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
|
||||
library, based on higher-order abstract syntax*)
|
||||
type lit = lcalc glit
|
||||
|
||||
type lit =
|
||||
| LBool of bool
|
||||
| LInt of Runtime.integer
|
||||
| LRat of Runtime.decimal
|
||||
| LMoney of Runtime.money
|
||||
| LUnit
|
||||
| LDate of Runtime.date
|
||||
| LDuration of Runtime.duration
|
||||
|
||||
type except = ConflictError | EmptyError | NoValueProvided | Crash
|
||||
type 'm mark = 'm Dcalc.Ast.mark
|
||||
|
||||
type 'm marked_expr = ('m expr, 'm) Dcalc.Ast.marked
|
||||
|
||||
and 'm expr =
|
||||
| EVar of 'm expr Bindlib.var
|
||||
| ETuple of 'm marked_expr list * Dcalc.Ast.StructName.t option
|
||||
(** The [MarkedString.info] is the former struct field name*)
|
||||
| ETupleAccess of
|
||||
'm marked_expr
|
||||
* int
|
||||
* Dcalc.Ast.StructName.t option
|
||||
* Dcalc.Ast.typ Marked.pos list
|
||||
(** The [MarkedString.info] is the former struct field name *)
|
||||
| EInj of
|
||||
'm marked_expr
|
||||
* int
|
||||
* Dcalc.Ast.EnumName.t
|
||||
* Dcalc.Ast.typ Marked.pos list
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| EMatch of 'm marked_expr * 'm marked_expr list * Dcalc.Ast.EnumName.t
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| EArray of 'm marked_expr list
|
||||
| ELit of lit
|
||||
| EAbs of
|
||||
('m expr, 'm marked_expr) Bindlib.mbinder * Dcalc.Ast.typ Marked.pos list
|
||||
| EApp of 'm marked_expr * 'm marked_expr list
|
||||
| EAssert of 'm marked_expr
|
||||
| EOp of Dcalc.Ast.operator
|
||||
| EIfThenElse of 'm marked_expr * 'm marked_expr * 'm marked_expr
|
||||
| ERaise of except
|
||||
| ECatch of 'm marked_expr * except * 'm marked_expr
|
||||
type 'm expr = (lcalc, 'm mark) gexpr
|
||||
and 'm marked_expr = (lcalc, 'm mark) marked_gexpr
|
||||
|
||||
type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic
|
||||
|
||||
(** {1 Variable helpers} *)
|
||||
|
||||
type 'm var = 'm expr Bindlib.var
|
||||
type 'm vars = 'm expr Bindlib.mvar
|
||||
|
||||
module Var : sig
|
||||
type t
|
||||
|
||||
val t : 'm expr Bindlib.var -> t
|
||||
val get : t -> 'm expr Bindlib.var
|
||||
val compare : t -> t -> int
|
||||
end
|
||||
|
||||
module VarMap : Map.S with type key = Var.t
|
||||
module VarSet : Set.S with type elt = Var.t
|
||||
|
||||
val new_var : string -> 'm var
|
||||
|
||||
type 'm binder = ('m expr, 'm marked_expr) Bindlib.binder
|
||||
type 'm var = 'm expr Var.t
|
||||
type 'm vars = 'm expr Var.vars
|
||||
|
||||
(** {2 Program traversal} *)
|
||||
|
||||
@ -246,5 +190,5 @@ val box_expr : 'm marked_expr -> 'm marked_expr Bindlib.box
|
||||
|
||||
(** {1 Special symbols} *)
|
||||
|
||||
val handle_default : Var.t
|
||||
val handle_default_opt : Var.t
|
||||
val handle_default : untyped var
|
||||
val handle_default_opt : untyped var
|
||||
|
@ -21,12 +21,12 @@ module D = Dcalc.Ast
|
||||
(** TODO: This version is not yet debugged and ought to be specialized when
|
||||
Lcalc has more structure. *)
|
||||
|
||||
type ctx = { name_context : string; globally_bound_vars : VarSet.t }
|
||||
type 'm ctx = { name_context : string; globally_bound_vars : 'm expr Var.Set.t }
|
||||
|
||||
(** Returns the expression with closed closures and the set of free variables
|
||||
inside this new expression. Implementation guided by
|
||||
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=9. *)
|
||||
let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
|
||||
let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
|
||||
m marked_expr Bindlib.box =
|
||||
let module MVarSet = Set.Make (struct
|
||||
type t = m var
|
||||
@ -39,7 +39,7 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
|
||||
( Bindlib.box_apply
|
||||
(fun new_v -> new_v, Marked.get_mark e)
|
||||
(Bindlib.box_var v),
|
||||
if VarSet.mem (Var.t v) ctx.globally_bound_vars then MVarSet.empty
|
||||
if Var.Set.mem v ctx.globally_bound_vars then MVarSet.empty
|
||||
else MVarSet.singleton v )
|
||||
| ETuple (args, s) ->
|
||||
let new_args, free_vars =
|
||||
@ -138,9 +138,9 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
|
||||
in
|
||||
let extra_vars_list = MVarSet.elements extra_vars in
|
||||
(* x1, ..., xn *)
|
||||
let code_var = new_var ctx.name_context in
|
||||
let code_var = Var.make ctx.name_context in
|
||||
(* code *)
|
||||
let inner_c_var = new_var "env" in
|
||||
let inner_c_var = Var.make "env" in
|
||||
let any_ty = Dcalc.Ast.TAny, binder_pos in
|
||||
let new_closure_body =
|
||||
make_multiple_let_in
|
||||
@ -200,8 +200,7 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
|
||||
(fun new_e2 -> EApp ((EOp op, pos_op), new_e2), Marked.get_mark e)
|
||||
(Bindlib.box_list new_args),
|
||||
free_vars )
|
||||
| EApp ((EVar v, v_pos), args)
|
||||
when VarSet.mem (Var.t v) ctx.globally_bound_vars ->
|
||||
| EApp ((EVar v, v_pos), args) when Var.Set.mem v ctx.globally_bound_vars ->
|
||||
(* This corresponds to a scope call, which we don't want to transform*)
|
||||
let new_args, free_vars =
|
||||
List.fold_right
|
||||
@ -217,8 +216,8 @@ let closure_conversion_expr (type m) (ctx : ctx) (e : m marked_expr) :
|
||||
free_vars )
|
||||
| EApp (e1, args) ->
|
||||
let new_e1, free_vars = aux e1 in
|
||||
let env_var = new_var "env" in
|
||||
let code_var = new_var "code" in
|
||||
let env_var = Var.make "env" in
|
||||
let code_var = Var.make "code" in
|
||||
let new_args, free_vars =
|
||||
List.fold_right
|
||||
(fun arg (new_args, free_vars) ->
|
||||
@ -286,7 +285,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
|
||||
let scope_input_var, scope_body_expr =
|
||||
Bindlib.unbind scope.scope_body.scope_body_expr
|
||||
in
|
||||
let global_vars = VarSet.add (Var.t scope_var) global_vars in
|
||||
let global_vars = Var.Set.add scope_var global_vars in
|
||||
let ctx =
|
||||
{
|
||||
name_context =
|
||||
@ -320,7 +319,10 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
|
||||
new_scope_body_expr
|
||||
(Bindlib.bind_var scope_var next))),
|
||||
global_vars ))
|
||||
~init:(Fun.id, VarSet.of_list [handle_default; handle_default_opt])
|
||||
~init:
|
||||
( Fun.id,
|
||||
Var.Set.of_list
|
||||
(List.map Var.translate [handle_default; handle_default_opt]) )
|
||||
p.scopes
|
||||
in
|
||||
Bindlib.box_apply
|
||||
|
@ -18,7 +18,7 @@ open Utils
|
||||
module D = Dcalc.Ast
|
||||
module A = Ast
|
||||
|
||||
type 'm ctx = 'm A.var D.VarMap.t
|
||||
type 'm ctx = ('m D.expr, 'm A.expr Var.t) Var.Map.t
|
||||
(** This environment contains a mapping between the variables in Dcalc and their
|
||||
correspondance in Lcalc. *)
|
||||
|
||||
@ -35,7 +35,7 @@ let translate_lit (l : D.lit) : 'm A.expr =
|
||||
|
||||
let thunk_expr (e : 'm A.marked_expr Bindlib.box) (mark : 'm A.mark) :
|
||||
'm A.marked_expr Bindlib.box =
|
||||
let dummy_var = A.new_var "_" in
|
||||
let dummy_var = Var.make "_" in
|
||||
A.make_abs [| dummy_var |] e [D.TAny, D.mark_pos mark] mark
|
||||
|
||||
let rec translate_default
|
||||
@ -51,7 +51,7 @@ let rec translate_default
|
||||
in
|
||||
let exceptions =
|
||||
A.make_app
|
||||
(A.make_var (A.Var.get A.handle_default, mark_default))
|
||||
(A.make_var (Var.translate A.handle_default, mark_default))
|
||||
[
|
||||
A.earray exceptions mark_default;
|
||||
thunk_expr (translate_expr ctx just) mark_default;
|
||||
@ -64,7 +64,7 @@ let rec translate_default
|
||||
and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) :
|
||||
'm A.marked_expr Bindlib.box =
|
||||
match Marked.unmark e with
|
||||
| D.EVar v -> A.make_var (D.VarMap.find (D.Var.t v) ctx, Marked.get_mark e)
|
||||
| D.EVar v -> A.make_var (Var.Map.find v ctx, Marked.get_mark e)
|
||||
| D.ETuple (args, s) ->
|
||||
A.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e)
|
||||
| D.ETupleAccess (e1, i, s, ts) ->
|
||||
@ -96,8 +96,8 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) :
|
||||
let ctx, lc_vars =
|
||||
Array.fold_right
|
||||
(fun var (ctx, lc_vars) ->
|
||||
let lc_var = A.new_var (Bindlib.name_of var) in
|
||||
D.VarMap.add (D.Var.t var) lc_var ctx, lc_var :: lc_vars)
|
||||
let lc_var = Var.make (Bindlib.name_of var) in
|
||||
Var.Map.add var lc_var ctx, lc_var :: lc_vars)
|
||||
vars (ctx, [])
|
||||
in
|
||||
let lc_vars = Array.of_list lc_vars in
|
||||
@ -126,11 +126,9 @@ let rec translate_scope_lets
|
||||
let old_scope_let_var, scope_let_next =
|
||||
Bindlib.unbind scope_let.scope_let_next
|
||||
in
|
||||
let new_scope_let_var = A.new_var (Bindlib.name_of old_scope_let_var) in
|
||||
let new_scope_let_var = Var.make (Bindlib.name_of old_scope_let_var) in
|
||||
let new_scope_let_expr = translate_expr ctx scope_let.scope_let_expr in
|
||||
let new_ctx =
|
||||
D.VarMap.add (D.Var.t old_scope_let_var) new_scope_let_var ctx
|
||||
in
|
||||
let new_ctx = Var.Map.add old_scope_let_var new_scope_let_var ctx in
|
||||
let new_scope_next = translate_scope_lets decl_ctx new_ctx scope_let_next in
|
||||
let new_scope_next = Bindlib.bind_var new_scope_let_var new_scope_next in
|
||||
Bindlib.box_apply2
|
||||
@ -154,15 +152,13 @@ let rec translate_scopes
|
||||
| ScopeDef scope_def ->
|
||||
let old_scope_var, scope_next = Bindlib.unbind scope_def.scope_next in
|
||||
let new_scope_var =
|
||||
A.new_var (Marked.unmark (D.ScopeName.get_info scope_def.scope_name))
|
||||
Var.make (Marked.unmark (D.ScopeName.get_info scope_def.scope_name))
|
||||
in
|
||||
let old_scope_input_var, scope_body_expr =
|
||||
Bindlib.unbind scope_def.scope_body.scope_body_expr
|
||||
in
|
||||
let new_scope_input_var = A.new_var (Bindlib.name_of old_scope_input_var) in
|
||||
let new_ctx =
|
||||
D.VarMap.add (D.Var.t old_scope_input_var) new_scope_input_var ctx
|
||||
in
|
||||
let new_scope_input_var = Var.make (Bindlib.name_of old_scope_input_var) in
|
||||
let new_ctx = Var.Map.add old_scope_input_var new_scope_input_var ctx in
|
||||
let new_scope_body_expr =
|
||||
translate_scope_lets decl_ctx new_ctx scope_body_expr
|
||||
in
|
||||
@ -181,7 +177,7 @@ let rec translate_scopes
|
||||
})
|
||||
new_scope_body_expr
|
||||
in
|
||||
let new_ctx = D.VarMap.add (D.Var.t old_scope_var) new_scope_var new_ctx in
|
||||
let new_ctx = Var.Map.add old_scope_var new_scope_var new_ctx in
|
||||
let scope_next =
|
||||
Bindlib.bind_var new_scope_var
|
||||
(translate_scopes decl_ctx new_ctx scope_next)
|
||||
@ -199,6 +195,6 @@ let rec translate_scopes
|
||||
let translate_program (prgm : 'm D.program) : 'm A.program =
|
||||
{
|
||||
scopes =
|
||||
Bindlib.unbox (translate_scopes prgm.decl_ctx D.VarMap.empty prgm.scopes);
|
||||
Bindlib.unbox (translate_scopes prgm.decl_ctx Var.Map.empty prgm.scopes);
|
||||
decl_ctx = prgm.decl_ctx;
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ module A = Ast
|
||||
hoisted and later handled by the [translate_expr] function. Every other
|
||||
cases is found in the translate_and_hoist function. *)
|
||||
|
||||
type 'm hoists = 'm D.marked_expr A.VarMap.t
|
||||
type 'm hoists = ('m A.expr, 'm D.marked_expr) Var.Map.t
|
||||
(** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *)
|
||||
|
||||
type 'm info = {
|
||||
@ -60,14 +60,13 @@ let pp_info (fmt : Format.formatter) (info : 'm info) =
|
||||
|
||||
type 'm ctx = {
|
||||
decl_ctx : D.decl_ctx;
|
||||
vars : 'm info D.VarMap.t;
|
||||
vars : ('m D.expr, 'm info) Var.Map.t;
|
||||
(** information context about variables in the current scope *)
|
||||
}
|
||||
|
||||
let _pp_ctx (fmt : Format.formatter) (ctx : 'm ctx) =
|
||||
let pp_binding (fmt : Format.formatter) ((v, info) : D.Var.t * 'm info) =
|
||||
Format.fprintf fmt "%a: %a" Dcalc.Print.format_var (D.Var.get v) pp_info
|
||||
info
|
||||
let pp_binding (fmt : Format.formatter) ((v, info) : 'm D.var * 'm info) =
|
||||
Format.fprintf fmt "%a: %a" Dcalc.Print.format_var v pp_info info
|
||||
in
|
||||
|
||||
let pp_bindings =
|
||||
@ -76,14 +75,14 @@ let _pp_ctx (fmt : Format.formatter) (ctx : 'm ctx) =
|
||||
pp_binding
|
||||
in
|
||||
|
||||
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx.vars)
|
||||
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (Var.Map.bindings ctx.vars)
|
||||
|
||||
(** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a
|
||||
slightly better way. *)
|
||||
let find ?(info : string = "none") (n : 'm D.var) (ctx : 'm ctx) : 'm info =
|
||||
(* let _ = Format.asprintf "Searching for variable %a inside context %a"
|
||||
Dcalc.Print.format_var n pp_ctx ctx |> Cli.debug_print in *)
|
||||
try D.VarMap.find (D.Var.t n) ctx.vars
|
||||
try Var.Map.find n ctx.vars
|
||||
with Not_found ->
|
||||
Errors.raise_spanned_error Pos.no_pos
|
||||
"Internal Error: Variable %a was not found in the current environment. \
|
||||
@ -96,7 +95,7 @@ let find ?(info : string = "none") (n : 'm D.var) (ctx : 'm ctx) : 'm info =
|
||||
debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *)
|
||||
let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx)
|
||||
: 'm ctx =
|
||||
let new_var = A.new_var (Bindlib.name_of var) in
|
||||
let new_var = Var.make (Bindlib.name_of var) in
|
||||
let expr = A.make_var (new_var, mark) in
|
||||
|
||||
(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var
|
||||
@ -104,7 +103,7 @@ let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx)
|
||||
{
|
||||
ctx with
|
||||
vars =
|
||||
D.VarMap.update (D.Var.t var)
|
||||
Var.Map.update var
|
||||
(fun _ -> Some { expr; var = new_var; is_pure })
|
||||
ctx.vars;
|
||||
}
|
||||
@ -147,16 +146,16 @@ let translate_lit (l : D.lit) (pos : Pos.t) : A.lit =
|
||||
|
||||
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
|
||||
Raises an internal error if there is two identicals keys in differnts parts. *)
|
||||
let disjoint_union_maps (pos : Pos.t) (cs : 'a A.VarMap.t list) : 'a A.VarMap.t
|
||||
=
|
||||
let disjoint_union_maps (pos : Pos.t) (cs : ('e, 'a) Var.Map.t list) :
|
||||
('e, 'a) Var.Map.t =
|
||||
let disjoint_union =
|
||||
A.VarMap.union (fun _ _ _ ->
|
||||
Var.Map.union (fun _ _ _ ->
|
||||
Errors.raise_spanned_error pos
|
||||
"Internal Error: Two supposed to be disjoints maps have one shared \
|
||||
key.")
|
||||
in
|
||||
|
||||
List.fold_left disjoint_union A.VarMap.empty cs
|
||||
List.fold_left disjoint_union Var.Map.empty cs
|
||||
|
||||
(** [e' = translate_and_hoist ctx e ] Translate the Dcalc expression e into an
|
||||
expression in Lcalc, given we translate each hoists correctly. It ensures
|
||||
@ -176,34 +175,34 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
|
||||
assumption can change in the future, and this case is here for this
|
||||
reason. *)
|
||||
if not (find ~info:"search for a variable" v ctx).is_pure then
|
||||
let v' = A.new_var (Bindlib.name_of v) in
|
||||
let v' = Var.make (Bindlib.name_of v) in
|
||||
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
|
||||
created a variable %a to replace it" Dcalc.Print.format_var v
|
||||
Print.format_var v'; *)
|
||||
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') e
|
||||
else (find ~info:"should never happend" v ctx).expr, A.VarMap.empty
|
||||
A.make_var (v', pos), Var.Map.singleton v' e
|
||||
else (find ~info:"should never happend" v ctx).expr, Var.Map.empty
|
||||
| D.EApp ((D.EVar v, p), [(D.ELit D.LUnit, _)]) ->
|
||||
if not (find ~info:"search for a variable" v ctx).is_pure then
|
||||
let v' = A.new_var (Bindlib.name_of v) in
|
||||
let v' = Var.make (Bindlib.name_of v) in
|
||||
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
|
||||
created a variable %a to replace it" Dcalc.Print.format_var v
|
||||
Print.format_var v'; *)
|
||||
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') (D.EVar v, p)
|
||||
A.make_var (v', pos), Var.Map.singleton v' (D.EVar v, p)
|
||||
else
|
||||
Errors.raise_spanned_error (D.pos e)
|
||||
"Internal error: an pure variable was found in an unpure environment."
|
||||
| D.EDefault (_exceptions, _just, _cons) ->
|
||||
let v' = A.new_var "default_term" in
|
||||
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') e
|
||||
let v' = Var.make "default_term" in
|
||||
A.make_var (v', pos), Var.Map.singleton v' e
|
||||
| D.ELit D.LEmptyError ->
|
||||
let v' = A.new_var "empty_litteral" in
|
||||
A.make_var (v', pos), A.VarMap.singleton (A.Var.t v') e
|
||||
let v' = Var.make "empty_litteral" in
|
||||
A.make_var (v', pos), Var.Map.singleton v' e
|
||||
(* This one is a very special case. It transform an unpure expression
|
||||
environement to a pure expression. *)
|
||||
| ErrorOnEmpty arg ->
|
||||
(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)
|
||||
let silent_var = A.new_var "_" in
|
||||
let x = A.new_var "non_empty_argument" in
|
||||
let silent_var = Var.make "_" in
|
||||
let x = Var.make "non_empty_argument" in
|
||||
|
||||
let arg' = translate_expr ctx arg in
|
||||
|
||||
@ -213,9 +212,9 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
|
||||
[D.TAny, D.pos e]
|
||||
pos)
|
||||
(A.make_abs [| x |] (A.make_var (x, pos)) [D.TAny, D.pos e] pos),
|
||||
A.VarMap.empty )
|
||||
Var.Map.empty )
|
||||
(* pure terms *)
|
||||
| D.ELit l -> A.elit (translate_lit l (D.pos e)) pos, A.VarMap.empty
|
||||
| D.ELit l -> A.elit (translate_lit l (D.pos e)) pos, Var.Map.empty
|
||||
| D.EIfThenElse (e1, e2, e3) ->
|
||||
let e1', h1 = translate_and_hoist ctx e1 in
|
||||
let e2', h2 = translate_and_hoist ctx e2 in
|
||||
@ -293,12 +292,12 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
|
||||
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in
|
||||
|
||||
A.earray es' pos, disjoint_union_maps (D.pos e) hoists
|
||||
| EOp op -> Bindlib.box (A.EOp op, pos), A.VarMap.empty
|
||||
| EOp op -> Bindlib.box (A.EOp op, pos), Var.Map.empty
|
||||
|
||||
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
|
||||
: 'm A.marked_expr Bindlib.box =
|
||||
let e', hoists = translate_and_hoist ctx e in
|
||||
let hoists = A.VarMap.bindings hoists in
|
||||
let hoists = Var.Map.bindings hoists in
|
||||
|
||||
let _pos = Marked.get_mark e in
|
||||
|
||||
@ -321,7 +320,7 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
|
||||
let cons' = translate_expr ctx cons in
|
||||
(* calls handle_option. *)
|
||||
A.make_app
|
||||
(A.make_var (A.Var.get A.handle_default_opt, mark_hoist))
|
||||
(A.make_var (Var.translate A.handle_default_opt, mark_hoist))
|
||||
[
|
||||
Bindlib.box_apply
|
||||
(fun excep' -> A.EArray excep', mark_hoist)
|
||||
@ -336,8 +335,8 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
|
||||
|
||||
(* [ match arg with | None -> raise NoValueProvided | Some v -> assert
|
||||
{{ v }} ] *)
|
||||
let silent_var = A.new_var "_" in
|
||||
let x = A.new_var "assertion_argument" in
|
||||
let silent_var = Var.make "_" in
|
||||
let x = Var.make "assertion_argument" in
|
||||
|
||||
A.make_matchopt_with_abs_arms arg'
|
||||
(A.make_abs [| silent_var |]
|
||||
@ -360,7 +359,7 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
|
||||
] *)
|
||||
(* Cli.debug_print @@ Format.asprintf "build matchopt using %a"
|
||||
Print.format_var v; *)
|
||||
A.make_matchopt mark_hoist (A.Var.get v)
|
||||
A.make_matchopt mark_hoist v
|
||||
(D.TAny, D.mark_pos mark_hoist)
|
||||
c' (A.make_none mark_hoist) acc)
|
||||
|
||||
@ -582,7 +581,7 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
|
||||
|
||||
let scopes =
|
||||
Bindlib.unbox
|
||||
(translate_scopes { decl_ctx; vars = D.VarMap.empty } prgm.scopes)
|
||||
(translate_scopes { decl_ctx; vars = Var.Map.empty } prgm.scopes)
|
||||
in
|
||||
|
||||
{ scopes; decl_ctx }
|
||||
|
@ -416,8 +416,8 @@ let rec format_expr
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
|
||||
format_with_parens arg1
|
||||
| EApp ((EVar x, pos), args)
|
||||
when Ast.Var.compare (Ast.Var.t x) Ast.handle_default = 0
|
||||
|| Ast.Var.compare (Ast.Var.t x) Ast.handle_default_opt = 0 ->
|
||||
when Var.compare x (Var.translate Ast.handle_default) = 0
|
||||
|| Var.compare x (Var.translate Ast.handle_default_opt) = 0 ->
|
||||
Format.fprintf fmt
|
||||
"@[<hov 2>%a@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \
|
||||
start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a@]"
|
||||
|
@ -19,23 +19,23 @@ module A = Ast
|
||||
module L = Lcalc.Ast
|
||||
module D = Dcalc.Ast
|
||||
|
||||
type ctxt = {
|
||||
func_dict : A.TopLevelName.t L.VarMap.t;
|
||||
type 'm ctxt = {
|
||||
func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t;
|
||||
decl_ctx : D.decl_ctx;
|
||||
var_dict : A.LocalName.t L.VarMap.t;
|
||||
var_dict : ('m L.expr, A.LocalName.t) Var.Map.t;
|
||||
inside_definition_of : A.LocalName.t option;
|
||||
context_name : string;
|
||||
}
|
||||
|
||||
(* Expressions can spill out side effect, hence this function also returns a
|
||||
list of statements to be prepended before the expression is evaluated *)
|
||||
let rec translate_expr (ctxt : ctxt) (expr : 'm L.marked_expr) :
|
||||
let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
|
||||
A.block * A.expr Marked.pos =
|
||||
match Marked.unmark expr with
|
||||
| L.EVar v ->
|
||||
let local_var =
|
||||
try A.EVar (L.VarMap.find (L.Var.t v) ctxt.var_dict)
|
||||
with Not_found -> A.EFunc (L.VarMap.find (L.Var.t v) ctxt.func_dict)
|
||||
try A.EVar (Var.Map.find v ctxt.var_dict)
|
||||
with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict)
|
||||
in
|
||||
[], (local_var, D.pos expr)
|
||||
| L.ETuple (args, Some s_name) ->
|
||||
@ -115,8 +115,8 @@ let rec translate_expr (ctxt : ctxt) (expr : 'm L.marked_expr) :
|
||||
:: tmp_stmts,
|
||||
(A.EVar tmp_var, D.pos expr) )
|
||||
|
||||
and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
|
||||
=
|
||||
and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
|
||||
A.block =
|
||||
match Marked.unmark block_expr with
|
||||
| L.EAssert e ->
|
||||
(* Assertions are always encapsulated in a unit-typed let binding *)
|
||||
@ -133,7 +133,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
|
||||
var_dict =
|
||||
List.fold_left
|
||||
(fun var_dict (x, _) ->
|
||||
L.VarMap.add (L.Var.t x)
|
||||
Var.Map.add x
|
||||
(A.LocalName.fresh (Bindlib.name_of x, binder_pos))
|
||||
var_dict)
|
||||
ctxt.var_dict vars_tau;
|
||||
@ -142,15 +142,14 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
|
||||
let local_decls =
|
||||
List.map
|
||||
(fun (x, tau) ->
|
||||
( A.SLocalDecl
|
||||
((L.VarMap.find (L.Var.t x) ctxt.var_dict, binder_pos), tau),
|
||||
( A.SLocalDecl ((Var.Map.find x ctxt.var_dict, binder_pos), tau),
|
||||
binder_pos ))
|
||||
vars_tau
|
||||
in
|
||||
let vars_args =
|
||||
List.map2
|
||||
(fun (x, tau) arg ->
|
||||
(L.VarMap.find (L.Var.t x) ctxt.var_dict, binder_pos), tau, arg)
|
||||
(Var.Map.find x ctxt.var_dict, binder_pos), tau, arg)
|
||||
vars_tau args
|
||||
in
|
||||
let def_blocks =
|
||||
@ -185,7 +184,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
|
||||
var_dict =
|
||||
List.fold_left
|
||||
(fun var_dict (x, _) ->
|
||||
L.VarMap.add (L.Var.t x)
|
||||
Var.Map.add x
|
||||
(A.LocalName.fresh (Bindlib.name_of x, binder_pos))
|
||||
var_dict)
|
||||
ctxt.var_dict vars_tau;
|
||||
@ -200,7 +199,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
|
||||
func_params =
|
||||
List.map
|
||||
(fun (var, tau) ->
|
||||
(L.VarMap.find (L.Var.t var) ctxt.var_dict, binder_pos), tau)
|
||||
(Var.Map.find var ctxt.var_dict, binder_pos), tau)
|
||||
vars_tau;
|
||||
func_body = new_body;
|
||||
} ),
|
||||
@ -220,10 +219,7 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
|
||||
A.LocalName.fresh (Bindlib.name_of var, D.pos arg)
|
||||
in
|
||||
let ctxt =
|
||||
{
|
||||
ctxt with
|
||||
var_dict = L.VarMap.add (L.Var.t var) scalc_var ctxt.var_dict;
|
||||
}
|
||||
{ ctxt with var_dict = Var.Map.add var scalc_var ctxt.var_dict }
|
||||
in
|
||||
let new_arg = translate_statements ctxt body in
|
||||
(new_arg, scalc_var) :: new_args
|
||||
@ -275,8 +271,8 @@ and translate_statements (ctxt : ctxt) (block_expr : 'm L.marked_expr) : A.block
|
||||
let rec translate_scope_body_expr
|
||||
(scope_name : D.ScopeName.t)
|
||||
(decl_ctx : D.decl_ctx)
|
||||
(var_dict : A.LocalName.t L.VarMap.t)
|
||||
(func_dict : A.TopLevelName.t L.VarMap.t)
|
||||
(var_dict : ('m L.expr, A.LocalName.t) Var.Map.t)
|
||||
(func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t)
|
||||
(scope_expr : ('m L.expr, 'm) D.scope_body_expr) : A.block =
|
||||
match scope_expr with
|
||||
| Result e ->
|
||||
@ -297,7 +293,7 @@ let rec translate_scope_body_expr
|
||||
let let_var_id =
|
||||
A.LocalName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos)
|
||||
in
|
||||
let new_var_dict = L.VarMap.add (L.Var.t let_var) let_var_id var_dict in
|
||||
let new_var_dict = Var.Map.add let_var let_var_id var_dict in
|
||||
(match scope_let.scope_let_kind with
|
||||
| D.Assertion ->
|
||||
translate_statements
|
||||
@ -349,7 +345,7 @@ let translate_program (p : 'm L.program) : A.program =
|
||||
A.LocalName.fresh (Bindlib.name_of scope_input_var, input_pos)
|
||||
in
|
||||
let var_dict =
|
||||
L.VarMap.singleton (L.Var.t scope_input_var) scope_input_var_id
|
||||
Var.Map.singleton scope_input_var scope_input_var_id
|
||||
in
|
||||
let new_scope_body =
|
||||
translate_scope_body_expr scope_def.D.scope_name p.decl_ctx
|
||||
@ -358,9 +354,7 @@ let translate_program (p : 'm L.program) : A.program =
|
||||
let func_id =
|
||||
A.TopLevelName.fresh (Bindlib.name_of scope_var, Pos.no_pos)
|
||||
in
|
||||
let func_dict =
|
||||
L.VarMap.add (L.Var.t scope_var) func_id func_dict
|
||||
in
|
||||
let func_dict = Var.Map.add scope_var func_id func_dict in
|
||||
( func_dict,
|
||||
{
|
||||
Ast.scope_body_name = scope_def.D.scope_name;
|
||||
@ -387,8 +381,8 @@ let translate_program (p : 'm L.program) : A.program =
|
||||
:: new_scopes ))
|
||||
~init:
|
||||
( (if !Cli.avoid_exceptions_flag then
|
||||
L.VarMap.singleton L.handle_default_opt A.handle_default_opt
|
||||
else L.VarMap.singleton L.handle_default A.handle_default),
|
||||
Var.Map.singleton L.handle_default_opt A.handle_default_opt
|
||||
else Var.Map.singleton L.handle_default A.handle_default),
|
||||
[] )
|
||||
p.D.scopes
|
||||
in
|
||||
|
@ -332,9 +332,7 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
|
||||
Bindlib.box_apply Marked.unmark new_e
|
||||
| EAbs (binder, typ) ->
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let new_xs =
|
||||
Array.map (fun x -> Dcalc.Ast.new_var (Bindlib.name_of x)) xs
|
||||
in
|
||||
let new_xs = Array.map (fun x -> Var.make (Bindlib.name_of x)) xs in
|
||||
let both_xs = Array.map2 (fun x new_x -> x, new_x) xs new_xs in
|
||||
let body =
|
||||
translate_expr
|
||||
@ -415,7 +413,7 @@ let translate_rule
|
||||
match rule with
|
||||
| Definition ((ScopeVar a, var_def_pos), tau, a_io, e) ->
|
||||
let a_name = Ast.ScopeVar.get_info (Marked.unmark a) in
|
||||
let a_var = Dcalc.Ast.new_var (Marked.unmark a_name) in
|
||||
let a_var = Var.make (Marked.unmark a_name) in
|
||||
let tau = translate_typ ctx tau in
|
||||
let new_e = translate_expr ctx e in
|
||||
let a_expr = Dcalc.Ast.make_var (a_var, pos_mark var_def_pos) in
|
||||
@ -469,14 +467,14 @@ let translate_rule
|
||||
^ Marked.unmark (Ast.ScopeVar.get_info (Marked.unmark subs_var)))
|
||||
(Ast.SubScopeName.get_info (Marked.unmark subs_index))
|
||||
in
|
||||
let a_var = Dcalc.Ast.new_var (Marked.unmark a_name) in
|
||||
let a_var = Var.make (Marked.unmark a_name) in
|
||||
let tau = translate_typ ctx tau in
|
||||
let new_e =
|
||||
tag_with_log_entry (translate_expr ctx e)
|
||||
(Dcalc.Ast.VarDef (Marked.unmark tau))
|
||||
[sigma_name, pos_sigma; a_name]
|
||||
in
|
||||
let silent_var = Dcalc.Ast.new_var "_" in
|
||||
let silent_var = Var.make "_" in
|
||||
let thunked_or_nonempty_new_e =
|
||||
match Marked.unmark a_io.io_input with
|
||||
| NoInput -> failwith "should not happen"
|
||||
@ -582,7 +580,7 @@ let translate_rule
|
||||
List.map
|
||||
(fun (subvar : scope_var_ctx) ->
|
||||
let sub_dcalc_var =
|
||||
Dcalc.Ast.new_var
|
||||
Var.make
|
||||
(Marked.unmark (Ast.SubScopeName.get_info subindex)
|
||||
^ "."
|
||||
^ Marked.unmark (Ast.ScopeVar.get_info subvar.scope_var_name))
|
||||
@ -613,7 +611,7 @@ let translate_rule
|
||||
Ast.ScopeName.get_info subname;
|
||||
]
|
||||
in
|
||||
let result_tuple_var = Dcalc.Ast.new_var "result" in
|
||||
let result_tuple_var = Var.make "result" in
|
||||
let result_tuple_typ =
|
||||
( Dcalc.Ast.TTuple
|
||||
( List.map
|
||||
@ -698,7 +696,7 @@ let translate_rule
|
||||
new_e;
|
||||
Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion;
|
||||
})
|
||||
(Bindlib.bind_var (Dcalc.Ast.new_var "_") next)
|
||||
(Bindlib.bind_var (Var.make "_") next)
|
||||
new_e),
|
||||
ctx )
|
||||
|
||||
@ -753,7 +751,7 @@ let translate_scope_decl
|
||||
(sigma : Ast.scope_decl) :
|
||||
(Dcalc.Ast.untyped Dcalc.Ast.expr, Dcalc.Ast.untyped) Dcalc.Ast.scope_body
|
||||
Bindlib.box
|
||||
* Dcalc.Ast.struct_ctx =
|
||||
* Astgen.struct_ctx =
|
||||
let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in
|
||||
let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in
|
||||
let scope_variables = scope_sig.scope_sig_local_vars in
|
||||
@ -765,9 +763,7 @@ let translate_scope_decl
|
||||
match Marked.unmark scope_var.scope_var_io.io_input with
|
||||
| OnlyInput ->
|
||||
let scope_var_name = Ast.ScopeVar.get_info scope_var.scope_var_name in
|
||||
let scope_var_dcalc =
|
||||
Dcalc.Ast.new_var (Marked.unmark scope_var_name)
|
||||
in
|
||||
let scope_var_dcalc = Var.make (Marked.unmark scope_var_name) in
|
||||
{
|
||||
ctx with
|
||||
scope_vars =
|
||||
@ -916,7 +912,7 @@ let translate_program (prgm : Ast.program) :
|
||||
Ast.ScopeMap.mapi
|
||||
(fun scope_name scope ->
|
||||
let scope_dvar =
|
||||
Dcalc.Ast.new_var
|
||||
Var.make
|
||||
(Marked.unmark (Ast.ScopeName.get_info scope.Ast.scope_decl_name))
|
||||
in
|
||||
let scope_return_struct_name =
|
||||
@ -926,8 +922,7 @@ let translate_program (prgm : Ast.program) :
|
||||
(Ast.ScopeName.get_info scope_name))
|
||||
in
|
||||
let scope_input_var =
|
||||
Dcalc.Ast.new_var
|
||||
(Marked.unmark (Ast.ScopeName.get_info scope_name) ^ "_in")
|
||||
Var.make (Marked.unmark (Ast.ScopeName.get_info scope_name) ^ "_in")
|
||||
in
|
||||
let scope_input_struct_name =
|
||||
Ast.StructName.fresh
|
||||
|
293
compiler/utils/astgen.ml
Normal file
293
compiler/utils/astgen.ml
Normal file
@ -0,0 +1,293 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
|
||||
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
|
||||
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
|
||||
module ScopeName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
|
||||
|
||||
module EnumName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
|
||||
|
||||
(** Abstract syntax tree for the default calculus *)
|
||||
|
||||
(** {1 Abstract syntax tree} *)
|
||||
|
||||
(** {2 Types} *)
|
||||
|
||||
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
|
||||
|
||||
type marked_typ = typ Marked.pos
|
||||
|
||||
and typ =
|
||||
| TLit of typ_lit
|
||||
| TTuple of marked_typ list * StructName.t option
|
||||
| TEnum of marked_typ list * EnumName.t
|
||||
| TArrow of marked_typ * marked_typ
|
||||
| TArray of marked_typ
|
||||
| TAny
|
||||
|
||||
(** {2 Constants and operators} *)
|
||||
|
||||
type date = Runtime.date
|
||||
type duration = Runtime.duration
|
||||
|
||||
type op_kind =
|
||||
| KInt
|
||||
| KRat
|
||||
| KMoney
|
||||
| KDate
|
||||
| KDuration (** All ops don't have a KDate and KDuration. *)
|
||||
|
||||
type ternop = Fold
|
||||
|
||||
type binop =
|
||||
| And
|
||||
| Or
|
||||
| Xor
|
||||
| Add of op_kind
|
||||
| Sub of op_kind
|
||||
| Mult of op_kind
|
||||
| Div of op_kind
|
||||
| Lt of op_kind
|
||||
| Lte of op_kind
|
||||
| Gt of op_kind
|
||||
| Gte of op_kind
|
||||
| Eq
|
||||
| Neq
|
||||
| Map
|
||||
| Concat
|
||||
| Filter
|
||||
|
||||
type log_entry =
|
||||
| VarDef of typ
|
||||
(** During code generation, we need to know the type of the variable being
|
||||
logged for embedding *)
|
||||
| BeginCall
|
||||
| EndCall
|
||||
| PosRecordIfTrueBool
|
||||
|
||||
type unop =
|
||||
| Not
|
||||
| Minus of op_kind
|
||||
| Log of log_entry * Uid.MarkedString.info list
|
||||
| Length
|
||||
| IntToRat
|
||||
| MoneyToRat
|
||||
| RatToMoney
|
||||
| GetDay
|
||||
| GetMonth
|
||||
| GetYear
|
||||
| FirstDayOfMonth
|
||||
| LastDayOfMonth
|
||||
| RoundMoney
|
||||
| RoundDecimal
|
||||
|
||||
type operator = Ternop of ternop | Binop of binop | Unop of unop
|
||||
type except = ConflictError | EmptyError | NoValueProvided | Crash
|
||||
|
||||
(** {2 Generic expressions} *)
|
||||
|
||||
(** Define a common base type for the expressions in most passes of the compiler *)
|
||||
|
||||
type desugared = [ `Desugared ]
|
||||
type scopelang = [ `Scopelang ]
|
||||
type dcalc = [ `Dcalc ]
|
||||
type lcalc = [ `Lcalc ]
|
||||
type scalc = [ `Scalc ]
|
||||
type any = [ desugared | scopelang | dcalc | lcalc | scalc ]
|
||||
|
||||
(** Literals are the same throughout compilation except for the [LEmptyError]
|
||||
case which is eliminated midway through. *)
|
||||
type 'a glit =
|
||||
| LBool : bool -> 'a glit
|
||||
| LEmptyError : [< desugared | scopelang | dcalc ] glit
|
||||
| LInt : Runtime.integer -> 'a glit
|
||||
| LRat : Runtime.decimal -> 'a glit
|
||||
| LMoney : Runtime.money -> 'a glit
|
||||
| LUnit : 'a glit
|
||||
| LDate : date -> 'a glit
|
||||
| LDuration : duration -> 'a glit
|
||||
|
||||
type ('a, 't) marked_gexpr = (('a, 't) gexpr, 't) Marked.t
|
||||
(** General expressions: groups all expression cases of the different ASTs, and
|
||||
uses a GADT to eliminate irrelevant cases for each one. The ['t] annotations
|
||||
are also totally unconstrained at this point. The dcalc exprs, for example,
|
||||
are then defined with [type expr = dcalc gexpr] plus the annotations. *)
|
||||
|
||||
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
|
||||
library, based on higher-order abstract syntax *)
|
||||
and ('a, 't) gexpr =
|
||||
(* Constructors common to all ASTs *)
|
||||
| ELit : 'a glit -> ('a, 't) gexpr
|
||||
| EApp : ('a, 't) marked_gexpr * ('a, 't) marked_gexpr list -> ('a, 't) gexpr
|
||||
| EOp : operator -> ('a, 't) gexpr
|
||||
| EArray : ('a, 't) marked_gexpr list -> ('a, 't) gexpr
|
||||
(* All but statement calculus *)
|
||||
| EVar :
|
||||
('a, 't) gexpr Bindlib.var
|
||||
-> (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr
|
||||
| EAbs :
|
||||
(('a, 't) gexpr, ('a, 't) marked_gexpr) Bindlib.mbinder
|
||||
* typ Marked.pos list
|
||||
-> (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr
|
||||
| EIfThenElse :
|
||||
('a, 't) marked_gexpr * ('a, 't) marked_gexpr * ('a, 't) marked_gexpr
|
||||
-> (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr
|
||||
(* (* Early stages *) | ELocation: location -> ([< desugared | scopelang ] as
|
||||
'a, 't) gexpr | EStruct: StructName.t * ('a, 't) marked_gexpr
|
||||
StructFieldMap.t -> ([< desugared | scopelang ] as 'a, 't) gexpr |
|
||||
EStructAccess: ('a, 't) marked_gexpr * StructFieldName.t * StructName.t ->
|
||||
([< desugared | scopelang ] as 'a, 't) gexpr | EEnumInj: ('a, 't)
|
||||
marked_gexpr * EnumConstructor.t * EnumName.t -> ([< desugared | scopelang
|
||||
] as 'a, 't) gexpr | EMatchS: ('a, 't) marked_gexpr * EnumName.t * ('a, 't)
|
||||
marked_gexpr EnumConstructorMap.t -> ([< desugared | scopelang ] as 'a, 't)
|
||||
gexpr *)
|
||||
(* Lambda-like *)
|
||||
| ETuple :
|
||||
('a, 't) marked_gexpr list * StructName.t option
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
|
||||
| ETupleAccess :
|
||||
('a, 't) marked_gexpr * int * StructName.t option * typ Marked.pos list
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
|
||||
| EInj :
|
||||
('a, 't) marked_gexpr * int * EnumName.t * typ Marked.pos list
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
|
||||
| EMatch :
|
||||
('a, 't) marked_gexpr * ('a, 't) marked_gexpr list * EnumName.t
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) gexpr
|
||||
| EAssert : ('a, 't) marked_gexpr -> (([< dcalc | lcalc ] as 'a), 't) gexpr
|
||||
(* Default terms *)
|
||||
| EDefault :
|
||||
('a, 't) marked_gexpr list * ('a, 't) marked_gexpr * ('a, 't) marked_gexpr
|
||||
-> (([< desugared | scopelang | dcalc ] as 'a), 't) gexpr
|
||||
| ErrorOnEmpty :
|
||||
('a, 't) marked_gexpr
|
||||
-> (([< desugared | scopelang | dcalc ] as 'a), 't) gexpr
|
||||
(* Lambda calculus with exceptions *)
|
||||
| ERaise : except -> ((lcalc as 'a), 't) gexpr
|
||||
| ECatch :
|
||||
('a, 't) marked_gexpr * except * ('a, 't) marked_gexpr
|
||||
-> ((lcalc as 'a), 't) gexpr
|
||||
|
||||
(* (\* Statement calculus *\)
|
||||
* | ESVar: LocalName.t -> (scalc as 'a, 't) gexpr
|
||||
* | ESStruct: ('a, 't) marked_gexpr list * StructName.t -> (scalc as 'a, 't) gexpr
|
||||
* | ESStructFieldAccess: ('a, 't) marked_gexpr * StructFieldName.t * StructName.t -> (scalc as 'a, 't) gexpr
|
||||
* | ESInj: ('a, 't) marked_gexpr * EnumConstructor.t * EnumName.t -> (scalc as 'a, 't) gexpr
|
||||
* | ESFunc: TopLevelName.t -> (scalc as 'a, 't) gexpr *)
|
||||
|
||||
(** {2 Markings} *)
|
||||
|
||||
type untyped = { pos : Pos.t } [@@ocaml.unboxed]
|
||||
type typed = { pos : Pos.t; ty : marked_typ }
|
||||
(* type inferring = { pos : Pos.t; uf : Infer.unionfind_typ } *)
|
||||
|
||||
(** The generic type of AST markings. Using a GADT allows functions to be
|
||||
polymorphic in the marking, but still do transformations on types when
|
||||
appropriate. Expected to fill the ['t] parameter of [gexpr] and
|
||||
[marked_gexpr] *)
|
||||
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
|
||||
(* | Inferring : inferring -> inferring mark *)
|
||||
|
||||
type ('a, 'm) marked = ('a, 'm mark) Marked.t
|
||||
|
||||
(** Useful for errors and printing, for example *)
|
||||
type any_marked_expr =
|
||||
| AnyExpr : ([< any ], 'm mark) marked_gexpr -> any_marked_expr
|
||||
|
||||
(** {2 Higher-level program structure} *)
|
||||
|
||||
(** Constructs scopes and programs on top of expressions. We may use the [gexpr]
|
||||
type above at some point, but at the moment this is polymorphic in the types
|
||||
of the expressions. Their markings are constrained to belong to the [mark]
|
||||
GADT defined above. *)
|
||||
|
||||
(** This kind annotation signals that the let-binding respects a structural
|
||||
invariant. These invariants concern the shape of the expression in the
|
||||
let-binding, and are documented below. *)
|
||||
type scope_let_kind =
|
||||
| DestructuringInputStruct (** [let x = input.field]*)
|
||||
| ScopeVarDefinition (** [let x = error_on_empty e]*)
|
||||
| SubScopeVarDefinition
|
||||
(** [let s.x = fun _ -> e] or [let s.x = error_on_empty e] for input-only
|
||||
subscope variables. *)
|
||||
| CallingSubScope (** [let result = s ({ x = s.x; y = s.x; ...}) ]*)
|
||||
| DestructuringSubScopeResults (** [let s.x = result.x ]**)
|
||||
| Assertion (** [let _ = assert e]*)
|
||||
|
||||
type ('expr, 'm) scope_let = {
|
||||
scope_let_kind : scope_let_kind;
|
||||
scope_let_typ : marked_typ;
|
||||
scope_let_expr : ('expr, 'm) marked;
|
||||
scope_let_next : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
|
||||
scope_let_pos : Pos.t;
|
||||
}
|
||||
(** This type is parametrized by the expression type so it can be reused in
|
||||
later intermediate representations. *)
|
||||
|
||||
(** A scope let-binding has all the information necessary to make a proper
|
||||
let-binding expression, plus an annotation for the kind of the let-binding
|
||||
that comes from the compilation of a {!module: Scopelang.Ast} statement. *)
|
||||
and ('expr, 'm) scope_body_expr =
|
||||
| Result of ('expr, 'm) marked
|
||||
| ScopeLet of ('expr, 'm) scope_let
|
||||
|
||||
type ('expr, 'm) scope_body = {
|
||||
scope_body_input_struct : StructName.t;
|
||||
scope_body_output_struct : StructName.t;
|
||||
scope_body_expr : ('expr, ('expr, 'm) scope_body_expr) Bindlib.binder;
|
||||
}
|
||||
(** Instead of being a single expression, we give a little more ad-hoc structure
|
||||
to the scope body by decomposing it in an ordered list of let-bindings, and
|
||||
a result expression that uses the let-binded variables. The first binder is
|
||||
the argument of type [scope_body_input_struct]. *)
|
||||
|
||||
type ('expr, 'm) scope_def = {
|
||||
scope_name : ScopeName.t;
|
||||
scope_body : ('expr, 'm) scope_body;
|
||||
scope_next : ('expr, ('expr, 'm) scopes) Bindlib.binder;
|
||||
}
|
||||
|
||||
(** Finally, we do the same transformation for the whole program for the kinded
|
||||
lets. This permit us to use bindlib variables for scopes names. *)
|
||||
and ('expr, 'm) scopes = Nil | ScopeDef of ('expr, 'm) scope_def
|
||||
|
||||
type struct_ctx = (StructFieldName.t * marked_typ) list StructMap.t
|
||||
|
||||
type decl_ctx = {
|
||||
ctx_enums : (EnumConstructor.t * marked_typ) list EnumMap.t;
|
||||
ctx_structs : struct_ctx;
|
||||
}
|
||||
|
||||
type ('expr, 'm) program_generic = {
|
||||
decl_ctx : decl_ctx;
|
||||
scopes : ('expr, 'm) scopes;
|
||||
}
|
180
compiler/utils/astgen_utils.ml
Normal file
180
compiler/utils/astgen_utils.ml
Normal file
@ -0,0 +1,180 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
|
||||
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
|
||||
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Astgen
|
||||
|
||||
(** Functions handling the types in [Astgen] *)
|
||||
|
||||
let evar v mark = Bindlib.box_apply (Marked.mark mark) (Bindlib.box_var v)
|
||||
|
||||
let etuple args s mark =
|
||||
Bindlib.box_apply (fun args -> ETuple (args, s), mark) (Bindlib.box_list args)
|
||||
|
||||
let etupleaccess e1 i s typs mark =
|
||||
Bindlib.box_apply (fun e1 -> ETupleAccess (e1, i, s, typs), mark) e1
|
||||
|
||||
let einj e1 i e_name typs mark =
|
||||
Bindlib.box_apply (fun e1 -> EInj (e1, i, e_name, typs), mark) e1
|
||||
|
||||
let ematch arg arms e_name mark =
|
||||
Bindlib.box_apply2
|
||||
(fun arg arms -> EMatch (arg, arms, e_name), mark)
|
||||
arg (Bindlib.box_list arms)
|
||||
|
||||
let earray args mark =
|
||||
Bindlib.box_apply (fun args -> EArray args, mark) (Bindlib.box_list args)
|
||||
|
||||
let elit l mark = Bindlib.box (ELit l, mark)
|
||||
|
||||
let eabs binder typs mark =
|
||||
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) binder
|
||||
|
||||
let eapp e1 args mark =
|
||||
Bindlib.box_apply2
|
||||
(fun e1 args -> EApp (e1, args), mark)
|
||||
e1 (Bindlib.box_list args)
|
||||
|
||||
let eassert e1 mark = Bindlib.box_apply (fun e1 -> EAssert e1, mark) e1
|
||||
let eop op mark = Bindlib.box (EOp op, mark)
|
||||
|
||||
let edefault excepts just cons mark =
|
||||
Bindlib.box_apply3
|
||||
(fun excepts just cons -> EDefault (excepts, just, cons), mark)
|
||||
(Bindlib.box_list excepts) just cons
|
||||
|
||||
let eifthenelse e1 e2 e3 mark =
|
||||
Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), mark) e1 e2 e3
|
||||
|
||||
let eerroronempty e1 mark =
|
||||
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1
|
||||
|
||||
let eraise e1 pos = Bindlib.box (ERaise e1, pos)
|
||||
|
||||
let ecatch e1 exn e2 pos =
|
||||
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2
|
||||
|
||||
let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
|
||||
let map_gexpr
|
||||
(type a)
|
||||
(ctx : 'ctx)
|
||||
~(f : 'ctx -> (a, 'm1) marked_gexpr -> (a, 'm2) marked_gexpr Bindlib.box)
|
||||
(e : ((a, 'm1) gexpr, 'm2) Marked.t) : (a, 'm2) marked_gexpr Bindlib.box =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| ELit l -> elit l m
|
||||
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m
|
||||
| EOp op -> Bindlib.box (EOp op, m)
|
||||
| EArray args -> earray (List.map (f ctx) args) m
|
||||
| EVar v -> evar (translate_var v) m
|
||||
| EAbs (binder, typs) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
|
||||
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
|
||||
| ETupleAccess (e1, n, s_name, typs) ->
|
||||
etupleaccess ((f ctx) e1) n s_name typs m
|
||||
| EInj (e1, i, e_name, typs) -> einj ((f ctx) e1) i e_name typs m
|
||||
| EMatch (arg, arms, e_name) ->
|
||||
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
|
||||
| EAssert e1 -> eassert ((f ctx) e1) m
|
||||
| EDefault (excepts, just, cons) ->
|
||||
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
|
||||
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
|
||||
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) (Marked.get_mark e)
|
||||
| ERaise exn -> eraise exn (Marked.get_mark e)
|
||||
|
||||
let rec map_gexpr_top_down ~f e =
|
||||
map_gexpr () ~f:(fun () -> map_gexpr_top_down ~f) (f e)
|
||||
|
||||
let map_gexpr_marks ~f e =
|
||||
map_gexpr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
|
||||
|
||||
let rec fold_left_scope_lets ~f ~init scope_body_expr =
|
||||
match scope_body_expr with
|
||||
| Result _ -> init
|
||||
| ScopeLet scope_let ->
|
||||
let var, next = Bindlib.unbind scope_let.scope_let_next in
|
||||
fold_left_scope_lets ~f ~init:(f init scope_let var) next
|
||||
|
||||
let rec fold_right_scope_lets ~f ~init scope_body_expr =
|
||||
match scope_body_expr with
|
||||
| Result result -> init result
|
||||
| ScopeLet scope_let ->
|
||||
let var, next = Bindlib.unbind scope_let.scope_let_next in
|
||||
let next_result = fold_right_scope_lets ~f ~init next in
|
||||
f scope_let var next_result
|
||||
|
||||
let map_exprs_in_scope_lets ~f ~varf scope_body_expr =
|
||||
fold_right_scope_lets
|
||||
~f:(fun scope_let var_next acc ->
|
||||
Bindlib.box_apply2
|
||||
(fun scope_let_next scope_let_expr ->
|
||||
ScopeLet { scope_let with scope_let_next; scope_let_expr })
|
||||
(Bindlib.bind_var (varf var_next) acc)
|
||||
(f scope_let.scope_let_expr))
|
||||
~init:(fun res -> Bindlib.box_apply (fun res -> Result res) (f res))
|
||||
scope_body_expr
|
||||
|
||||
let rec fold_left_scope_defs ~f ~init scopes =
|
||||
match scopes with
|
||||
| Nil -> init
|
||||
| ScopeDef scope_def ->
|
||||
let var, next = Bindlib.unbind scope_def.scope_next in
|
||||
fold_left_scope_defs ~f ~init:(f init scope_def var) next
|
||||
|
||||
let rec fold_right_scope_defs ~f ~init scopes =
|
||||
match scopes with
|
||||
| Nil -> init
|
||||
| ScopeDef scope_def ->
|
||||
let var_next, next = Bindlib.unbind scope_def.scope_next in
|
||||
let result_next = fold_right_scope_defs ~f ~init next in
|
||||
f scope_def var_next result_next
|
||||
|
||||
let map_scope_defs ~f scopes =
|
||||
fold_right_scope_defs
|
||||
~f:(fun scope_def var_next acc ->
|
||||
let new_scope_def = f scope_def in
|
||||
let new_next = Bindlib.bind_var var_next acc in
|
||||
Bindlib.box_apply2
|
||||
(fun new_scope_def new_next ->
|
||||
ScopeDef { new_scope_def with scope_next = new_next })
|
||||
new_scope_def new_next)
|
||||
~init:(Bindlib.box Nil) scopes
|
||||
|
||||
let map_exprs_in_scopes ~f ~varf scopes =
|
||||
fold_right_scope_defs
|
||||
~f:(fun scope_def var_next acc ->
|
||||
let scope_input_var, scope_lets =
|
||||
Bindlib.unbind scope_def.scope_body.scope_body_expr
|
||||
in
|
||||
let new_scope_body_expr = map_exprs_in_scope_lets ~f ~varf scope_lets in
|
||||
let new_scope_body_expr =
|
||||
Bindlib.bind_var (varf scope_input_var) new_scope_body_expr
|
||||
in
|
||||
let new_next = Bindlib.bind_var (varf var_next) acc in
|
||||
Bindlib.box_apply2
|
||||
(fun scope_body_expr scope_next ->
|
||||
ScopeDef
|
||||
{
|
||||
scope_def with
|
||||
scope_body = { scope_def.scope_body with scope_body_expr };
|
||||
scope_next;
|
||||
})
|
||||
new_scope_body_expr new_next)
|
||||
~init:(Bindlib.box Nil) scopes
|
183
compiler/utils/astgen_utils.mli
Normal file
183
compiler/utils/astgen_utils.mli
Normal file
@ -0,0 +1,183 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
|
||||
contributor: Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët-Tixeuil
|
||||
<alain.delaet--tixeuil@inria.fr>, Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
(** Functions handling the types in [Astgen] *)
|
||||
|
||||
open Astgen
|
||||
|
||||
(** {2 Boxed constructors} *)
|
||||
|
||||
val evar :
|
||||
(([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr Bindlib.var ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val etuple :
|
||||
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box list ->
|
||||
StructName.t option ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val etupleaccess :
|
||||
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
|
||||
int ->
|
||||
StructName.t option ->
|
||||
marked_typ list ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val einj :
|
||||
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
|
||||
int ->
|
||||
EnumName.t ->
|
||||
marked_typ list ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val ematch :
|
||||
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
|
||||
('a, 't) marked_gexpr Bindlib.box list ->
|
||||
EnumName.t ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val earray :
|
||||
('a, 't) marked_gexpr Bindlib.box list ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val elit : 'a glit -> 't -> ('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val eabs :
|
||||
( (([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) gexpr,
|
||||
('a, 't) marked_gexpr )
|
||||
Bindlib.mbinder
|
||||
Bindlib.box ->
|
||||
marked_typ list ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val eapp :
|
||||
('a, 't) marked_gexpr Bindlib.box ->
|
||||
('a, 't) marked_gexpr Bindlib.box list ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val eassert :
|
||||
(([< dcalc | lcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val eop : operator -> 't -> ('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val edefault :
|
||||
(([< desugared | scopelang | dcalc ] as 'a), 't) marked_gexpr Bindlib.box list ->
|
||||
('a, 't) marked_gexpr Bindlib.box ->
|
||||
('a, 't) marked_gexpr Bindlib.box ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val eifthenelse :
|
||||
(([< desugared | scopelang | dcalc | lcalc ] as 'a), 't) marked_gexpr
|
||||
Bindlib.box ->
|
||||
('a, 't) marked_gexpr Bindlib.box ->
|
||||
('a, 't) marked_gexpr Bindlib.box ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
val eerroronempty :
|
||||
(([< desugared | scopelang | dcalc ] as 'a), 't) marked_gexpr Bindlib.box ->
|
||||
't ->
|
||||
('a, 't) marked_gexpr Bindlib.box
|
||||
|
||||
(** ---------- *)
|
||||
|
||||
val map_gexpr :
|
||||
'ctx ->
|
||||
f:('ctx -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box) ->
|
||||
(('a, 't1) gexpr, 't2) Marked.t ->
|
||||
('a, 't2) marked_gexpr Bindlib.box
|
||||
|
||||
val map_gexpr_top_down :
|
||||
f:(('a, 't1) marked_gexpr -> (('a, 't1) gexpr, 't2) Marked.t) ->
|
||||
('a, 't1) marked_gexpr ->
|
||||
('a, 't2) marked_gexpr Bindlib.box
|
||||
(** Recursively applies [f] to the nodes of the expression tree. The type
|
||||
returned by [f] is hybrid since the mark at top-level has been rewritten,
|
||||
but not yet the marks in the subtrees. *)
|
||||
|
||||
val map_gexpr_marks :
|
||||
f:('t1 -> 't2) -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box
|
||||
|
||||
val fold_left_scope_lets :
|
||||
f:('a -> ('expr, 'm) scope_let -> 'expr Bindlib.var -> 'a) ->
|
||||
init:'a ->
|
||||
('expr, 'm) scope_body_expr ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_left_scope_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets],
|
||||
where [scope_let_var] is the variable bound to the scope let in the next
|
||||
scope lets to be examined. *)
|
||||
|
||||
val fold_right_scope_lets :
|
||||
f:(('expr1, 'm1) scope_let -> 'expr1 Bindlib.var -> 'a -> 'a) ->
|
||||
init:(('expr1, 'm1) marked -> 'a) ->
|
||||
('expr1, 'm1) scope_body_expr ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_right_scope_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets],
|
||||
where [scope_let_var] is the variable bound to the scope let in the next
|
||||
scope lets to be examined (which are before in the program order). *)
|
||||
|
||||
val map_exprs_in_scope_lets :
|
||||
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
|
||||
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
|
||||
('expr1, 'm1) scope_body_expr ->
|
||||
('expr2, 'm2) scope_body_expr Bindlib.box
|
||||
|
||||
val fold_left_scope_defs :
|
||||
f:('a -> ('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a) ->
|
||||
init:'a ->
|
||||
('expr1, 'm1) scopes ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_left_scope_defs ~f:(fun acc scope_def scope_var -> ...) ~init scope_def],
|
||||
where [scope_var] is the variable bound to the scope in the next scopes to
|
||||
be examined. *)
|
||||
|
||||
val fold_right_scope_defs :
|
||||
f:(('expr1, 'm1) scope_def -> 'expr1 Bindlib.var -> 'a -> 'a) ->
|
||||
init:'a ->
|
||||
('expr1, 'm1) scopes ->
|
||||
'a
|
||||
(** Usage:
|
||||
[fold_right_scope_defs ~f:(fun scope_def scope_var acc -> ...) ~init scope_def],
|
||||
where [scope_var] is the variable bound to the scope in the next scopes to
|
||||
be examined (which are before in the program order). *)
|
||||
|
||||
val map_scope_defs :
|
||||
f:(('expr, 'm) scope_def -> ('expr, 'm) scope_def Bindlib.box) ->
|
||||
('expr, 'm) scopes ->
|
||||
('expr, 'm) scopes Bindlib.box
|
||||
|
||||
val map_exprs_in_scopes :
|
||||
f:(('expr1, 'm1) marked -> ('expr2, 'm2) marked Bindlib.box) ->
|
||||
varf:('expr1 Bindlib.var -> 'expr2 Bindlib.var) ->
|
||||
('expr1, 'm1) scopes ->
|
||||
('expr2, 'm2) scopes Bindlib.box
|
||||
(** This is the main map visitor for all the expressions inside all the scopes
|
||||
of the program. *)
|
@ -1,7 +1,7 @@
|
||||
(library
|
||||
(name utils)
|
||||
(public_name catala.utils)
|
||||
(libraries cmdliner ubase ANSITerminal re))
|
||||
(libraries cmdliner ubase ANSITerminal re bindlib catala.runtime_ocaml))
|
||||
|
||||
(documentation
|
||||
(package catala)
|
||||
|
@ -21,6 +21,7 @@ type 'a pos = ('a, Pos.t) t
|
||||
let mark m e : ('a, 'm) t = e, m
|
||||
let unmark ((x, _) : ('a, 'm) t) : 'a = x
|
||||
let get_mark ((_, x) : ('a, 'm) t) : 'm = x
|
||||
let map_mark (f : 'm1 -> 'm2) ((a, m) : ('a, 'm1) t) : ('a, 'm2) t = a, f m
|
||||
let map_under_mark (f : 'a -> 'b) ((x, y) : ('a, 'm) t) : ('b, 'c) t = f x, y
|
||||
let same_mark_as (x : 'a) ((_, y) : ('b, 'm) t) : ('a, 'm) t = x, y
|
||||
|
||||
|
@ -27,6 +27,7 @@ type 'a pos = ('a, Pos.t) t
|
||||
val mark : 'm -> 'a -> ('a, 'm) t
|
||||
val unmark : ('a, 'm) t -> 'a
|
||||
val get_mark : ('a, 'm) t -> 'm
|
||||
val map_mark : ('m1 -> 'm2) -> ('a, 'm1) t -> ('a, 'm2) t
|
||||
val map_under_mark : ('a -> 'b) -> ('a, 'm) t -> ('b, 'm) t
|
||||
val same_mark_as : 'a -> ('b, 'm) t -> ('a, 'm) t
|
||||
val unmark_option : ('a, 'm) t option -> 'a option
|
||||
|
108
compiler/utils/var.ml
Normal file
108
compiler/utils/var.ml
Normal file
@ -0,0 +1,108 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
|
||||
contributor: Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Astgen
|
||||
|
||||
(** {1 Variables and their collections} *)
|
||||
|
||||
(** This module provides types and helpers for Bindlib variables on the
|
||||
[Astgen.gexpr] type *)
|
||||
|
||||
(* The subtypes of the generic AST that hold vars *)
|
||||
type 'e expr = 'e
|
||||
constraint 'e = ([< desugared | scopelang | dcalc | lcalc ], 't) gexpr
|
||||
|
||||
type 'e var = 'e expr Bindlib.var
|
||||
type 'e t = 'e var
|
||||
type 'e vars = 'e expr Bindlib.mvar
|
||||
|
||||
type 'e binder = (('a, 't) gexpr, ('a, 't) marked_gexpr) Bindlib.binder
|
||||
constraint 'e = ('a, 't) gexpr
|
||||
|
||||
let make (name : string) : 'e var = Bindlib.new_var (fun x -> EVar x) name
|
||||
let compare = Bindlib.compare_vars
|
||||
let eq = Bindlib.eq_vars
|
||||
|
||||
let translate (v : 'e1 var) : 'e2 var =
|
||||
Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
|
||||
(* The purpose of this module is just to lift a type parameter outside of
|
||||
[Set.S] and [Map.S], so that we can have ['e Var.Set.t] for sets of variables
|
||||
bound to the ['e = ('a, 't) gexpr] expression type. This is made possible by
|
||||
the fact that [Bindlib.compare_vars] is polymorphic in that parameter; we
|
||||
first hide that parameter inside an existential, then re-add a phantom type
|
||||
outside of the set to ensure consistency. Extracting the elements is then
|
||||
done with [Bindlib.copy_var] but technically it's not much different from an
|
||||
[Obj] conversion.
|
||||
|
||||
If anyone has a better solution, besides a copy-paste of Set.Make / Map.Make
|
||||
code... *)
|
||||
module Generic = struct
|
||||
(* Existentially quantify the type parameters to allow application of
|
||||
Set.Make *)
|
||||
type t = Var : 'e var -> t
|
||||
(* Note: adding [[@@ocaml.unboxed]] would be OK and make our wrappers live at
|
||||
the type-level without affecting the actual data representation. But
|
||||
[Bindlib.var] being abstract, we can't convince OCaml it's ok at the moment
|
||||
and have to hold it *)
|
||||
|
||||
let t v = Var v
|
||||
let get (Var v) = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v)
|
||||
let compare (Var x) (Var y) = Bindlib.compare_vars x y
|
||||
let eq (Var x) (Var y) = Bindlib.eq_vars x y
|
||||
end
|
||||
|
||||
(* Wrapper around Set.Make to re-add type parameters (avoid inconsistent
|
||||
sets) *)
|
||||
module Set = struct
|
||||
open Generic
|
||||
open Set.Make (Generic)
|
||||
|
||||
type nonrec 'e t = t constraint 'e = 'e expr
|
||||
|
||||
let empty = empty
|
||||
let singleton x = singleton (t x)
|
||||
let add x s = add (t x) s
|
||||
let remove x s = remove (t x) s
|
||||
let union s1 s2 = union s1 s2
|
||||
let mem x s = mem (t x) s
|
||||
let of_list l = of_list (List.map t l)
|
||||
let elements s = elements s |> List.map get
|
||||
|
||||
(* Add more as needed *)
|
||||
end
|
||||
|
||||
(* Wrapper around Map.Make to re-add type parameters (avoid inconsistent
|
||||
maps) *)
|
||||
module Map = struct
|
||||
open Generic
|
||||
open Map.Make (Generic)
|
||||
|
||||
type nonrec ('e, 'x) t = 'x t constraint 'e = 'e expr
|
||||
|
||||
let empty = empty
|
||||
let singleton v x = singleton (t v) x
|
||||
let add v x m = add (t v) x m
|
||||
let update v f m = update (t v) f m
|
||||
let find v m = find (t v) m
|
||||
let find_opt v m = find_opt (t v) m
|
||||
let bindings m = bindings m |> List.map (fun (v, x) -> get v, x)
|
||||
let mem x m = mem (t x) m
|
||||
let union f m1 m2 = union (fun v x1 x2 -> f (get v) x1 x2) m1 m2
|
||||
let fold f m acc = fold (fun v x acc -> f (get v) x acc) m acc
|
||||
|
||||
(* Add more as needed *)
|
||||
end
|
73
compiler/utils/var.mli
Normal file
73
compiler/utils/var.mli
Normal file
@ -0,0 +1,73 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
|
||||
contributor: Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Astgen
|
||||
|
||||
(** {1 Variables and their collections} *)
|
||||
|
||||
(** This module provides types and helpers for Bindlib variables on the
|
||||
[Astgen.gexpr] type *)
|
||||
|
||||
type 'e expr = 'e
|
||||
constraint 'e = ([< desugared | scopelang | dcalc | lcalc ], 't) gexpr
|
||||
(** Subtype of Astgen.gexpr where variables are handled *)
|
||||
|
||||
type 'e var = 'e expr Bindlib.var
|
||||
type 'e t = 'e var
|
||||
type 'e vars = 'e expr Bindlib.mvar
|
||||
|
||||
val make : string -> 'e t
|
||||
val compare : 'e t -> 'e t -> int
|
||||
val eq : 'e t -> 'e t -> bool
|
||||
|
||||
val translate : 'e1 t -> 'e2 t
|
||||
(** Needed when converting from one AST type to another. See the note of caution
|
||||
on [Bindlib.copy_var]. *)
|
||||
|
||||
(** Wrapper over [Set.S] but with a type variable for the AST type parameters.
|
||||
Extend as needed *)
|
||||
module Set : sig
|
||||
type 'e t constraint 'e = 'e expr
|
||||
|
||||
val empty : 'e t
|
||||
val singleton : 'e var -> 'e t
|
||||
val add : 'e var -> 'e t -> 'e t
|
||||
val remove : 'e var -> 'e t -> 'e t
|
||||
val union : 'e t -> 'e t -> 'e t
|
||||
val mem : 'e var -> 'e t -> bool
|
||||
val of_list : 'e var list -> 'e t
|
||||
val elements : 'e t -> 'e var list
|
||||
end
|
||||
|
||||
(** Wrapper over [Map.S] but with a type variable for the AST type parameters.
|
||||
Extend as needed *)
|
||||
module Map : sig
|
||||
type ('e, 'x) t constraint 'e = 'e expr
|
||||
|
||||
val empty : ('e, 'x) t
|
||||
val singleton : 'e var -> 'x -> ('e, 'x) t
|
||||
val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t
|
||||
val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t
|
||||
val find : 'e var -> ('e, 'x) t -> 'x
|
||||
val find_opt : 'e var -> ('e, 'x) t -> 'x option
|
||||
val bindings : ('e, 'x) t -> ('e var * 'x) list
|
||||
val mem : 'e var -> ('e, 'x) t -> bool
|
||||
|
||||
val union :
|
||||
('e var -> 'x -> 'x -> 'x option) -> ('e, 'x) t -> ('e, 'x) t -> ('e, 'x) t
|
||||
|
||||
val fold : ('e var -> 'x -> 'acc -> 'acc) -> ('e, 'x) t -> 'acc -> 'acc
|
||||
end
|
@ -21,27 +21,28 @@ open Ast
|
||||
|
||||
(** {1 Helpers and type definitions}*)
|
||||
|
||||
type vc_return = typed marked_expr * typ Marked.pos VarMap.t
|
||||
type vc_return = typed marked_expr * (typed expr, typ Marked.pos) Var.Map.t
|
||||
(** The return type of VC generators is the VC expression plus the types of any
|
||||
locally free variable inside that expression. *)
|
||||
|
||||
type ctx = {
|
||||
current_scope_name : ScopeName.t;
|
||||
decl : decl_ctx;
|
||||
input_vars : Var.t list;
|
||||
scope_variables_typs : typ Marked.pos VarMap.t;
|
||||
input_vars : typed var list;
|
||||
scope_variables_typs : (typed expr, typ Marked.pos) Var.Map.t;
|
||||
}
|
||||
|
||||
let conjunction (args : vc_return list) (mark : typed mark) : vc_return =
|
||||
let acc, list =
|
||||
match args with
|
||||
| hd :: tl -> hd, tl
|
||||
| [] -> ((ELit (LBool true), mark), VarMap.empty), []
|
||||
| [] -> ((ELit (LBool true), mark), Var.Map.empty), []
|
||||
in
|
||||
List.fold_left
|
||||
(fun (acc, acc_ty) (arg, arg_ty) ->
|
||||
( (EApp ((EOp (Binop And), mark), [arg; acc]), mark),
|
||||
VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty ))
|
||||
Var.Map.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty
|
||||
))
|
||||
acc list
|
||||
|
||||
let negation ((arg, arg_ty) : vc_return) (mark : typed mark) : vc_return =
|
||||
@ -51,12 +52,13 @@ let disjunction (args : vc_return list) (mark : typed mark) : vc_return =
|
||||
let acc, list =
|
||||
match args with
|
||||
| hd :: tl -> hd, tl
|
||||
| [] -> ((ELit (LBool false), mark), VarMap.empty), []
|
||||
| [] -> ((ELit (LBool false), mark), Var.Map.empty), []
|
||||
in
|
||||
List.fold_left
|
||||
(fun ((acc, acc_ty) : vc_return) (arg, arg_ty) ->
|
||||
( (EApp ((EOp (Binop Or), mark), [arg; acc]), mark),
|
||||
VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty ))
|
||||
Var.Map.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty
|
||||
))
|
||||
acc list
|
||||
|
||||
(** [half_product \[a1,...,an\] \[b1,...,bm\] returns \[(a1,b1),...(a1,bn),...(an,b1),...(an,bm)\]] *)
|
||||
@ -80,7 +82,7 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : typed marked_expr)
|
||||
(ELit (LBool true), _),
|
||||
cons ),
|
||||
_ )
|
||||
when List.exists (fun x' -> Var.eq (Var.t x) x') ctx.input_vars ->
|
||||
when List.exists (fun x' -> Var.eq x x') ctx.input_vars ->
|
||||
(* scope variables*)
|
||||
cons
|
||||
| EAbs (binder, [(TLit TUnit, _)]) ->
|
||||
@ -130,7 +132,7 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : typed marked_expr) :
|
||||
in
|
||||
( vc_body_expr,
|
||||
List.fold_left
|
||||
(fun acc (var, ty) -> VarMap.add (Var.t var) ty acc)
|
||||
(fun acc (var, ty) -> Var.Map.add var ty acc)
|
||||
vc_body_ty
|
||||
(List.map2 (fun x y -> x, y) (Array.to_list vars) typs) )
|
||||
| EApp (f, args) ->
|
||||
@ -147,18 +149,18 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : typed marked_expr) :
|
||||
[
|
||||
e1_vc, vc_typ1;
|
||||
( (EIfThenElse (e1, e2_vc, e3_vc), Marked.get_mark e),
|
||||
VarMap.union
|
||||
Var.Map.union
|
||||
(fun _ _ _ -> failwith "should not happen")
|
||||
vc_typ2 vc_typ3 );
|
||||
]
|
||||
(Marked.get_mark e)
|
||||
| ELit LEmptyError ->
|
||||
Marked.same_mark_as (ELit (LBool false)) e, VarMap.empty
|
||||
Marked.same_mark_as (ELit (LBool false)) e, Var.Map.empty
|
||||
| EVar _
|
||||
(* Per default calculus semantics, you cannot call a function with an argument
|
||||
that evaluates to the empty error. Thus, all variable evaluate to non-empty-error terms. *)
|
||||
| ELit _ | EOp _ ->
|
||||
Marked.same_mark_as (ELit (LBool true)) e, VarMap.empty
|
||||
Marked.same_mark_as (ELit (LBool true)) e, Var.Map.empty
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
(* <e1 ... en | ejust :- econs > never returns empty if and only if:
|
||||
- first we look if e1 .. en ejust can return empty;
|
||||
@ -223,7 +225,7 @@ let rec generate_vs_must_not_return_confict (ctx : ctx) (e : typed marked_expr)
|
||||
in
|
||||
( vc_body_expr,
|
||||
List.fold_left
|
||||
(fun acc (var, ty) -> VarMap.add (Var.t var) ty acc)
|
||||
(fun acc (var, ty) -> Var.Map.add var ty acc)
|
||||
vc_body_ty
|
||||
(List.map2 (fun x y -> x, y) (Array.to_list vars) typs) )
|
||||
| EApp (f, args) ->
|
||||
@ -238,13 +240,13 @@ let rec generate_vs_must_not_return_confict (ctx : ctx) (e : typed marked_expr)
|
||||
[
|
||||
e1_vc, vc_typ1;
|
||||
( (EIfThenElse (e1, e2_vc, e3_vc), Marked.get_mark e),
|
||||
VarMap.union
|
||||
Var.Map.union
|
||||
(fun _ _ _ -> failwith "should not happen")
|
||||
vc_typ2 vc_typ3 );
|
||||
]
|
||||
(Marked.get_mark e)
|
||||
| EVar _ | ELit _ | EOp _ ->
|
||||
Marked.same_mark_as (ELit (LBool true)) e, VarMap.empty
|
||||
Marked.same_mark_as (ELit (LBool true)) e, Var.Map.empty
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
(* <e1 ... en | ejust :- econs > never returns conflict if and only if:
|
||||
- neither e1 nor ... nor en nor ejust nor econs return conflict
|
||||
@ -284,8 +286,8 @@ type verification_condition = {
|
||||
(* should have type bool *)
|
||||
vc_kind : verification_condition_kind;
|
||||
vc_scope : ScopeName.t;
|
||||
vc_variable : Var.t Marked.pos;
|
||||
vc_free_vars_typ : typ Marked.pos VarMap.t;
|
||||
vc_variable : typed var Marked.pos;
|
||||
vc_free_vars_typ : (typed expr, typ Marked.pos) Var.Map.t;
|
||||
}
|
||||
|
||||
let rec generate_verification_conditions_scope_body_expr
|
||||
@ -301,7 +303,7 @@ let rec generate_verification_conditions_scope_body_expr
|
||||
let new_ctx, vc_list =
|
||||
match scope_let.scope_let_kind with
|
||||
| DestructuringInputStruct ->
|
||||
{ ctx with input_vars = Var.t scope_let_var :: ctx.input_vars }, []
|
||||
{ ctx with input_vars = scope_let_var :: ctx.input_vars }, []
|
||||
| ScopeVarDefinition | SubScopeVarDefinition ->
|
||||
(* For scope variables, we should check both that they never evaluate to
|
||||
emptyError nor conflictError. But for subscope variable definitions,
|
||||
@ -324,11 +326,11 @@ let rec generate_verification_conditions_scope_body_expr
|
||||
vc_guard = Marked.same_mark_as (Marked.unmark vc_confl) e;
|
||||
vc_kind = NoOverlappingExceptions;
|
||||
vc_free_vars_typ =
|
||||
VarMap.union
|
||||
Var.Map.union
|
||||
(fun _ _ -> failwith "should not happen")
|
||||
ctx.scope_variables_typs vc_confl_typs;
|
||||
vc_scope = ctx.current_scope_name;
|
||||
vc_variable = Var.t scope_let_var, scope_let.scope_let_pos;
|
||||
vc_variable = scope_let_var, scope_let.scope_let_pos;
|
||||
};
|
||||
]
|
||||
in
|
||||
@ -347,11 +349,11 @@ let rec generate_verification_conditions_scope_body_expr
|
||||
vc_guard = Marked.same_mark_as (Marked.unmark vc_empty) e;
|
||||
vc_kind = NoEmptyError;
|
||||
vc_free_vars_typ =
|
||||
VarMap.union
|
||||
Var.Map.union
|
||||
(fun _ _ -> failwith "should not happen")
|
||||
ctx.scope_variables_typs vc_empty_typs;
|
||||
vc_scope = ctx.current_scope_name;
|
||||
vc_variable = Var.t scope_let_var, scope_let.scope_let_pos;
|
||||
vc_variable = scope_let_var, scope_let.scope_let_pos;
|
||||
}
|
||||
:: vc_list
|
||||
| _ -> vc_list
|
||||
@ -364,7 +366,7 @@ let rec generate_verification_conditions_scope_body_expr
|
||||
{
|
||||
new_ctx with
|
||||
scope_variables_typs =
|
||||
VarMap.add (Var.t scope_let_var) scope_let.scope_let_typ
|
||||
Var.Map.add scope_let_var scope_let.scope_let_typ
|
||||
new_ctx.scope_variables_typs;
|
||||
}
|
||||
scope_let_next
|
||||
@ -396,7 +398,7 @@ let rec generate_verification_conditions_scopes
|
||||
decl = decl_ctx;
|
||||
input_vars = [];
|
||||
scope_variables_typs =
|
||||
VarMap.empty
|
||||
Var.Map.empty
|
||||
(* We don't need to add the typ of the scope input var here
|
||||
because it will never appear in an expression for which we
|
||||
generate a verification conditions (the big struct is
|
||||
@ -423,7 +425,7 @@ let generate_verification_conditions
|
||||
let to_str vc =
|
||||
Format.asprintf "%s.%s"
|
||||
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
|
||||
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable)))
|
||||
(Bindlib.name_of (Marked.unmark vc.vc_variable))
|
||||
in
|
||||
String.compare (to_str vc1) (to_str vc2))
|
||||
vcs
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
(** Generates verification conditions from scope definitions *)
|
||||
|
||||
open Utils
|
||||
|
||||
type verification_condition_kind =
|
||||
| NoEmptyError
|
||||
(** This verification condition checks whether a definition never returns
|
||||
@ -30,8 +32,9 @@ type verification_condition = {
|
||||
(** This expression should have type [bool]*)
|
||||
vc_kind : verification_condition_kind;
|
||||
vc_scope : Dcalc.Ast.ScopeName.t;
|
||||
vc_variable : Dcalc.Ast.Var.t Utils.Marked.pos;
|
||||
vc_free_vars_typ : Dcalc.Ast.typ Utils.Marked.pos Dcalc.Ast.VarMap.t;
|
||||
vc_variable : Astgen.typed Dcalc.Ast.var Marked.pos;
|
||||
vc_free_vars_typ :
|
||||
(Astgen.typed Dcalc.Ast.expr, Dcalc.Ast.typ Marked.pos) Var.Map.t;
|
||||
(** Types of the locally free variables in [vc_guard]. The types of other
|
||||
free variables linked to scope variables can be obtained with
|
||||
[Dcalc.Ast.variable_types]. *)
|
||||
|
@ -23,7 +23,8 @@ module type Backend = sig
|
||||
|
||||
type backend_context
|
||||
|
||||
val make_context : decl_ctx -> typ Marked.pos VarMap.t -> backend_context
|
||||
val make_context :
|
||||
decl_ctx -> (typed expr, typ Marked.pos) Var.Map.t -> backend_context
|
||||
|
||||
type vc_encoding
|
||||
|
||||
@ -37,7 +38,9 @@ module type Backend = sig
|
||||
val is_model_empty : model -> bool
|
||||
|
||||
val translate_expr :
|
||||
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
|
||||
backend_context ->
|
||||
Astgen.typed Dcalc.Ast.marked_expr ->
|
||||
backend_context * vc_encoding
|
||||
end
|
||||
|
||||
module type BackendIO = sig
|
||||
@ -45,12 +48,15 @@ module type BackendIO = sig
|
||||
|
||||
type backend_context
|
||||
|
||||
val make_context : decl_ctx -> typ Marked.pos VarMap.t -> backend_context
|
||||
val make_context :
|
||||
decl_ctx -> (Astgen.typed expr, typ Marked.pos) Var.Map.t -> backend_context
|
||||
|
||||
type vc_encoding
|
||||
|
||||
val translate_expr :
|
||||
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
|
||||
backend_context ->
|
||||
Astgen.typed Dcalc.Ast.marked_expr ->
|
||||
backend_context * vc_encoding
|
||||
|
||||
type model
|
||||
|
||||
@ -95,12 +101,12 @@ module MakeBackendIO (B : Backend) = struct
|
||||
Format.asprintf "%s This variable never returns an empty error"
|
||||
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
|
||||
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
|
||||
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
|
||||
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
|
||||
| Conditions.NoOverlappingExceptions ->
|
||||
Format.asprintf "%s No two exceptions to ever overlap for this variable"
|
||||
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
|
||||
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
|
||||
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
|
||||
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
|
||||
|
||||
let print_negative_result
|
||||
(vc : Conditions.verification_condition)
|
||||
@ -112,14 +118,14 @@ module MakeBackendIO (B : Backend) = struct
|
||||
Format.asprintf "%s This variable might return an empty error:\n%s"
|
||||
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
|
||||
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
|
||||
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
|
||||
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
|
||||
(Pos.retrieve_loc_text (Marked.get_mark vc.vc_variable))
|
||||
| Conditions.NoOverlappingExceptions ->
|
||||
Format.asprintf
|
||||
"%s At least two exceptions overlap for this variable:\n%s"
|
||||
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
|
||||
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
|
||||
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
|
||||
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
|
||||
(Pos.retrieve_loc_text (Marked.get_mark vc.vc_variable))
|
||||
in
|
||||
let counterexample : string option =
|
||||
@ -178,6 +184,6 @@ module MakeBackendIO (B : Backend) = struct
|
||||
Cli.error_print "%s The translation to Z3 failed:\n%s"
|
||||
(Cli.with_style [ANSITerminal.yellow] "[%s.%s]"
|
||||
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
|
||||
(Bindlib.name_of (Var.get (Marked.unmark vc.vc_variable))))
|
||||
(Bindlib.name_of (Marked.unmark vc.vc_variable)))
|
||||
msg
|
||||
end
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
(** Common code for handling the IO of all proof backends supported *)
|
||||
|
||||
open Utils
|
||||
|
||||
module type Backend = sig
|
||||
val init_backend : unit -> unit
|
||||
|
||||
@ -24,7 +26,7 @@ module type Backend = sig
|
||||
|
||||
val make_context :
|
||||
Dcalc.Ast.decl_ctx ->
|
||||
Dcalc.Ast.typ Utils.Marked.pos Dcalc.Ast.VarMap.t ->
|
||||
(Astgen.typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t ->
|
||||
backend_context
|
||||
|
||||
type vc_encoding
|
||||
@ -39,7 +41,9 @@ module type Backend = sig
|
||||
val is_model_empty : model -> bool
|
||||
|
||||
val translate_expr :
|
||||
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
|
||||
backend_context ->
|
||||
Astgen.typed Dcalc.Ast.marked_expr ->
|
||||
backend_context * vc_encoding
|
||||
end
|
||||
|
||||
module type BackendIO = sig
|
||||
@ -49,13 +53,15 @@ module type BackendIO = sig
|
||||
|
||||
val make_context :
|
||||
Dcalc.Ast.decl_ctx ->
|
||||
Dcalc.Ast.typ Utils.Marked.pos Dcalc.Ast.VarMap.t ->
|
||||
(Astgen.typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t ->
|
||||
backend_context
|
||||
|
||||
type vc_encoding
|
||||
|
||||
val translate_expr :
|
||||
backend_context -> 'm Dcalc.Ast.marked_expr -> backend_context * vc_encoding
|
||||
backend_context ->
|
||||
Astgen.typed Dcalc.Ast.marked_expr ->
|
||||
backend_context * vc_encoding
|
||||
|
||||
type model
|
||||
|
||||
|
@ -26,20 +26,20 @@ type context = {
|
||||
ctx_decl : decl_ctx;
|
||||
(* The declaration context from the Catala program, containing information to
|
||||
precisely pretty print Catala expressions *)
|
||||
ctx_var : typ Marked.pos VarMap.t;
|
||||
ctx_var : (typed expr, typ Marked.pos) Var.Map.t;
|
||||
(* A map from Catala variables to their types, needed to create Z3 expressions
|
||||
of the right sort *)
|
||||
ctx_funcdecl : FuncDecl.func_decl VarMap.t;
|
||||
ctx_funcdecl : (typed expr, FuncDecl.func_decl) Var.Map.t;
|
||||
(* A map from Catala function names (represented as variables) to Z3 function
|
||||
declarations, used to only define once functions in Z3 queries *)
|
||||
ctx_z3vars : Var.t StringMap.t;
|
||||
ctx_z3vars : typed var StringMap.t;
|
||||
(* A map from strings, corresponding to Z3 symbol names, to the Catala
|
||||
variable they represent. Used when to pretty-print Z3 models when a
|
||||
counterexample is generated *)
|
||||
ctx_z3datatypes : Sort.sort EnumMap.t;
|
||||
(* A map from Catala enumeration names to the corresponding Z3 sort, from
|
||||
which we can retrieve constructors and accessors *)
|
||||
ctx_z3matchsubsts : Expr.expr VarMap.t;
|
||||
ctx_z3matchsubsts : (typed expr, Expr.expr) Var.Map.t;
|
||||
(* A map from Catala temporary variables, generated when translating a match,
|
||||
to the corresponding enum accessor call as a Z3 expression *)
|
||||
ctx_z3structs : Sort.sort StructMap.t;
|
||||
@ -64,13 +64,13 @@ type context = {
|
||||
|
||||
(** [add_funcdecl] adds the mapping between the Catala variable [v] and the Z3
|
||||
function declaration [fd] to the context **)
|
||||
let add_funcdecl (v : Var.t) (fd : FuncDecl.func_decl) (ctx : context) : context
|
||||
=
|
||||
{ ctx with ctx_funcdecl = VarMap.add v fd ctx.ctx_funcdecl }
|
||||
let add_funcdecl (v : typed var) (fd : FuncDecl.func_decl) (ctx : context) :
|
||||
context =
|
||||
{ ctx with ctx_funcdecl = Var.Map.add v fd ctx.ctx_funcdecl }
|
||||
|
||||
(** [add_z3var] adds the mapping between [name] and the Catala variable [v] to
|
||||
the context **)
|
||||
let add_z3var (name : string) (v : Var.t) (ctx : context) : context =
|
||||
let add_z3var (name : string) (v : typed var) (ctx : context) : context =
|
||||
{ ctx with ctx_z3vars = StringMap.add name v ctx.ctx_z3vars }
|
||||
|
||||
(** [add_z3enum] adds the mapping between the Catala enumeration [enum] and the
|
||||
@ -81,8 +81,8 @@ let add_z3enum (enum : EnumName.t) (sort : Sort.sort) (ctx : context) : context
|
||||
|
||||
(** [add_z3var] adds the mapping between temporary variable [v] and the Z3
|
||||
expression [e] representing an accessor application to the context **)
|
||||
let add_z3matchsubst (v : Var.t) (e : Expr.expr) (ctx : context) : context =
|
||||
{ ctx with ctx_z3matchsubsts = VarMap.add v e ctx.ctx_z3matchsubsts }
|
||||
let add_z3matchsubst (v : typed var) (e : Expr.expr) (ctx : context) : context =
|
||||
{ ctx with ctx_z3matchsubsts = Var.Map.add v e ctx.ctx_z3matchsubsts }
|
||||
|
||||
(** [add_z3struct] adds the mapping between the Catala struct [s] and the
|
||||
corresponding Z3 datatype [sort] to the context **)
|
||||
@ -223,9 +223,8 @@ let print_model (ctx : context) (model : Model.model) : string =
|
||||
let v = StringMap.find symbol_name ctx.ctx_z3vars in
|
||||
Format.fprintf fmt "%s %s : %s"
|
||||
(Cli.with_style [ANSITerminal.blue] "%s" "-->")
|
||||
(Cli.with_style [ANSITerminal.yellow] "%s"
|
||||
(Bindlib.name_of (Var.get v)))
|
||||
(print_z3model_expr ctx (VarMap.find v ctx.ctx_var) e)
|
||||
(Cli.with_style [ANSITerminal.yellow] "%s" (Bindlib.name_of v))
|
||||
(print_z3model_expr ctx (Var.Map.find v ctx.ctx_var) e)
|
||||
else
|
||||
(* Declaration d is a function *)
|
||||
match Model.get_func_interp model d with
|
||||
@ -239,8 +238,7 @@ let print_model (ctx : context) (model : Model.model) : string =
|
||||
let v = StringMap.find symbol_name ctx.ctx_z3vars in
|
||||
Format.fprintf fmt "%s %s : %s"
|
||||
(Cli.with_style [ANSITerminal.blue] "%s" "-->")
|
||||
(Cli.with_style [ANSITerminal.yellow] "%s"
|
||||
(Bindlib.name_of (Var.get v)))
|
||||
(Cli.with_style [ANSITerminal.yellow] "%s" (Bindlib.name_of v))
|
||||
(* TODO: Model of a Z3 function should be pretty-printed *)
|
||||
(Model.FuncInterp.to_string f)))
|
||||
decls
|
||||
@ -387,18 +385,18 @@ let translate_lit (ctx : context) (l : lit) : Expr.expr =
|
||||
corresponding to the variable [v]. If no such function declaration exists
|
||||
yet, we construct it and add it to the context, thus requiring to return a
|
||||
new context *)
|
||||
let find_or_create_funcdecl (ctx : context) (v : Var.t) :
|
||||
let find_or_create_funcdecl (ctx : context) (v : typed var) :
|
||||
context * FuncDecl.func_decl =
|
||||
match VarMap.find_opt v ctx.ctx_funcdecl with
|
||||
match Var.Map.find_opt v ctx.ctx_funcdecl with
|
||||
| Some fd -> ctx, fd
|
||||
| None -> (
|
||||
(* Retrieves the Catala type of the function [v] *)
|
||||
let f_ty = VarMap.find v ctx.ctx_var in
|
||||
let f_ty = Var.Map.find v ctx.ctx_var in
|
||||
match Marked.unmark f_ty with
|
||||
| TArrow (t1, t2) ->
|
||||
let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in
|
||||
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
|
||||
let name = unique_name (Var.get v) in
|
||||
let name = unique_name v in
|
||||
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in
|
||||
let ctx = add_funcdecl v fd ctx in
|
||||
let ctx = add_z3var name v ctx in
|
||||
@ -631,7 +629,7 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
|
||||
match Marked.unmark e with
|
||||
| EAbs (e, _) ->
|
||||
(* Create a fresh Catala variable to substitue and obtain the body *)
|
||||
let fresh_v = new_var "arm!tmp" in
|
||||
let fresh_v = Var.make "arm!tmp" in
|
||||
let fresh_e = EVar fresh_v in
|
||||
|
||||
(* Invariant: Catala enums always have exactly one argument *)
|
||||
@ -639,7 +637,7 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
|
||||
let proj = Expr.mk_app ctx.ctx_z3 accessor [head] in
|
||||
(* The fresh variable should be substituted by a projection into the enum
|
||||
in the body, we add this to the context *)
|
||||
let ctx = add_z3matchsubst (Var.t fresh_v) proj ctx in
|
||||
let ctx = add_z3matchsubst fresh_v proj ctx in
|
||||
|
||||
let body = Bindlib.msubst e [| fresh_e |] in
|
||||
translate_expr ctx body
|
||||
@ -649,12 +647,12 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
|
||||
|
||||
match Marked.unmark vc with
|
||||
| EVar v -> (
|
||||
match VarMap.find_opt (Var.t v) ctx.ctx_z3matchsubsts with
|
||||
match Var.Map.find_opt v ctx.ctx_z3matchsubsts with
|
||||
| None ->
|
||||
(* We are in the standard case, where this is a true Catala variable *)
|
||||
let t = VarMap.find (Var.t v) ctx.ctx_var in
|
||||
let t = Var.Map.find v ctx.ctx_var in
|
||||
let name = unique_name v in
|
||||
let ctx = add_z3var name (Var.t v) ctx in
|
||||
let ctx = add_z3var name v ctx in
|
||||
let ctx, ty = translate_typ ctx (Marked.unmark t) in
|
||||
let z3_var = Expr.mk_const_s ctx.ctx_z3 name ty in
|
||||
let ctx =
|
||||
@ -726,7 +724,7 @@ and translate_expr (ctx : context) (vc : 'm marked_expr) : context * Expr.expr =
|
||||
match Marked.unmark head with
|
||||
| EOp op -> translate_op ctx op args
|
||||
| EVar v ->
|
||||
let ctx, fd = find_or_create_funcdecl ctx (Var.t v) in
|
||||
let ctx, fd = find_or_create_funcdecl ctx v in
|
||||
(* Fold_right to preserve the order of the arguments: The head argument is
|
||||
appended at the head *)
|
||||
let ctx, z3_args =
|
||||
@ -804,7 +802,8 @@ module Backend = struct
|
||||
|
||||
let make_context
|
||||
(decl_ctx : decl_ctx)
|
||||
(free_vars_typ : typ Marked.pos VarMap.t) : backend_context =
|
||||
(free_vars_typ : (typed expr, typ Marked.pos) Var.Map.t) : backend_context
|
||||
=
|
||||
let cfg =
|
||||
(if !Cli.disable_counterexamples then [] else ["model", "true"])
|
||||
@ ["proof", "false"]
|
||||
@ -815,10 +814,10 @@ module Backend = struct
|
||||
ctx_z3 = z3_ctx;
|
||||
ctx_decl = decl_ctx;
|
||||
ctx_var = free_vars_typ;
|
||||
ctx_funcdecl = VarMap.empty;
|
||||
ctx_funcdecl = Var.Map.empty;
|
||||
ctx_z3vars = StringMap.empty;
|
||||
ctx_z3datatypes = EnumMap.empty;
|
||||
ctx_z3matchsubsts = VarMap.empty;
|
||||
ctx_z3matchsubsts = Var.Map.empty;
|
||||
ctx_z3structs = StructMap.empty;
|
||||
ctx_z3unit = z3unit;
|
||||
ctx_z3constraints = [];
|
||||
|
@ -1,7 +1,7 @@
|
||||
let Foo =
|
||||
λ (Foo_in_28: Foo_in{}) →
|
||||
let bar_29 : integer =
|
||||
λ (Foo_in_27: Foo_in{}) →
|
||||
let bar_28 : integer =
|
||||
try
|
||||
handle_default_0 [] (λ (__30: any) → true) (λ (__31: any) → 0)
|
||||
handle_default_0 [] (λ (__29: any) → true) (λ (__30: any) → 0)
|
||||
with EmptyError -> raise NoValueProvided in
|
||||
Foo_out {"bar_out": bar_29}
|
||||
Foo_out {"bar_out": bar_28}
|
||||
|
Loading…
Reference in New Issue
Block a user