Leverage the shared AST: big cleanup (part I)

This commit is contained in:
Louis Gesbert 2022-08-12 22:42:39 +02:00
parent 988e5eff1c
commit 2b6ee8dd4b
58 changed files with 1420 additions and 1655 deletions

View File

@ -16,8 +16,7 @@
the License. *) the License. *)
open Utils open Utils
include Shared_ast open Shared_ast
include Shared_ast.Expr
type lit = dcalc glit type lit = dcalc glit
@ -26,53 +25,9 @@ and 'm marked_expr = (dcalc, 'm mark) marked_gexpr
type 'm program = ('m expr, 'm) program_generic type 'm program = ('m expr, 'm) program_generic
let no_mark (type m) : m mark -> m mark = function
| Untyped _ -> Untyped { pos = Pos.no_pos }
| Typed _ -> Typed { pos = Pos.no_pos; ty = Marked.mark Pos.no_pos TAny }
let mark_pos (type m) (m : m mark) : Pos.t =
match m with Untyped { pos } | Typed { pos; _ } -> pos
let pos (type m) (x : ('a, m) marked) : Pos.t = mark_pos (Marked.get_mark x)
let ty (_, m) : marked_typ = match m with Typed { ty; _ } -> ty
let with_ty (type m) (ty : marked_typ) (x : ('a, m) marked) : ('a, typed) marked
=
Marked.mark
(match Marked.get_mark x with
| Untyped { pos } -> Typed { pos; ty }
| Typed m -> Typed { m with ty })
(Marked.unmark x)
let map_expr ctx ~f e = Expr.map ctx ~f e
let rec map_expr_top_down ~f e =
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
let map_expr_marks ~f e =
map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
let untype_expr e = map_expr_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e
type ('expr, 'm) box_expr_sig = type ('expr, 'm) box_expr_sig =
('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box ('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box
(** See [Bindlib.box_term] documentation for why we are doing that. *)
let box_expr : ('m expr, 'm) box_expr_sig =
fun e ->
let rec id_t () e = map_expr () ~f:id_t e in
id_t () e
let untype_program prg =
{
prg with
scopes =
Bindlib.unbox
(map_exprs_in_scopes
~f:(fun e -> untype_expr e)
~varf:Var.translate prg.scopes);
}
type 'm var = 'm expr Var.t type 'm var = 'm expr Var.t
type 'm vars = 'm expr Var.vars type 'm vars = 'm expr Var.vars
@ -158,49 +113,14 @@ type ('expr, 'm) make_let_in_sig =
Pos.t -> Pos.t ->
('expr, 'm) marked Bindlib.box ('expr, 'm) marked Bindlib.box
let map_mark
(type m)
(pos_f : Pos.t -> Pos.t)
(ty_f : marked_typ -> marked_typ)
(m : m mark) : m mark =
match m with
| Untyped { pos } -> Untyped { pos = pos_f pos }
| Typed { pos; ty } -> Typed { pos = pos_f pos; ty = ty_f ty }
let map_mark2
(type m)
(pos_f : Pos.t -> Pos.t -> Pos.t)
(ty_f : typed -> typed -> marked_typ)
(m1 : m mark)
(m2 : m mark) : m mark =
match m1, m2 with
| Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos }
| Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 }
let fold_marks
(type m)
(pos_f : Pos.t list -> Pos.t)
(ty_f : typed list -> marked_typ)
(ms : m mark list) : m mark =
match ms with
| [] -> invalid_arg "Dcalc.Ast.fold_mark"
| Untyped _ :: _ as ms ->
Untyped { pos = pos_f (List.map (function Untyped { pos } -> pos) ms) }
| Typed _ :: _ ->
Typed
{
pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms);
ty = ty_f (List.map (function Typed m -> m) ms);
}
let empty_thunked_term mark : 'm marked_expr = let empty_thunked_term mark : 'm marked_expr =
let silent = Var.make "_" in let silent = Var.make "_" in
let pos = mark_pos mark in let pos = Expr.mark_pos mark in
Bindlib.unbox Bindlib.unbox
(make_abs [| silent |] (make_abs [| silent |]
(Bindlib.box (ELit LEmptyError, mark)) (Bindlib.box (ELit LEmptyError, mark))
[TLit TUnit, pos] [TLit TUnit, pos]
(map_mark (Expr.map_mark
(fun pos -> pos) (fun pos -> pos)
(fun ty -> (fun ty ->
Marked.mark pos (TArrow (Marked.mark pos (TLit TUnit), ty))) Marked.mark pos (TArrow (Marked.mark pos (TLit TUnit), ty)))
@ -211,7 +131,7 @@ let (make_let_in : ('m expr, 'm) make_let_in_sig) =
let m_e1 = Marked.get_mark (Bindlib.unbox e1) in let m_e1 = Marked.get_mark (Bindlib.unbox e1) in
let m_e2 = Marked.get_mark (Bindlib.unbox e2) in let m_e2 = Marked.get_mark (Bindlib.unbox e2) in
let m_abs = let m_abs =
map_mark2 Expr.map_mark2
(fun _ _ -> pos) (fun _ _ -> pos)
(fun m1 m2 -> Marked.mark pos (TArrow (m1.ty, m2.ty))) (fun m1 m2 -> Marked.mark pos (TArrow (m1.ty, m2.ty)))
m_e1 m_e2 m_e1 m_e2
@ -329,7 +249,7 @@ let build_whole_scope_expr
( List.map snd ( List.map snd
(StructMap.find body.scope_body_input_struct ctx.ctx_structs), (StructMap.find body.scope_body_input_struct ctx.ctx_structs),
Some body.scope_body_input_struct ), Some body.scope_body_input_struct ),
mark_pos mark_scope ); Expr.mark_pos mark_scope );
] ]
mark_scope mark_scope
@ -354,10 +274,6 @@ type 'expr scope_name_or_var =
| ScopeName of ScopeName.t | ScopeName of ScopeName.t
| ScopeVar of 'expr Bindlib.var | ScopeVar of 'expr Bindlib.var
let get_scope_body_mark scope_body =
match snd (Bindlib.unbind scope_body.scope_body_expr) with
| Result e | ScopeLet { scope_let_expr = e; _ } -> Marked.get_mark e
let rec unfold_scopes let rec unfold_scopes
~(box_expr : ('expr, 'm) box_expr_sig) ~(box_expr : ('expr, 'm) box_expr_sig)
~(make_abs : ('expr, 'm) make_abs_sig) ~(make_abs : ('expr, 'm) make_abs_sig)
@ -374,7 +290,7 @@ let rec unfold_scopes
| ScopeDef { scope_name; scope_body; scope_next } -> | ScopeDef { scope_name; scope_body; scope_next } ->
let scope_var, scope_next = Bindlib.unbind scope_next in let scope_var, scope_next = Bindlib.unbind scope_next in
let scope_pos = Marked.get_mark (ScopeName.get_info scope_name) in let scope_pos = Marked.get_mark (ScopeName.get_info scope_name) in
let scope_body_mark = get_scope_body_mark scope_body in let scope_body_mark = Expr.get_scope_body_mark scope_body in
let main_scope = let main_scope =
match main_scope with match main_scope with
| ScopeVar v -> ScopeVar v | ScopeVar v -> ScopeVar v
@ -407,7 +323,7 @@ let build_whole_program_expr
(main_scope : ScopeName.t) : ('expr, 'm) marked Bindlib.box = (main_scope : ScopeName.t) : ('expr, 'm) marked Bindlib.box =
let _, main_scope_body = find_scope main_scope [] p.scopes in let _, main_scope_body = find_scope main_scope [] p.scopes in
unfold_scopes ~box_expr ~make_abs ~make_let_in p.decl_ctx p.scopes unfold_scopes ~box_expr ~make_abs ~make_let_in p.decl_ctx p.scopes
(get_scope_body_mark main_scope_body) (Expr.get_scope_body_mark main_scope_body)
(ScopeName main_scope) (ScopeName main_scope)
let rec expr_size (e : 'm marked_expr) : int = let rec expr_size (e : 'm marked_expr) : int =
@ -435,7 +351,7 @@ let rec expr_size (e : 'm marked_expr) : int =
let remove_logging_calls (e : 'm marked_expr) : 'm marked_expr Bindlib.box = let remove_logging_calls (e : 'm marked_expr) : 'm marked_expr Bindlib.box =
let rec f () e = let rec f () e =
match Marked.unmark e with match Marked.unmark e with
| EApp ((EOp (Unop (Log _)), _), [arg]) -> map_expr () ~f arg | EApp ((EOp (Unop (Log _)), _), [arg]) -> Expr.map () ~f arg
| _ -> map_expr () ~f e | _ -> Expr.map () ~f e
in in
f () e f () e

View File

@ -18,8 +18,7 @@
(** Abstract syntax tree of the default calculus intermediate representation *) (** Abstract syntax tree of the default calculus intermediate representation *)
open Utils open Utils
include module type of Shared_ast open Shared_ast
include module type of Shared_ast.Expr
type lit = dcalc glit type lit = dcalc glit
@ -44,149 +43,9 @@ val free_vars_scope_body : ('m expr, 'm) scope_body -> 'm expr Var.Set.t
val free_vars_scopes : ('m expr, 'm) scopes -> 'm expr Var.Set.t val free_vars_scopes : ('m expr, 'm) scopes -> 'm expr Var.Set.t
val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box
(** {2 Manipulation of marks} *)
val no_mark : 'm mark -> 'm mark
val mark_pos : 'm mark -> Pos.t
val pos : ('a, 'm) marked -> Pos.t
val ty : ('a, typed) marked -> marked_typ
val with_ty : marked_typ -> ('a, 'm) marked -> ('a, typed) marked
(** All the following functions will resolve the types if called on an
[Inferring] type *)
val map_mark :
(Pos.t -> Pos.t) -> (marked_typ -> marked_typ) -> 'm mark -> 'm mark
val map_mark2 :
(Pos.t -> Pos.t -> Pos.t) ->
(typed -> typed -> marked_typ) ->
'm mark ->
'm mark ->
'm mark
val fold_marks :
(Pos.t list -> Pos.t) -> (typed list -> marked_typ) -> 'm mark list -> 'm mark
val get_scope_body_mark : ('expr, 'm) scope_body -> 'm mark
val untype_expr : 'm marked_expr -> untyped marked_expr Bindlib.box
val untype_program : 'm program -> untyped program
(** {2 Boxed constructors} *)
val evar : 'm expr Bindlib.var -> 'm mark -> 'm marked_expr Bindlib.box
val etuple :
'm marked_expr Bindlib.box list ->
StructName.t option ->
'm mark ->
'm marked_expr Bindlib.box
val etupleaccess :
'm marked_expr Bindlib.box ->
int ->
StructName.t option ->
marked_typ list ->
'm mark ->
'm marked_expr Bindlib.box
val einj :
'm marked_expr Bindlib.box ->
int ->
EnumName.t ->
marked_typ list ->
'm mark ->
'm marked_expr Bindlib.box
val ematch :
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box list ->
EnumName.t ->
'm mark ->
'm marked_expr Bindlib.box
val earray :
'm marked_expr Bindlib.box list -> 'm mark -> 'm marked_expr Bindlib.box
val elit : lit -> 'm mark -> 'm marked_expr Bindlib.box
val eabs :
('m expr, 'm marked_expr) Bindlib.mbinder Bindlib.box ->
marked_typ list ->
'm mark ->
'm marked_expr Bindlib.box
val eapp :
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box list ->
'm mark ->
'm marked_expr Bindlib.box
val eassert :
'm marked_expr Bindlib.box -> 'm mark -> 'm marked_expr Bindlib.box
val eop : operator -> 'm mark -> 'm marked_expr Bindlib.box
val edefault :
'm marked_expr Bindlib.box list ->
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box ->
'm mark ->
'm marked_expr Bindlib.box
val eifthenelse :
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box ->
'm mark ->
'm marked_expr Bindlib.box
val eerroronempty :
'm marked_expr Bindlib.box -> 'm mark -> 'm marked_expr Bindlib.box
type ('expr, 'm) box_expr_sig = type ('expr, 'm) box_expr_sig =
('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box ('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box
val box_expr : ('m expr, 'm) box_expr_sig
(**{2 Program traversal}*)
(** Be careful when using these traversal functions, as the bound variables they
open will be different at each traversal. *)
val map_expr :
'a ->
f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) ->
('m1 expr, 'm2 mark) Marked.t ->
'm2 marked_expr Bindlib.box
(** If you want to apply a map transform to an expression, you can save up
writing a painful match over all the cases of the AST. For instance, if you
want to remove all errors on empty, you can write
{[
let remove_error_empty =
let rec f () e =
match Marked.unmark e with
| ErrorOnEmpty e1 -> map_expr () f e1
| _ -> map_expr () f e
in
f () e
]}
The first argument of map_expr is an optional context that you can carry
around during your map traversal. *)
val map_expr_top_down :
f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) ->
'm1 marked_expr ->
'm2 marked_expr Bindlib.box
(** Recursively applies [f] to the nodes of the expression tree. The type
returned by [f] is hybrid since the mark at top-level has been rewritten,
but not yet the marks in the subtrees. *)
val map_expr_marks :
f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box
(** {2 Boxed term constructors} *) (** {2 Boxed term constructors} *)
type ('e, 'm) make_abs_sig = type ('e, 'm) make_abs_sig =

View File

@ -17,12 +17,13 @@
(** Reference interpreter for the default calculus *) (** Reference interpreter for the default calculus *)
open Utils open Utils
open Shared_ast
module A = Ast module A = Ast
module Runtime = Runtime_ocaml.Runtime module Runtime = Runtime_ocaml.Runtime
(** {1 Helpers} *) (** {1 Helpers} *)
let is_empty_error (e : 'm A.marked_expr) : bool = let is_empty_error (e : 'm Ast.marked_expr) : bool =
match Marked.unmark e with ELit LEmptyError -> true | _ -> false match Marked.unmark e with ELit LEmptyError -> true | _ -> false
let log_indent = ref 0 let log_indent = ref 0
@ -30,25 +31,25 @@ let log_indent = ref 0
(** {1 Evaluation} *) (** {1 Evaluation} *)
let rec evaluate_operator let rec evaluate_operator
(ctx : Ast.decl_ctx) (ctx : decl_ctx)
(op : A.operator) (op : operator)
(pos : Pos.t) (pos : Pos.t)
(args : 'm A.marked_expr list) : 'm A.expr = (args : 'm Ast.marked_expr list) : 'm Ast.expr =
(* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use (* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use
[op] to raise multispanned errors. *) [op] to raise multispanned errors. *)
let apply_div_or_raise_err (div : unit -> 'm A.expr) : 'm A.expr = let apply_div_or_raise_err (div : unit -> 'm Ast.expr) : 'm Ast.expr =
try div () try div ()
with Division_by_zero -> with Division_by_zero ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
Some "The division operator:", pos; Some "The division operator:", pos;
Some "The null denominator:", Ast.pos (List.nth args 1); Some "The null denominator:", Expr.pos (List.nth args 1);
] ]
"division by zero at runtime" "division by zero at runtime"
in in
let get_binop_args_pos = function let get_binop_args_pos = function
| (arg0 :: arg1 :: _ : 'm A.marked_expr list) -> | (arg0 :: arg1 :: _ : 'm A.marked_expr list) ->
[None, Ast.pos arg0; None, Ast.pos arg1] [None, Expr.pos arg0; None, Expr.pos arg1]
| _ -> assert false | _ -> assert false
in in
(* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched, (* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched,
@ -63,211 +64,211 @@ let rec evaluate_operator
precise number of days" precise number of days"
in in
match op, List.map Marked.unmark args with match op, List.map Marked.unmark args with
| A.Ternop A.Fold, [_f; _init; EArray es] -> | Ternop Fold, [_f; _init; EArray es] ->
Marked.unmark Marked.unmark
(List.fold_left (List.fold_left
(fun acc e' -> (fun acc e' ->
evaluate_expr ctx evaluate_expr ctx
(Marked.same_mark_as (A.EApp (List.nth args 0, [acc; e'])) e')) (Marked.same_mark_as (EApp (List.nth args 0, [acc; e'])) e'))
(List.nth args 1) es) (List.nth args 1) es)
| A.Binop A.And, [ELit (LBool b1); ELit (LBool b2)] -> | Binop And, [ELit (LBool b1); ELit (LBool b2)] ->
A.ELit (LBool (b1 && b2)) ELit (LBool (b1 && b2))
| A.Binop A.Or, [ELit (LBool b1); ELit (LBool b2)] -> | Binop Or, [ELit (LBool b1); ELit (LBool b2)] ->
A.ELit (LBool (b1 || b2)) ELit (LBool (b1 || b2))
| A.Binop A.Xor, [ELit (LBool b1); ELit (LBool b2)] -> | Binop Xor, [ELit (LBool b1); ELit (LBool b2)] ->
A.ELit (LBool (b1 <> b2)) ELit (LBool (b1 <> b2))
| A.Binop (A.Add KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Add KInt), [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LInt Runtime.(i1 +! i2)) ELit (LInt Runtime.(i1 +! i2))
| A.Binop (A.Sub KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Sub KInt), [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LInt Runtime.(i1 -! i2)) ELit (LInt Runtime.(i1 -! i2))
| A.Binop (A.Mult KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Mult KInt), [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LInt Runtime.(i1 *! i2)) ELit (LInt Runtime.(i1 *! i2))
| A.Binop (A.Div KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Div KInt), [ELit (LInt i1); ELit (LInt i2)] ->
apply_div_or_raise_err (fun _ -> A.ELit (LInt Runtime.(i1 /! i2))) apply_div_or_raise_err (fun _ -> ELit (LInt Runtime.(i1 /! i2)))
| A.Binop (A.Add KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Add KRat), [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LRat Runtime.(i1 +& i2)) ELit (LRat Runtime.(i1 +& i2))
| A.Binop (A.Sub KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Sub KRat), [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LRat Runtime.(i1 -& i2)) ELit (LRat Runtime.(i1 -& i2))
| A.Binop (A.Mult KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Mult KRat), [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LRat Runtime.(i1 *& i2)) ELit (LRat Runtime.(i1 *& i2))
| A.Binop (A.Div KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Div KRat), [ELit (LRat i1); ELit (LRat i2)] ->
apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(i1 /& i2))) apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(i1 /& i2)))
| A.Binop (A.Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop (Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
A.ELit (LMoney Runtime.(m1 +$ m2)) ELit (LMoney Runtime.(m1 +$ m2))
| A.Binop (A.Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop (Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
A.ELit (LMoney Runtime.(m1 -$ m2)) ELit (LMoney Runtime.(m1 -$ m2))
| A.Binop (A.Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] -> | Binop (Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] ->
A.ELit (LMoney Runtime.(m1 *$ m2)) ELit (LMoney Runtime.(m1 *$ m2))
| A.Binop (A.Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop (Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(m1 /$ m2))) apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(m1 /$ m2)))
| A.Binop (A.Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop (Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
A.ELit (LDuration Runtime.(d1 +^ d2)) ELit (LDuration Runtime.(d1 +^ d2))
| A.Binop (A.Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop (Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
A.ELit (LDuration Runtime.(d1 -^ d2)) ELit (LDuration Runtime.(d1 -^ d2))
| A.Binop (A.Sub KDate), [ELit (LDate d1); ELit (LDate d2)] -> | Binop (Sub KDate), [ELit (LDate d1); ELit (LDate d2)] ->
A.ELit (LDuration Runtime.(d1 -@ d2)) ELit (LDuration Runtime.(d1 -@ d2))
| A.Binop (A.Add KDate), [ELit (LDate d1); ELit (LDuration d2)] -> | Binop (Add KDate), [ELit (LDate d1); ELit (LDuration d2)] ->
A.ELit (LDate Runtime.(d1 +@ d2)) ELit (LDate Runtime.(d1 +@ d2))
| A.Binop (A.Div KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop (Div KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_div_or_raise_err (fun _ -> apply_div_or_raise_err (fun _ ->
try A.ELit (LRat Runtime.(d1 /^ d2)) try ELit (LRat Runtime.(d1 /^ d2))
with Runtime.IndivisableDurations -> with Runtime.IndivisableDurations ->
Errors.raise_multispanned_error (get_binop_args_pos args) Errors.raise_multispanned_error (get_binop_args_pos args)
"Cannot divide durations that cannot be converted to a precise \ "Cannot divide durations that cannot be converted to a precise \
number of days") number of days")
| A.Binop (A.Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] -> | Binop (Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] ->
A.ELit (LDuration Runtime.(d1 *^ i1)) ELit (LDuration Runtime.(d1 *^ i1))
| A.Binop (A.Lt KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Lt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LBool Runtime.(i1 <! i2)) ELit (LBool Runtime.(i1 <! i2))
| A.Binop (A.Lte KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Lte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LBool Runtime.(i1 <=! i2)) ELit (LBool Runtime.(i1 <=! i2))
| A.Binop (A.Gt KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Gt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LBool Runtime.(i1 >! i2)) ELit (LBool Runtime.(i1 >! i2))
| A.Binop (A.Gte KInt), [ELit (LInt i1); ELit (LInt i2)] -> | Binop (Gte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LBool Runtime.(i1 >=! i2)) ELit (LBool Runtime.(i1 >=! i2))
| A.Binop (A.Lt KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Lt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LBool Runtime.(i1 <& i2)) ELit (LBool Runtime.(i1 <& i2))
| A.Binop (A.Lte KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Lte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LBool Runtime.(i1 <=& i2)) ELit (LBool Runtime.(i1 <=& i2))
| A.Binop (A.Gt KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Gt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LBool Runtime.(i1 >& i2)) ELit (LBool Runtime.(i1 >& i2))
| A.Binop (A.Gte KRat), [ELit (LRat i1); ELit (LRat i2)] -> | Binop (Gte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LBool Runtime.(i1 >=& i2)) ELit (LBool Runtime.(i1 >=& i2))
| A.Binop (A.Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop (Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
A.ELit (LBool Runtime.(m1 <$ m2)) ELit (LBool Runtime.(m1 <$ m2))
| A.Binop (A.Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop (Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
A.ELit (LBool Runtime.(m1 <=$ m2)) ELit (LBool Runtime.(m1 <=$ m2))
| A.Binop (A.Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop (Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
A.ELit (LBool Runtime.(m1 >$ m2)) ELit (LBool Runtime.(m1 >$ m2))
| A.Binop (A.Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop (Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
A.ELit (LBool Runtime.(m1 >=$ m2)) ELit (LBool Runtime.(m1 >=$ m2))
| A.Binop (A.Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop (Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <^ d2))) args apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <^ d2))) args
| A.Binop (A.Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop (Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <=^ d2))) args apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <=^ d2))) args
| A.Binop (A.Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop (Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >^ d2))) args apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >^ d2))) args
| A.Binop (A.Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop (Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >=^ d2))) args apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >=^ d2))) args
| A.Binop (A.Lt KDate), [ELit (LDate d1); ELit (LDate d2)] -> | Binop (Lt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
A.ELit (LBool Runtime.(d1 <@ d2)) ELit (LBool Runtime.(d1 <@ d2))
| A.Binop (A.Lte KDate), [ELit (LDate d1); ELit (LDate d2)] -> | Binop (Lte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
A.ELit (LBool Runtime.(d1 <=@ d2)) ELit (LBool Runtime.(d1 <=@ d2))
| A.Binop (A.Gt KDate), [ELit (LDate d1); ELit (LDate d2)] -> | Binop (Gt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
A.ELit (LBool Runtime.(d1 >@ d2)) ELit (LBool Runtime.(d1 >@ d2))
| A.Binop (A.Gte KDate), [ELit (LDate d1); ELit (LDate d2)] -> | Binop (Gte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
A.ELit (LBool Runtime.(d1 >=@ d2)) ELit (LBool Runtime.(d1 >=@ d2))
| A.Binop A.Eq, [ELit LUnit; ELit LUnit] -> A.ELit (LBool true) | Binop Eq, [ELit LUnit; ELit LUnit] -> ELit (LBool true)
| A.Binop A.Eq, [ELit (LDuration d1); ELit (LDuration d2)] -> | Binop Eq, [ELit (LDuration d1); ELit (LDuration d2)] ->
A.ELit (LBool Runtime.(d1 =^ d2)) ELit (LBool Runtime.(d1 =^ d2))
| A.Binop A.Eq, [ELit (LDate d1); ELit (LDate d2)] -> | Binop Eq, [ELit (LDate d1); ELit (LDate d2)] ->
A.ELit (LBool Runtime.(d1 =@ d2)) ELit (LBool Runtime.(d1 =@ d2))
| A.Binop A.Eq, [ELit (LMoney m1); ELit (LMoney m2)] -> | Binop Eq, [ELit (LMoney m1); ELit (LMoney m2)] ->
A.ELit (LBool Runtime.(m1 =$ m2)) ELit (LBool Runtime.(m1 =$ m2))
| A.Binop A.Eq, [ELit (LRat i1); ELit (LRat i2)] -> | Binop Eq, [ELit (LRat i1); ELit (LRat i2)] ->
A.ELit (LBool Runtime.(i1 =& i2)) ELit (LBool Runtime.(i1 =& i2))
| A.Binop A.Eq, [ELit (LInt i1); ELit (LInt i2)] -> | Binop Eq, [ELit (LInt i1); ELit (LInt i2)] ->
A.ELit (LBool Runtime.(i1 =! i2)) ELit (LBool Runtime.(i1 =! i2))
| A.Binop A.Eq, [ELit (LBool b1); ELit (LBool b2)] -> A.ELit (LBool (b1 = b2)) | Binop Eq, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 = b2))
| A.Binop A.Eq, [EArray es1; EArray es2] -> | Binop Eq, [EArray es1; EArray es2] ->
A.ELit ELit
(LBool (LBool
(try (try
List.for_all2 List.for_all2
(fun e1 e2 -> (fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with match evaluate_operator ctx op pos [e1; e2] with
| A.ELit (LBool b) -> b | ELit (LBool b) -> b
| _ -> assert false | _ -> assert false
(* should not happen *)) (* should not happen *))
es1 es2 es1 es2
with Invalid_argument _ -> false)) with Invalid_argument _ -> false))
| A.Binop A.Eq, [ETuple (es1, s1); ETuple (es2, s2)] -> | Binop Eq, [ETuple (es1, s1); ETuple (es2, s2)] ->
A.ELit ELit
(LBool (LBool
(try (try
s1 = s2 s1 = s2
&& List.for_all2 && List.for_all2
(fun e1 e2 -> (fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with match evaluate_operator ctx op pos [e1; e2] with
| A.ELit (LBool b) -> b | ELit (LBool b) -> b
| _ -> assert false | _ -> assert false
(* should not happen *)) (* should not happen *))
es1 es2 es1 es2
with Invalid_argument _ -> false)) with Invalid_argument _ -> false))
| A.Binop A.Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] -> | Binop Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] ->
A.ELit ELit
(LBool (LBool
(try (try
en1 = en2 en1 = en2
&& i1 = i2 && i1 = i2
&& &&
match evaluate_operator ctx op pos [e1; e2] with match evaluate_operator ctx op pos [e1; e2] with
| A.ELit (LBool b) -> b | ELit (LBool b) -> b
| _ -> assert false | _ -> assert false
(* should not happen *) (* should not happen *)
with Invalid_argument _ -> false)) with Invalid_argument _ -> false))
| A.Binop A.Eq, [_; _] -> | Binop Eq, [_; _] ->
A.ELit (LBool false) (* comparing anything else return false *) ELit (LBool false) (* comparing anything else return false *)
| A.Binop A.Neq, [_; _] -> ( | Binop Neq, [_; _] -> (
match evaluate_operator ctx (A.Binop A.Eq) pos args with match evaluate_operator ctx (Binop Eq) pos args with
| A.ELit (A.LBool b) -> A.ELit (A.LBool (not b)) | ELit (LBool b) -> ELit (LBool (not b))
| _ -> assert false (*should not happen *)) | _ -> assert false (*should not happen *))
| A.Binop A.Concat, [A.EArray es1; A.EArray es2] -> A.EArray (es1 @ es2) | Binop Concat, [EArray es1; EArray es2] -> EArray (es1 @ es2)
| A.Binop A.Map, [_; A.EArray es] -> | Binop Map, [_; EArray es] ->
A.EArray EArray
(List.map (List.map
(fun e' -> (fun e' ->
evaluate_expr ctx evaluate_expr ctx
(Marked.same_mark_as (A.EApp (List.nth args 0, [e'])) e')) (Marked.same_mark_as (EApp (List.nth args 0, [e'])) e'))
es) es)
| A.Binop A.Filter, [_; A.EArray es] -> | Binop Filter, [_; EArray es] ->
A.EArray EArray
(List.filter (List.filter
(fun e' -> (fun e' ->
match match
evaluate_expr ctx evaluate_expr ctx
(Marked.same_mark_as (A.EApp (List.nth args 0, [e'])) e') (Marked.same_mark_as (EApp (List.nth args 0, [e'])) e')
with with
| A.ELit (A.LBool b), _ -> b | ELit (LBool b), _ -> b
| _ -> | _ ->
Errors.raise_spanned_error Errors.raise_spanned_error
(A.pos (List.nth args 0)) (Expr.pos (List.nth args 0))
"This predicate evaluated to something else than a boolean \ "This predicate evaluated to something else than a boolean \
(should not happen if the term was well-typed)") (should not happen if the term was well-typed)")
es) es)
| A.Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> | Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) ->
A.ELit LEmptyError ELit LEmptyError
| A.Unop (A.Minus KInt), [ELit (LInt i)] -> | Unop (Minus KInt), [ELit (LInt i)] ->
A.ELit (LInt Runtime.(integer_of_int 0 -! i)) ELit (LInt Runtime.(integer_of_int 0 -! i))
| A.Unop (A.Minus KRat), [ELit (LRat i)] -> | Unop (Minus KRat), [ELit (LRat i)] ->
A.ELit (LRat Runtime.(decimal_of_string "0" -& i)) ELit (LRat Runtime.(decimal_of_string "0" -& i))
| A.Unop (A.Minus KMoney), [ELit (LMoney i)] -> | Unop (Minus KMoney), [ELit (LMoney i)] ->
A.ELit (LMoney Runtime.(money_of_units_int 0 -$ i)) ELit (LMoney Runtime.(money_of_units_int 0 -$ i))
| A.Unop (A.Minus KDuration), [ELit (LDuration i)] -> | Unop (Minus KDuration), [ELit (LDuration i)] ->
A.ELit (LDuration Runtime.(~-^i)) ELit (LDuration Runtime.(~-^i))
| A.Unop A.Not, [ELit (LBool b)] -> A.ELit (LBool (not b)) | Unop Not, [ELit (LBool b)] -> ELit (LBool (not b))
| A.Unop A.Length, [EArray es] -> | Unop Length, [EArray es] ->
A.ELit (LInt (Runtime.integer_of_int (List.length es))) ELit (LInt (Runtime.integer_of_int (List.length es)))
| A.Unop A.GetDay, [ELit (LDate d)] -> | Unop GetDay, [ELit (LDate d)] ->
A.ELit (LInt Runtime.(day_of_month_of_date d)) ELit (LInt Runtime.(day_of_month_of_date d))
| A.Unop A.GetMonth, [ELit (LDate d)] -> | Unop GetMonth, [ELit (LDate d)] ->
A.ELit (LInt Runtime.(month_number_of_date d)) ELit (LInt Runtime.(month_number_of_date d))
| A.Unop A.GetYear, [ELit (LDate d)] -> A.ELit (LInt Runtime.(year_of_date d)) | Unop GetYear, [ELit (LDate d)] -> ELit (LInt Runtime.(year_of_date d))
| A.Unop A.FirstDayOfMonth, [ELit (LDate d)] -> | Unop FirstDayOfMonth, [ELit (LDate d)] ->
A.ELit (LDate Runtime.(first_day_of_month d)) ELit (LDate Runtime.(first_day_of_month d))
| A.Unop A.LastDayOfMonth, [ELit (LDate d)] -> | Unop LastDayOfMonth, [ELit (LDate d)] ->
A.ELit (LDate Runtime.(first_day_of_month d)) ELit (LDate Runtime.(first_day_of_month d))
| A.Unop A.IntToRat, [ELit (LInt i)] -> | Unop IntToRat, [ELit (LInt i)] ->
A.ELit (LRat Runtime.(decimal_of_integer i)) ELit (LRat Runtime.(decimal_of_integer i))
| A.Unop A.MoneyToRat, [ELit (LMoney i)] -> | Unop MoneyToRat, [ELit (LMoney i)] ->
A.ELit (LRat Runtime.(decimal_of_money i)) ELit (LRat Runtime.(decimal_of_money i))
| A.Unop A.RatToMoney, [ELit (LRat i)] -> | Unop RatToMoney, [ELit (LRat i)] ->
A.ELit (LMoney Runtime.(money_of_decimal i)) ELit (LMoney Runtime.(money_of_decimal i))
| A.Unop A.RoundMoney, [ELit (LMoney m)] -> | Unop RoundMoney, [ELit (LMoney m)] ->
A.ELit (LMoney Runtime.(money_round m)) ELit (LMoney Runtime.(money_round m))
| A.Unop A.RoundDecimal, [ELit (LRat m)] -> | Unop RoundDecimal, [ELit (LRat m)] ->
A.ELit (LRat Runtime.(decimal_round m)) ELit (LRat Runtime.(decimal_round m))
| A.Unop (A.Log (entry, infos)), [e'] -> | Unop (Log (entry, infos)), [e'] ->
if !Cli.trace_flag then ( if !Cli.trace_flag then (
match entry with match entry with
| VarDef _ -> | VarDef _ ->
@ -276,7 +277,7 @@ let rec evaluate_operator
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Cli.log_format "%*s%a %a: %s" (!log_indent * 2) ""
Print.format_log_entry entry Print.format_uid_list infos Print.format_log_entry entry Print.format_uid_list infos
(match e' with (match e' with
| Ast.EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>" | EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
| _ -> | _ ->
let expr_str = let expr_str =
Format.asprintf "%a" Format.asprintf "%a"
@ -308,7 +309,7 @@ let rec evaluate_operator
entry Print.format_uid_list infos) entry Print.format_uid_list infos)
else (); else ();
e' e'
| A.Unop _, [ELit LEmptyError] -> A.ELit LEmptyError | Unop _, [ELit LEmptyError] -> ELit LEmptyError
| _ -> | _ ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
([Some "Operator:", pos] ([Some "Operator:", pos]
@ -318,16 +319,16 @@ let rec evaluate_operator
(Format.asprintf "Argument n°%d, value %a" (i + 1) (Format.asprintf "Argument n°%d, value %a" (i + 1)
(Print.format_expr ctx ~debug:true) (Print.format_expr ctx ~debug:true)
arg), arg),
A.pos arg )) Expr.pos arg ))
args) args)
"Operator applied to the wrong arguments\n\ "Operator applied to the wrong arguments\n\
(should not happen if the term was well-typed)" (should not happen if the term was well-typed)"
and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.marked_expr) : 'm Ast.marked_expr
= =
match Marked.unmark e with match Marked.unmark e with
| EVar _ -> | EVar _ ->
Errors.raise_spanned_error (A.pos e) Errors.raise_spanned_error (Expr.pos e)
"free variable found at evaluation (should not happen if term was \ "free variable found at evaluation (should not happen if term was \
well-typed" well-typed"
| EApp (e1, args) -> ( | EApp (e1, args) -> (
@ -339,22 +340,22 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr
evaluate_expr ctx evaluate_expr ctx
(Bindlib.msubst binder (Array.of_list (List.map Marked.unmark args))) (Bindlib.msubst binder (Array.of_list (List.map Marked.unmark args)))
else else
Errors.raise_spanned_error (A.pos e) Errors.raise_spanned_error (Expr.pos e)
"wrong function call, expected %d arguments, got %d" "wrong function call, expected %d arguments, got %d"
(Bindlib.mbinder_arity binder) (Bindlib.mbinder_arity binder)
(List.length args) (List.length args)
| EOp op -> Marked.same_mark_as (evaluate_operator ctx op (A.pos e) args) e | EOp op -> Marked.same_mark_as (evaluate_operator ctx op (Expr.pos e) args) e
| ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (A.pos e) Errors.raise_spanned_error (Expr.pos e)
"function has not been reduced to a lambda at evaluation (should not \ "function has not been reduced to a lambda at evaluation (should not \
happen if the term was well-typed") happen if the term was well-typed")
| EAbs _ | ELit _ | EOp _ -> e (* these are values *) | EAbs _ | ELit _ | EOp _ -> e (* these are values *)
| ETuple (es, s) -> | ETuple (es, s) ->
let new_es = List.map (evaluate_expr ctx) es in let new_es = List.map (evaluate_expr ctx) es in
if List.exists is_empty_error new_es then if List.exists is_empty_error new_es then
Marked.same_mark_as (A.ELit LEmptyError) e Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (A.ETuple (new_es, s)) e else Marked.same_mark_as (ETuple (new_es, s)) e
| ETupleAccess (e1, n, s, _) -> ( | ETupleAccess (e1, n, s, _) -> (
let e1 = evaluate_expr ctx e1 in let e1 = evaluate_expr ctx e1 in
match Marked.unmark e1 with match Marked.unmark e1 with
@ -364,49 +365,49 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr
| Some s, Some s' when s = s' -> () | Some s, Some s' when s = s' -> ()
| _ -> | _ ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[None, A.pos e; None, A.pos e1] [None, Expr.pos e; None, Expr.pos e1]
"Error during tuple access: not the same structs (should not happen \ "Error during tuple access: not the same structs (should not happen \
if the term was well-typed)"); if the term was well-typed)");
match List.nth_opt es n with match List.nth_opt es n with
| Some e' -> e' | Some e' -> e'
| None -> | None ->
Errors.raise_spanned_error (A.pos e1) Errors.raise_spanned_error (Expr.pos e1)
"The tuple has %d components but the %i-th element was requested \ "The tuple has %d components but the %i-th element was requested \
(should not happen if the term was well-type)" (should not happen if the term was well-type)"
(List.length es) n) (List.length es) n)
| ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (A.pos e1) Errors.raise_spanned_error (Expr.pos e1)
"The expression %a should be a tuple with %d components but is not \ "The expression %a should be a tuple with %d components but is not \
(should not happen if the term was well-typed)" (should not happen if the term was well-typed)"
(Print.format_expr ctx ~debug:true) (Print.format_expr ctx ~debug:true)
e n) e n)
| EInj (e1, n, en, ts) -> | EInj (e1, n, en, ts) ->
let e1' = evaluate_expr ctx e1 in let e1' = evaluate_expr ctx e1 in
if is_empty_error e1' then Marked.same_mark_as (A.ELit LEmptyError) e if is_empty_error e1' then Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (A.EInj (e1', n, en, ts)) e else Marked.same_mark_as (EInj (e1', n, en, ts)) e
| EMatch (e1, es, e_name) -> ( | EMatch (e1, es, e_name) -> (
let e1 = evaluate_expr ctx e1 in let e1 = evaluate_expr ctx e1 in
match Marked.unmark e1 with match Marked.unmark e1 with
| A.EInj (e1, n, e_name', _) -> | EInj (e1, n, e_name', _) ->
if e_name <> e_name' then if e_name <> e_name' then
Errors.raise_multispanned_error Errors.raise_multispanned_error
[None, A.pos e; None, A.pos e1] [None, Expr.pos e; None, Expr.pos e1]
"Error during match: two different enums found (should not happend \ "Error during match: two different enums found (should not happend \
if the term was well-typed)"; if the term was well-typed)";
let es_n = let es_n =
match List.nth_opt es n with match List.nth_opt es n with
| Some es_n -> es_n | Some es_n -> es_n
| None -> | None ->
Errors.raise_spanned_error (A.pos e) Errors.raise_spanned_error (Expr.pos e)
"sum type index error (should not happend if the term was \ "sum type index error (should not happend if the term was \
well-typed)" well-typed)"
in in
let new_e = Marked.same_mark_as (A.EApp (es_n, [e1])) e in let new_e = Marked.same_mark_as (EApp (es_n, [e1])) e in
evaluate_expr ctx new_e evaluate_expr ctx new_e
| A.ELit A.LEmptyError -> Marked.same_mark_as (A.ELit A.LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (A.pos e1) Errors.raise_spanned_error (Expr.pos e1)
"Expected a term having a sum type as an argument to a match (should \ "Expected a term having a sum type as an argument to a match (should \
not happend if the term was well-typed") not happend if the term was well-typed")
| EDefault (exceptions, just, cons) -> ( | EDefault (exceptions, just, cons) -> (
@ -416,11 +417,11 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr
| 0 -> ( | 0 -> (
let just = evaluate_expr ctx just in let just = evaluate_expr ctx just in
match Marked.unmark just with match Marked.unmark just with
| ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| ELit (LBool true) -> evaluate_expr ctx cons | ELit (LBool true) -> evaluate_expr ctx cons
| ELit (LBool false) -> Marked.same_mark_as (A.ELit LEmptyError) e | ELit (LBool false) -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (A.pos e) Errors.raise_spanned_error (Expr.pos e)
"Default justification has not been reduced to a boolean at \ "Default justification has not been reduced to a boolean at \
evaluation (should not happen if the term was well-typed") evaluation (should not happen if the term was well-typed")
| 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions | 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions
@ -428,7 +429,7 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr
Errors.raise_multispanned_error Errors.raise_multispanned_error
(List.map (List.map
(fun except -> (fun except ->
Some "This consequence has a valid justification:", A.pos except) Some "This consequence has a valid justification:", Expr.pos except)
(List.filter (fun sub -> not (is_empty_error sub)) exceptions)) (List.filter (fun sub -> not (is_empty_error sub)) exceptions))
"There is a conflict between multiple valid consequences for assigning \ "There is a conflict between multiple valid consequences for assigning \
the same variable.") the same variable.")
@ -436,55 +437,55 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr
match Marked.unmark (evaluate_expr ctx cond) with match Marked.unmark (evaluate_expr ctx cond) with
| ELit (LBool true) -> evaluate_expr ctx et | ELit (LBool true) -> evaluate_expr ctx et
| ELit (LBool false) -> evaluate_expr ctx ef | ELit (LBool false) -> evaluate_expr ctx ef
| ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (A.pos cond) Errors.raise_spanned_error (Expr.pos cond)
"Expected a boolean literal for the result of this condition (should \ "Expected a boolean literal for the result of this condition (should \
not happen if the term was well-typed)") not happen if the term was well-typed)")
| EArray es -> | EArray es ->
let new_es = List.map (evaluate_expr ctx) es in let new_es = List.map (evaluate_expr ctx) es in
if List.exists is_empty_error new_es then if List.exists is_empty_error new_es then
Marked.same_mark_as (A.ELit LEmptyError) e Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (A.EArray new_es) e else Marked.same_mark_as (EArray new_es) e
| ErrorOnEmpty e' -> | ErrorOnEmpty e' ->
let e' = evaluate_expr ctx e' in let e' = evaluate_expr ctx e' in
if Marked.unmark e' = A.ELit LEmptyError then if Marked.unmark e' = ELit LEmptyError then
Errors.raise_spanned_error (A.pos e') Errors.raise_spanned_error (Expr.pos e')
"This variable evaluated to an empty term (no rule that defined it \ "This variable evaluated to an empty term (no rule that defined it \
applied in this situation)" applied in this situation)"
else e' else e'
| EAssert e' -> ( | EAssert e' -> (
match Marked.unmark (evaluate_expr ctx e') with match Marked.unmark (evaluate_expr ctx e') with
| ELit (LBool true) -> Marked.same_mark_as (Ast.ELit LUnit) e' | ELit (LBool true) -> Marked.same_mark_as (ELit LUnit) e'
| ELit (LBool false) -> ( | ELit (LBool false) -> (
match Marked.unmark e' with match Marked.unmark e' with
| Ast.ErrorOnEmpty | ErrorOnEmpty
( EApp ( EApp
( (Ast.EOp (Binop op), _), ( (EOp (Binop op), _),
[((ELit _, _) as e1); ((ELit _, _) as e2)] ), [((ELit _, _) as e1); ((ELit _, _) as e2)] ),
_ ) _ )
| EApp | EApp
( (Ast.EOp (Ast.Unop (Ast.Log _)), _), ( (EOp (Unop (Log _)), _),
[ [
( Ast.EApp ( EApp
( (Ast.EOp (Binop op), _), ( (EOp (Binop op), _),
[((ELit _, _) as e1); ((ELit _, _) as e2)] ), [((ELit _, _) as e1); ((ELit _, _) as e2)] ),
_ ); _ );
] ) ] )
| EApp | EApp
((Ast.EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)]) ((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)])
-> ->
Errors.raise_spanned_error (A.pos e') "Assertion failed: %a %a %a" Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Print.format_expr ctx ~debug:false) (Print.format_expr ctx ~debug:false)
e1 Print.format_binop op e1 Print.format_binop op
(Print.format_expr ctx ~debug:false) (Print.format_expr ctx ~debug:false)
e2 e2
| _ -> | _ ->
Cli.debug_format "%a" (Print.format_expr ctx) e'; Cli.debug_format "%a" (Print.format_expr ctx) e';
Errors.raise_spanned_error (A.pos e') "Assertion failed") Errors.raise_spanned_error (Expr.pos e') "Assertion failed")
| ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (A.pos e') Errors.raise_spanned_error (Expr.pos e')
"Expected a boolean literal for the result of this assertion (should \ "Expected a boolean literal for the result of this assertion (should \
not happen if the term was well-typed)") not happen if the term was well-typed)")
@ -492,13 +493,13 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr
let interpret_program : let interpret_program :
'm. 'm.
Ast.decl_ctx -> decl_ctx ->
'm Ast.marked_expr -> 'm Ast.marked_expr ->
(Uid.MarkedString.info * 'm Ast.marked_expr) list = (Uid.MarkedString.info * 'm Ast.marked_expr) list =
fun (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) : fun (ctx : decl_ctx) (e : 'm Ast.marked_expr) :
(Uid.MarkedString.info * 'm Ast.marked_expr) list -> (Uid.MarkedString.info * 'm Ast.marked_expr) list ->
match evaluate_expr ctx e with match evaluate_expr ctx e with
| Ast.EAbs (_, [((Ast.TTuple (taus, Some s_in), _) as targs)]), mark_e -> | EAbs (_, [((TTuple (taus, Some s_in), _) as targs)]), mark_e ->
begin begin
(* At this point, the interpreter seeks to execute the scope but does not (* At this point, the interpreter seeks to execute the scope but does not
have a way to retrieve input values from the command line. [taus] contain have a way to retrieve input values from the command line. [taus] contain
@ -509,9 +510,9 @@ let interpret_program :
List.map List.map
(fun ty -> (fun ty ->
match Marked.unmark ty with match Marked.unmark ty with
| A.TArrow ((A.TLit A.TUnit, _), ty_in) -> | TArrow ((TLit TUnit, _), ty_in) ->
Ast.empty_thunked_term Ast.empty_thunked_term
(A.map_mark (fun pos -> pos) (fun _ -> ty_in) mark_e) (Expr.map_mark (fun pos -> pos) (fun _ -> ty_in) mark_e)
| _ -> | _ ->
Errors.raise_spanned_error (Marked.get_mark ty) Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \ "This scope needs input arguments to be executed. But the Catala \
@ -522,23 +523,23 @@ let interpret_program :
taus taus
in in
let to_interpret = let to_interpret =
( Ast.EApp ( EApp
( e, ( e,
[ [
( Ast.ETuple (application_term, Some s_in), ( ETuple (application_term, Some s_in),
let pos = let pos =
match application_term with match application_term with
| a :: _ -> A.pos a | a :: _ -> Expr.pos a
| [] -> Pos.no_pos | [] -> Pos.no_pos
in in
A.map_mark (fun _ -> pos) (fun _ -> targs) mark_e ); Expr.map_mark (fun _ -> pos) (fun _ -> targs) mark_e );
] ), ] ),
A.map_mark Expr.map_mark
(fun pos -> pos) (fun pos -> pos)
(fun ty -> (fun ty ->
match application_term, ty with match application_term, ty with
| [], t_out -> t_out | [], t_out -> t_out
| _ :: _, (A.TArrow (_, t_out), _) -> t_out | _ :: _, (TArrow (_, t_out), _) -> t_out
| _ :: _, (_, bad_pos) -> | _ :: _, (_, bad_pos) ->
Errors.raise_spanned_error bad_pos Errors.raise_spanned_error bad_pos
"@[<hv 2>(bug) Result of interpretation doesn't have the \ "@[<hv 2>(bug) Result of interpretation doesn't have the \
@ -547,19 +548,19 @@ let interpret_program :
mark_e ) mark_e )
in in
match Marked.unmark (evaluate_expr ctx to_interpret) with match Marked.unmark (evaluate_expr ctx to_interpret) with
| Ast.ETuple (args, Some s_out) -> | ETuple (args, Some s_out) ->
let s_out_fields = let s_out_fields =
List.map List.map
(fun (f, _) -> Ast.StructFieldName.get_info f) (fun (f, _) -> StructFieldName.get_info f)
(Ast.StructMap.find s_out ctx.ctx_structs) (StructMap.find s_out ctx.ctx_structs)
in in
List.map2 (fun arg var -> var, arg) args s_out_fields List.map2 (fun arg var -> var, arg) args s_out_fields
| _ -> | _ ->
Errors.raise_spanned_error (A.pos e) Errors.raise_spanned_error (Expr.pos e)
"The interpretation of a program should always yield a struct \ "The interpretation of a program should always yield a struct \
corresponding to the scope variables" corresponding to the scope variables"
end end
| _ -> | _ ->
Errors.raise_spanned_error (A.pos e) Errors.raise_spanned_error (Expr.pos e)
"The interpreter can only interpret terms starting with functions having \ "The interpreter can only interpret terms starting with functions having \
thunked arguments" thunked arguments"

View File

@ -17,12 +17,13 @@
(** Reference interpreter for the default calculus *) (** Reference interpreter for the default calculus *)
open Utils open Utils
open Shared_ast
val evaluate_expr : Ast.decl_ctx -> 'm Ast.marked_expr -> 'm Ast.marked_expr val evaluate_expr : decl_ctx -> 'm Ast.marked_expr -> 'm Ast.marked_expr
(** Evaluates an expression according to the semantics of the default calculus. *) (** Evaluates an expression according to the semantics of the default calculus. *)
val interpret_program : val interpret_program :
Ast.decl_ctx -> decl_ctx ->
'm Ast.marked_expr -> 'm Ast.marked_expr ->
(Uid.MarkedString.info * 'm Ast.marked_expr) list (Uid.MarkedString.info * 'm Ast.marked_expr) list
(** Interprets a program. This function expects an expression typed as a (** Interprets a program. This function expects an expression typed as a

View File

@ -15,6 +15,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
type partial_evaluation_ctx = { type partial_evaluation_ctx = {
@ -82,7 +83,7 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm marked_expr) :
(fun arg arms -> (fun arg arms ->
match arg, arms with match arg, arms with
| (EInj (e1, i, e_name', _ts), _), _ | (EInj (e1, i, e_name', _ts), _), _
when Ast.EnumName.compare e_name e_name' = 0 -> when EnumName.compare e_name e_name' = 0 ->
(* iota reduction *) (* iota reduction *)
EApp (List.nth arms i, [e1]), pos EApp (List.nth arms i, [e1]), pos
| _ -> EMatch (arg, arms, e_name), pos) | _ -> EMatch (arg, arms, e_name), pos)
@ -252,4 +253,4 @@ let optimize_program (p : 'm program) : untyped program =
(program_map partial_evaluation (program_map partial_evaluation
{ var_values = Var.Map.empty; decl_ctx = p.decl_ctx } { var_values = Var.Map.empty; decl_ctx = p.decl_ctx }
p) p)
|> untype_program |> Expr.untype_program

View File

@ -17,6 +17,7 @@
(** Optimization passes for default calculus programs and expressions *) (** Optimization passes for default calculus programs and expressions *)
open Shared_ast
open Ast open Ast
val optimize_expr : decl_ctx -> 'm marked_expr -> 'm marked_expr Bindlib.box val optimize_expr : decl_ctx -> 'm marked_expr -> 'm marked_expr Bindlib.box

View File

@ -15,6 +15,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
open String_common open String_common
@ -68,7 +69,7 @@ let format_enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) :
(Utils.Cli.format_with_style [ANSITerminal.magenta]) (Utils.Cli.format_with_style [ANSITerminal.magenta])
(Format.asprintf "%a" EnumConstructor.format_t c) (Format.asprintf "%a" EnumConstructor.format_t c)
let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ) : let rec format_typ (ctx : decl_ctx) (fmt : Format.formatter) (typ : typ) :
unit = unit =
let format_typ = format_typ ctx in let format_typ = format_typ ctx in
let format_typ_with_parens (fmt : Format.formatter) (t : typ) = let format_typ_with_parens (fmt : Format.formatter) (t : typ) =
@ -84,7 +85,7 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ) :
(fun fmt t -> Format.fprintf fmt "%a" format_typ t)) (fun fmt t -> Format.fprintf fmt "%a" format_typ t))
(List.map Marked.unmark ts) (List.map Marked.unmark ts)
| TTuple (_args, Some s) -> | TTuple (_args, Some s) ->
Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" Ast.StructName.format_t s Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" StructName.format_t s
format_punctuation "{" format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> ~pp_sep:(fun fmt () ->
@ -98,7 +99,7 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ) :
(StructMap.find s ctx.ctx_structs)) (StructMap.find s ctx.ctx_structs))
format_punctuation "}" format_punctuation "}"
| TEnum (_, e) -> | TEnum (_, e) ->
Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" Ast.EnumName.format_t e Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" EnumName.format_t e
format_punctuation "[" format_punctuation "["
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> ~pp_sep:(fun fmt () ->
@ -211,7 +212,7 @@ let format_var (fmt : Format.formatter) (v : 'm Ast.var) : unit =
let rec format_expr let rec format_expr
?(debug : bool = false) ?(debug : bool = false)
(ctx : Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(e : 'm marked_expr) : unit = (e : 'm marked_expr) : unit =
let format_expr = format_expr ~debug ctx in let format_expr = format_expr ~debug ctx in
@ -231,15 +232,15 @@ let rec format_expr
es format_punctuation ")" es format_punctuation ")"
| ETuple (es, Some s) -> | ETuple (es, Some s) ->
Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]" Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]"
Ast.StructName.format_t s format_punctuation "{" StructName.format_t s format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> ~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " format_punctuation ";") Format.fprintf fmt "%a@ " format_punctuation ";")
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\""
Ast.StructFieldName.format_t struct_field format_punctuation "\"" StructFieldName.format_t struct_field format_punctuation "\""
format_punctuation "=" format_expr e)) format_punctuation "=" format_expr e))
(List.combine es (List.map fst (Ast.StructMap.find s ctx.ctx_structs))) (List.combine es (List.map fst (StructMap.find s ctx.ctx_structs)))
format_punctuation "}" format_punctuation "}"
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" format_punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a@]" format_punctuation "["
@ -253,12 +254,12 @@ let rec format_expr
Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n
| Some s -> | Some s ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_operator "." Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_operator "."
format_punctuation "\"" Ast.StructFieldName.format_t format_punctuation "\"" StructFieldName.format_t
(fst (List.nth (Ast.StructMap.find s ctx.ctx_structs) n)) (fst (List.nth (StructMap.find s ctx.ctx_structs) n))
format_punctuation "\"") format_punctuation "\"")
| EInj (e, n, en, _ts) -> | EInj (e, n, en, _ts) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_enum_constructor Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_enum_constructor
(fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n)) (fst (List.nth (EnumMap.find en ctx.ctx_enums) n))
format_expr e format_expr e
| EMatch (e, es, e_name) -> | EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword
@ -268,7 +269,7 @@ let rec format_expr
(fun fmt (e, c) -> (fun fmt (e, c) ->
Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" format_punctuation "|" Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" format_punctuation "|"
format_enum_constructor c format_punctuation ":" format_expr e)) format_enum_constructor c format_punctuation ":" format_expr e))
(List.combine es (List.map fst (Ast.EnumMap.find e_name ctx.ctx_enums))) (List.combine es (List.map fst (EnumMap.find e_name ctx.ctx_enums)))
| ELit l -> format_lit fmt l | ELit l -> format_lit fmt l
| EApp ((EAbs (binder, taus), _), args) -> | EApp ((EAbs (binder, taus), _), args) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
@ -298,7 +299,7 @@ let rec format_expr
Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var x Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var x
format_punctuation ":" (format_typ ctx) tau format_punctuation ")")) format_punctuation ":" (format_typ ctx) tau format_punctuation ")"))
xs_tau format_punctuation "" format_expr body xs_tau format_punctuation "" format_expr body
| EApp ((EOp (Binop ((Ast.Map | Ast.Filter) as op)), _), [arg1; arg2]) -> | EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop op Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop op
format_with_parens arg1 format_with_parens arg2 format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) -> | EApp ((EOp (Binop op), _), [arg1; arg2]) ->
@ -347,13 +348,13 @@ let format_scope
?(debug : bool = false) ?(debug : bool = false)
(ctx : decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
((n, s) : Ast.ScopeName.t * ('m Ast.expr, 'm) scope_body) = ((n, s) : ScopeName.t * ('m Ast.expr, 'm) scope_body) =
Format.fprintf fmt "@[<hov 2>%a %a =@ %a@]" format_keyword "let" Format.fprintf fmt "@[<hov 2>%a %a =@ %a@]" format_keyword "let"
Ast.ScopeName.format_t n (format_expr ctx ~debug) ScopeName.format_t n (format_expr ctx ~debug)
(Bindlib.unbox (Bindlib.unbox
(Ast.build_whole_scope_expr ~make_abs:Ast.make_abs (Ast.build_whole_scope_expr ~make_abs:Ast.make_abs
~make_let_in:Ast.make_let_in ~box_expr:Ast.box_expr ctx s ~make_let_in:Ast.make_let_in ~box_expr:Expr.box ctx s
(Ast.map_mark (Expr.map_mark
(fun _ -> Marked.get_mark (Ast.ScopeName.get_info n)) (fun _ -> Marked.get_mark (ScopeName.get_info n))
(fun ty -> ty) (fun ty -> ty)
(Ast.get_scope_body_mark s)))) (Expr.get_scope_body_mark s))))

View File

@ -17,6 +17,7 @@
(** Printing functions for the default calculus AST *) (** Printing functions for the default calculus AST *)
open Utils open Utils
open Shared_ast
(** {1 Common syntax highlighting helpers}*) (** {1 Common syntax highlighting helpers}*)
@ -29,27 +30,27 @@ val format_lit_style : Format.formatter -> string -> unit
(** {1 Formatters} *) (** {1 Formatters} *)
val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit
val format_enum_constructor : Format.formatter -> Ast.EnumConstructor.t -> unit val format_enum_constructor : Format.formatter -> EnumConstructor.t -> unit
val format_tlit : Format.formatter -> Ast.typ_lit -> unit val format_tlit : Format.formatter -> typ_lit -> unit
val format_typ : Ast.decl_ctx -> Format.formatter -> Ast.typ -> unit val format_typ : decl_ctx -> Format.formatter -> typ -> unit
val format_lit : Format.formatter -> Ast.lit -> unit val format_lit : Format.formatter -> Ast.lit -> unit
val format_op_kind : Format.formatter -> Ast.op_kind -> unit val format_op_kind : Format.formatter -> op_kind -> unit
val format_binop : Format.formatter -> Ast.binop -> unit val format_binop : Format.formatter -> binop -> unit
val format_ternop : Format.formatter -> Ast.ternop -> unit val format_ternop : Format.formatter -> ternop -> unit
val format_log_entry : Format.formatter -> Ast.log_entry -> unit val format_log_entry : Format.formatter -> log_entry -> unit
val format_unop : Format.formatter -> Ast.unop -> unit val format_unop : Format.formatter -> unop -> unit
val format_var : Format.formatter -> 'm Ast.var -> unit val format_var : Format.formatter -> 'm Ast.var -> unit
val format_expr : val format_expr :
?debug:bool (** [true] for debug printing *) -> ?debug:bool (** [true] for debug printing *) ->
Ast.decl_ctx -> decl_ctx ->
Format.formatter -> Format.formatter ->
'm Ast.marked_expr -> 'm Ast.marked_expr ->
unit unit
val format_scope : val format_scope :
?debug:bool (** [true] for debug printing *) -> ?debug:bool (** [true] for debug printing *) ->
Ast.decl_ctx -> decl_ctx ->
Format.formatter -> Format.formatter ->
Ast.ScopeName.t * ('m Ast.expr, 'm) Ast.scope_body -> ScopeName.t * ('m Ast.expr, 'm) scope_body ->
unit unit

View File

@ -71,7 +71,7 @@ let typ_needs_parens (t : typ Marked.pos UnionFind.elem) : bool =
match Marked.unmark t with TArrow _ | TArray _ -> true | _ -> false match Marked.unmark t with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ let rec format_typ
(ctx : Ast.decl_ctx) (ctx : A.decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(typ : typ Marked.pos UnionFind.elem) : unit = (typ : typ Marked.pos UnionFind.elem) : unit =
let format_typ = format_typ ctx in let format_typ = format_typ ctx in
@ -90,8 +90,8 @@ let rec format_typ
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ")
(fun fmt t -> Format.fprintf fmt "%a" format_typ t)) (fun fmt t -> Format.fprintf fmt "%a" format_typ t))
ts ts
| TTuple (_ts, Some s) -> Format.fprintf fmt "%a" Ast.StructName.format_t s | TTuple (_ts, Some s) -> Format.fprintf fmt "%a" A.StructName.format_t s
| TEnum (_ts, e) -> Format.fprintf fmt "%a" Ast.EnumName.format_t e | TEnum (_ts, e) -> Format.fprintf fmt "%a" A.EnumName.format_t e
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1 Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1
format_typ t2 format_typ t2
@ -108,8 +108,8 @@ type mark = { pos : Pos.t; uf : unionfind_typ }
(** Raises an error if unification cannot be performed *) (** Raises an error if unification cannot be performed *)
let rec unify let rec unify
(ctx : Ast.decl_ctx) (ctx : A.decl_ctx)
(e : ('a, 'm A.mark) Ast.marked_gexpr) (* used for error context *) (e : ('a, 'm A.mark) A.marked_gexpr) (* used for error context *)
(t1 : typ Marked.pos UnionFind.elem) (t1 : typ Marked.pos UnionFind.elem)
(t2 : typ Marked.pos UnionFind.elem) : unit = (t2 : typ Marked.pos UnionFind.elem) : unit =
let unify = unify ctx in let unify = unify ctx in
@ -263,7 +263,7 @@ let op_type (op : A.operator Marked.pos) : typ Marked.pos UnionFind.elem =
type 'e env = ('e, typ Marked.pos UnionFind.elem) A.Var.Map.t type 'e env = ('e, typ Marked.pos UnionFind.elem) A.Var.Map.t
let add_pos e ty = Marked.mark (Ast.pos e) ty let add_pos e ty = Marked.mark (A.Expr.pos e) ty
let ty (_, { uf; _ }) = uf let ty (_, { uf; _ }) = uf
let ( let+ ) x f = Bindlib.box_apply f x let ( let+ ) x f = Bindlib.box_apply f x
let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2 let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2
@ -290,12 +290,12 @@ let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
(** Infers the most permissive type from an expression *) (** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up let rec typecheck_expr_bottom_up
(ctx : Ast.decl_ctx) (ctx : A.decl_ctx)
(env : 'm Ast.expr env) (env : 'm Ast.expr env)
(e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box = (e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box =
(* Cli.debug_format "Looking for type of %a" (Print.format_expr ~debug:true (* Cli.debug_format "Looking for type of %a" (Print.format_expr ~debug:true
ctx) e; *) ctx) e; *)
let pos_e = Ast.pos e in let pos_e = A.Expr.pos e in
let mark (e : (A.dcalc, mark) A.gexpr) uf = let mark (e : (A.dcalc, mark) A.gexpr) uf =
Marked.mark { uf; pos = pos_e } e Marked.mark { uf; pos = pos_e } e
in in
@ -308,7 +308,7 @@ let rec typecheck_expr_bottom_up
let+ v' = Bindlib.box_var (A.Var.translate v) in let+ v' = Bindlib.box_var (A.Var.translate v) in
mark v' t mark v' t
| None -> | None ->
Errors.raise_spanned_error (Ast.pos e) Errors.raise_spanned_error (A.Expr.pos e)
"Variable %s not found in the current context." (Bindlib.name_of v) "Variable %s not found in the current context." (Bindlib.name_of v)
end end
| A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool) | A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool)
@ -343,7 +343,7 @@ let rec typecheck_expr_bottom_up
match List.nth_opt ts' n with match List.nth_opt ts' n with
| Some ts_n -> ts_n | Some ts_n -> ts_n
| None -> | None ->
Errors.raise_spanned_error (Ast.pos e) Errors.raise_spanned_error (A.Expr.pos e)
"Expression should have a sum type with at least %d cases but only \ "Expression should have a sum type with at least %d cases but only \
has %d" has %d"
n (List.length ts') n (List.length ts')
@ -368,7 +368,7 @@ let rec typecheck_expr_bottom_up
mark (EMatch (e1', es', e_name)) t_ret mark (EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, taus) -> | A.EAbs (binder, taus) ->
if Bindlib.mbinder_arity binder <> List.length taus then if Bindlib.mbinder_arity binder <> List.length taus then
Errors.raise_spanned_error (Ast.pos e) Errors.raise_spanned_error (A.Expr.pos e)
"function has %d variables but was supplied %d types" "function has %d variables but was supplied %d types"
(Bindlib.mbinder_arity binder) (Bindlib.mbinder_arity binder)
(List.length taus) (List.length taus)
@ -446,13 +446,13 @@ let rec typecheck_expr_bottom_up
(** Checks whether the expression can be typed with the provided type *) (** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down and typecheck_expr_top_down
(ctx : Ast.decl_ctx) (ctx : A.decl_ctx)
(env : 'm Ast.expr env) (env : 'm Ast.expr env)
(tau : typ Marked.pos UnionFind.elem) (tau : typ Marked.pos UnionFind.elem)
(e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box = (e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box =
(* Cli.debug_format "Propagating type %a for expr %a" (format_typ ctx) tau (* Cli.debug_format "Propagating type %a for expr %a" (format_typ ctx) tau
(Print.format_expr ctx) e; *) (Print.format_expr ctx) e; *)
let pos_e = Ast.pos e in let pos_e = A.Expr.pos e in
let mark e = Marked.mark { uf = tau; pos = pos_e } e in let mark e = Marked.mark { uf = tau; pos = pos_e } e in
let unify_and_mark (e' : (A.dcalc, mark) A.gexpr) tau' = let unify_and_mark (e' : (A.dcalc, mark) A.gexpr) tau' =
(* This try...with was added because of (* This try...with was added because of
@ -502,7 +502,7 @@ and typecheck_expr_top_down
match List.nth_opt typs' n with match List.nth_opt typs' n with
| Some t1n -> unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n | Some t1n -> unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
| None -> | None ->
Errors.raise_spanned_error (Ast.pos e1) Errors.raise_spanned_error (A.Expr.pos e1)
"Expression should have a tuple type with at least %d elements but \ "Expression should have a tuple type with at least %d elements but \
only has %d" only has %d"
n (List.length typs) n (List.length typs)
@ -513,7 +513,7 @@ and typecheck_expr_top_down
match List.nth_opt ts' n with match List.nth_opt ts' n with
| Some ts_n -> ts_n | Some ts_n -> ts_n
| None -> | None ->
Errors.raise_spanned_error (Ast.pos e) Errors.raise_spanned_error (A.Expr.pos e)
"Expression should have a sum type with at least %d cases but only \ "Expression should have a sum type with at least %d cases but only \
has %d" has %d"
n (List.length ts) n (List.length ts)
@ -544,7 +544,7 @@ and typecheck_expr_top_down
unify_and_mark (EMatch (e1', es', e_name)) t_ret unify_and_mark (EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, t_args) -> | A.EAbs (binder, t_args) ->
if Bindlib.mbinder_arity binder <> List.length t_args then if Bindlib.mbinder_arity binder <> List.length t_args then
Errors.raise_spanned_error (Ast.pos e) Errors.raise_spanned_error (A.Expr.pos e)
"function has %d variables but was supplied %d types" "function has %d variables but was supplied %d types"
(Bindlib.mbinder_arity binder) (Bindlib.mbinder_arity binder)
(List.length t_args) (List.length t_args)
@ -628,8 +628,8 @@ let wrap ctx f e =
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos } let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
(* Infer the type of an expression *) (* Infer the type of an expression *)
let infer_types (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) : let infer_types (ctx : A.decl_ctx) (e : 'm Ast.marked_expr) :
Ast.typed Ast.marked_expr Bindlib.box = A.typed Ast.marked_expr Bindlib.box =
A.Expr.map_marks ~f:get_ty_mark A.Expr.map_marks ~f:get_ty_mark
@@ Bindlib.unbox @@ Bindlib.unbox
@@ wrap ctx (typecheck_expr_bottom_up ctx A.Var.Map.empty) e @@ wrap ctx (typecheck_expr_bottom_up ctx A.Var.Map.empty) e
@ -637,11 +637,11 @@ let infer_types (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) :
let infer_type (type m) ctx (e : m Ast.marked_expr) = let infer_type (type m) ctx (e : m Ast.marked_expr) =
match Marked.get_mark e with match Marked.get_mark e with
| A.Typed { ty; _ } -> ty | A.Typed { ty; _ } -> ty
| A.Untyped _ -> Ast.ty (Bindlib.unbox (infer_types ctx e)) | A.Untyped _ -> A.Expr.ty (Bindlib.unbox (infer_types ctx e))
(** Typechecks an expression given an expected type *) (** Typechecks an expression given an expected type *)
let check_type let check_type
(ctx : Ast.decl_ctx) (ctx : A.decl_ctx)
(e : 'm Ast.marked_expr) (e : 'm Ast.marked_expr)
(tau : A.typ Marked.pos) = (tau : A.typ Marked.pos) =
(* todo: consider using the already inferred type if ['m] = [typed] *) (* todo: consider using the already inferred type if ['m] = [typed] *)

View File

@ -17,18 +17,20 @@
(** Typing for the default calculus. Because of the error terms, we perform type (** Typing for the default calculus. Because of the error terms, we perform type
inference using the classical W algorithm with union-find unification. *) inference using the classical W algorithm with union-find unification. *)
open Shared_ast
val infer_types : val infer_types :
Ast.decl_ctx -> decl_ctx ->
Ast.untyped Ast.marked_expr -> untyped Ast.marked_expr ->
Ast.typed Ast.marked_expr Bindlib.box typed Ast.marked_expr Bindlib.box
(** Infers types everywhere on the given expression, and adds (or replaces) type (** Infers types everywhere on the given expression, and adds (or replaces) type
annotations on each node *) annotations on each node *)
val infer_type : Ast.decl_ctx -> 'm Ast.marked_expr -> Ast.typ Utils.Marked.pos val infer_type : decl_ctx -> 'm Ast.marked_expr -> typ Utils.Marked.pos
(** Gets the outer type of the given expression, using either the existing (** Gets the outer type of the given expression, using either the existing
annotations or inference *) annotations or inference *)
val check_type : val check_type :
Ast.decl_ctx -> 'm Ast.marked_expr -> Ast.typ Utils.Marked.pos -> unit decl_ctx -> 'm Ast.marked_expr -> typ Utils.Marked.pos -> unit
val infer_types_program : Ast.untyped Ast.program -> Ast.typed Ast.program val infer_types_program : untyped Ast.program -> typed Ast.program

View File

@ -17,6 +17,7 @@
(** Abstract syntax tree of the desugared representation *) (** Abstract syntax tree of the desugared representation *)
open Utils open Utils
open Shared_ast
(** {1 Names, Maps and Keys} *) (** {1 Names, Maps and Keys} *)
@ -99,7 +100,7 @@ module ScopeDefSet : Set.S with type elt = ScopeDef.t = Set.Make (ScopeDef)
type location = type location =
| ScopeVar of ScopeVar.t Marked.pos * StateName.t option | ScopeVar of ScopeVar.t Marked.pos * StateName.t option
| SubScopeVar of | SubScopeVar of
Scopelang.Ast.ScopeName.t ScopeName.t
* Scopelang.Ast.SubScopeName.t Marked.pos * Scopelang.Ast.SubScopeName.t Marked.pos
* ScopeVar.t Marked.pos * ScopeVar.t Marked.pos
@ -132,20 +133,20 @@ and expr =
| ELocation of location | ELocation of location
| EVar of expr Bindlib.var | EVar of expr Bindlib.var
| EStruct of | EStruct of
Scopelang.Ast.StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t
| EStructAccess of | EStructAccess of
marked_expr * Scopelang.Ast.StructFieldName.t * Scopelang.Ast.StructName.t marked_expr * StructFieldName.t * StructName.t
| EEnumInj of | EEnumInj of
marked_expr * Scopelang.Ast.EnumConstructor.t * Scopelang.Ast.EnumName.t marked_expr * EnumConstructor.t * EnumName.t
| EMatch of | EMatch of
marked_expr marked_expr
* Scopelang.Ast.EnumName.t * EnumName.t
* marked_expr Scopelang.Ast.EnumConstructorMap.t * marked_expr Scopelang.Ast.EnumConstructorMap.t
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of | EAbs of
(expr, marked_expr) Bindlib.mbinder * Scopelang.Ast.typ Marked.pos list (expr, marked_expr) Bindlib.mbinder * Scopelang.Ast.typ Marked.pos list
| EApp of marked_expr * marked_expr list | EApp of marked_expr * marked_expr list
| EOp of Dcalc.Ast.operator | EOp of operator
| EDefault of marked_expr list * marked_expr * marked_expr | EDefault of marked_expr list * marked_expr * marked_expr
| EIfThenElse of marked_expr * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr
| EArray of marked_expr list | EArray of marked_expr list
@ -170,7 +171,7 @@ module Expr = struct
| ELocation _, ELocation _ -> 0 | ELocation _, ELocation _ -> 0
| EVar v1, EVar v2 -> Bindlib.compare_vars v1 v2 | EVar v1, EVar v2 -> Bindlib.compare_vars v1 v2
| EStruct (name1, field_map1), EStruct (name2, field_map2) -> ( | EStruct (name1, field_map1), EStruct (name2, field_map2) -> (
match Scopelang.Ast.StructName.compare name1 name2 with match StructName.compare name1 name2 with
| 0 -> | 0 ->
Scopelang.Ast.StructFieldMap.compare (Marked.compare compare) field_map1 Scopelang.Ast.StructFieldMap.compare (Marked.compare compare) field_map1
field_map2 field_map2
@ -179,21 +180,21 @@ module Expr = struct
EStructAccess ((e2, _), field_name2, struct_name2) ) -> ( EStructAccess ((e2, _), field_name2, struct_name2) ) -> (
match compare e1 e2 with match compare e1 e2 with
| 0 -> ( | 0 -> (
match Scopelang.Ast.StructFieldName.compare field_name1 field_name2 with match StructFieldName.compare field_name1 field_name2 with
| 0 -> Scopelang.Ast.StructName.compare struct_name1 struct_name2 | 0 -> StructName.compare struct_name1 struct_name2
| n -> n) | n -> n)
| n -> n) | n -> n)
| EEnumInj ((e1, _), cstr1, name1), EEnumInj ((e2, _), cstr2, name2) -> ( | EEnumInj ((e1, _), cstr1, name1), EEnumInj ((e2, _), cstr2, name2) -> (
match compare e1 e2 with match compare e1 e2 with
| 0 -> ( | 0 -> (
match Scopelang.Ast.EnumName.compare name1 name2 with match EnumName.compare name1 name2 with
| 0 -> Scopelang.Ast.EnumConstructor.compare cstr1 cstr2 | 0 -> EnumConstructor.compare cstr1 cstr2
| n -> n) | n -> n)
| n -> n) | n -> n)
| EMatch ((e1, _), name1, emap1), EMatch ((e2, _), name2, emap2) -> ( | EMatch ((e1, _), name1, emap1), EMatch ((e2, _), name2, emap2) -> (
match compare e1 e2 with match compare e1 e2 with
| 0 -> ( | 0 -> (
match Scopelang.Ast.EnumName.compare name1 name2 with match EnumName.compare name1 name2 with
| 0 -> | 0 ->
Scopelang.Ast.EnumConstructorMap.compare (Marked.compare compare) Scopelang.Ast.EnumConstructorMap.compare (Marked.compare compare)
emap1 emap2 emap1 emap2
@ -325,8 +326,8 @@ let empty_rule
(pos : Pos.t) (pos : Pos.t)
(have_parameter : Scopelang.Ast.typ Marked.pos option) : rule = (have_parameter : Scopelang.Ast.typ Marked.pos option) : rule =
{ {
rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); rule_just = Bindlib.box (ELit (LBool false), pos);
rule_cons = Bindlib.box (ELit Dcalc.Ast.LEmptyError, pos); rule_cons = Bindlib.box (ELit LEmptyError, pos);
rule_parameter = rule_parameter =
(match have_parameter with (match have_parameter with
| Some typ -> Some (Var.make "dummy", typ) | Some typ -> Some (Var.make "dummy", typ)
@ -340,8 +341,8 @@ let always_false_rule
(pos : Pos.t) (pos : Pos.t)
(have_parameter : Scopelang.Ast.typ Marked.pos option) : rule = (have_parameter : Scopelang.Ast.typ Marked.pos option) : rule =
{ {
rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool true), pos); rule_just = Bindlib.box (ELit (LBool true), pos);
rule_cons = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); rule_cons = Bindlib.box (ELit (LBool false), pos);
rule_parameter = rule_parameter =
(match have_parameter with (match have_parameter with
| Some typ -> Some (Var.make "dummy", typ) | Some typ -> Some (Var.make "dummy", typ)
@ -370,8 +371,8 @@ type var_or_states = WholeVar | States of StateName.t list
type scope = { type scope = {
scope_vars : var_or_states ScopeVarMap.t; scope_vars : var_or_states ScopeVarMap.t;
scope_sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; scope_sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t;
scope_uid : Scopelang.Ast.ScopeName.t; scope_uid : ScopeName.t;
scope_defs : scope_def ScopeDefMap.t; scope_defs : scope_def ScopeDefMap.t;
scope_assertions : assertion list; scope_assertions : assertion list;
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;

View File

@ -17,6 +17,7 @@
(** Abstract syntax tree of the desugared representation *) (** Abstract syntax tree of the desugared representation *)
open Utils open Utils
open Shared_ast
(** {1 Names, Maps and Keys} *) (** {1 Names, Maps and Keys} *)
@ -54,7 +55,7 @@ module ScopeDefSet : Set.S with type elt = ScopeDef.t
type location = type location =
| ScopeVar of ScopeVar.t Marked.pos * StateName.t option | ScopeVar of ScopeVar.t Marked.pos * StateName.t option
| SubScopeVar of | SubScopeVar of
Scopelang.Ast.ScopeName.t ScopeName.t
* Scopelang.Ast.SubScopeName.t Marked.pos * Scopelang.Ast.SubScopeName.t Marked.pos
* ScopeVar.t Marked.pos * ScopeVar.t Marked.pos
@ -68,20 +69,20 @@ and expr =
| ELocation of location | ELocation of location
| EVar of expr Bindlib.var | EVar of expr Bindlib.var
| EStruct of | EStruct of
Scopelang.Ast.StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t
| EStructAccess of | EStructAccess of
marked_expr * Scopelang.Ast.StructFieldName.t * Scopelang.Ast.StructName.t marked_expr * StructFieldName.t * StructName.t
| EEnumInj of | EEnumInj of
marked_expr * Scopelang.Ast.EnumConstructor.t * Scopelang.Ast.EnumName.t marked_expr * EnumConstructor.t * EnumName.t
| EMatch of | EMatch of
marked_expr marked_expr
* Scopelang.Ast.EnumName.t * EnumName.t
* marked_expr Scopelang.Ast.EnumConstructorMap.t * marked_expr Scopelang.Ast.EnumConstructorMap.t
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of | EAbs of
(expr, marked_expr) Bindlib.mbinder * Scopelang.Ast.typ Marked.pos list (expr, marked_expr) Bindlib.mbinder * Scopelang.Ast.typ Marked.pos list
| EApp of marked_expr * marked_expr list | EApp of marked_expr * marked_expr list
| EOp of Dcalc.Ast.operator | EOp of operator
| EDefault of marked_expr list * marked_expr * marked_expr | EDefault of marked_expr list * marked_expr * marked_expr
| EIfThenElse of marked_expr * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr
| EArray of marked_expr list | EArray of marked_expr list
@ -166,8 +167,8 @@ type var_or_states = WholeVar | States of StateName.t list
type scope = { type scope = {
scope_vars : var_or_states ScopeVarMap.t; scope_vars : var_or_states ScopeVarMap.t;
scope_sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; scope_sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t;
scope_uid : Scopelang.Ast.ScopeName.t; scope_uid : ScopeName.t;
scope_defs : scope_def ScopeDefMap.t; scope_defs : scope_def ScopeDefMap.t;
scope_assertions : assertion list; scope_assertions : assertion list;
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;

View File

@ -18,6 +18,7 @@
OCamlgraph} *) OCamlgraph} *)
open Utils open Utils
open Shared_ast
(** {1 Scope variables dependency graph} *) (** {1 Scope variables dependency graph} *)
@ -140,7 +141,7 @@ let check_for_cycle (scope : Ast.scope) (g : ScopeDependencies.t) : unit =
in in
Errors.raise_multispanned_error spans Errors.raise_multispanned_error spans
"Cyclic dependency detected between variables of scope %a!" "Cyclic dependency detected between variables of scope %a!"
Scopelang.Ast.ScopeName.format_t scope.scope_uid ScopeName.format_t scope.scope_uid
(** Builds the dependency graph of a particular scope *) (** Builds the dependency graph of a particular scope *)
let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =

View File

@ -17,6 +17,7 @@
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
open Utils open Utils
open Shared_ast
(** {1 Expression translation}*) (** {1 Expression translation}*)
@ -31,11 +32,11 @@ type ctx = {
let tag_with_log_entry let tag_with_log_entry
(e : Scopelang.Ast.expr Marked.pos) (e : Scopelang.Ast.expr Marked.pos)
(l : Dcalc.Ast.log_entry) (l : log_entry)
(markings : Utils.Uid.MarkedString.info list) : (markings : Utils.Uid.MarkedString.info list) :
Scopelang.Ast.expr Marked.pos = Scopelang.Ast.expr Marked.pos =
( Scopelang.Ast.EApp ( Scopelang.Ast.EApp
( ( Scopelang.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings))), ( ( Scopelang.Ast.EOp (Unop (Log (l, markings))),
Marked.get_mark e ), Marked.get_mark e ),
[e] ), [e] ),
Marked.get_mark e ) Marked.get_mark e )
@ -263,11 +264,11 @@ let rec rule_tree_to_expr
Scopelang.Ast.make_default ~pos:def_pos [] Scopelang.Ast.make_default ~pos:def_pos []
(* Here we insert the logging command that records when a (* Here we insert the logging command that records when a
decision is taken for the value of a variable. *) decision is taken for the value of a variable. *)
(tag_with_log_entry base_just Dcalc.Ast.PosRecordIfTrueBool []) (tag_with_log_entry base_just PosRecordIfTrueBool [])
base_cons) base_cons)
base_just_list base_cons_list) base_just_list base_cons_list)
(Scopelang.Ast.ELit (Dcalc.Ast.LBool false), def_pos) (Scopelang.Ast.ELit (LBool false), def_pos)
(Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, def_pos)) (Scopelang.Ast.ELit LEmptyError, def_pos))
(Bindlib.box_list (translate_and_unbox_list base_just_list)) (Bindlib.box_list (translate_and_unbox_list base_just_list))
(Bindlib.box_list (translate_and_unbox_list base_cons_list)) (Bindlib.box_list (translate_and_unbox_list base_cons_list))
in in
@ -281,7 +282,7 @@ let rec rule_tree_to_expr
Bindlib.box_apply2 Bindlib.box_apply2
(fun exceptions default_containing_base_cases -> (fun exceptions default_containing_base_cases ->
Scopelang.Ast.make_default exceptions Scopelang.Ast.make_default exceptions
(Scopelang.Ast.ELit (Dcalc.Ast.LBool true), def_pos) (Scopelang.Ast.ELit (LBool true), def_pos)
default_containing_base_cases) default_containing_base_cases)
exceptions default_containing_base_cases exceptions default_containing_base_cases
in in

View File

@ -200,10 +200,10 @@ let driver source_file (options : Cli.options) : int =
(Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) (Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
( scope_uid, ( scope_uid,
Option.get Option.get
(Dcalc.Ast.fold_left_scope_defs ~init:None (Shared_ast.Expr.fold_left_scope_defs ~init:None
~f:(fun acc scope_def _ -> ~f:(fun acc scope_def _ ->
if if
Dcalc.Ast.ScopeName.compare scope_def.scope_name Shared_ast.ScopeName.compare scope_def.scope_name
scope_uid scope_uid
= 0 = 0
then Some scope_def.scope_body then Some scope_def.scope_body
@ -212,7 +212,7 @@ let driver source_file (options : Cli.options) : int =
else else
let prgrm_dcalc_expr = let prgrm_dcalc_expr =
Bindlib.unbox Bindlib.unbox
(Dcalc.Ast.build_whole_program_expr ~box_expr:Dcalc.Ast.box_expr (Dcalc.Ast.build_whole_program_expr ~box_expr:Shared_ast.Expr.box
~make_abs:Dcalc.Ast.make_abs ~make_abs:Dcalc.Ast.make_abs
~make_let_in:Dcalc.Ast.make_let_in prgm scope_uid) ~make_let_in:Dcalc.Ast.make_let_in prgm scope_uid)
in in
@ -242,7 +242,7 @@ let driver source_file (options : Cli.options) : int =
Cli.debug_print "Starting interpretation..."; Cli.debug_print "Starting interpretation...";
let prgrm_dcalc_expr = let prgrm_dcalc_expr =
Bindlib.unbox Bindlib.unbox
(Dcalc.Ast.build_whole_program_expr ~box_expr:Dcalc.Ast.box_expr (Dcalc.Ast.build_whole_program_expr ~box_expr:Shared_ast.Expr.box
~make_abs:Dcalc.Ast.make_abs ~make_abs:Dcalc.Ast.make_abs
~make_let_in:Dcalc.Ast.make_let_in prgm scope_uid) ~make_let_in:Dcalc.Ast.make_let_in prgm scope_uid)
in in
@ -285,7 +285,7 @@ let driver source_file (options : Cli.options) : int =
Cli.debug_print "Optimizing lambda calculus..."; Cli.debug_print "Optimizing lambda calculus...";
Lcalc.Optimizations.optimize_program prgm Lcalc.Optimizations.optimize_program prgm
end end
else Lcalc.Ast.untype_program prgm else Shared_ast.Expr.untype_program prgm
in in
let prgm = let prgm =
if options.closure_conversion then ( if options.closure_conversion then (
@ -305,10 +305,10 @@ let driver source_file (options : Cli.options) : int =
(Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) (Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx)
( scope_uid, ( scope_uid,
Option.get Option.get
(Dcalc.Ast.fold_left_scope_defs ~init:None (Shared_ast.Expr.fold_left_scope_defs ~init:None
~f:(fun acc scope_def _ -> ~f:(fun acc scope_def _ ->
if if
Dcalc.Ast.ScopeName.compare scope_def.scope_name Shared_ast.ScopeName.compare scope_def.scope_name
scope_uid scope_uid
= 0 = 0
then Some scope_def.scope_body then Some scope_def.scope_body
@ -318,7 +318,7 @@ let driver source_file (options : Cli.options) : int =
let prgrm_lcalc_expr = let prgrm_lcalc_expr =
Bindlib.unbox Bindlib.unbox
(Dcalc.Ast.build_whole_program_expr (Dcalc.Ast.build_whole_program_expr
~box_expr:Lcalc.Ast.box_expr ~make_abs:Lcalc.Ast.make_abs ~box_expr:Shared_ast.Expr.box ~make_abs:Lcalc.Ast.make_abs
~make_let_in:Lcalc.Ast.make_let_in prgm scope_uid) ~make_let_in:Lcalc.Ast.make_let_in prgm scope_uid)
in in
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"

View File

@ -23,80 +23,10 @@ type lit = lcalc glit
type 'm expr = (lcalc, 'm mark) gexpr type 'm expr = (lcalc, 'm mark) gexpr
and 'm marked_expr = (lcalc, 'm mark) marked_gexpr and 'm marked_expr = (lcalc, 'm mark) marked_gexpr
type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic type 'm program = ('m expr, 'm) program_generic
type 'm var = 'm expr Var.t type 'm var = 'm expr Var.t
type 'm vars = 'm expr Var.vars type 'm vars = 'm expr Var.vars
(* <copy-paste from dcalc/ast.ml> *)
let evar v mark = Bindlib.box_apply (Marked.mark mark) (Bindlib.box_var v)
let etuple args s mark =
Bindlib.box_apply (fun args -> ETuple (args, s), mark) (Bindlib.box_list args)
let etupleaccess e1 i s typs mark =
Bindlib.box_apply (fun e1 -> ETupleAccess (e1, i, s, typs), mark) e1
let einj e1 i e_name typs mark =
Bindlib.box_apply (fun e1 -> EInj (e1, i, e_name, typs), mark) e1
let ematch arg arms e_name mark =
Bindlib.box_apply2
(fun arg arms -> EMatch (arg, arms, e_name), mark)
arg (Bindlib.box_list arms)
let earray args mark =
Bindlib.box_apply (fun args -> EArray args, mark) (Bindlib.box_list args)
let elit l mark = Bindlib.box (ELit l, mark)
let eabs binder typs mark =
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) binder
let eapp e1 args mark =
Bindlib.box_apply2
(fun e1 args -> EApp (e1, args), mark)
e1 (Bindlib.box_list args)
let eassert e1 mark = Bindlib.box_apply (fun e1 -> EAssert e1, mark) e1
let eop op mark = Bindlib.box (EOp op, mark)
let eifthenelse e1 e2 e3 pos =
Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), pos) e1 e2 e3
(* </copy-paste> *)
let eraise e1 pos = Bindlib.box (ERaise e1, pos)
let ecatch e1 exn e2 pos =
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2
let map_expr ctx ~f e = Expr.map ctx ~f e
let rec map_expr_top_down ~f e =
map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e)
let map_expr_marks ~f e =
map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
let untype_expr e =
map_expr_marks ~f:(fun m -> Untyped { pos = D.mark_pos m }) e
let untype_program prg =
{
prg with
D.scopes =
Bindlib.unbox
(D.map_exprs_in_scopes
~f:(fun e -> untype_expr e)
~varf:Var.translate prg.D.scopes);
}
(** See [Bindlib.box_term] documentation for why we are doing that. *)
let box_expr (e : 'm marked_expr) : 'm marked_expr Bindlib.box =
let rec id_t () e = map_expr () ~f:id_t e in
id_t () e
let make_var (x, mark) = let make_var (x, mark) =
Bindlib.box_apply (fun x -> x, mark) (Bindlib.box_var x) Bindlib.box_apply (fun x -> x, mark) (Bindlib.box_var x)
@ -110,7 +40,7 @@ let make_let_in x tau e1 e2 pos =
let m_e1 = Marked.get_mark (Bindlib.unbox e1) in let m_e1 = Marked.get_mark (Bindlib.unbox e1) in
let m_e2 = Marked.get_mark (Bindlib.unbox e2) in let m_e2 = Marked.get_mark (Bindlib.unbox e2) in
let m_abs = let m_abs =
D.map_mark2 Expr.map_mark2
(fun _ _ -> pos) (fun _ _ -> pos)
(fun m1 m2 -> TArrow (m1.ty, m2.ty), m1.pos) (fun m1 m2 -> TArrow (m1.ty, m2.ty), m1.pos)
m_e1 m_e2 m_e1 m_e2
@ -120,47 +50,47 @@ let make_let_in x tau e1 e2 pos =
let make_multiple_let_in xs taus e1s e2 pos = let make_multiple_let_in xs taus e1s e2 pos =
(* let m_e1s = List.map (fun e -> Marked.get_mark (Bindlib.unbox e)) e1s in *) (* let m_e1s = List.map (fun e -> Marked.get_mark (Bindlib.unbox e)) e1s in *)
let m_e1s = let m_e1s =
D.fold_marks List.hd Expr.fold_marks List.hd
(fun tys -> (fun tys ->
D.TTuple (List.map (fun t -> t.D.ty) tys, None), (List.hd tys).D.pos) TTuple (List.map (fun t -> t.ty) tys, None), (List.hd tys).pos)
(List.map (fun e -> Marked.get_mark (Bindlib.unbox e)) e1s) (List.map (fun e -> Marked.get_mark (Bindlib.unbox e)) e1s)
in in
let m_e2 = Marked.get_mark (Bindlib.unbox e2) in let m_e2 = Marked.get_mark (Bindlib.unbox e2) in
let m_abs = let m_abs =
D.map_mark2 Expr.map_mark2
(fun _ _ -> pos) (fun _ _ -> pos)
(fun m1 m2 -> Marked.mark pos (D.TArrow (m1.ty, m2.ty))) (fun m1 m2 -> Marked.mark pos (TArrow (m1.ty, m2.ty)))
m_e1s m_e2 m_e1s m_e2
in in
make_app (make_abs xs e2 taus m_abs) e1s m_e2 make_app (make_abs xs e2 taus m_abs) e1s m_e2
let ( let+ ) x f = Bindlib.box_apply f x let ( let+ ) x f = Bindlib.box_apply f x
let ( and+ ) x y = Bindlib.box_pair x y let ( and+ ) x y = Bindlib.box_pair x y
let option_enum : D.EnumName.t = D.EnumName.fresh ("eoption", Pos.no_pos) let option_enum : EnumName.t = EnumName.fresh ("eoption", Pos.no_pos)
let none_constr : D.EnumConstructor.t = let none_constr : EnumConstructor.t =
D.EnumConstructor.fresh ("ENone", Pos.no_pos) EnumConstructor.fresh ("ENone", Pos.no_pos)
let some_constr : D.EnumConstructor.t = let some_constr : EnumConstructor.t =
D.EnumConstructor.fresh ("ESome", Pos.no_pos) EnumConstructor.fresh ("ESome", Pos.no_pos)
let option_enum_config : (D.EnumConstructor.t * D.typ Marked.pos) list = let option_enum_config : (EnumConstructor.t * typ Marked.pos) list =
[none_constr, (D.TLit D.TUnit, Pos.no_pos); some_constr, (D.TAny, Pos.no_pos)] [none_constr, (TLit TUnit, Pos.no_pos); some_constr, (TAny, Pos.no_pos)]
(* FIXME: proper typing in all the constructors below *) (* FIXME: proper typing in all the constructors below *)
let make_none m = let make_none m =
let mark = Marked.mark m in let mark = Marked.mark m in
let tunit = D.TLit D.TUnit, D.mark_pos m in let tunit = TLit TUnit, Expr.mark_pos m in
Bindlib.box Bindlib.box
@@ mark @@ mark
@@ EInj @@ EInj
( Marked.mark ( Marked.mark
(D.map_mark (fun pos -> pos) (fun _ -> tunit) m) (Expr.map_mark (fun pos -> pos) (fun _ -> tunit) m)
(ELit LUnit), (ELit LUnit),
0, 0,
option_enum, option_enum,
[D.TLit D.TUnit, Pos.no_pos; D.TAny, Pos.no_pos] ) [TLit TUnit, Pos.no_pos; TAny, Pos.no_pos] )
let make_some e = let make_some e =
let m = Marked.get_mark @@ Bindlib.unbox e in let m = Marked.get_mark @@ Bindlib.unbox e in
@ -168,7 +98,7 @@ let make_some e =
let+ e in let+ e in
mark mark
@@ EInj @@ EInj
(e, 1, option_enum, [D.TLit D.TUnit, D.mark_pos m; D.TAny, D.mark_pos m]) (e, 1, option_enum, [TLit TUnit, Expr.mark_pos m; TAny, Expr.mark_pos m])
(** [make_matchopt_with_abs_arms arg e_none e_some] build an expression (** [make_matchopt_with_abs_arms arg e_none e_some] build an expression
[match arg with |None -> e_none | Some -> e_some] and requires e_some and [match arg with |None -> e_none | Some -> e_some] and requires e_some and
@ -187,7 +117,7 @@ let make_matchopt m v tau arg e_none e_some =
let x = Var.make "_" in let x = Var.make "_" in
make_matchopt_with_abs_arms arg make_matchopt_with_abs_arms arg
(make_abs (Array.of_list [x]) e_none [D.TLit D.TUnit, D.mark_pos m] m) (make_abs (Array.of_list [x]) e_none [TLit TUnit, Expr.mark_pos m] m)
(make_abs (Array.of_list [v]) e_some [tau] m) (make_abs (Array.of_list [v]) e_some [tau] m)
let handle_default = Var.make "handle_default" let handle_default = Var.make "handle_default"

View File

@ -15,7 +15,7 @@
the License. *) the License. *)
open Utils open Utils
include module type of Shared_ast open Shared_ast
(** Abstract syntax tree for the lambda calculus *) (** Abstract syntax tree for the lambda calculus *)
@ -26,114 +26,21 @@ type lit = lcalc glit
type 'm expr = (lcalc, 'm mark) gexpr type 'm expr = (lcalc, 'm mark) gexpr
and 'm marked_expr = (lcalc, 'm mark) marked_gexpr and 'm marked_expr = (lcalc, 'm mark) marked_gexpr
type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic type 'm program = ('m expr, 'm) program_generic
(** {1 Variable helpers} *) (** {1 Variable helpers} *)
type 'm var = 'm expr Var.t type 'm var = 'm expr Var.t
type 'm vars = 'm expr Var.vars type 'm vars = 'm expr Var.vars
(** {2 Program traversal} *)
val map_expr :
'a ->
f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) ->
('m1 expr, 'm2 mark) Marked.t ->
'm2 marked_expr Bindlib.box
(** See [Dcalc.Ast.map_expr] *)
val map_expr_top_down :
f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) ->
'm1 marked_expr ->
'm2 marked_expr Bindlib.box
(** See [Dcalc.Ast.map_expr_top_down] *)
val map_expr_marks :
f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box
(** See [Dcalc.Ast.map_expr_marks] *)
val untype_expr : 'm marked_expr -> Dcalc.Ast.untyped marked_expr Bindlib.box
val untype_program : 'm program -> Dcalc.Ast.untyped program
(** {1 Boxed constructors} *)
val evar : 'm expr Bindlib.var -> 'm mark -> 'm marked_expr Bindlib.box
val etuple :
'm marked_expr Bindlib.box list ->
Dcalc.Ast.StructName.t option ->
'm mark ->
'm marked_expr Bindlib.box
val etupleaccess :
'm marked_expr Bindlib.box ->
int ->
Dcalc.Ast.StructName.t option ->
Dcalc.Ast.typ Marked.pos list ->
'm mark ->
'm marked_expr Bindlib.box
val einj :
'm marked_expr Bindlib.box ->
int ->
Dcalc.Ast.EnumName.t ->
Dcalc.Ast.typ Marked.pos list ->
'm mark ->
'm marked_expr Bindlib.box
val ematch :
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box list ->
Dcalc.Ast.EnumName.t ->
'm mark ->
'm marked_expr Bindlib.box
val earray :
'm marked_expr Bindlib.box list -> 'm mark -> 'm marked_expr Bindlib.box
val elit : lit -> 'm mark -> 'm marked_expr Bindlib.box
val eabs :
('m expr, 'm marked_expr) Bindlib.mbinder Bindlib.box ->
Dcalc.Ast.typ Marked.pos list ->
'm mark ->
'm marked_expr Bindlib.box
val eapp :
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box list ->
'm mark ->
'm marked_expr Bindlib.box
val eassert :
'm marked_expr Bindlib.box -> 'm mark -> 'm marked_expr Bindlib.box
val eop : Dcalc.Ast.operator -> 'm mark -> 'm marked_expr Bindlib.box
val eifthenelse :
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box ->
'm mark ->
'm marked_expr Bindlib.box
val ecatch :
'm marked_expr Bindlib.box ->
except ->
'm marked_expr Bindlib.box ->
'm mark ->
'm marked_expr Bindlib.box
val eraise : except -> 'm mark -> 'm marked_expr Bindlib.box
(** {1 Language terms construction}*) (** {1 Language terms construction}*)
val make_var : ('m var, 'm) Dcalc.Ast.marked -> 'm marked_expr Bindlib.box val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box
val make_abs : val make_abs :
'm vars -> 'm vars ->
'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box ->
Dcalc.Ast.typ Marked.pos list -> typ Marked.pos list ->
'm mark -> 'm mark ->
'm marked_expr Bindlib.box 'm marked_expr Bindlib.box
@ -145,7 +52,7 @@ val make_app :
val make_let_in : val make_let_in :
'm var -> 'm var ->
Dcalc.Ast.typ Marked.pos -> typ Marked.pos ->
'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box ->
Pos.t -> Pos.t ->
@ -153,18 +60,18 @@ val make_let_in :
val make_multiple_let_in : val make_multiple_let_in :
'm vars -> 'm vars ->
Dcalc.Ast.typ Marked.pos list -> typ Marked.pos list ->
'm marked_expr Bindlib.box list -> 'm marked_expr Bindlib.box list ->
'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box ->
Pos.t -> Pos.t ->
'm marked_expr Bindlib.box 'm marked_expr Bindlib.box
val option_enum : Dcalc.Ast.EnumName.t val option_enum : EnumName.t
val none_constr : Dcalc.Ast.EnumConstructor.t val none_constr : EnumConstructor.t
val some_constr : Dcalc.Ast.EnumConstructor.t val some_constr : EnumConstructor.t
val option_enum_config : val option_enum_config :
(Dcalc.Ast.EnumConstructor.t * Dcalc.Ast.typ Marked.pos) list (EnumConstructor.t * typ Marked.pos) list
val make_none : 'm mark -> 'm marked_expr Bindlib.box val make_none : 'm mark -> 'm marked_expr Bindlib.box
val make_some : 'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box val make_some : 'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box
@ -178,7 +85,7 @@ val make_matchopt_with_abs_arms :
val make_matchopt : val make_matchopt :
'm mark -> 'm mark ->
'm var -> 'm var ->
Dcalc.Ast.typ Marked.pos -> typ Marked.pos ->
'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box ->
@ -186,8 +93,6 @@ val make_matchopt :
(** [e' = make_matchopt'' pos v e e_none e_some] Builds the term corresponding (** [e' = make_matchopt'' pos v e e_none e_some] Builds the term corresponding
to [match e with | None -> fun () -> e_none |Some -> fun v -> e_some]. *) to [match e with | None -> fun () -> e_none |Some -> fun v -> e_some]. *)
val box_expr : 'm marked_expr -> 'm marked_expr Bindlib.box
(** {1 Special symbols} *) (** {1 Special symbols} *)
val handle_default : untyped var val handle_default : untyped var

View File

@ -14,8 +14,9 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Ast
open Utils open Utils
open Shared_ast
open Ast
module D = Dcalc.Ast module D = Dcalc.Ast
(** TODO: This version is not yet debugged and ought to be specialized when (** TODO: This version is not yet debugged and ought to be specialized when
@ -127,7 +128,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
| EAbs (binder, typs) -> | EAbs (binder, typs) ->
(* λ x.t *) (* λ x.t *)
let binder_mark = Marked.get_mark e in let binder_mark = Marked.get_mark e in
let binder_pos = D.mark_pos binder_mark in let binder_pos = Expr.mark_pos binder_mark in
(* Converting the closure. *) (* Converting the closure. *)
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
(* t *) (* t *)
@ -141,7 +142,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
let code_var = Var.make ctx.name_context in let code_var = Var.make ctx.name_context in
(* code *) (* code *)
let inner_c_var = Var.make "env" in let inner_c_var = Var.make "env" in
let any_ty = Dcalc.Ast.TAny, binder_pos in let any_ty = TAny, binder_pos in
let new_closure_body = let new_closure_body =
make_multiple_let_in make_multiple_let_in
(Array.of_list extra_vars_list) (Array.of_list extra_vars_list)
@ -158,17 +159,17 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
binder_mark )) binder_mark ))
(Bindlib.box_var inner_c_var)) (Bindlib.box_var inner_c_var))
extra_vars_list) extra_vars_list)
new_body (D.mark_pos binder_mark) new_body (Expr.mark_pos binder_mark)
in in
let new_closure = let new_closure =
make_abs make_abs
(Array.concat [Array.make 1 inner_c_var; vars]) (Array.concat [Array.make 1 inner_c_var; vars])
new_closure_body new_closure_body
((Dcalc.Ast.TAny, binder_pos) :: typs) ((TAny, binder_pos) :: typs)
(Marked.get_mark e) (Marked.get_mark e)
in in
( make_let_in code_var ( make_let_in code_var
(Dcalc.Ast.TAny, D.pos e) (TAny, Expr.pos e)
new_closure new_closure
(Bindlib.box_apply2 (Bindlib.box_apply2
(fun code_var extra_vars -> (fun code_var extra_vars ->
@ -184,7 +185,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
(List.map (List.map
(fun extra_var -> Bindlib.box_var extra_var) (fun extra_var -> Bindlib.box_var extra_var)
extra_vars_list))) extra_vars_list)))
(D.pos e), (Expr.pos e),
extra_vars ) extra_vars )
| EApp ((EOp op, pos_op), args) -> | EApp ((EOp op, pos_op), args) ->
(* This corresponds to an operator call, which we don't want to (* This corresponds to an operator call, which we don't want to
@ -227,7 +228,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
in in
let call_expr = let call_expr =
make_let_in code_var make_let_in code_var
(Dcalc.Ast.TAny, D.pos e) (TAny, Expr.pos e)
(Bindlib.box_apply (Bindlib.box_apply
(fun env_var -> (fun env_var ->
( ETupleAccess ( ETupleAccess
@ -242,9 +243,9 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
Marked.get_mark e )) Marked.get_mark e ))
(Bindlib.box_var code_var) (Bindlib.box_var env_var) (Bindlib.box_var code_var) (Bindlib.box_var env_var)
(Bindlib.box_list new_args)) (Bindlib.box_list new_args))
(D.pos e) (Expr.pos e)
in in
( make_let_in env_var (Dcalc.Ast.TAny, D.pos e) new_e1 call_expr (D.pos e), ( make_let_in env_var (TAny, Expr.pos e) new_e1 call_expr (Expr.pos e),
free_vars ) free_vars )
| EAssert e1 -> | EAssert e1 ->
let new_e1, free_vars = aux e1 in let new_e1, free_vars = aux e1 in
@ -278,7 +279,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) :
let closure_conversion (p : 'm program) : 'm program Bindlib.box = let closure_conversion (p : 'm program) : 'm program Bindlib.box =
let new_scopes, _ = let new_scopes, _ =
D.fold_left_scope_defs Expr.fold_left_scope_defs
~f:(fun (acc_new_scopes, global_vars) scope scope_var -> ~f:(fun (acc_new_scopes, global_vars) scope scope_var ->
(* [acc_new_scopes] represents what has been translated in the past, it (* [acc_new_scopes] represents what has been translated in the past, it
needs a continuation to attach the rest of the translated scopes. *) needs a continuation to attach the rest of the translated scopes. *)
@ -289,12 +290,12 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
let ctx = let ctx =
{ {
name_context = name_context =
Marked.unmark (Dcalc.Ast.ScopeName.get_info scope.scope_name); Marked.unmark (ScopeName.get_info scope.scope_name);
globally_bound_vars = global_vars; globally_bound_vars = global_vars;
} }
in in
let new_scope_lets = let new_scope_lets =
D.map_exprs_in_scope_lets Expr.map_exprs_in_scope_lets
~f:(closure_conversion_expr ctx) ~f:(closure_conversion_expr ctx)
~varf:(fun v -> v) ~varf:(fun v -> v)
scope_body_expr scope_body_expr
@ -306,7 +307,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
acc_new_scopes acc_new_scopes
(Bindlib.box_apply2 (Bindlib.box_apply2
(fun new_scope_body_expr next -> (fun new_scope_body_expr next ->
D.ScopeDef ScopeDef
{ {
scope with scope with
scope_body = scope_body =
@ -327,4 +328,4 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
in in
Bindlib.box_apply Bindlib.box_apply
(fun new_scopes -> { p with scopes = new_scopes }) (fun new_scopes -> { p with scopes = new_scopes })
(new_scopes (Bindlib.box D.Nil)) (new_scopes (Bindlib.box Nil))

View File

@ -25,26 +25,26 @@ type 'm ctx = ('m D.expr, 'm A.expr Var.t) Var.Map.t
let translate_lit (l : D.lit) : 'm A.expr = let translate_lit (l : D.lit) : 'm A.expr =
match l with match l with
| D.LBool l -> A.ELit (A.LBool l) | LBool l -> ELit (LBool l)
| D.LInt i -> A.ELit (A.LInt i) | LInt i -> ELit (LInt i)
| D.LRat r -> A.ELit (A.LRat r) | LRat r -> ELit (LRat r)
| D.LMoney m -> A.ELit (A.LMoney m) | LMoney m -> ELit (LMoney m)
| D.LUnit -> A.ELit A.LUnit | LUnit -> ELit LUnit
| D.LDate d -> A.ELit (A.LDate d) | LDate d -> ELit (LDate d)
| D.LDuration d -> A.ELit (A.LDuration d) | LDuration d -> ELit (LDuration d)
| D.LEmptyError -> A.ERaise A.EmptyError | LEmptyError -> ERaise EmptyError
let thunk_expr (e : 'm A.marked_expr Bindlib.box) (mark : 'm A.mark) : let thunk_expr (e : 'm A.marked_expr Bindlib.box) (mark : 'm mark) :
'm A.marked_expr Bindlib.box = 'm A.marked_expr Bindlib.box =
let dummy_var = Var.make "_" in let dummy_var = Var.make "_" in
A.make_abs [| dummy_var |] e [D.TAny, D.mark_pos mark] mark A.make_abs [| dummy_var |] e [TAny, Expr.mark_pos mark] mark
let rec translate_default let rec translate_default
(ctx : 'm ctx) (ctx : 'm ctx)
(exceptions : 'm D.marked_expr list) (exceptions : 'm D.marked_expr list)
(just : 'm D.marked_expr) (just : 'm D.marked_expr)
(cons : 'm D.marked_expr) (cons : 'm D.marked_expr)
(mark_default : 'm D.mark) : 'm A.marked_expr Bindlib.box = (mark_default : 'm mark) : 'm A.marked_expr Bindlib.box =
let exceptions = let exceptions =
List.map List.map
(fun except -> thunk_expr (translate_expr ctx except) mark_default) (fun except -> thunk_expr (translate_expr ctx except) mark_default)
@ -54,7 +54,7 @@ let rec translate_default
A.make_app A.make_app
(A.make_var (Var.translate A.handle_default, mark_default)) (A.make_var (Var.translate A.handle_default, mark_default))
[ [
A.earray exceptions mark_default; Expr.earray exceptions mark_default;
thunk_expr (translate_expr ctx just) mark_default; thunk_expr (translate_expr ctx just) mark_default;
thunk_expr (translate_expr ctx cons) mark_default; thunk_expr (translate_expr ctx cons) mark_default;
] ]
@ -65,34 +65,34 @@ let rec translate_default
and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) : and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) :
'm A.marked_expr Bindlib.box = 'm A.marked_expr Bindlib.box =
match Marked.unmark e with match Marked.unmark e with
| D.EVar v -> A.make_var (Var.Map.find v ctx, Marked.get_mark e) | EVar v -> A.make_var (Var.Map.find v ctx, Marked.get_mark e)
| D.ETuple (args, s) -> | ETuple (args, s) ->
A.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e) Expr.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e)
| D.ETupleAccess (e1, i, s, ts) -> | ETupleAccess (e1, i, s, ts) ->
A.etupleaccess (translate_expr ctx e1) i s ts (Marked.get_mark e) Expr.etupleaccess (translate_expr ctx e1) i s ts (Marked.get_mark e)
| D.EInj (e1, i, en, ts) -> | EInj (e1, i, en, ts) ->
A.einj (translate_expr ctx e1) i en ts (Marked.get_mark e) Expr.einj (translate_expr ctx e1) i en ts (Marked.get_mark e)
| D.EMatch (e1, cases, en) -> | EMatch (e1, cases, en) ->
A.ematch (translate_expr ctx e1) Expr.ematch (translate_expr ctx e1)
(List.map (translate_expr ctx) cases) (List.map (translate_expr ctx) cases)
en (Marked.get_mark e) en (Marked.get_mark e)
| D.EArray es -> | EArray es ->
A.earray (List.map (translate_expr ctx) es) (Marked.get_mark e) Expr.earray (List.map (translate_expr ctx) es) (Marked.get_mark e)
| D.ELit l -> Bindlib.box (Marked.same_mark_as (translate_lit l) e) | ELit l -> Bindlib.box (Marked.same_mark_as (translate_lit l) e)
| D.EOp op -> A.eop op (Marked.get_mark e) | EOp op -> Expr.eop op (Marked.get_mark e)
| D.EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
A.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2) Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2)
(translate_expr ctx e3) (Marked.get_mark e) (translate_expr ctx e3) (Marked.get_mark e)
| D.EAssert e1 -> A.eassert (translate_expr ctx e1) (Marked.get_mark e) | EAssert e1 -> Expr.eassert (translate_expr ctx e1) (Marked.get_mark e)
| D.ErrorOnEmpty arg -> | ErrorOnEmpty arg ->
A.ecatch (translate_expr ctx arg) A.EmptyError Expr.ecatch (translate_expr ctx arg) EmptyError
(Bindlib.box (Marked.same_mark_as (A.ERaise A.NoValueProvided) e)) (Bindlib.box (Marked.same_mark_as (ERaise NoValueProvided) e))
(Marked.get_mark e) (Marked.get_mark e)
| D.EApp (e1, args) -> | EApp (e1, args) ->
A.eapp (translate_expr ctx e1) Expr.eapp (translate_expr ctx e1)
(List.map (translate_expr ctx) args) (List.map (translate_expr ctx) args)
(Marked.get_mark e) (Marked.get_mark e)
| D.EAbs (binder, ts) -> | EAbs (binder, ts) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars = let ctx, lc_vars =
Array.fold_right Array.fold_right
@ -105,24 +105,24 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) :
let new_body = translate_expr ctx body in let new_body = translate_expr ctx body in
let new_binder = Bindlib.bind_mvar lc_vars new_body in let new_binder = Bindlib.bind_mvar lc_vars new_body in
Bindlib.box_apply Bindlib.box_apply
(fun new_binder -> Marked.same_mark_as (A.EAbs (new_binder, ts)) e) (fun new_binder -> Marked.same_mark_as (EAbs (new_binder, ts)) e)
new_binder new_binder
| D.EDefault ([exn], just, cons) when !Cli.optimize_flag -> | EDefault ([exn], just, cons) when !Cli.optimize_flag ->
A.ecatch (translate_expr ctx exn) A.EmptyError Expr.ecatch (translate_expr ctx exn) EmptyError
(A.eifthenelse (translate_expr ctx just) (translate_expr ctx cons) (Expr.eifthenelse (translate_expr ctx just) (translate_expr ctx cons)
(Bindlib.box (Marked.same_mark_as (A.ERaise A.EmptyError) e)) (Bindlib.box (Marked.same_mark_as (ERaise EmptyError) e))
(Marked.get_mark e)) (Marked.get_mark e))
(Marked.get_mark e) (Marked.get_mark e)
| D.EDefault (exceptions, just, cons) -> | EDefault (exceptions, just, cons) ->
translate_default ctx exceptions just cons (Marked.get_mark e) translate_default ctx exceptions just cons (Marked.get_mark e)
let rec translate_scope_lets let rec translate_scope_lets
(decl_ctx : D.decl_ctx) (decl_ctx : decl_ctx)
(ctx : 'm ctx) (ctx : 'm ctx)
(scope_lets : ('m D.expr, 'm) D.scope_body_expr) : (scope_lets : ('m D.expr, 'm) scope_body_expr) :
('m A.expr, 'm) D.scope_body_expr Bindlib.box = ('m A.expr, 'm) scope_body_expr Bindlib.box =
match scope_lets with match scope_lets with
| Result e -> Bindlib.box_apply (fun e -> D.Result e) (translate_expr ctx e) | Result e -> Bindlib.box_apply (fun e -> Result e) (translate_expr ctx e)
| ScopeLet scope_let -> | ScopeLet scope_let ->
let old_scope_let_var, scope_let_next = let old_scope_let_var, scope_let_next =
Bindlib.unbind scope_let.scope_let_next Bindlib.unbind scope_let.scope_let_next
@ -134,26 +134,26 @@ let rec translate_scope_lets
let new_scope_next = Bindlib.bind_var new_scope_let_var new_scope_next in let new_scope_next = Bindlib.bind_var new_scope_let_var new_scope_next in
Bindlib.box_apply2 Bindlib.box_apply2
(fun new_scope_next new_scope_let_expr -> (fun new_scope_next new_scope_let_expr ->
D.ScopeLet ScopeLet
{ {
scope_let_typ = scope_let.D.scope_let_typ; scope_let_typ = scope_let.scope_let_typ;
scope_let_kind = scope_let.D.scope_let_kind; scope_let_kind = scope_let.scope_let_kind;
scope_let_pos = scope_let.D.scope_let_pos; scope_let_pos = scope_let.scope_let_pos;
scope_let_next = new_scope_next; scope_let_next = new_scope_next;
scope_let_expr = new_scope_let_expr; scope_let_expr = new_scope_let_expr;
}) })
new_scope_next new_scope_let_expr new_scope_next new_scope_let_expr
let rec translate_scopes let rec translate_scopes
(decl_ctx : D.decl_ctx) (decl_ctx : decl_ctx)
(ctx : 'm ctx) (ctx : 'm ctx)
(scopes : ('m D.expr, 'm) D.scopes) : ('m A.expr, 'm) D.scopes Bindlib.box = (scopes : ('m D.expr, 'm) scopes) : ('m A.expr, 'm) scopes Bindlib.box =
match scopes with match scopes with
| Nil -> Bindlib.box D.Nil | Nil -> Bindlib.box Nil
| ScopeDef scope_def -> | ScopeDef scope_def ->
let old_scope_var, scope_next = Bindlib.unbind scope_def.scope_next in let old_scope_var, scope_next = Bindlib.unbind scope_def.scope_next in
let new_scope_var = let new_scope_var =
Var.make (Marked.unmark (D.ScopeName.get_info scope_def.scope_name)) Var.make (Marked.unmark (ScopeName.get_info scope_def.scope_name))
in in
let old_scope_input_var, scope_body_expr = let old_scope_input_var, scope_body_expr =
Bindlib.unbind scope_def.scope_body.scope_body_expr Bindlib.unbind scope_def.scope_body.scope_body_expr
@ -166,11 +166,11 @@ let rec translate_scopes
let new_scope_body_expr = let new_scope_body_expr =
Bindlib.bind_var new_scope_input_var new_scope_body_expr Bindlib.bind_var new_scope_input_var new_scope_body_expr
in in
let new_scope : ('m A.expr, 'm) D.scope_body Bindlib.box = let new_scope : ('m A.expr, 'm) scope_body Bindlib.box =
Bindlib.box_apply Bindlib.box_apply
(fun new_scope_body_expr -> (fun new_scope_body_expr ->
{ {
D.scope_body_input_struct = scope_body_input_struct =
scope_def.scope_body.scope_body_input_struct; scope_def.scope_body.scope_body_input_struct;
scope_body_output_struct = scope_body_output_struct =
scope_def.scope_body.scope_body_output_struct; scope_def.scope_body.scope_body_output_struct;
@ -185,7 +185,7 @@ let rec translate_scopes
in in
Bindlib.box_apply2 Bindlib.box_apply2
(fun new_scope scope_next -> (fun new_scope scope_next ->
D.ScopeDef ScopeDef
{ {
scope_name = scope_def.scope_name; scope_name = scope_def.scope_name;
scope_body = new_scope; scope_body = new_scope;

View File

@ -61,7 +61,7 @@ let pp_info (fmt : Format.formatter) (info : 'm info) =
info.is_pure info.is_pure
type 'm ctx = { type 'm ctx = {
decl_ctx : D.decl_ctx; decl_ctx : decl_ctx;
vars : ('m D.expr, 'm info) Var.Map.t; vars : ('m D.expr, 'm info) Var.Map.t;
(** information context about variables in the current scope *) (** information context about variables in the current scope *)
} }
@ -95,7 +95,7 @@ let find ?(info : string = "none") (n : 'm D.var) (ctx : 'm ctx) : 'm info =
var, creating a unique corresponding variable in Lcalc, with the var, creating a unique corresponding variable in Lcalc, with the
corresponding expression, and the boolean is_pure. It is usefull for corresponding expression, and the boolean is_pure. It is usefull for
debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *) debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *)
let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx) let add_var (mark : 'm mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx)
: 'm ctx = : 'm ctx =
let new_var = Var.make (Bindlib.name_of var) in let new_var = Var.make (Bindlib.name_of var) in
let expr = A.make_var (new_var, mark) in let expr = A.make_var (new_var, mark) in
@ -115,33 +115,33 @@ let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx)
Since positions where there is thunked expressions is exactly where we will Since positions where there is thunked expressions is exactly where we will
put option expressions. Hence, the transformation simply reduce [unit -> 'a] put option expressions. Hence, the transformation simply reduce [unit -> 'a]
into ['a option] recursivly. There is no polymorphism inside catala. *) into ['a option] recursivly. There is no polymorphism inside catala. *)
let rec translate_typ (tau : D.typ Marked.pos) : D.typ Marked.pos = let rec translate_typ (tau : typ Marked.pos) : typ Marked.pos =
(Fun.flip Marked.same_mark_as) (Fun.flip Marked.same_mark_as)
tau tau
begin begin
match Marked.unmark tau with match Marked.unmark tau with
| D.TLit l -> D.TLit l | TLit l -> TLit l
| D.TTuple (ts, s) -> D.TTuple (List.map translate_typ ts, s) | TTuple (ts, s) -> TTuple (List.map translate_typ ts, s)
| D.TEnum (ts, en) -> D.TEnum (List.map translate_typ ts, en) | TEnum (ts, en) -> TEnum (List.map translate_typ ts, en)
| D.TAny -> D.TAny | TAny -> TAny
| D.TArray ts -> D.TArray (translate_typ ts) | TArray ts -> TArray (translate_typ ts)
(* catala is not polymorphic *) (* catala is not polymorphic *)
| D.TArrow ((D.TLit D.TUnit, pos_unit), t2) -> | TArrow ((TLit TUnit, pos_unit), t2) ->
D.TEnum ([D.TLit D.TUnit, pos_unit; translate_typ t2], A.option_enum) TEnum ([TLit TUnit, pos_unit; translate_typ t2], A.option_enum)
(* D.TAny *) (* TAny *)
| D.TArrow (t1, t2) -> D.TArrow (translate_typ t1, translate_typ t2) | TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2)
end end
let translate_lit (l : D.lit) (pos : Pos.t) : A.lit = let translate_lit (l : D.lit) (pos : Pos.t) : A.lit =
match l with match l with
| D.LBool l -> A.LBool l | LBool l -> LBool l
| D.LInt i -> A.LInt i | LInt i -> LInt i
| D.LRat r -> A.LRat r | LRat r -> LRat r
| D.LMoney m -> A.LMoney m | LMoney m -> LMoney m
| D.LUnit -> A.LUnit | LUnit -> LUnit
| D.LDate d -> A.LDate d | LDate d -> LDate d
| D.LDuration d -> A.LDuration d | LDuration d -> LDuration d
| D.LEmptyError -> | LEmptyError ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: An empty error was found in a place that shouldn't be \ "Internal Error: An empty error was found in a place that shouldn't be \
possible." possible."
@ -171,7 +171,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
(* empty-producing/using terms. We hoist those. (D.EVar in some cases, (* empty-producing/using terms. We hoist those. (D.EVar in some cases,
EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure
about assert. *) about assert. *)
| D.EVar v -> | EVar v ->
(* todo: for now, every unpure (such that [is_pure] is [false] in the (* todo: for now, every unpure (such that [is_pure] is [false] in the
current context) is thunked, hence matched in the next case. This current context) is thunked, hence matched in the next case. This
assumption can change in the future, and this case is here for this assumption can change in the future, and this case is here for this
@ -183,20 +183,20 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
Print.format_var v'; *) Print.format_var v'; *)
A.make_var (v', pos), Var.Map.singleton v' e A.make_var (v', pos), Var.Map.singleton v' e
else (find ~info:"should never happend" v ctx).expr, Var.Map.empty else (find ~info:"should never happend" v ctx).expr, Var.Map.empty
| D.EApp ((D.EVar v, p), [(D.ELit D.LUnit, _)]) -> | EApp ((EVar v, p), [(ELit LUnit, _)]) ->
if not (find ~info:"search for a variable" v ctx).is_pure then if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = Var.make (Bindlib.name_of v) in let v' = Var.make (Bindlib.name_of v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Dcalc.Print.format_var v created a variable %a to replace it" Dcalc.Print.format_var v
Print.format_var v'; *) Print.format_var v'; *)
A.make_var (v', pos), Var.Map.singleton v' (D.EVar v, p) A.make_var (v', pos), Var.Map.singleton v' (EVar v, p)
else else
Errors.raise_spanned_error (D.pos e) Errors.raise_spanned_error (Expr.pos e)
"Internal error: an pure variable was found in an unpure environment." "Internal error: an pure variable was found in an unpure environment."
| D.EDefault (_exceptions, _just, _cons) -> | EDefault (_exceptions, _just, _cons) ->
let v' = Var.make "default_term" in let v' = Var.make "default_term" in
A.make_var (v', pos), Var.Map.singleton v' e A.make_var (v', pos), Var.Map.singleton v' e
| D.ELit D.LEmptyError -> | ELit LEmptyError ->
let v' = Var.make "empty_litteral" in let v' = Var.make "empty_litteral" in
A.make_var (v', pos), Var.Map.singleton v' e A.make_var (v', pos), Var.Map.singleton v' e
(* This one is a very special case. It transform an unpure expression (* This one is a very special case. It transform an unpure expression
@ -210,29 +210,29 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
( A.make_matchopt_with_abs_arms arg' ( A.make_matchopt_with_abs_arms arg'
(A.make_abs [| silent_var |] (A.make_abs [| silent_var |]
(Bindlib.box (A.ERaise A.NoValueProvided, pos)) (Bindlib.box (ERaise NoValueProvided, pos))
[D.TAny, D.pos e] [TAny, Expr.pos e]
pos) pos)
(A.make_abs [| x |] (A.make_var (x, pos)) [D.TAny, D.pos e] pos), (A.make_abs [| x |] (A.make_var (x, pos)) [TAny, Expr.pos e] pos),
Var.Map.empty ) Var.Map.empty )
(* pure terms *) (* pure terms *)
| D.ELit l -> A.elit (translate_lit l (D.pos e)) pos, Var.Map.empty | ELit l -> Expr.elit (translate_lit l (Expr.pos e)) pos, Var.Map.empty
| D.EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
let e2', h2 = translate_and_hoist ctx e2 in let e2', h2 = translate_and_hoist ctx e2 in
let e3', h3 = translate_and_hoist ctx e3 in let e3', h3 = translate_and_hoist ctx e3 in
let e' = A.eifthenelse e1' e2' e3' pos in let e' = Expr.eifthenelse e1' e2' e3' pos in
(*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' = (*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' =
e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *) e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *)
e', disjoint_union_maps (D.pos e) [h1; h2; h3] e', disjoint_union_maps (Expr.pos e) [h1; h2; h3]
| D.EAssert e1 -> | EAssert e1 ->
(* same behavior as in the ICFP paper: if e1 is empty, then no error is (* same behavior as in the ICFP paper: if e1 is empty, then no error is
raised. *) raised. *)
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
A.eassert e1' pos, h1 Expr.eassert e1' pos, h1
| D.EAbs (binder, ts) -> | EAbs (binder, ts) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars = let ctx, lc_vars =
ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) -> ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) ->
@ -254,7 +254,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
let new_binder = Bindlib.bind_mvar lc_vars new_body in let new_binder = Bindlib.bind_mvar lc_vars new_body in
( Bindlib.box_apply ( Bindlib.box_apply
(fun new_binder -> A.EAbs (new_binder, List.map translate_typ ts), pos) (fun new_binder -> EAbs (new_binder, List.map translate_typ ts), pos)
new_binder, new_binder,
hoists ) hoists )
| EApp (e1, args) -> | EApp (e1, args) ->
@ -263,23 +263,23 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
args |> List.map (translate_and_hoist ctx) |> List.split args |> List.map (translate_and_hoist ctx) |> List.split
in in
let hoists = disjoint_union_maps (D.pos e) (h1 :: h_args) in let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in
let e' = A.eapp e1' args' pos in let e' = Expr.eapp e1' args' pos in
e', hoists e', hoists
| ETuple (args, s) -> | ETuple (args, s) ->
let args', h_args = let args', h_args =
args |> List.map (translate_and_hoist ctx) |> List.split args |> List.map (translate_and_hoist ctx) |> List.split
in in
let hoists = disjoint_union_maps (D.pos e) h_args in let hoists = disjoint_union_maps (Expr.pos e) h_args in
A.etuple args' s pos, hoists Expr.etuple args' s pos, hoists
| ETupleAccess (e1, i, s, ts) -> | ETupleAccess (e1, i, s, ts) ->
let e1', hoists = translate_and_hoist ctx e1 in let e1', hoists = translate_and_hoist ctx e1 in
let e1' = A.etupleaccess e1' i s ts pos in let e1' = Expr.etupleaccess e1' i s ts pos in
e1', hoists e1', hoists
| EInj (e1, i, en, ts) -> | EInj (e1, i, en, ts) ->
let e1', hoists = translate_and_hoist ctx e1 in let e1', hoists = translate_and_hoist ctx e1 in
let e1' = A.einj e1' i en ts pos in let e1' = Expr.einj e1' i en ts pos in
e1', hoists e1', hoists
| EMatch (e1, cases, en) -> | EMatch (e1, cases, en) ->
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
@ -287,14 +287,14 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) :
cases |> List.map (translate_and_hoist ctx) |> List.split cases |> List.map (translate_and_hoist ctx) |> List.split
in in
let hoists = disjoint_union_maps (D.pos e) (h1 :: h_cases) in let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_cases) in
let e' = A.ematch e1' cases' en pos in let e' = Expr.ematch e1' cases' en pos in
e', hoists e', hoists
| EArray es -> | EArray es ->
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in
A.earray es' pos, disjoint_union_maps (D.pos e) hoists Expr.earray es' pos, disjoint_union_maps (Expr.pos e) hoists
| EOp op -> Bindlib.box (A.EOp op, pos), Var.Map.empty | EOp op -> Bindlib.box (EOp op, pos), Var.Map.empty
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr) and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
: 'm A.marked_expr Bindlib.box = : 'm A.marked_expr Bindlib.box =
@ -315,8 +315,8 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
match hoist with match hoist with
(* Here we have to handle only the cases appearing in hoists, as defined (* Here we have to handle only the cases appearing in hoists, as defined
the [translate_and_hoist] function. *) the [translate_and_hoist] function. *)
| D.EVar v -> (find ~info:"should never happend" v ctx).expr | EVar v -> (find ~info:"should never happend" v ctx).expr
| D.EDefault (excep, just, cons) -> | EDefault (excep, just, cons) ->
let excep' = List.map (translate_expr ctx) excep in let excep' = List.map (translate_expr ctx) excep in
let just' = translate_expr ctx just in let just' = translate_expr ctx just in
let cons' = translate_expr ctx cons in let cons' = translate_expr ctx cons in
@ -325,14 +325,14 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
(A.make_var (Var.translate A.handle_default_opt, mark_hoist)) (A.make_var (Var.translate A.handle_default_opt, mark_hoist))
[ [
Bindlib.box_apply Bindlib.box_apply
(fun excep' -> A.EArray excep', mark_hoist) (fun excep' -> EArray excep', mark_hoist)
(Bindlib.box_list excep'); (Bindlib.box_list excep');
just'; just';
cons'; cons';
] ]
mark_hoist mark_hoist
| D.ELit D.LEmptyError -> A.make_none mark_hoist | ELit LEmptyError -> A.make_none mark_hoist
| D.EAssert arg -> | EAssert arg ->
let arg' = translate_expr ctx arg in let arg' = translate_expr ctx arg in
(* [ match arg with | None -> raise NoValueProvided | Some v -> assert (* [ match arg with | None -> raise NoValueProvided | Some v -> assert
@ -342,17 +342,17 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
A.make_matchopt_with_abs_arms arg' A.make_matchopt_with_abs_arms arg'
(A.make_abs [| silent_var |] (A.make_abs [| silent_var |]
(Bindlib.box (A.ERaise A.NoValueProvided, mark_hoist)) (Bindlib.box (ERaise NoValueProvided, mark_hoist))
[D.TAny, D.mark_pos mark_hoist] [TAny, Expr.mark_pos mark_hoist]
mark_hoist) mark_hoist)
(A.make_abs [| x |] (A.make_abs [| x |]
(Bindlib.box_apply (Bindlib.box_apply
(fun arg -> A.EAssert arg, mark_hoist) (fun arg -> EAssert arg, mark_hoist)
(A.make_var (x, mark_hoist))) (A.make_var (x, mark_hoist)))
[D.TAny, D.mark_pos mark_hoist] [TAny, Expr.mark_pos mark_hoist]
mark_hoist) mark_hoist)
| _ -> | _ ->
Errors.raise_spanned_error (D.mark_pos mark_hoist) Errors.raise_spanned_error (Expr.mark_pos mark_hoist)
"Internal Error: An term was found in a position where it should \ "Internal Error: An term was found in a position where it should \
not be" not be"
in in
@ -362,23 +362,23 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr)
(* Cli.debug_print @@ Format.asprintf "build matchopt using %a" (* Cli.debug_print @@ Format.asprintf "build matchopt using %a"
Print.format_var v; *) Print.format_var v; *)
A.make_matchopt mark_hoist v A.make_matchopt mark_hoist v
(D.TAny, D.mark_pos mark_hoist) (TAny, Expr.mark_pos mark_hoist)
c' (A.make_none mark_hoist) acc) c' (A.make_none mark_hoist) acc)
let rec translate_scope_let let rec translate_scope_let
(ctx : 'm ctx) (ctx : 'm ctx)
(lets : ('m D.expr, 'm) D.scope_body_expr) : (lets : ('m D.expr, 'm) scope_body_expr) :
('m A.expr, 'm) D.scope_body_expr Bindlib.box = ('m A.expr, 'm) scope_body_expr Bindlib.box =
match lets with match lets with
| Result e -> | Result e ->
Bindlib.box_apply Bindlib.box_apply
(fun e -> D.Result e) (fun e -> Result e)
(translate_expr ~append_esome:false ctx e) (translate_expr ~append_esome:false ctx e)
| ScopeLet | ScopeLet
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ; scope_let_typ = typ;
scope_let_expr = D.EAbs (binder, _), emark; scope_let_expr = EAbs (binder, _), emark;
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos; scope_let_pos = pos;
} -> } ->
@ -390,13 +390,13 @@ let rec translate_scope_let
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var
var; *) var; *)
let vmark = D.map_mark (fun _ -> pos) (fun _ -> typ) emark in let vmark = Expr.map_mark (fun _ -> pos) (fun _ -> typ) emark in
let ctx' = add_var vmark var var_is_pure ctx in let ctx' = add_var vmark var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var = (find ~info:"variable that was just created" var ctx').var in
let new_next = translate_scope_let ctx' next in let new_next = translate_scope_let ctx' next in
Bindlib.box_apply2 Bindlib.box_apply2
(fun new_expr new_next -> (fun new_expr new_next ->
D.ScopeLet ScopeLet
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_typ = translate_typ typ; scope_let_typ = translate_typ typ;
@ -410,7 +410,7 @@ let rec translate_scope_let
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ; scope_let_typ = typ;
scope_let_expr = (D.ErrorOnEmpty _, emark) as expr; scope_let_expr = (ErrorOnEmpty _, emark) as expr;
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos; scope_let_pos = pos;
} -> } ->
@ -419,12 +419,12 @@ let rec translate_scope_let
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var
var; *) var; *)
let vmark = D.map_mark (fun _ -> pos) (fun _ -> typ) emark in let vmark = Expr.map_mark (fun _ -> pos) (fun _ -> typ) emark in
let ctx' = add_var vmark var var_is_pure ctx in let ctx' = add_var vmark var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var = (find ~info:"variable that was just created" var ctx').var in
Bindlib.box_apply2 Bindlib.box_apply2
(fun new_expr new_next -> (fun new_expr new_next ->
D.ScopeLet ScopeLet
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_typ = translate_typ typ; scope_let_typ = translate_typ typ;
@ -463,7 +463,7 @@ let rec translate_scope_let
thunked, then the variable is context. If it's not thunked, it's a thunked, then the variable is context. If it's not thunked, it's a
regular input. *) regular input. *)
match Marked.unmark typ with match Marked.unmark typ with
| D.TArrow ((D.TLit D.TUnit, _), _) -> false | TArrow ((TLit TUnit, _), _) -> false
| _ -> true) | _ -> true)
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope | ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
| DestructuringSubScopeResults | Assertion -> | DestructuringSubScopeResults | Assertion ->
@ -473,13 +473,13 @@ let rec translate_scope_let
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var
var; *) var; *)
let vmark = let vmark =
D.map_mark (fun _ -> pos) (fun _ -> typ) (Marked.get_mark expr) Expr.map_mark (fun _ -> pos) (fun _ -> typ) (Marked.get_mark expr)
in in
let ctx' = add_var vmark var var_is_pure ctx in let ctx' = add_var vmark var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var = (find ~info:"variable that was just created" var ctx').var in
Bindlib.box_apply2 Bindlib.box_apply2
(fun new_expr new_next -> (fun new_expr new_next ->
D.ScopeLet ScopeLet
{ {
scope_let_kind = kind; scope_let_kind = kind;
scope_let_typ = translate_typ typ; scope_let_typ = translate_typ typ;
@ -493,8 +493,8 @@ let rec translate_scope_let
let translate_scope_body let translate_scope_body
(scope_pos : Pos.t) (scope_pos : Pos.t)
(ctx : 'm ctx) (ctx : 'm ctx)
(body : ('m D.expr, 'm) D.scope_body) : (body : ('m D.expr, 'm) scope_body) :
('m A.expr, 'm) D.scope_body Bindlib.box = ('m A.expr, 'm) scope_body Bindlib.box =
match body with match body with
| { | {
scope_body_expr = result; scope_body_expr = result;
@ -507,23 +507,23 @@ let translate_scope_body
match lets with match lets with
| Result e | ScopeLet { scope_let_expr = e; _ } -> Marked.get_mark e | Result e | ScopeLet { scope_let_expr = e; _ } -> Marked.get_mark e
in in
D.map_mark (fun _ -> scope_pos) (fun ty -> ty) m Expr.map_mark (fun _ -> scope_pos) (fun ty -> ty) m
in in
let ctx' = add_var vmark v true ctx in let ctx' = add_var vmark v true ctx in
let v' = (find ~info:"variable that was just created" v ctx').var in let v' = (find ~info:"variable that was just created" v ctx').var in
Bindlib.box_apply Bindlib.box_apply
(fun new_expr -> (fun new_expr ->
{ {
D.scope_body_expr = new_expr; scope_body_expr = new_expr;
scope_body_input_struct = input_struct; scope_body_input_struct = input_struct;
scope_body_output_struct = output_struct; scope_body_output_struct = output_struct;
}) })
(Bindlib.bind_var v' (translate_scope_let ctx' lets)) (Bindlib.bind_var v' (translate_scope_let ctx' lets))
let rec translate_scopes (ctx : 'm ctx) (scopes : ('m D.expr, 'm) D.scopes) : let rec translate_scopes (ctx : 'm ctx) (scopes : ('m D.expr, 'm) scopes) :
('m A.expr, 'm) D.scopes Bindlib.box = ('m A.expr, 'm) scopes Bindlib.box =
match scopes with match scopes with
| Nil -> Bindlib.box D.Nil | Nil -> Bindlib.box Nil
| ScopeDef { scope_name; scope_body; scope_next } -> | ScopeDef { scope_name; scope_body; scope_next } ->
let scope_var, next = Bindlib.unbind scope_next in let scope_var, next = Bindlib.unbind scope_next in
let vmark = let vmark =
@ -536,21 +536,21 @@ let rec translate_scopes (ctx : 'm ctx) (scopes : ('m D.expr, 'm) D.scopes) :
(find ~info:"variable that was just created" scope_var new_ctx).var (find ~info:"variable that was just created" scope_var new_ctx).var
in in
let scope_pos = Marked.get_mark (D.ScopeName.get_info scope_name) in let scope_pos = Marked.get_mark (ScopeName.get_info scope_name) in
let new_body = translate_scope_body scope_pos ctx scope_body in let new_body = translate_scope_body scope_pos ctx scope_body in
let tail = translate_scopes new_ctx next in let tail = translate_scopes new_ctx next in
Bindlib.box_apply2 Bindlib.box_apply2
(fun body tail -> (fun body tail ->
D.ScopeDef { scope_name; scope_body = body; scope_next = tail }) ScopeDef { scope_name; scope_body = body; scope_next = tail })
new_body new_body
(Bindlib.bind_var new_scope_name tail) (Bindlib.bind_var new_scope_name tail)
let translate_program (prgm : 'm D.program) : 'm A.program = let translate_program (prgm : 'm D.program) : 'm A.program =
let inputs_structs = let inputs_structs =
D.fold_left_scope_defs prgm.scopes ~init:[] ~f:(fun acc scope_def _ -> Expr.fold_left_scope_defs prgm.scopes ~init:[] ~f:(fun acc scope_def _ ->
scope_def.D.scope_body.scope_body_input_struct :: acc) scope_def.scope_body.scope_body_input_struct :: acc)
in in
(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]" (* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]"
@ -558,17 +558,17 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
let decl_ctx = let decl_ctx =
{ {
prgm.decl_ctx with prgm.decl_ctx with
D.ctx_enums = ctx_enums =
prgm.decl_ctx.ctx_enums prgm.decl_ctx.ctx_enums
|> D.EnumMap.add A.option_enum A.option_enum_config; |> EnumMap.add A.option_enum A.option_enum_config;
} }
in in
let decl_ctx = let decl_ctx =
{ {
decl_ctx with decl_ctx with
D.ctx_structs = ctx_structs =
prgm.decl_ctx.ctx_structs prgm.decl_ctx.ctx_structs
|> D.StructMap.mapi (fun n l -> |> StructMap.mapi (fun n l ->
if List.mem n inputs_structs then if List.mem n inputs_structs then
ListLabels.map l ~f:(fun (n, tau) -> ListLabels.map l ~f:(fun (n, tau) ->
(* Cli.debug_print @@ Format.asprintf "Input type: %a" (* Cli.debug_print @@ Format.asprintf "Input type: %a"

View File

@ -14,6 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
module D = Dcalc.Ast module D = Dcalc.Ast
@ -71,7 +72,7 @@ let rec iota_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box =
let default_mark e' = Marked.mark (Marked.get_mark e) e' in let default_mark e' = Marked.mark (Marked.get_mark e) e' in
match Marked.unmark e with match Marked.unmark e with
| EMatch ((EInj (e1, i, n', _ts), _), cases, n) | EMatch ((EInj (e1, i, n', _ts), _), cases, n)
when Dcalc.Ast.EnumName.compare n n' = 0 -> when EnumName.compare n n' = 0 ->
let+ e1 = visitor_map iota_expr () e1 let+ e1 = visitor_map iota_expr () e1
and+ case = visitor_map iota_expr () (List.nth cases i) in and+ case = visitor_map iota_expr () (List.nth cases i) in
default_mark @@ EApp (case, [e1]) default_mark @@ EApp (case, [e1])
@ -80,7 +81,7 @@ let rec iota_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box =
|> List.mapi (fun i (case, _pos) -> |> List.mapi (fun i (case, _pos) ->
match case with match case with
| EInj (_ei, i', n', _ts') -> | EInj (_ei, i', n', _ts') ->
i = i' && (* n = n' *) Dcalc.Ast.EnumName.compare n n' = 0 i = i' && (* n = n' *) EnumName.compare n n' = 0
| _ -> false) | _ -> false)
|> List.for_all Fun.id -> |> List.for_all Fun.id ->
visitor_map iota_expr () e' visitor_map iota_expr () e'
@ -101,9 +102,9 @@ let rec beta_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box =
let iota_optimizations (p : 'm program) : 'm program = let iota_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes =
Dcalc.Ast.map_exprs_in_scopes ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes Expr.map_exprs_in_scopes ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes
in in
{ p with D.scopes = Bindlib.unbox new_scopes } { p with scopes = Bindlib.unbox new_scopes }
(* TODO: beta optimizations apply inlining of the program. We left the inclusion (* TODO: beta optimizations apply inlining of the program. We left the inclusion
of beta-optimization as future work since its produce code that is harder to of beta-optimization as future work since its produce code that is harder to
@ -111,7 +112,7 @@ let iota_optimizations (p : 'm program) : 'm program =
program. *) program. *)
let _beta_optimizations (p : 'm program) : 'm program = let _beta_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes =
Dcalc.Ast.map_exprs_in_scopes ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes Expr.map_exprs_in_scopes ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes
in in
{ p with scopes = Bindlib.unbox new_scopes } { p with scopes = Bindlib.unbox new_scopes }
@ -145,11 +146,11 @@ let rec peephole_expr (_ : unit) (e : 'm marked_expr) :
let peephole_optimizations (p : 'm program) : 'm program = let peephole_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes =
Dcalc.Ast.map_exprs_in_scopes ~f:(peephole_expr ()) Expr.map_exprs_in_scopes ~f:(peephole_expr ())
~varf:(fun v -> v) ~varf:(fun v -> v)
p.scopes p.scopes
in in
{ p with scopes = Bindlib.unbox new_scopes } { p with scopes = Bindlib.unbox new_scopes }
let optimize_program (p : 'm program) : Dcalc.Ast.untyped program = let optimize_program (p : 'm program) : untyped program =
p |> iota_optimizations |> peephole_optimizations |> untype_program p |> iota_optimizations |> peephole_optimizations |> Expr.untype_program

View File

@ -16,6 +16,6 @@
open Ast open Ast
val optimize_program : 'm program -> Dcalc.Ast.untyped program val optimize_program : 'm program -> Shared_ast.untyped program
(** Warning/todo: no effort was yet made to ensure correct propagation of type (** Warning/todo: no effort was yet made to ensure correct propagation of type
annotations in the typed case *) annotations in the typed case *)

View File

@ -15,6 +15,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
(** {b Note:} (EmileRolley) seems to be factorizable with (** {b Note:} (EmileRolley) seems to be factorizable with
@ -64,7 +65,7 @@ let format_var (fmt : Format.formatter) (v : 'm Ast.var) : unit =
let rec format_expr let rec format_expr
?(debug : bool = false) ?(debug : bool = false)
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(e : 'm marked_expr) : unit = (e : 'm marked_expr) : unit =
let format_expr = format_expr ctx ~debug in let format_expr = format_expr ctx ~debug in
@ -83,16 +84,16 @@ let rec format_expr
(fun fmt e -> Format.fprintf fmt "%a" format_expr e)) (fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es format_punctuation ")" es format_punctuation ")"
| ETuple (es, Some s) -> | ETuple (es, Some s) ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" StructName.format_t s
format_punctuation "{" format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a %a" format_punctuation "\"" Format.fprintf fmt "%a%a%a%a %a" format_punctuation "\""
Dcalc.Ast.StructFieldName.format_t struct_field format_punctuation StructFieldName.format_t struct_field format_punctuation
"\"" format_punctuation ":" format_expr e)) "\"" format_punctuation ":" format_expr e))
(List.combine es (List.combine es
(List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) (List.map fst (StructMap.find s ctx.ctx_structs)))
format_punctuation "}" format_punctuation "}"
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" format_punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a@]" format_punctuation "["
@ -106,12 +107,12 @@ let rec format_expr
Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n
| Some s -> | Some s ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "." Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "."
format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t format_punctuation "\"" StructFieldName.format_t
(fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n)) (fst (List.nth (StructMap.find s ctx.ctx_structs) n))
format_punctuation "\"") format_punctuation "\"")
| EInj (e, n, en, _ts) -> | EInj (e, n, en, _ts) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor
(fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n)) (fst (List.nth (EnumMap.find en ctx.ctx_enums) n))
format_expr e format_expr e
| EMatch (e, es, e_name) -> | EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" format_keyword "match" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" format_keyword "match"
@ -123,9 +124,9 @@ let rec format_expr
Dcalc.Print.format_enum_constructor c format_punctuation ":" Dcalc.Print.format_enum_constructor c format_punctuation ":"
format_expr e)) format_expr e))
(List.combine es (List.combine es
(List.map fst (Dcalc.Ast.EnumMap.find e_name ctx.ctx_enums))) (List.map fst (EnumMap.find e_name ctx.ctx_enums)))
| ELit l -> | ELit l ->
Format.fprintf fmt "%a" format_lit (Marked.mark (Dcalc.Ast.pos e) l) Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l)
| EApp ((EAbs (binder, taus), _), args) -> | EApp ((EAbs (binder, taus), _), args) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
Format.fprintf fmt "%a%a" Format.fprintf fmt "%a%a"
@ -152,7 +153,7 @@ let rec format_expr
(List.combine (Array.to_list xs) taus) (List.combine (Array.to_list xs) taus)
format_punctuation "" format_expr body format_punctuation "" format_expr body
| EApp | EApp
((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2])
-> ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop op Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop op
format_with_parens arg1 format_with_parens arg2 format_with_parens arg1 format_with_parens arg2
@ -190,11 +191,11 @@ let rec format_expr
let format_scope ?(debug = false) ctx fmt (n, s) = let format_scope ?(debug = false) ctx fmt (n, s) =
Format.fprintf fmt "@[<hov 2>%a %a =@ %a@]" format_keyword "let" Format.fprintf fmt "@[<hov 2>%a %a =@ %a@]" format_keyword "let"
Dcalc.Ast.ScopeName.format_t n (format_expr ctx ~debug) ScopeName.format_t n (format_expr ctx ~debug)
(Bindlib.unbox (Bindlib.unbox
(Dcalc.Ast.build_whole_scope_expr ~make_abs:Ast.make_abs (Dcalc.Ast.build_whole_scope_expr ~make_abs:Ast.make_abs
~make_let_in:Ast.make_let_in ~box_expr:Ast.box_expr ctx s ~make_let_in:Ast.make_let_in ~box_expr:Expr.box ctx s
(Dcalc.Ast.map_mark (Expr.map_mark
(fun _ -> Marked.get_mark (Dcalc.Ast.ScopeName.get_info n)) (fun _ -> Marked.get_mark (ScopeName.get_info n))
(fun ty -> ty) (fun ty -> ty)
(Dcalc.Ast.get_scope_body_mark s)))) (Expr.get_scope_body_mark s))))

View File

@ -15,23 +15,24 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
(** {1 Formatters} *) (** {1 Formatters} *)
val format_lit : Format.formatter -> Ast.lit Marked.pos -> unit val format_lit : Format.formatter -> Ast.lit Marked.pos -> unit
val format_var : Format.formatter -> 'm Ast.var -> unit val format_var : Format.formatter -> 'm Ast.var -> unit
val format_exception : Format.formatter -> Ast.except -> unit val format_exception : Format.formatter -> except -> unit
val format_expr : val format_expr :
?debug:bool -> ?debug:bool ->
Dcalc.Ast.decl_ctx -> decl_ctx ->
Format.formatter -> Format.formatter ->
'm Ast.marked_expr -> 'm Ast.marked_expr ->
unit unit
val format_scope : val format_scope :
?debug:bool -> ?debug:bool ->
Dcalc.Ast.decl_ctx -> decl_ctx ->
Format.formatter -> Format.formatter ->
Dcalc.Ast.ScopeName.t * ('m Ast.expr, 'm) Dcalc.Ast.scope_body -> ScopeName.t * ('m Ast.expr, 'm) scope_body ->
unit unit

View File

@ -15,37 +15,38 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
open String_common open String_common
module D = Dcalc.Ast module D = Dcalc.Ast
let find_struct (s : D.StructName.t) (ctx : D.decl_ctx) : let find_struct (s : StructName.t) (ctx : decl_ctx) :
(D.StructFieldName.t * D.typ Marked.pos) list = (StructFieldName.t * typ Marked.pos) list =
try D.StructMap.find s ctx.D.ctx_structs try StructMap.find s ctx.ctx_structs
with Not_found -> with Not_found ->
let s_name, pos = D.StructName.get_info s in let s_name, pos = StructName.get_info s in
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: Structure %s was not found in the current environment." "Internal Error: Structure %s was not found in the current environment."
s_name s_name
let find_enum (en : D.EnumName.t) (ctx : D.decl_ctx) : let find_enum (en : EnumName.t) (ctx : decl_ctx) :
(D.EnumConstructor.t * D.typ Marked.pos) list = (EnumConstructor.t * typ Marked.pos) list =
try D.EnumMap.find en ctx.D.ctx_enums try EnumMap.find en ctx.ctx_enums
with Not_found -> with Not_found ->
let en_name, pos = D.EnumName.get_info en in let en_name, pos = EnumName.get_info en in
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: Enumeration %s was not found in the current environment." "Internal Error: Enumeration %s was not found in the current environment."
en_name en_name
let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit = let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit =
match Marked.unmark l with match Marked.unmark l with
| LBool b -> Dcalc.Print.format_lit fmt (Dcalc.Ast.LBool b) | LBool b -> Dcalc.Print.format_lit fmt (LBool b)
| LInt i -> | LInt i ->
Format.fprintf fmt "integer_of_string@ \"%s\"" (Runtime.integer_to_string i) Format.fprintf fmt "integer_of_string@ \"%s\"" (Runtime.integer_to_string i)
| LUnit -> Dcalc.Print.format_lit fmt Dcalc.Ast.LUnit | LUnit -> Dcalc.Print.format_lit fmt LUnit
| LRat i -> | LRat i ->
Format.fprintf fmt "decimal_of_string \"%a\"" Dcalc.Print.format_lit Format.fprintf fmt "decimal_of_string \"%a\"" Dcalc.Print.format_lit
(Dcalc.Ast.LRat i) (LRat i)
| LMoney e -> | LMoney e ->
Format.fprintf fmt "money_of_cents_string@ \"%s\"" Format.fprintf fmt "money_of_cents_string@ \"%s\""
(Runtime.integer_to_string (Runtime.money_to_cents e)) (Runtime.integer_to_string (Runtime.money_to_cents e))
@ -58,7 +59,7 @@ let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit =
let years, months, days = Runtime.duration_to_years_months_days d in let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days
let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) = let format_op_kind (fmt : Format.formatter) (k : op_kind) =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(match k with (match k with
| KInt -> "!" | KInt -> "!"
@ -67,7 +68,7 @@ let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) =
| KDate -> "@" | KDate -> "@"
| KDuration -> "^") | KDuration -> "^")
let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) : let format_binop (fmt : Format.formatter) (op : binop Marked.pos) :
unit = unit =
match Marked.unmark op with match Marked.unmark op with
| Add k -> Format.fprintf fmt "+%a" format_op_kind k | Add k -> Format.fprintf fmt "+%a" format_op_kind k
@ -86,7 +87,7 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) :
| Map -> Format.fprintf fmt "Array.map" | Map -> Format.fprintf fmt "Array.map"
| Filter -> Format.fprintf fmt "array_filter" | Filter -> Format.fprintf fmt "array_filter"
let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Marked.pos) : let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) :
unit = unit =
match Marked.unmark op with Fold -> Format.fprintf fmt "Array.fold_left" match Marked.unmark op with Fold -> Format.fprintf fmt "Array.fold_left"
@ -109,7 +110,7 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids uids
let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Marked.pos) : unit let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit
= =
match Marked.unmark op with match Marked.unmark op with
| Minus k -> Format.fprintf fmt "~-%a" format_op_kind k | Minus k -> Format.fprintf fmt "~-%a" format_op_kind k
@ -145,9 +146,9 @@ let avoid_keywords (s : string) : string =
s ^ "_user" s ^ "_user"
| _ -> s | _ -> s
let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : let format_struct_name (fmt : Format.formatter) (v : StructName.t) :
unit = unit =
Format.asprintf "%a" Dcalc.Ast.StructName.format_t v Format.asprintf "%a" StructName.format_t v
|> to_ascii |> to_ascii
|> to_snake_case |> to_snake_case
|> avoid_keywords |> avoid_keywords
@ -155,10 +156,10 @@ let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) :
let format_to_module_name let format_to_module_name
(fmt : Format.formatter) (fmt : Format.formatter)
(name : [< `Ename of D.EnumName.t | `Sname of D.StructName.t ]) = (name : [< `Ename of EnumName.t | `Sname of StructName.t ]) =
(match name with (match name with
| `Ename v -> Format.asprintf "%a" D.EnumName.format_t v | `Ename v -> Format.asprintf "%a" EnumName.format_t v
| `Sname v -> Format.asprintf "%a" D.StructName.format_t v) | `Sname v -> Format.asprintf "%a" StructName.format_t v)
|> to_ascii |> to_ascii
|> to_snake_case |> to_snake_case
|> avoid_keywords |> avoid_keywords
@ -170,52 +171,52 @@ let format_to_module_name
let format_struct_field_name let format_struct_field_name
(fmt : Format.formatter) (fmt : Format.formatter)
((sname_opt, v) : ((sname_opt, v) :
Dcalc.Ast.StructName.t option * Dcalc.Ast.StructFieldName.t) : unit = StructName.t option * StructFieldName.t) : unit =
(match sname_opt with (match sname_opt with
| Some sname -> | Some sname ->
Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname) Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname)
| None -> Format.fprintf fmt "%s") | None -> Format.fprintf fmt "%s")
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) (to_ascii (Format.asprintf "%a" StructFieldName.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit
= =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_snake_case (to_snake_case
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) (to_ascii (Format.asprintf "%a" EnumName.format_t v))))
let format_enum_cons_name let format_enum_cons_name
(fmt : Format.formatter) (fmt : Format.formatter)
(v : Dcalc.Ast.EnumConstructor.t) : unit = (v : EnumConstructor.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) (to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Marked.pos) : let rec typ_embedding_name (fmt : Format.formatter) (ty : typ Marked.pos) :
unit = unit =
match Marked.unmark ty with match Marked.unmark ty with
| D.TLit D.TUnit -> Format.fprintf fmt "embed_unit" | TLit TUnit -> Format.fprintf fmt "embed_unit"
| D.TLit D.TBool -> Format.fprintf fmt "embed_bool" | TLit TBool -> Format.fprintf fmt "embed_bool"
| D.TLit D.TInt -> Format.fprintf fmt "embed_integer" | TLit TInt -> Format.fprintf fmt "embed_integer"
| D.TLit D.TRat -> Format.fprintf fmt "embed_decimal" | TLit TRat -> Format.fprintf fmt "embed_decimal"
| D.TLit D.TMoney -> Format.fprintf fmt "embed_money" | TLit TMoney -> Format.fprintf fmt "embed_money"
| D.TLit D.TDate -> Format.fprintf fmt "embed_date" | TLit TDate -> Format.fprintf fmt "embed_date"
| D.TLit D.TDuration -> Format.fprintf fmt "embed_duration" | TLit TDuration -> Format.fprintf fmt "embed_duration"
| D.TTuple (_, Some s_name) -> | TTuple (_, Some s_name) ->
Format.fprintf fmt "embed_%a" format_struct_name s_name Format.fprintf fmt "embed_%a" format_struct_name s_name
| D.TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name | TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name
| D.TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty | TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty
| _ -> Format.fprintf fmt "unembeddable" | _ -> Format.fprintf fmt "unembeddable"
let typ_needs_parens (e : Dcalc.Ast.typ Marked.pos) : bool = let typ_needs_parens (e : typ Marked.pos) : bool =
match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) :
unit = unit =
let format_typ_with_parens let format_typ_with_parens
(fmt : Format.formatter) (fmt : Format.formatter)
(t : Dcalc.Ast.typ Marked.pos) = (t : typ Marked.pos) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t else Format.fprintf fmt "%a" format_typ t
in in
@ -229,10 +230,10 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) :
ts ts
| TTuple (_, Some s) -> | TTuple (_, Some s) ->
Format.fprintf fmt "%a.t" format_to_module_name (`Sname s) Format.fprintf fmt "%a.t" format_to_module_name (`Sname s)
| TEnum ([t], e) when D.EnumName.compare e Ast.option_enum = 0 -> | TEnum ([t], e) when EnumName.compare e Ast.option_enum = 0 ->
Format.fprintf fmt "@[<hov 2>(%a)@] %a" format_typ_with_parens t Format.fprintf fmt "@[<hov 2>(%a)@] %a" format_typ_with_parens t
format_enum_name e format_enum_name e
| TEnum (_, e) when D.EnumName.compare e Ast.option_enum = 0 -> | TEnum (_, e) when EnumName.compare e Ast.option_enum = 0 ->
Errors.raise_spanned_error (Marked.get_mark typ) Errors.raise_spanned_error (Marked.get_mark typ)
"Internal Error: found an typing parameter for an eoption type of the \ "Internal Error: found an typing parameter for an eoption type of the \
wrong length." wrong length."
@ -290,7 +291,7 @@ let format_exception (fmt : Format.formatter) (exc : except Marked.pos) : unit =
(Pos.get_law_info pos) (Pos.get_law_info pos)
let rec format_expr let rec format_expr
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(e : 'm marked_expr) : unit = (e : 'm marked_expr) : unit =
let format_expr = format_expr ctx in let format_expr = format_expr ctx in
@ -360,7 +361,7 @@ let rec format_expr
(* should not happen *)) (* should not happen *))
e)) e))
(List.combine es (List.map fst (find_enum e_name ctx))) (List.combine es (List.map fst (find_enum e_name ctx)))
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (D.pos e) l) | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l)
| EApp ((EAbs (binder, taus), _), args) -> | EApp ((EAbs (binder, taus), _), args) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in
@ -382,35 +383,35 @@ let rec format_expr
Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau)) Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau))
xs_tau format_expr body xs_tau format_expr body
| EApp | EApp
((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2])
-> ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos) Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos)
format_with_parens arg1 format_with_parens arg2 format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) -> | EApp ((EOp (Binop op), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
format_binop (op, Pos.no_pos) format_with_parens arg2 format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [f]), _), [arg]) | EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info
format_with_parens f format_with_parens arg format_with_parens f format_with_parens arg
| EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [arg1]) | EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list
info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1 info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1
| EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), m), [arg1]) | EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), m), [arg1])
when !Cli.trace_flag -> when !Cli.trace_flag ->
let pos = D.mark_pos m in let pos = Expr.mark_pos m in
Format.fprintf fmt Format.fprintf fmt
"(log_decision_taken@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \ "(log_decision_taken@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \
start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)" start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) format_with_parens arg1 (Pos.get_law_info pos) format_with_parens arg1
| EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [arg1]) | EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info
format_with_parens arg1 format_with_parens arg1
| EApp ((EOp (Unop (D.Log _)), _), [arg1]) -> | EApp ((EOp (Unop (Log _)), _), [arg1]) ->
Format.fprintf fmt "%a" format_with_parens arg1 Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [arg1]) -> | EApp ((EOp (Unop op), _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos) Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
@ -422,13 +423,13 @@ let rec format_expr
"@[<hov 2>%a@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \ "@[<hov 2>%a@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \
start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a@]" start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a@]"
format_var x format_var x
(Pos.get_file (D.mark_pos pos)) (Pos.get_file (Expr.mark_pos pos))
(Pos.get_start_line (D.mark_pos pos)) (Pos.get_start_line (Expr.mark_pos pos))
(Pos.get_start_column (D.mark_pos pos)) (Pos.get_start_column (Expr.mark_pos pos))
(Pos.get_end_line (D.mark_pos pos)) (Pos.get_end_line (Expr.mark_pos pos))
(Pos.get_end_column (D.mark_pos pos)) (Pos.get_end_column (Expr.mark_pos pos))
format_string_list format_string_list
(Pos.get_law_info (D.mark_pos pos)) (Pos.get_law_info (Expr.mark_pos pos))
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens) format_with_parens)
@ -452,25 +453,25 @@ let rec format_expr
2>{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ end_line=%d; \ 2>{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ end_line=%d; \
end_column=%d;@ law_headings=%a}@])@]" end_column=%d;@ law_headings=%a}@])@]"
format_with_parens e' format_with_parens e'
(Pos.get_file (D.pos e')) (Pos.get_file (Expr.pos e'))
(Pos.get_start_line (D.pos e')) (Pos.get_start_line (Expr.pos e'))
(Pos.get_start_column (D.pos e')) (Pos.get_start_column (Expr.pos e'))
(Pos.get_end_line (D.pos e')) (Pos.get_end_line (Expr.pos e'))
(Pos.get_end_column (D.pos e')) (Pos.get_end_column (Expr.pos e'))
format_string_list format_string_list
(Pos.get_law_info (D.pos e')) (Pos.get_law_info (Expr.pos e'))
| ERaise exc -> Format.fprintf fmt "raise@ %a" format_exception (exc, D.pos e) | ERaise exc -> Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e)
| ECatch (e1, exc, e2) -> | ECatch (e1, exc, e2) ->
Format.fprintf fmt Format.fprintf fmt
"@,@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]" "@,@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]"
format_with_parens e1 format_exception format_with_parens e1 format_exception
(exc, D.pos e) (exc, Expr.pos e)
format_with_parens e2 format_with_parens e2
let format_struct_embedding let format_struct_embedding
(fmt : Format.formatter) (fmt : Format.formatter)
((struct_name, struct_fields) : ((struct_name, struct_fields) :
D.StructName.t * (D.StructFieldName.t * D.typ Marked.pos) list) = StructName.t * (StructFieldName.t * typ Marked.pos) list) =
if List.length struct_fields = 0 then if List.length struct_fields = 0 then
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
format_struct_name struct_name format_to_module_name (`Sname struct_name) format_struct_name struct_name format_to_module_name (`Sname struct_name)
@ -480,11 +481,11 @@ let format_struct_embedding
@[<hov 2>[%a]@])@]@\n\ @[<hov 2>[%a]@])@]@\n\
@\n" @\n"
format_struct_name struct_name format_to_module_name (`Sname struct_name) format_struct_name struct_name format_to_module_name (`Sname struct_name)
D.StructName.format_t struct_name StructName.format_t struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n")
(fun _fmt (struct_field, struct_field_type) -> (fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" D.StructFieldName.format_t Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructFieldName.format_t
struct_field typ_embedding_name struct_field_type struct_field typ_embedding_name struct_field_type
format_struct_field_name format_struct_field_name
(Some struct_name, struct_field))) (Some struct_name, struct_field)))
@ -493,7 +494,7 @@ let format_struct_embedding
let format_enum_embedding let format_enum_embedding
(fmt : Format.formatter) (fmt : Format.formatter)
((enum_name, enum_cases) : ((enum_name, enum_cases) :
D.EnumName.t * (D.EnumConstructor.t * D.typ Marked.pos) list) = EnumName.t * (EnumConstructor.t * typ Marked.pos) list) =
if List.length enum_cases = 0 then if List.length enum_cases = 0 then
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
format_to_module_name (`Ename enum_name) format_enum_name enum_name format_to_module_name (`Ename enum_name) format_enum_name enum_name
@ -503,19 +504,19 @@ let format_enum_embedding
=@]@ Enum([\"%a\"],@ @[<hov 2>match x with@ %a@])@]@\n\ =@]@ Enum([\"%a\"],@ @[<hov 2>match x with@ %a@])@]@\n\
@\n" @\n"
format_enum_name enum_name format_to_module_name (`Ename enum_name) format_enum_name enum_name format_to_module_name (`Ename enum_name)
D.EnumName.format_t enum_name EnumName.format_t enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (enum_cons, enum_cons_type) -> (fun _fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]" Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]"
format_enum_cons_name enum_cons D.EnumConstructor.format_t format_enum_cons_name enum_cons EnumConstructor.format_t
enum_cons typ_embedding_name enum_cons_type)) enum_cons typ_embedding_name enum_cons_type))
enum_cases enum_cases
let format_ctx let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list) (type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter) (fmt : Format.formatter)
(ctx : D.decl_ctx) : unit = (ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) = let format_struct_decl fmt (struct_name, struct_fields) =
if List.length struct_fields = 0 then if List.length struct_fields = 0 then
Format.fprintf fmt Format.fprintf fmt
@ -559,8 +560,8 @@ let format_ctx
let scope_structs = let scope_structs =
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(Dcalc.Ast.StructMap.bindings (StructMap.bindings
(Dcalc.Ast.StructMap.filter (StructMap.filter
(fun s _ -> not (is_in_type_ordering s)) (fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs)) ctx.ctx_structs))
in in
@ -574,12 +575,12 @@ let format_ctx
(type_ordering @ scope_structs) (type_ordering @ scope_structs)
let rec format_scope_body_expr let rec format_scope_body_expr
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(scope_lets : ('m Ast.expr, 'm) Dcalc.Ast.scope_body_expr) : unit = (scope_lets : ('m Ast.expr, 'm) scope_body_expr) : unit =
match scope_lets with match scope_lets with
| Dcalc.Ast.Result e -> format_expr ctx fmt e | Result e -> format_expr ctx fmt e
| Dcalc.Ast.ScopeLet scope_let -> | ScopeLet scope_let ->
let scope_let_var, scope_let_next = let scope_let_var, scope_let_next =
Bindlib.unbind scope_let.scope_let_next Bindlib.unbind scope_let.scope_let_next
in in
@ -590,12 +591,12 @@ let rec format_scope_body_expr
scope_let_next scope_let_next
let rec format_scopes let rec format_scopes
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(scopes : ('m Ast.expr, 'm) Dcalc.Ast.scopes) : unit = (scopes : ('m Ast.expr, 'm) scopes) : unit =
match scopes with match scopes with
| Dcalc.Ast.Nil -> () | Nil -> ()
| Dcalc.Ast.ScopeDef scope_def -> | ScopeDef scope_def ->
let scope_input_var, scope_body_expr = let scope_input_var, scope_body_expr =
Bindlib.unbind scope_def.scope_body.scope_body_expr Bindlib.unbind scope_def.scope_body.scope_body_expr
in in

View File

@ -15,6 +15,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
(** Formats a lambda calculus program into a valid OCaml program *) (** Formats a lambda calculus program into a valid OCaml program *)
@ -22,32 +23,32 @@ open Ast
val avoid_keywords : string -> string val avoid_keywords : string -> string
val find_struct : val find_struct :
Dcalc.Ast.StructName.t -> StructName.t ->
Dcalc.Ast.decl_ctx -> decl_ctx ->
(Dcalc.Ast.StructFieldName.t * Dcalc.Ast.typ Marked.pos) list (StructFieldName.t * typ Marked.pos) list
val find_enum : val find_enum :
Dcalc.Ast.EnumName.t -> EnumName.t ->
Dcalc.Ast.decl_ctx -> decl_ctx ->
(Dcalc.Ast.EnumConstructor.t * Dcalc.Ast.typ Marked.pos) list (EnumConstructor.t * typ Marked.pos) list
val typ_needs_parens : Dcalc.Ast.typ Marked.pos -> bool val typ_needs_parens : typ Marked.pos -> bool
val needs_parens : 'm marked_expr -> bool val needs_parens : 'm marked_expr -> bool
val format_enum_name : Format.formatter -> Dcalc.Ast.EnumName.t -> unit val format_enum_name : Format.formatter -> EnumName.t -> unit
val format_enum_cons_name : val format_enum_cons_name :
Format.formatter -> Dcalc.Ast.EnumConstructor.t -> unit Format.formatter -> EnumConstructor.t -> unit
val format_struct_name : Format.formatter -> Dcalc.Ast.StructName.t -> unit val format_struct_name : Format.formatter -> StructName.t -> unit
val format_struct_field_name : val format_struct_field_name :
Format.formatter -> Format.formatter ->
Dcalc.Ast.StructName.t option * Dcalc.Ast.StructFieldName.t -> StructName.t option * StructFieldName.t ->
unit unit
val format_to_module_name : val format_to_module_name :
Format.formatter -> Format.formatter ->
[< `Ename of Dcalc.Ast.EnumName.t | `Sname of Dcalc.Ast.StructName.t ] -> [< `Ename of EnumName.t | `Sname of StructName.t ] ->
unit unit
val format_lit : Format.formatter -> lit Marked.pos -> unit val format_lit : Format.formatter -> lit Marked.pos -> unit

View File

@ -29,7 +29,7 @@ type 'ast gen = {
} }
type t = type t =
| Lcalc of Dcalc.Ast.untyped Lcalc.Ast.program gen | Lcalc of Shared_ast.untyped Lcalc.Ast.program gen
| Scalc of Scalc.Ast.program gen | Scalc of Scalc.Ast.program gen
let name = function Lcalc { name; _ } | Scalc { name; _ } -> name let name = function Lcalc { name; _ } | Scalc { name; _ } -> name

View File

@ -31,7 +31,7 @@ type 'ast gen = {
} }
type t = type t =
| Lcalc of Dcalc.Ast.untyped Lcalc.Ast.program gen | Lcalc of Shared_ast.untyped Lcalc.Ast.program gen
| Scalc of Scalc.Ast.program gen | Scalc of Scalc.Ast.program gen
val find : string -> t val find : string -> t
@ -49,7 +49,7 @@ module PluginAPI : sig
val register_lcalc : val register_lcalc :
name:string -> name:string ->
extension:string -> extension:string ->
Dcalc.Ast.untyped Lcalc.Ast.program plugin_apply_fun_typ -> Shared_ast.untyped Lcalc.Ast.program plugin_apply_fun_typ ->
unit unit
val register_scalc : val register_scalc :

View File

@ -19,6 +19,7 @@
the associated [js_of_ocaml] wrapper. *) the associated [js_of_ocaml] wrapper. *)
open Utils open Utils
open Shared_ast
open String_common open String_common
open Lcalc open Lcalc
open Lcalc.Ast open Lcalc.Ast
@ -39,9 +40,9 @@ module To_jsoo = struct
let format_struct_field_name_camel_case let format_struct_field_name_camel_case
(fmt : Format.formatter) (fmt : Format.formatter)
(v : Dcalc.Ast.StructFieldName.t) : unit = (v : StructFieldName.t) : unit =
let s = let s =
Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v Format.asprintf "%a" StructFieldName.format_t v
|> to_ascii |> to_ascii
|> to_snake_case |> to_snake_case
|> avoid_keywords |> avoid_keywords
@ -49,7 +50,7 @@ module To_jsoo = struct
in in
Format.fprintf fmt "%s" s Format.fprintf fmt "%s" s
let format_tlit (fmt : Format.formatter) (l : Dcalc.Ast.typ_lit) : unit = let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit =
Dcalc.Print.format_base_type fmt Dcalc.Print.format_base_type fmt
(match l with (match l with
| TUnit -> "unit" | TUnit -> "unit"
@ -59,11 +60,11 @@ module To_jsoo = struct
| TBool -> "bool Js.t" | TBool -> "bool Js.t"
| TDate -> "Js.js_string Js.t") | TDate -> "Js.js_string Js.t")
let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) :
unit = unit =
let format_typ_with_parens let format_typ_with_parens
(fmt : Format.formatter) (fmt : Format.formatter)
(t : Dcalc.Ast.typ Marked.pos) = (t : typ Marked.pos) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t else Format.fprintf fmt "%a" format_typ t
in in
@ -73,10 +74,10 @@ module To_jsoo = struct
| TTuple (_, None) -> | TTuple (_, None) ->
(* Tuples are encoded as an javascript polymorphic array. *) (* Tuples are encoded as an javascript polymorphic array. *)
Format.fprintf fmt "Js.Unsafe.any_js_array Js.t " Format.fprintf fmt "Js.Unsafe.any_js_array Js.t "
| TEnum ([t], e) when D.EnumName.compare e option_enum = 0 -> | TEnum ([t], e) when EnumName.compare e option_enum = 0 ->
Format.fprintf fmt "@[<hov 2>(%a)@] %a" format_typ_with_parens t Format.fprintf fmt "@[<hov 2>(%a)@] %a" format_typ_with_parens t
format_enum_name e format_enum_name e
| TEnum (_, e) when D.EnumName.compare e option_enum = 0 -> | TEnum (_, e) when EnumName.compare e option_enum = 0 ->
Errors.raise_spanned_error (Marked.get_mark typ) Errors.raise_spanned_error (Marked.get_mark typ)
"Internal Error: found an typing parameter for an eoption type of the \ "Internal Error: found an typing parameter for an eoption type of the \
wrong length." wrong length."
@ -90,41 +91,41 @@ module To_jsoo = struct
let rec format_typ_to_jsoo fmt typ = let rec format_typ_to_jsoo fmt typ =
match Marked.unmark typ with match Marked.unmark typ with
| Dcalc.Ast.TLit TBool -> Format.fprintf fmt "Js.bool" | TLit TBool -> Format.fprintf fmt "Js.bool"
| Dcalc.Ast.TLit TInt -> Format.fprintf fmt "integer_to_int" | TLit TInt -> Format.fprintf fmt "integer_to_int"
| Dcalc.Ast.TLit TRat -> | TLit TRat ->
Format.fprintf fmt "Js.number_of_float %@%@ decimal_to_float" Format.fprintf fmt "Js.number_of_float %@%@ decimal_to_float"
| Dcalc.Ast.TLit TMoney -> | TLit TMoney ->
Format.fprintf fmt "Js.number_of_float %@%@ money_to_float" Format.fprintf fmt "Js.number_of_float %@%@ money_to_float"
| Dcalc.Ast.TLit TDuration -> Format.fprintf fmt "duration_to_jsoo" | TLit TDuration -> Format.fprintf fmt "duration_to_jsoo"
| Dcalc.Ast.TLit TDate -> Format.fprintf fmt "date_to_jsoo" | TLit TDate -> Format.fprintf fmt "date_to_jsoo"
| Dcalc.Ast.TEnum (_, ename) -> | TEnum (_, ename) ->
Format.fprintf fmt "%a_to_jsoo" format_enum_name ename Format.fprintf fmt "%a_to_jsoo" format_enum_name ename
| Dcalc.Ast.TTuple (_, Some sname) -> | TTuple (_, Some sname) ->
Format.fprintf fmt "%a_to_jsoo" format_struct_name sname Format.fprintf fmt "%a_to_jsoo" format_struct_name sname
| Dcalc.Ast.TArray t -> | TArray t ->
Format.fprintf fmt "Js.array %@%@ Array.map (fun x -> %a x)" Format.fprintf fmt "Js.array %@%@ Array.map (fun x -> %a x)"
format_typ_to_jsoo t format_typ_to_jsoo t
| Dcalc.Ast.TAny | Dcalc.Ast.TTuple (_, None) -> | TAny | TTuple (_, None) ->
Format.fprintf fmt "Js.Unsafe.inject" Format.fprintf fmt "Js.Unsafe.inject"
| _ -> Format.fprintf fmt "" | _ -> Format.fprintf fmt ""
let rec format_typ_of_jsoo fmt typ = let rec format_typ_of_jsoo fmt typ =
match Marked.unmark typ with match Marked.unmark typ with
| Dcalc.Ast.TLit TBool -> Format.fprintf fmt "Js.to_bool" | TLit TBool -> Format.fprintf fmt "Js.to_bool"
| Dcalc.Ast.TLit TInt -> Format.fprintf fmt "integer_of_int" | TLit TInt -> Format.fprintf fmt "integer_of_int"
| Dcalc.Ast.TLit TRat -> | TLit TRat ->
Format.fprintf fmt "decimal_of_float %@%@ Js.float_of_number" Format.fprintf fmt "decimal_of_float %@%@ Js.float_of_number"
| Dcalc.Ast.TLit TMoney -> | TLit TMoney ->
Format.fprintf fmt Format.fprintf fmt
"money_of_decimal %@%@ decimal_of_float %@%@ Js.float_of_number" "money_of_decimal %@%@ decimal_of_float %@%@ Js.float_of_number"
| Dcalc.Ast.TLit TDuration -> Format.fprintf fmt "duration_of_jsoo" | TLit TDuration -> Format.fprintf fmt "duration_of_jsoo"
| Dcalc.Ast.TLit TDate -> Format.fprintf fmt "date_of_jsoo" | TLit TDate -> Format.fprintf fmt "date_of_jsoo"
| Dcalc.Ast.TEnum (_, ename) -> | TEnum (_, ename) ->
Format.fprintf fmt "%a_of_jsoo" format_enum_name ename Format.fprintf fmt "%a_of_jsoo" format_enum_name ename
| Dcalc.Ast.TTuple (_, Some sname) -> | TTuple (_, Some sname) ->
Format.fprintf fmt "%a_of_jsoo" format_struct_name sname Format.fprintf fmt "%a_of_jsoo" format_struct_name sname
| Dcalc.Ast.TArray t -> | TArray t ->
Format.fprintf fmt "Array.map (fun x -> %a x) %@%@ Js.to_array" Format.fprintf fmt "Array.map (fun x -> %a x) %@%@ Js.to_array"
format_typ_of_jsoo t format_typ_of_jsoo t
| _ -> Format.fprintf fmt "" | _ -> Format.fprintf fmt ""
@ -150,10 +151,10 @@ module To_jsoo = struct
let format_ctx let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list) (type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter) (fmt : Format.formatter)
(ctx : D.decl_ctx) : unit = (ctx : decl_ctx) : unit =
let format_prop_or_meth fmt (struct_field_type : D.typ Marked.pos) = let format_prop_or_meth fmt (struct_field_type : typ Marked.pos) =
match Marked.unmark struct_field_type with match Marked.unmark struct_field_type with
| Dcalc.Ast.TArrow _ -> Format.fprintf fmt "Js.meth" | TArrow _ -> Format.fprintf fmt "Js.meth"
| _ -> Format.fprintf fmt "Js.readonly_prop" | _ -> Format.fprintf fmt "Js.readonly_prop"
in in
let format_struct_decl fmt (struct_name, struct_fields) = let format_struct_decl fmt (struct_name, struct_fields) =
@ -167,7 +168,7 @@ module To_jsoo = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (struct_field, struct_field_type) -> (fun fmt (struct_field, struct_field_type) ->
match Marked.unmark struct_field_type with match Marked.unmark struct_field_type with
| Dcalc.Ast.TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>method %a =@ Js.wrap_meth_callback@ @[<hv 2>(@,\ "@[<hov 2>method %a =@ Js.wrap_meth_callback@ @[<hv 2>(@,\
fun input ->@ %a (%a.%a (%a input)))@]@]" fun input ->@ %a (%a.%a (%a input)))@]@]"
@ -188,7 +189,7 @@ module To_jsoo = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n")
(fun fmt (struct_field, struct_field_type) -> (fun fmt (struct_field, struct_field_type) ->
match Marked.unmark struct_field_type with match Marked.unmark struct_field_type with
| Dcalc.Ast.TArrow _ -> | TArrow _ ->
Format.fprintf fmt Format.fprintf fmt
"%a = failwith \"The function '%a' translation isn't yet \ "%a = failwith \"The function '%a' translation isn't yet \
supported...\"" supported...\""
@ -238,7 +239,7 @@ module To_jsoo = struct
in in
let format_enum_decl let format_enum_decl
fmt fmt
(enum_name, (enum_cons : (D.EnumConstructor.t * D.typ Marked.pos) list)) (enum_name, (enum_cons : (EnumConstructor.t * typ Marked.pos) list))
= =
let fmt_enum_name fmt _ = format_enum_name fmt enum_name in let fmt_enum_name fmt _ = format_enum_name fmt enum_name in
let fmt_module_enum_name fmt _ = let fmt_module_enum_name fmt _ =
@ -250,7 +251,7 @@ module To_jsoo = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (cname, typ) -> (fun fmt (cname, typ) ->
match Marked.unmark typ with match Marked.unmark typ with
| Dcalc.Ast.TTuple (_, None) -> | TTuple (_, None) ->
Cli.error_print Cli.error_print
"Tuples aren't supported yet in the conversion to JS" "Tuples aren't supported yet in the conversion to JS"
| _ -> | _ ->
@ -275,10 +276,10 @@ module To_jsoo = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (cname, typ) -> (fun fmt (cname, typ) ->
match Marked.unmark typ with match Marked.unmark typ with
| Dcalc.Ast.TTuple (_, None) -> | TTuple (_, None) ->
Cli.error_print Cli.error_print
"Tuples aren't yet supported in the conversion to JS..." "Tuples aren't yet supported in the conversion to JS..."
| Dcalc.Ast.TLit TUnit -> | TLit TUnit ->
Format.fprintf fmt "@[<hv 2>| \"%a\" ->@ %a.%a ()@]" Format.fprintf fmt "@[<hv 2>| \"%a\" ->@ %a.%a ()@]"
format_enum_cons_name cname fmt_module_enum_name () format_enum_cons_name cname fmt_module_enum_name ()
format_enum_cons_name cname format_enum_cons_name cname
@ -329,8 +330,8 @@ module To_jsoo = struct
let scope_structs = let scope_structs =
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(Dcalc.Ast.StructMap.bindings (StructMap.bindings
(Dcalc.Ast.StructMap.filter (StructMap.filter
(fun s _ -> not (is_in_type_ordering s)) (fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs)) ctx.ctx_structs))
in in
@ -343,19 +344,19 @@ module To_jsoo = struct
Format.fprintf fmt "%a@\n" format_enum_decl (e, find_enum e ctx)) Format.fprintf fmt "%a@\n" format_enum_decl (e, find_enum e ctx))
(type_ordering @ scope_structs) (type_ordering @ scope_structs)
let fmt_input_struct_name fmt (scope_def : ('a expr, 'm) D.scope_def) = let fmt_input_struct_name fmt (scope_def : ('a expr, 'm) scope_def) =
format_struct_name fmt scope_def.scope_body.scope_body_input_struct format_struct_name fmt scope_def.scope_body.scope_body_input_struct
let fmt_output_struct_name fmt (scope_def : ('a expr, 'm) D.scope_def) = let fmt_output_struct_name fmt (scope_def : ('a expr, 'm) scope_def) =
format_struct_name fmt scope_def.scope_body.scope_body_output_struct format_struct_name fmt scope_def.scope_body.scope_body_output_struct
let rec format_scopes_to_fun let rec format_scopes_to_fun
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(scopes : ('expr, 'm) Dcalc.Ast.scopes) = (scopes : ('expr, 'm) scopes) =
match scopes with match scopes with
| Dcalc.Ast.Nil -> () | Nil -> ()
| Dcalc.Ast.ScopeDef scope_def -> | ScopeDef scope_def ->
let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in
let fmt_fun_call fmt _ = let fmt_fun_call fmt _ =
Format.fprintf fmt "@[<hv>%a@ |> %a_of_jsoo@ |> %a@ |> %a_to_jsoo@]" Format.fprintf fmt "@[<hv>%a@ |> %a_of_jsoo@ |> %a@ |> %a_to_jsoo@]"
@ -369,12 +370,12 @@ module To_jsoo = struct
fmt_fun_call () (format_scopes_to_fun ctx) scope_next fmt_fun_call () (format_scopes_to_fun ctx) scope_next
let rec format_scopes_to_callbacks let rec format_scopes_to_callbacks
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(scopes : ('expr, 'm) Dcalc.Ast.scopes) : unit = (scopes : ('expr, 'm) scopes) : unit =
match scopes with match scopes with
| Dcalc.Ast.Nil -> () | Nil -> ()
| Dcalc.Ast.ScopeDef scope_def -> | ScopeDef scope_def ->
let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in
let fmt_meth_name fmt _ = let fmt_meth_name fmt _ =
Format.fprintf fmt "method %a : (%a Js.t -> %a Js.t) Js.callback" Format.fprintf fmt "method %a : (%a Js.t -> %a Js.t) Js.callback"

View File

@ -22,6 +22,7 @@ let extension = "_schema.json"
open Utils open Utils
open String_common open String_common
open Shared_ast
open Lcalc.Ast open Lcalc.Ast
open Lcalc.To_ocaml open Lcalc.To_ocaml
module D = Dcalc.Ast module D = Dcalc.Ast
@ -37,9 +38,9 @@ module To_json = struct
let format_struct_field_name_camel_case let format_struct_field_name_camel_case
(fmt : Format.formatter) (fmt : Format.formatter)
(v : Dcalc.Ast.StructFieldName.t) : unit = (v : StructFieldName.t) : unit =
let s = let s =
Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v Format.asprintf "%a" StructFieldName.format_t v
|> to_ascii |> to_ascii
|> to_snake_case |> to_snake_case
|> avoid_keywords |> avoid_keywords
@ -48,18 +49,18 @@ module To_json = struct
Format.fprintf fmt "%s" s Format.fprintf fmt "%s" s
let rec find_scope_def (target_name : string) : let rec find_scope_def (target_name : string) :
('m expr, 'm) D.scopes -> ('m expr, 'm) D.scope_def option = function ('m expr, 'm) scopes -> ('m expr, 'm) scope_def option = function
| D.Nil -> None | Nil -> None
| D.ScopeDef scope_def -> | ScopeDef scope_def ->
let name = let name =
Format.asprintf "%a" D.ScopeName.format_t scope_def.scope_name Format.asprintf "%a" ScopeName.format_t scope_def.scope_name
in in
if name = target_name then Some scope_def if name = target_name then Some scope_def
else else
let _, next_scope = Bindlib.unbind scope_def.scope_next in let _, next_scope = Bindlib.unbind scope_def.scope_next in
find_scope_def target_name next_scope find_scope_def target_name next_scope
let fmt_tlit fmt (tlit : D.typ_lit) = let fmt_tlit fmt (tlit : typ_lit) =
match tlit with match tlit with
| TUnit -> Format.fprintf fmt "\"type\": \"null\",@\n\"default\": null" | TUnit -> Format.fprintf fmt "\"type\": \"null\",@\n\"default\": null"
| TInt | TRat -> Format.fprintf fmt "\"type\": \"number\",@\n\"default\": 0" | TInt | TRat -> Format.fprintf fmt "\"type\": \"number\",@\n\"default\": 0"
@ -70,15 +71,15 @@ module To_json = struct
| TDate -> Format.fprintf fmt "\"type\": \"string\",@\n\"format\": \"date\"" | TDate -> Format.fprintf fmt "\"type\": \"string\",@\n\"format\": \"date\""
| TDuration -> failwith "TODO: tlit duration" | TDuration -> failwith "TODO: tlit duration"
let rec fmt_type fmt (typ : D.marked_typ) = let rec fmt_type fmt (typ : marked_typ) =
match Marked.unmark typ with match Marked.unmark typ with
| D.TLit tlit -> fmt_tlit fmt tlit | TLit tlit -> fmt_tlit fmt tlit
| D.TTuple (_, Some sname) -> | TTuple (_, Some sname) ->
Format.fprintf fmt "\"$ref\": \"#/definitions/%a\"" format_struct_name Format.fprintf fmt "\"$ref\": \"#/definitions/%a\"" format_struct_name
sname sname
| D.TEnum (_, ename) -> | TEnum (_, ename) ->
Format.fprintf fmt "\"$ref\": \"#/definitions/%a\"" format_enum_name ename Format.fprintf fmt "\"$ref\": \"#/definitions/%a\"" format_enum_name ename
| D.TArray t -> | TArray t ->
Format.fprintf fmt Format.fprintf fmt
"\"type\": \"array\",@\n\ "\"type\": \"array\",@\n\
\"default\": [],@\n\ \"default\": [],@\n\
@ -89,9 +90,9 @@ module To_json = struct
| _ -> () | _ -> ()
let fmt_struct_properties let fmt_struct_properties
(ctx : D.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(sname : D.StructName.t) = (sname : StructName.t) =
Format.fprintf fmt "%a" Format.fprintf fmt "%a"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n")
@ -101,26 +102,26 @@ module To_json = struct
(find_struct sname ctx) (find_struct sname ctx)
let fmt_definitions let fmt_definitions
(ctx : D.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(scope_def : ('m expr, 'm) D.scope_def) = (scope_def : ('m expr, 'm) scope_def) =
let get_name t = let get_name t =
match Marked.unmark t with match Marked.unmark t with
| D.TTuple (_, Some sname) -> | TTuple (_, Some sname) ->
Format.asprintf "%a" format_struct_name sname Format.asprintf "%a" format_struct_name sname
| D.TEnum (_, ename) -> Format.asprintf "%a" format_enum_name ename | TEnum (_, ename) -> Format.asprintf "%a" format_enum_name ename
| _ -> failwith "unreachable: only structs and enums are collected." | _ -> failwith "unreachable: only structs and enums are collected."
in in
let rec collect_required_type_defs_from_scope_input let rec collect_required_type_defs_from_scope_input
(input_struct : D.StructName.t) : D.marked_typ list = (input_struct : StructName.t) : marked_typ list =
let rec collect (acc : D.marked_typ list) (t : D.marked_typ) : let rec collect (acc : marked_typ list) (t : marked_typ) :
D.marked_typ list = marked_typ list =
match Marked.unmark t with match Marked.unmark t with
| D.TTuple (_, Some s) -> | TTuple (_, Some s) ->
(* Scope's input is a struct. *) (* Scope's input is a struct. *)
(t :: acc) @ collect_required_type_defs_from_scope_input s (t :: acc) @ collect_required_type_defs_from_scope_input s
| D.TEnum (ts, _) -> List.fold_left collect (t :: acc) ts | TEnum (ts, _) -> List.fold_left collect (t :: acc) ts
| D.TArray t -> collect acc t | TArray t -> collect acc t
| _ -> acc | _ -> acc
in in
find_struct input_struct ctx find_struct input_struct ctx
@ -177,7 +178,7 @@ module To_json = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n")
(fun fmt typ -> (fun fmt typ ->
match Marked.unmark typ with match Marked.unmark typ with
| D.TTuple (_, Some sname) -> | TTuple (_, Some sname) ->
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>\"%a\": {@\n\ "@[<hov 2>\"%a\": {@\n\
\"type\": \"object\",@\n\ \"type\": \"object\",@\n\
@ -188,7 +189,7 @@ module To_json = struct
format_struct_name sname format_struct_name sname
(fmt_struct_properties ctx) (fmt_struct_properties ctx)
sname sname
| D.TEnum (_, ename) -> | TEnum (_, ename) ->
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>\"%a\": {@\n\ "@[<hov 2>\"%a\": {@\n\
\"type\": \"object\",@\n\ \"type\": \"object\",@\n\

View File

@ -15,6 +15,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
module D = Dcalc.Ast module D = Dcalc.Ast
module L = Lcalc.Ast module L = Lcalc.Ast
module TopLevelName = Uid.Make (Uid.MarkedString) () module TopLevelName = Uid.Make (Uid.MarkedString) ()
@ -27,24 +28,24 @@ let handle_default_opt = TopLevelName.fresh ("handle_default_opt", Pos.no_pos)
type expr = type expr =
| EVar of LocalName.t | EVar of LocalName.t
| EFunc of TopLevelName.t | EFunc of TopLevelName.t
| EStruct of expr Marked.pos list * D.StructName.t | EStruct of expr Marked.pos list * StructName.t
| EStructFieldAccess of expr Marked.pos * D.StructFieldName.t * D.StructName.t | EStructFieldAccess of expr Marked.pos * StructFieldName.t * StructName.t
| EInj of expr Marked.pos * D.EnumConstructor.t * D.EnumName.t | EInj of expr Marked.pos * EnumConstructor.t * EnumName.t
| EArray of expr Marked.pos list | EArray of expr Marked.pos list
| ELit of L.lit | ELit of L.lit
| EApp of expr Marked.pos * expr Marked.pos list | EApp of expr Marked.pos * expr Marked.pos list
| EOp of Dcalc.Ast.operator | EOp of operator
type stmt = type stmt =
| SInnerFuncDef of LocalName.t Marked.pos * func | SInnerFuncDef of LocalName.t Marked.pos * func
| SLocalDecl of LocalName.t Marked.pos * D.typ Marked.pos | SLocalDecl of LocalName.t Marked.pos * typ Marked.pos
| SLocalDef of LocalName.t Marked.pos * expr Marked.pos | SLocalDef of LocalName.t Marked.pos * expr Marked.pos
| STryExcept of block * L.except * block | STryExcept of block * except * block
| SRaise of L.except | SRaise of except
| SIfThenElse of expr Marked.pos * block * block | SIfThenElse of expr Marked.pos * block * block
| SSwitch of | SSwitch of
expr Marked.pos expr Marked.pos
* D.EnumName.t * EnumName.t
* (block (* Statements corresponding to arm closure body*) * (block (* Statements corresponding to arm closure body*)
* (* Variable instantiated with enum payload *) LocalName.t) * (* Variable instantiated with enum payload *) LocalName.t)
list (** Each block corresponds to one case of the enum *) list (** Each block corresponds to one case of the enum *)
@ -54,14 +55,14 @@ type stmt =
and block = stmt Marked.pos list and block = stmt Marked.pos list
and func = { and func = {
func_params : (LocalName.t Marked.pos * D.typ Marked.pos) list; func_params : (LocalName.t Marked.pos * typ Marked.pos) list;
func_body : block; func_body : block;
} }
type scope_body = { type scope_body = {
scope_body_name : Dcalc.Ast.ScopeName.t; scope_body_name : ScopeName.t;
scope_body_var : TopLevelName.t; scope_body_var : TopLevelName.t;
scope_body_func : func; scope_body_func : func;
} }
type program = { decl_ctx : D.decl_ctx; scopes : scope_body list } type program = { decl_ctx : decl_ctx; scopes : scope_body list }

View File

@ -22,7 +22,7 @@ module D = Dcalc.Ast
type 'm ctxt = { type 'm ctxt = {
func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t; func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t;
decl_ctx : D.decl_ctx; decl_ctx : decl_ctx;
var_dict : ('m L.expr, A.LocalName.t) Var.Map.t; var_dict : ('m L.expr, A.LocalName.t) Var.Map.t;
inside_definition_of : A.LocalName.t option; inside_definition_of : A.LocalName.t option;
context_name : string; context_name : string;
@ -33,13 +33,13 @@ type 'm ctxt = {
let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) : let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
A.block * A.expr Marked.pos = A.block * A.expr Marked.pos =
match Marked.unmark expr with match Marked.unmark expr with
| L.EVar v -> | EVar v ->
let local_var = let local_var =
try A.EVar (Var.Map.find v ctxt.var_dict) try A.EVar (Var.Map.find v ctxt.var_dict)
with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict) with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict)
in in
[], (local_var, D.pos expr) [], (local_var, Expr.pos expr)
| L.ETuple (args, Some s_name) -> | ETuple (args, Some s_name) ->
let args_stmts, new_args = let args_stmts, new_args =
List.fold_left List.fold_left
(fun (args_stmts, new_args) arg -> (fun (args_stmts, new_args) arg ->
@ -49,25 +49,25 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
in in
let new_args = List.rev new_args in let new_args = List.rev new_args in
let args_stmts = List.rev args_stmts in let args_stmts = List.rev args_stmts in
args_stmts, (A.EStruct (new_args, s_name), D.pos expr) args_stmts, (A.EStruct (new_args, s_name), Expr.pos expr)
| L.ETuple (_, None) -> | ETuple (_, None) ->
failwith "Non-struct tuples cannot be compiled to scalc" failwith "Non-struct tuples cannot be compiled to scalc"
| L.ETupleAccess (e1, num_field, Some s_name, _) -> | ETupleAccess (e1, num_field, Some s_name, _) ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let field_name = let field_name =
fst fst
(List.nth (D.StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field) (List.nth (StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field)
in in
e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), D.pos expr) e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), Expr.pos expr)
| L.ETupleAccess (_, _, None, _) -> | ETupleAccess (_, _, None, _) ->
failwith "Non-struct tuples cannot be compiled to scalc" failwith "Non-struct tuples cannot be compiled to scalc"
| L.EInj (e1, num_cons, e_name, _) -> | EInj (e1, num_cons, e_name, _) ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let cons_name = let cons_name =
fst (List.nth (D.EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons) fst (List.nth (EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons)
in in
e1_stmts, (A.EInj (new_e1, cons_name, e_name), D.pos expr) e1_stmts, (A.EInj (new_e1, cons_name, e_name), Expr.pos expr)
| L.EApp (f, args) -> | EApp (f, args) ->
let f_stmts, new_f = translate_expr ctxt f in let f_stmts, new_f = translate_expr ctxt f in
let args_stmts, new_args = let args_stmts, new_args =
List.fold_left List.fold_left
@ -77,8 +77,8 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
([], []) args ([], []) args
in in
let new_args = List.rev new_args in let new_args = List.rev new_args in
f_stmts @ args_stmts, (A.EApp (new_f, new_args), D.pos expr) f_stmts @ args_stmts, (A.EApp (new_f, new_args), Expr.pos expr)
| L.EArray args -> | EArray args ->
let args_stmts, new_args = let args_stmts, new_args =
List.fold_left List.fold_left
(fun (args_stmts, new_args) arg -> (fun (args_stmts, new_args) arg ->
@ -87,9 +87,9 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
([], []) args ([], []) args
in in
let new_args = List.rev new_args in let new_args = List.rev new_args in
args_stmts, (A.EArray new_args, D.pos expr) args_stmts, (A.EArray new_args, Expr.pos expr)
| L.EOp op -> [], (A.EOp op, D.pos expr) | EOp op -> [], (A.EOp op, Expr.pos expr)
| L.ELit l -> [], (A.ELit l, D.pos expr) | ELit l -> [], (A.ELit l, Expr.pos expr)
| _ -> | _ ->
let tmp_var = let tmp_var =
A.LocalName.fresh A.LocalName.fresh
@ -102,7 +102,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
let v = Marked.unmark (A.LocalName.get_info v) in let v = Marked.unmark (A.LocalName.get_info v) in
let tmp_rex = Re.Pcre.regexp "^temp_" in let tmp_rex = Re.Pcre.regexp "^temp_" in
if Re.Pcre.pmatch ~rex:tmp_rex v then v else "temp_" ^ v), if Re.Pcre.pmatch ~rex:tmp_rex v then v else "temp_" ^ v),
D.pos expr ) Expr.pos expr )
in in
let ctxt = let ctxt =
{ {
@ -112,20 +112,20 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) :
} }
in in
let tmp_stmts = translate_statements ctxt expr in let tmp_stmts = translate_statements ctxt expr in
( (A.SLocalDecl ((tmp_var, D.pos expr), (D.TAny, D.pos expr)), D.pos expr) ( (A.SLocalDecl ((tmp_var, Expr.pos expr), (TAny, Expr.pos expr)), Expr.pos expr)
:: tmp_stmts, :: tmp_stmts,
(A.EVar tmp_var, D.pos expr) ) (A.EVar tmp_var, Expr.pos expr) )
and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) : and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
A.block = A.block =
match Marked.unmark block_expr with match Marked.unmark block_expr with
| L.EAssert e -> | EAssert e ->
(* Assertions are always encapsulated in a unit-typed let binding *) (* Assertions are always encapsulated in a unit-typed let binding *)
let e_stmts, new_e = translate_expr ctxt e in let e_stmts, new_e = translate_expr ctxt e in
e_stmts @ [A.SAssert (Marked.unmark new_e), D.pos block_expr] e_stmts @ [A.SAssert (Marked.unmark new_e), Expr.pos block_expr]
| L.EApp ((L.EAbs (binder, taus), binder_mark), args) -> | EApp ((EAbs (binder, taus), binder_mark), args) ->
(* This defines multiple local variables at the time *) (* This defines multiple local variables at the time *)
let binder_pos = D.mark_pos binder_mark in let binder_pos = Expr.mark_pos binder_mark in
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in
let ctxt = let ctxt =
@ -170,13 +170,13 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
in in
let rest_of_block = translate_statements ctxt body in let rest_of_block = translate_statements ctxt body in
local_decls @ List.flatten def_blocks @ rest_of_block local_decls @ List.flatten def_blocks @ rest_of_block
| L.EAbs (binder, taus) -> | EAbs (binder, taus) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let binder_pos = D.pos block_expr in let binder_pos = Expr.pos block_expr in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in
let closure_name = let closure_name =
match ctxt.inside_definition_of with match ctxt.inside_definition_of with
| None -> A.LocalName.fresh (ctxt.context_name, D.pos block_expr) | None -> A.LocalName.fresh (ctxt.context_name, Expr.pos block_expr)
| Some x -> x | Some x -> x
in in
let ctxt = let ctxt =
@ -206,18 +206,18 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
} ), } ),
binder_pos ); binder_pos );
] ]
| L.EMatch (e1, args, e_name) -> | EMatch (e1, args, e_name) ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let new_args = let new_args =
List.fold_left List.fold_left
(fun new_args arg -> (fun new_args arg ->
match Marked.unmark arg with match Marked.unmark arg with
| L.EAbs (binder, _) -> | EAbs (binder, _) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
assert (Array.length vars = 1); assert (Array.length vars = 1);
let var = vars.(0) in let var = vars.(0) in
let scalc_var = let scalc_var =
A.LocalName.fresh (Bindlib.name_of var, D.pos arg) A.LocalName.fresh (Bindlib.name_of var, Expr.pos arg)
in in
let ctxt = let ctxt =
{ ctxt with var_dict = Var.Map.add var scalc_var ctxt.var_dict } { ctxt with var_dict = Var.Map.add var scalc_var ctxt.var_dict }
@ -229,17 +229,17 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
[] args [] args
in in
let new_args = List.rev new_args in let new_args = List.rev new_args in
e1_stmts @ [A.SSwitch (new_e1, e_name, new_args), D.pos block_expr] e1_stmts @ [A.SSwitch (new_e1, e_name, new_args), Expr.pos block_expr]
| L.EIfThenElse (cond, e_true, e_false) -> | EIfThenElse (cond, e_true, e_false) ->
let cond_stmts, s_cond = translate_expr ctxt cond in let cond_stmts, s_cond = translate_expr ctxt cond in
let s_e_true = translate_statements ctxt e_true in let s_e_true = translate_statements ctxt e_true in
let s_e_false = translate_statements ctxt e_false in let s_e_false = translate_statements ctxt e_false in
cond_stmts @ [A.SIfThenElse (s_cond, s_e_true, s_e_false), D.pos block_expr] cond_stmts @ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr]
| L.ECatch (e_try, except, e_catch) -> | ECatch (e_try, except, e_catch) ->
let s_e_try = translate_statements ctxt e_try in let s_e_try = translate_statements ctxt e_try in
let s_e_catch = translate_statements ctxt e_catch in let s_e_catch = translate_statements ctxt e_catch in
[A.STryExcept (s_e_try, except, s_e_catch), D.pos block_expr] [A.STryExcept (s_e_try, except, s_e_catch), Expr.pos block_expr]
| L.ERaise except -> | ERaise except ->
(* Before raising the exception, we still give a dummy definition to the (* Before raising the exception, we still give a dummy definition to the
current variable so that tools like mypy don't complain. *) current variable so that tools like mypy don't complain. *)
(match ctxt.inside_definition_of with (match ctxt.inside_definition_of with
@ -247,10 +247,10 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
| Some x -> | Some x ->
[ [
( A.SLocalDef ( A.SLocalDef
((x, D.pos block_expr), (Ast.EVar Ast.dead_value, D.pos block_expr)), ((x, Expr.pos block_expr), (Ast.EVar Ast.dead_value, Expr.pos block_expr)),
D.pos block_expr ); Expr.pos block_expr );
]) ])
@ [A.SRaise except, D.pos block_expr] @ [A.SRaise except, Expr.pos block_expr]
| _ -> ( | _ -> (
let e_stmts, new_e = translate_expr ctxt block_expr in let e_stmts, new_e = translate_expr ctxt block_expr in
e_stmts e_stmts
@ -266,15 +266,15 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) :
( (match ctxt.inside_definition_of with ( (match ctxt.inside_definition_of with
| None -> A.SReturn (Marked.unmark new_e) | None -> A.SReturn (Marked.unmark new_e)
| Some x -> A.SLocalDef (Marked.same_mark_as x new_e, new_e)), | Some x -> A.SLocalDef (Marked.same_mark_as x new_e, new_e)),
D.pos block_expr ); Expr.pos block_expr );
]) ])
let rec translate_scope_body_expr let rec translate_scope_body_expr
(scope_name : D.ScopeName.t) (scope_name : ScopeName.t)
(decl_ctx : D.decl_ctx) (decl_ctx : decl_ctx)
(var_dict : ('m L.expr, A.LocalName.t) Var.Map.t) (var_dict : ('m L.expr, A.LocalName.t) Var.Map.t)
(func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t) (func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t)
(scope_expr : ('m L.expr, 'm) D.scope_body_expr) : A.block = (scope_expr : ('m L.expr, 'm) scope_body_expr) : A.block =
match scope_expr with match scope_expr with
| Result e -> | Result e ->
let block, new_e = let block, new_e =
@ -284,7 +284,7 @@ let rec translate_scope_body_expr
func_dict; func_dict;
var_dict; var_dict;
inside_definition_of = None; inside_definition_of = None;
context_name = Marked.unmark (D.ScopeName.get_info scope_name); context_name = Marked.unmark (ScopeName.get_info scope_name);
} }
e e
in in
@ -296,14 +296,14 @@ let rec translate_scope_body_expr
in in
let new_var_dict = Var.Map.add let_var let_var_id var_dict in let new_var_dict = Var.Map.add let_var let_var_id var_dict in
(match scope_let.scope_let_kind with (match scope_let.scope_let_kind with
| D.Assertion -> | Assertion ->
translate_statements translate_statements
{ {
decl_ctx; decl_ctx;
func_dict; func_dict;
var_dict; var_dict;
inside_definition_of = Some let_var_id; inside_definition_of = Some let_var_id;
context_name = Marked.unmark (D.ScopeName.get_info scope_name); context_name = Marked.unmark (ScopeName.get_info scope_name);
} }
scope_let.scope_let_expr scope_let.scope_let_expr
| _ -> | _ ->
@ -314,7 +314,7 @@ let rec translate_scope_body_expr
func_dict; func_dict;
var_dict; var_dict;
inside_definition_of = Some let_var_id; inside_definition_of = Some let_var_id;
context_name = Marked.unmark (D.ScopeName.get_info scope_name); context_name = Marked.unmark (ScopeName.get_info scope_name);
} }
scope_let.scope_let_expr scope_let.scope_let_expr
in in
@ -331,16 +331,16 @@ let rec translate_scope_body_expr
let translate_program (p : 'm L.program) : A.program = let translate_program (p : 'm L.program) : A.program =
{ {
decl_ctx = p.D.decl_ctx; decl_ctx = p.decl_ctx;
scopes = scopes =
(let _, new_scopes = (let _, new_scopes =
D.fold_left_scope_defs Expr.fold_left_scope_defs
~f:(fun (func_dict, new_scopes) scope_def scope_var -> ~f:(fun (func_dict, new_scopes) scope_def scope_var ->
let scope_input_var, scope_body_expr = let scope_input_var, scope_body_expr =
Bindlib.unbind scope_def.scope_body.scope_body_expr Bindlib.unbind scope_def.scope_body.scope_body_expr
in in
let input_pos = let input_pos =
Marked.get_mark (D.ScopeName.get_info scope_def.scope_name) Marked.get_mark (ScopeName.get_info scope_def.scope_name)
in in
let scope_input_var_id = let scope_input_var_id =
A.LocalName.fresh (Bindlib.name_of scope_input_var, input_pos) A.LocalName.fresh (Bindlib.name_of scope_input_var, input_pos)
@ -349,7 +349,7 @@ let translate_program (p : 'm L.program) : A.program =
Var.Map.singleton scope_input_var scope_input_var_id Var.Map.singleton scope_input_var scope_input_var_id
in in
let new_scope_body = let new_scope_body =
translate_scope_body_expr scope_def.D.scope_name p.decl_ctx translate_scope_body_expr scope_def.scope_name p.decl_ctx
var_dict func_dict scope_body_expr var_dict func_dict scope_body_expr
in in
let func_id = let func_id =
@ -358,22 +358,22 @@ let translate_program (p : 'm L.program) : A.program =
let func_dict = Var.Map.add scope_var func_id func_dict in let func_dict = Var.Map.add scope_var func_id func_dict in
( func_dict, ( func_dict,
{ {
Ast.scope_body_name = scope_def.D.scope_name; Ast.scope_body_name = scope_def.scope_name;
Ast.scope_body_var = func_id; Ast.scope_body_var = func_id;
scope_body_func = scope_body_func =
{ {
A.func_params = A.func_params =
[ [
( (scope_input_var_id, input_pos), ( (scope_input_var_id, input_pos),
( D.TTuple ( TTuple
( List.map snd ( List.map snd
(D.StructMap.find (StructMap.find
scope_def.D.scope_body scope_def.scope_body
.D.scope_body_input_struct .scope_body_input_struct
p.D.decl_ctx.ctx_structs), p.decl_ctx.ctx_structs),
Some Some
scope_def.D.scope_body scope_def.scope_body
.D.scope_body_input_struct ), .scope_body_input_struct ),
input_pos ) ); input_pos ) );
]; ];
A.func_body = new_scope_body; A.func_body = new_scope_body;
@ -385,7 +385,7 @@ let translate_program (p : 'm L.program) : A.program =
Var.Map.singleton L.handle_default_opt A.handle_default_opt Var.Map.singleton L.handle_default_opt A.handle_default_opt
else Var.Map.singleton L.handle_default A.handle_default), else Var.Map.singleton L.handle_default A.handle_default),
[] ) [] )
p.D.scopes p.scopes
in in
List.rev new_scopes); List.rev new_scopes);
} }

View File

@ -15,6 +15,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
let needs_parens (_e : expr Marked.pos) : bool = false let needs_parens (_e : expr Marked.pos) : bool = false
@ -24,7 +25,7 @@ let format_local_name (fmt : Format.formatter) (v : LocalName.t) : unit =
(string_of_int (LocalName.hash v)) (string_of_int (LocalName.hash v))
let rec format_expr let rec format_expr
(decl_ctx : Dcalc.Ast.decl_ctx) (decl_ctx : decl_ctx)
?(debug : bool = false) ?(debug : bool = false)
(fmt : Format.formatter) (fmt : Format.formatter)
(e : expr Marked.pos) : unit = (e : expr Marked.pos) : unit =
@ -39,17 +40,17 @@ let rec format_expr
| EVar v -> Format.fprintf fmt "%a" format_local_name v | EVar v -> Format.fprintf fmt "%a" format_local_name v
| EFunc v -> Format.fprintf fmt "%a" TopLevelName.format_t v | EFunc v -> Format.fprintf fmt "%a" TopLevelName.format_t v
| EStruct (es, s) -> | EStruct (es, s) ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" StructName.format_t s
Dcalc.Print.format_punctuation "{" Dcalc.Print.format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation "\"" Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation "\""
Dcalc.Ast.StructFieldName.format_t struct_field StructFieldName.format_t struct_field
Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation
":" format_expr e)) ":" format_expr e))
(List.combine es (List.combine es
(List.map fst (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) (List.map fst (StructMap.find s decl_ctx.ctx_structs)))
Dcalc.Print.format_punctuation "}" Dcalc.Print.format_punctuation "}"
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Dcalc.Print.format_punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a@]" Dcalc.Print.format_punctuation "["
@ -60,24 +61,24 @@ let rec format_expr
| EStructFieldAccess (e1, field, s) -> | EStructFieldAccess (e1, field, s) ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Format.fprintf fmt "%a%a%a%a%a" format_expr e1
Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\""
Dcalc.Ast.StructFieldName.format_t StructFieldName.format_t
(fst (fst
(List.find (List.find
(fun (field', _) -> (fun (field', _) ->
Dcalc.Ast.StructFieldName.compare field' field = 0) StructFieldName.compare field' field = 0)
(Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) (StructMap.find s decl_ctx.ctx_structs)))
Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation "\""
| EInj (e, case, enum) -> | EInj (e, case, enum) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor
(fst (fst
(List.find (List.find
(fun (case', _) -> Dcalc.Ast.EnumConstructor.compare case' case = 0) (fun (case', _) -> EnumConstructor.compare case' case = 0)
(Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums))) (EnumMap.find enum decl_ctx.ctx_enums)))
format_expr e format_expr e
| ELit l -> | ELit l ->
Format.fprintf fmt "%a" Lcalc.Print.format_lit (Marked.same_mark_as l e) Format.fprintf fmt "%a" Lcalc.Print.format_lit (Marked.same_mark_as l e)
| EApp | EApp
((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2])
-> ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop op Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop op
format_with_parens arg1 format_with_parens arg2 format_with_parens arg1 format_with_parens arg2
@ -100,7 +101,7 @@ let rec format_expr
| EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop op | EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop op
let rec format_statement let rec format_statement
(decl_ctx : Dcalc.Ast.decl_ctx) (decl_ctx : decl_ctx)
?(debug : bool = false) ?(debug : bool = false)
(fmt : Format.formatter) (fmt : Format.formatter)
(stmt : stmt Marked.pos) : unit = (stmt : stmt Marked.pos) : unit =
@ -174,10 +175,10 @@ let rec format_statement
Dcalc.Print.format_punctuation "" Dcalc.Print.format_punctuation ""
(format_block decl_ctx ~debug) (format_block decl_ctx ~debug)
arm_block)) arm_block))
(List.combine (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums) arms) (List.combine (EnumMap.find enum decl_ctx.ctx_enums) arms)
and format_block and format_block
(decl_ctx : Dcalc.Ast.decl_ctx) (decl_ctx : decl_ctx)
?(debug : bool = false) ?(debug : bool = false)
(fmt : Format.formatter) (fmt : Format.formatter)
(block : block) : unit = (block : block) : unit =
@ -188,7 +189,7 @@ and format_block
fmt block fmt block
let format_scope let format_scope
(decl_ctx : Dcalc.Ast.decl_ctx) (decl_ctx : decl_ctx)
?(debug : bool = false) ?(debug : bool = false)
(fmt : Format.formatter) (fmt : Format.formatter)
(body : scope_body) : unit = (body : scope_body) : unit =

View File

@ -15,7 +15,7 @@
the License. *) the License. *)
val format_scope : val format_scope :
Dcalc.Ast.decl_ctx -> Shared_ast.decl_ctx ->
?debug:bool -> ?debug:bool ->
Format.formatter -> Format.formatter ->
Ast.scope_body -> Ast.scope_body ->

View File

@ -16,6 +16,7 @@
[@@@warning "-32-27"] [@@@warning "-32-27"]
open Utils open Utils
open Shared_ast
open Ast open Ast
open String_common open String_common
module Runtime = Runtime_ocaml.Runtime module Runtime = Runtime_ocaml.Runtime
@ -31,7 +32,7 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
| LUnit -> Format.fprintf fmt "Unit()" | LUnit -> Format.fprintf fmt "Unit()"
| LRat i -> | LRat i ->
Format.fprintf fmt "decimal_of_string(\"%a\")" Dcalc.Print.format_lit Format.fprintf fmt "decimal_of_string(\"%a\")" Dcalc.Print.format_lit
(Dcalc.Ast.LRat i) (LRat i)
| LMoney e -> | LMoney e ->
Format.fprintf fmt "money_of_cents_string(\"%s\")" Format.fprintf fmt "money_of_cents_string(\"%s\")"
(Runtime.integer_to_string (Runtime.money_to_cents e)) (Runtime.integer_to_string (Runtime.money_to_cents e))
@ -44,7 +45,7 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
let years, months, days = Runtime.duration_to_years_months_days d in let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days
let format_log_entry (fmt : Format.formatter) (entry : Dcalc.Ast.log_entry) : let format_log_entry (fmt : Format.formatter) (entry : log_entry) :
unit = unit =
match entry with match entry with
| VarDef _ -> Format.fprintf fmt ":=" | VarDef _ -> Format.fprintf fmt ":="
@ -52,13 +53,13 @@ let format_log_entry (fmt : Format.formatter) (entry : Dcalc.Ast.log_entry) :
| EndCall -> Format.fprintf fmt "%s" "" | EndCall -> Format.fprintf fmt "%s" ""
| PosRecordIfTrueBool -> Format.fprintf fmt "" | PosRecordIfTrueBool -> Format.fprintf fmt ""
let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) : let format_binop (fmt : Format.formatter) (op : binop Marked.pos) :
unit = unit =
match Marked.unmark op with match Marked.unmark op with
| Add _ | Concat -> Format.fprintf fmt "+" | Add _ | Concat -> Format.fprintf fmt "+"
| Sub _ -> Format.fprintf fmt "-" | Sub _ -> Format.fprintf fmt "-"
| Mult _ -> Format.fprintf fmt "*" | Mult _ -> Format.fprintf fmt "*"
| Div D.KInt -> Format.fprintf fmt "//" | Div KInt -> Format.fprintf fmt "//"
| Div _ -> Format.fprintf fmt "/" | Div _ -> Format.fprintf fmt "/"
| And -> Format.fprintf fmt "and" | And -> Format.fprintf fmt "and"
| Or -> Format.fprintf fmt "or" | Or -> Format.fprintf fmt "or"
@ -71,7 +72,7 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) :
| Map -> Format.fprintf fmt "list_map" | Map -> Format.fprintf fmt "list_map"
| Filter -> Format.fprintf fmt "list_filter" | Filter -> Format.fprintf fmt "list_filter"
let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Marked.pos) : let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) :
unit = unit =
match Marked.unmark op with Fold -> Format.fprintf fmt "list_fold_left" match Marked.unmark op with Fold -> Format.fprintf fmt "list_fold_left"
@ -94,7 +95,7 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids uids
let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Marked.pos) : unit let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit
= =
match Marked.unmark op with match Marked.unmark op with
| Minus _ -> Format.fprintf fmt "-" | Minus _ -> Format.fprintf fmt "-"
@ -127,43 +128,43 @@ let avoid_keywords (s : string) : string =
then s ^ "_" then s ^ "_"
else s else s
let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : let format_struct_name (fmt : Format.formatter) (v : StructName.t) :
unit = unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_camel_case (to_camel_case
(to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v)))) (to_ascii (Format.asprintf "%a" StructName.format_t v))))
let format_struct_field_name let format_struct_field_name
(fmt : Format.formatter) (fmt : Format.formatter)
(v : Dcalc.Ast.StructFieldName.t) : unit = (v : StructFieldName.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) (to_ascii (Format.asprintf "%a" StructFieldName.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit
= =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_camel_case (to_camel_case
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) (to_ascii (Format.asprintf "%a" EnumName.format_t v))))
let format_enum_cons_name let format_enum_cons_name
(fmt : Format.formatter) (fmt : Format.formatter)
(v : Dcalc.Ast.EnumConstructor.t) : unit = (v : EnumConstructor.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) (to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
let typ_needs_parens (e : Dcalc.Ast.typ Marked.pos) : bool = let typ_needs_parens (e : typ Marked.pos) : bool =
match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) :
unit = unit =
let format_typ = format_typ in let format_typ = format_typ in
let format_typ_with_parens let format_typ_with_parens
(fmt : Format.formatter) (fmt : Format.formatter)
(t : Dcalc.Ast.typ Marked.pos) = (t : typ Marked.pos) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t else Format.fprintf fmt "%a" format_typ t
in in
@ -182,7 +183,7 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) :
(fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t))
ts ts
| TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s | TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s
| TEnum ([_; some_typ], e) when D.EnumName.compare e L.option_enum = 0 -> | TEnum ([_; some_typ], e) when EnumName.compare e L.option_enum = 0 ->
(* We translate the option type with an overloading by Python's [None] *) (* We translate the option type with an overloading by Python's [None] *)
Format.fprintf fmt "Optional[%a]" format_typ some_typ Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e | TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e
@ -251,7 +252,7 @@ let needs_parens (e : expr Marked.pos) : bool =
| ELit (LBool _ | LUnit) | EVar _ | EOp _ -> false | ELit (LBool _ | LUnit) | EVar _ | EOp _ -> false
| _ -> true | _ -> true
let format_exception (fmt : Format.formatter) (exc : L.except Marked.pos) : unit let format_exception (fmt : Format.formatter) (exc : except Marked.pos) : unit
= =
let pos = Marked.get_mark exc in let pos = Marked.get_mark exc in
match Marked.unmark exc with match Marked.unmark exc with
@ -275,7 +276,7 @@ let format_exception (fmt : Format.formatter) (exc : L.except Marked.pos) : unit
(Pos.get_law_info pos) (Pos.get_law_info pos)
let rec format_expression let rec format_expression
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(e : expr Marked.pos) : unit = (e : expr Marked.pos) : unit =
match Marked.unmark e with match Marked.unmark e with
@ -289,18 +290,18 @@ let rec format_expression
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
(format_expression ctx) e)) (format_expression ctx) e))
(List.combine es (List.combine es
(List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) (List.map fst (StructMap.find s ctx.ctx_structs)))
| EStructFieldAccess (e1, field, _) -> | EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1 Format.fprintf fmt "%a.%a" (format_expression ctx) e1
format_struct_field_name field format_struct_field_name field
| EInj (_, cons, e_name) | EInj (_, cons, e_name)
when D.EnumName.compare e_name L.option_enum = 0 when EnumName.compare e_name L.option_enum = 0
&& D.EnumConstructor.compare cons L.none_constr = 0 -> && EnumConstructor.compare cons L.none_constr = 0 ->
(* We translate the option type with an overloading by Python's [None] *) (* We translate the option type with an overloading by Python's [None] *)
Format.fprintf fmt "None" Format.fprintf fmt "None"
| EInj (e, cons, e_name) | EInj (e, cons, e_name)
when D.EnumName.compare e_name L.option_enum = 0 when EnumName.compare e_name L.option_enum = 0
&& D.EnumConstructor.compare cons L.some_constr = 0 -> && EnumConstructor.compare cons L.some_constr = 0 ->
(* We translate the option type with an overloading by Python's [None] *) (* We translate the option type with an overloading by Python's [None] *)
format_expression ctx fmt e format_expression ctx fmt e
| EInj (e, cons, enum_name) -> | EInj (e, cons, enum_name) ->
@ -315,22 +316,22 @@ let rec format_expression
es es
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e) | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e)
| EApp | EApp
((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2])
-> ->
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos)
(format_expression ctx) arg1 (format_expression ctx) arg2 (format_expression ctx) arg1 (format_expression ctx) arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) -> | EApp ((EOp (Binop op), _), [arg1; arg2]) ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop
(op, Pos.no_pos) (format_expression ctx) arg2 (op, Pos.no_pos) (format_expression ctx) arg2
| EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [f]), _), [arg]) | EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info
(format_expression ctx) f (format_expression ctx) arg (format_expression ctx) f (format_expression ctx) arg
| EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [arg1]) | EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [arg1]) | EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), pos), [arg1])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt Format.fprintf fmt
"log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \ "log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \
@ -338,11 +339,11 @@ let rec format_expression
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) (format_expression ctx) arg1 (Pos.get_law_info pos) (format_expression ctx) arg1
| EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [arg1]) | EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EOp (Unop (D.Log _)), _), [arg1]) -> | EApp ((EOp (Unop (Log _)), _), [arg1]) ->
Format.fprintf fmt "%a" (format_expression ctx) arg1 Format.fprintf fmt "%a" (format_expression ctx) arg1
| EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) -> | EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) ->
Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos) Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos)
@ -374,7 +375,7 @@ let rec format_expression
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
let rec format_statement let rec format_statement
(ctx : Dcalc.Ast.decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(s : stmt Marked.pos) : unit = (s : stmt Marked.pos) : unit =
match Marked.unmark s with match Marked.unmark s with
@ -403,7 +404,7 @@ let rec format_statement
Format.fprintf fmt "@[<hov 4>if %a:@\n%a@]@\n@[<hov 4>else:@\n%a@]" Format.fprintf fmt "@[<hov 4>if %a:@\n%a@]@\n@[<hov 4>else:@\n%a@]"
(format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch (e1, e_name, [(case_none, _); (case_some, case_some_var)]) | SSwitch (e1, e_name, [(case_none, _); (case_some, case_some_var)])
when D.EnumName.compare e_name L.option_enum = 0 -> when EnumName.compare e_name L.option_enum = 0 ->
(* We translate the option type with an overloading by Python's [None] *) (* We translate the option type with an overloading by Python's [None] *)
let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in
Format.fprintf fmt Format.fprintf fmt
@ -421,7 +422,7 @@ let rec format_statement
List.map2 List.map2
(fun (x, y) (cons, _) -> x, y, cons) (fun (x, y) (cons, _) -> x, y, cons)
cases cases
(D.EnumMap.find e_name ctx.ctx_enums) (EnumMap.find e_name ctx.ctx_enums)
in in
let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var
@ -450,7 +451,7 @@ let rec format_statement
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) (Pos.get_law_info pos)
and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block) and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block)
: unit = : unit =
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
@ -462,7 +463,7 @@ and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block)
let format_ctx let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list) (type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter) (fmt : Format.formatter)
(ctx : D.decl_ctx) : unit = (ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) = let format_struct_decl fmt (struct_name, struct_fields) =
Format.fprintf fmt Format.fprintf fmt
"class %a:@\n\ "class %a:@\n\
@ -562,8 +563,8 @@ let format_ctx
let scope_structs = let scope_structs =
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(Dcalc.Ast.StructMap.bindings (StructMap.bindings
(Dcalc.Ast.StructMap.filter (StructMap.filter
(fun s _ -> not (is_in_type_ordering s)) (fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs)) ctx.ctx_structs))
in in
@ -572,10 +573,10 @@ let format_ctx
match struct_or_enum with match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s -> | Scopelang.Dependency.TVertex.Struct s ->
Format.fprintf fmt "%a@\n@\n" format_struct_decl Format.fprintf fmt "%a@\n@\n" format_struct_decl
(s, Dcalc.Ast.StructMap.find s ctx.Dcalc.Ast.ctx_structs) (s, StructMap.find s ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e -> | Scopelang.Dependency.TVertex.Enum e ->
Format.fprintf fmt "%a@\n@\n" format_enum_decl Format.fprintf fmt "%a@\n@\n" format_enum_decl
(e, Dcalc.Ast.EnumMap.find e ctx.Dcalc.Ast.ctx_enums)) (e, EnumMap.find e ctx.ctx_enums))
(type_ordering @ scope_structs) (type_ordering @ scope_structs)
let format_program let format_program

View File

@ -15,8 +15,8 @@
the License. *) the License. *)
open Utils open Utils
module ScopeName = Dcalc.Ast.ScopeName open Shared_ast
module ScopeNameSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName)
module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName) module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName)
module SubScopeName : Uid.Id with type info = Uid.MarkedString.info = module SubScopeName : Uid.Id with type info = Uid.MarkedString.info =
@ -33,17 +33,11 @@ module ScopeVar : Uid.Id with type info = Uid.MarkedString.info =
module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar) module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar)
module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar) module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar)
module StructName = Dcalc.Ast.StructName
module StructMap = Dcalc.Ast.StructMap
module StructFieldName = Dcalc.Ast.StructFieldName
module StructFieldMap : Map.S with type key = StructFieldName.t = module StructFieldMap : Map.S with type key = StructFieldName.t =
Map.Make (StructFieldName) Map.Make (StructFieldName)
module StructFieldMapLift = Bindlib.Lift (StructFieldMap) module StructFieldMapLift = Bindlib.Lift (StructFieldMap)
module EnumName = Dcalc.Ast.EnumName
module EnumMap = Dcalc.Ast.EnumMap
module EnumConstructor = Dcalc.Ast.EnumConstructor
module EnumConstructorMap : Map.S with type key = EnumConstructor.t = module EnumConstructorMap : Map.S with type key = EnumConstructor.t =
Map.Make (EnumConstructor) Map.Make (EnumConstructor)
@ -71,7 +65,7 @@ Set.Make (struct
end) end)
type typ = type typ =
| TLit of Dcalc.Ast.typ_lit | TLit of typ_lit
| TStruct of StructName.t | TStruct of StructName.t
| TEnum of EnumName.t | TEnum of EnumName.t
| TArrow of typ Marked.pos * typ Marked.pos | TArrow of typ Marked.pos * typ Marked.pos
@ -114,7 +108,7 @@ and expr =
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of (expr, marked_expr) Bindlib.mbinder * typ Marked.pos list | EAbs of (expr, marked_expr) Bindlib.mbinder * typ Marked.pos list
| EApp of marked_expr * marked_expr list | EApp of marked_expr * marked_expr list
| EOp of Dcalc.Ast.operator | EOp of operator
| EDefault of marked_expr list * marked_expr * marked_expr | EDefault of marked_expr list * marked_expr * marked_expr
| EIfThenElse of marked_expr * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr
| EArray of marked_expr list | EArray of marked_expr list
@ -319,9 +313,9 @@ let make_let_in
let make_default ?(pos = Pos.no_pos) exceptions just cons = let make_default ?(pos = Pos.no_pos) exceptions just cons =
let rec bool_value = function let rec bool_value = function
| ELit (Dcalc.Ast.LBool b), _ -> Some b | ELit (LBool b), _ -> Some b
| EApp ((EOp (Unop (Log (l, _))), _), [e]), _ | EApp ((EOp (Unop (Log (l, _))), _), [e]), _
when l <> Dcalc.Ast.PosRecordIfTrueBool when l <> PosRecordIfTrueBool
(* we don't remove the log calls corresponding to source code (* we don't remove the log calls corresponding to source code
definitions !*) -> definitions !*) ->
bool_value e bool_value e

View File

@ -17,11 +17,10 @@
(** Abstract syntax tree of the scope language *) (** Abstract syntax tree of the scope language *)
open Utils open Utils
open Shared_ast
(** {1 Identifiers} *) (** {1 Identifiers} *)
module ScopeName = Dcalc.Ast.ScopeName
module ScopeNameSet : Set.S with type elt = ScopeName.t
module ScopeMap : Map.S with type key = ScopeName.t module ScopeMap : Map.S with type key = ScopeName.t
module SubScopeName : Uid.Id with type info = Uid.MarkedString.info module SubScopeName : Uid.Id with type info = Uid.MarkedString.info
module SubScopeNameSet : Set.S with type elt = SubScopeName.t module SubScopeNameSet : Set.S with type elt = SubScopeName.t
@ -29,9 +28,6 @@ module SubScopeMap : Map.S with type key = SubScopeName.t
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info module ScopeVar : Uid.Id with type info = Uid.MarkedString.info
module ScopeVarSet : Set.S with type elt = ScopeVar.t module ScopeVarSet : Set.S with type elt = ScopeVar.t
module ScopeVarMap : Map.S with type key = ScopeVar.t module ScopeVarMap : Map.S with type key = ScopeVar.t
module StructName = Dcalc.Ast.StructName
module StructMap = Dcalc.Ast.StructMap
module StructFieldName = Dcalc.Ast.StructFieldName
module StructFieldMap : Map.S with type key = StructFieldName.t module StructFieldMap : Map.S with type key = StructFieldName.t
module StructFieldMapLift : sig module StructFieldMapLift : sig
@ -39,9 +35,6 @@ module StructFieldMapLift : sig
'a Bindlib.box StructFieldMap.t -> 'a StructFieldMap.t Bindlib.box 'a Bindlib.box StructFieldMap.t -> 'a StructFieldMap.t Bindlib.box
end end
module EnumName = Dcalc.Ast.EnumName
module EnumMap = Dcalc.Ast.EnumMap
module EnumConstructor = Dcalc.Ast.EnumConstructor
module EnumConstructorMap : Map.S with type key = EnumConstructor.t module EnumConstructorMap : Map.S with type key = EnumConstructor.t
module EnumConstructorMapLift : sig module EnumConstructorMapLift : sig
@ -59,7 +52,7 @@ module LocationSet : Set.S with type elt = location Marked.pos
(** {1 Abstract syntax tree} *) (** {1 Abstract syntax tree} *)
type typ = type typ =
| TLit of Dcalc.Ast.typ_lit | TLit of typ_lit
| TStruct of StructName.t | TStruct of StructName.t
| TEnum of EnumName.t | TEnum of EnumName.t
| TArrow of typ Marked.pos * typ Marked.pos | TArrow of typ Marked.pos * typ Marked.pos
@ -82,7 +75,7 @@ and expr =
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of (expr, marked_expr) Bindlib.mbinder * typ Marked.pos list | EAbs of (expr, marked_expr) Bindlib.mbinder * typ Marked.pos list
| EApp of marked_expr * marked_expr list | EApp of marked_expr * marked_expr list
| EOp of Dcalc.Ast.operator | EOp of operator
| EDefault of marked_expr list * marked_expr * marked_expr | EDefault of marked_expr list * marked_expr * marked_expr
| EIfThenElse of marked_expr * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr
| EArray of marked_expr list | EArray of marked_expr list

View File

@ -18,13 +18,14 @@
program. Vertices are functions, x -> y if x is used in the definition of y. *) program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils open Utils
open Shared_ast
module SVertex = struct module SVertex = struct
type t = Ast.ScopeName.t type t = ScopeName.t
let hash x = Ast.ScopeName.hash x let hash x = ScopeName.hash x
let compare = Ast.ScopeName.compare let compare = ScopeName.compare
let equal x y = Ast.ScopeName.compare x y = 0 let equal x y = ScopeName.compare x y = 0
end end
(** On the edges, the label is the expression responsible for the use of the (** On the edges, the label is the expression responsible for the use of the
@ -62,10 +63,10 @@ let build_program_dep_graph (prgm : Ast.program) : SDependencies.t =
if subscope = scope_name then if subscope = scope_name then
Errors.raise_spanned_error Errors.raise_spanned_error
(Marked.get_mark (Marked.get_mark
(Ast.ScopeName.get_info scope.Ast.scope_decl_name)) (ScopeName.get_info scope.Ast.scope_decl_name))
"The scope %a is calling into itself as a subscope, which is \ "The scope %a is calling into itself as a subscope, which is \
forbidden since Catala does not provide recursion" forbidden since Catala does not provide recursion"
Ast.ScopeName.format_t scope.Ast.scope_decl_name ScopeName.format_t scope.Ast.scope_decl_name
else else
Ast.ScopeMap.add subscope Ast.ScopeMap.add subscope
(Marked.get_mark (Ast.SubScopeName.get_info subindex)) (Marked.get_mark (Ast.SubScopeName.get_info subindex))
@ -90,14 +91,14 @@ let check_for_cycle_in_scope (g : SDependencies.t) : unit =
(List.map (List.map
(fun v -> (fun v ->
let var_str, var_info = let var_str, var_info =
( Format.asprintf "%a" Ast.ScopeName.format_t v, ( Format.asprintf "%a" ScopeName.format_t v,
Ast.ScopeName.get_info v ) ScopeName.get_info v )
in in
let succs = SDependencies.succ_e g v in let succs = SDependencies.succ_e g v in
let _, edge_pos, succ = let _, edge_pos, succ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs List.find (fun (_, _, succ) -> List.mem succ scc) succs
in in
let succ_str = Format.asprintf "%a" Ast.ScopeName.format_t succ in let succ_str = Format.asprintf "%a" ScopeName.format_t succ in
[ [
( Some ("Cycle variable " ^ var_str ^ ", declared:"), ( Some ("Cycle variable " ^ var_str ^ ", declared:"),
Marked.get_mark var_info ); Marked.get_mark var_info );
@ -112,39 +113,39 @@ let check_for_cycle_in_scope (g : SDependencies.t) : unit =
Errors.raise_multispanned_error spans Errors.raise_multispanned_error spans
"Cyclic dependency detected between scopes!" "Cyclic dependency detected between scopes!"
let get_scope_ordering (g : SDependencies.t) : Ast.ScopeName.t list = let get_scope_ordering (g : SDependencies.t) : ScopeName.t list =
List.rev (STopologicalTraversal.fold (fun sd acc -> sd :: acc) g []) List.rev (STopologicalTraversal.fold (fun sd acc -> sd :: acc) g [])
module TVertex = struct module TVertex = struct
type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t type t = Struct of StructName.t | Enum of EnumName.t
let hash x = let hash x =
match x with match x with
| Struct x -> Ast.StructName.hash x | Struct x -> StructName.hash x
| Enum x -> Ast.EnumName.hash x | Enum x -> EnumName.hash x
let compare x y = let compare x y =
match x, y with match x, y with
| Struct x, Struct y -> Ast.StructName.compare x y | Struct x, Struct y -> StructName.compare x y
| Enum x, Enum y -> Ast.EnumName.compare x y | Enum x, Enum y -> EnumName.compare x y
| Struct _, Enum _ -> 1 | Struct _, Enum _ -> 1
| Enum _, Struct _ -> -1 | Enum _, Struct _ -> -1
let equal x y = let equal x y =
match x, y with match x, y with
| Struct x, Struct y -> Ast.StructName.compare x y = 0 | Struct x, Struct y -> StructName.compare x y = 0
| Enum x, Enum y -> Ast.EnumName.compare x y = 0 | Enum x, Enum y -> EnumName.compare x y = 0
| _ -> false | _ -> false
let format_t (fmt : Format.formatter) (x : t) : unit = let format_t (fmt : Format.formatter) (x : t) : unit =
match x with match x with
| Struct x -> Ast.StructName.format_t fmt x | Struct x -> StructName.format_t fmt x
| Enum x -> Ast.EnumName.format_t fmt x | Enum x -> EnumName.format_t fmt x
let get_info (x : t) = let get_info (x : t) =
match x with match x with
| Struct x -> Ast.StructName.get_info x | Struct x -> StructName.get_info x
| Enum x -> Ast.EnumName.get_info x | Enum x -> EnumName.get_info x
end end
module TVertexSet = Set.Make (TVertex) module TVertexSet = Set.Make (TVertex)
@ -181,7 +182,7 @@ let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) :
TDependencies.t = TDependencies.t =
let g = TDependencies.empty in let g = TDependencies.empty in
let g = let g =
Ast.StructMap.fold StructMap.fold
(fun s fields g -> (fun s fields g ->
List.fold_left List.fold_left
(fun g (_, typ) -> (fun g (_, typ) ->
@ -205,7 +206,7 @@ let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) :
structs g structs g
in in
let g = let g =
Ast.EnumMap.fold EnumMap.fold
(fun e cases g -> (fun e cases g ->
List.fold_left List.fold_left
(fun g (_, typ) -> (fun g (_, typ) ->

View File

@ -18,25 +18,26 @@
program. Vertices are functions, x -> y if x is used in the definition of y. *) program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils open Utils
open Shared_ast
(** {1 Scope dependencies} *) (** {1 Scope dependencies} *)
(** On the edges, the label is the expression responsible for the use of the (** On the edges, the label is the expression responsible for the use of the
function *) function *)
module SDependencies : module SDependencies :
Graph.Sig.P with type V.t = Ast.ScopeName.t and type E.label = Pos.t Graph.Sig.P with type V.t = ScopeName.t and type E.label = Pos.t
val build_program_dep_graph : Ast.program -> SDependencies.t val build_program_dep_graph : Ast.program -> SDependencies.t
val check_for_cycle_in_scope : SDependencies.t -> unit val check_for_cycle_in_scope : SDependencies.t -> unit
val get_scope_ordering : SDependencies.t -> Ast.ScopeName.t list val get_scope_ordering : SDependencies.t -> ScopeName.t list
(** {1 Type dependencies} *) (** {1 Type dependencies} *)
module TVertex : sig module TVertex : sig
type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t type t = Struct of StructName.t | Enum of EnumName.t
val format_t : Format.formatter -> t -> unit val format_t : Format.formatter -> t -> unit
val get_info : t -> Ast.StructName.info val get_info : t -> StructName.info
include Graph.Sig.COMPARABLE with type t := t include Graph.Sig.COMPARABLE with type t := t
end end

View File

@ -15,6 +15,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Ast open Ast
let needs_parens (e : expr Marked.pos) : bool = let needs_parens (e : expr Marked.pos) : bool =
@ -42,8 +43,8 @@ let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) : unit =
in in
match Marked.unmark typ with match Marked.unmark typ with
| TLit l -> Dcalc.Print.format_tlit fmt l | TLit l -> Dcalc.Print.format_tlit fmt l
| TStruct s -> Format.fprintf fmt "%a" Ast.StructName.format_t s | TStruct s -> Format.fprintf fmt "%a" StructName.format_t s
| TEnum e -> Format.fprintf fmt "%a" Ast.EnumName.format_t e | TEnum e -> Format.fprintf fmt "%a" EnumName.format_t e
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_typ_with_parens t1 Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_typ_with_parens t1
Dcalc.Print.format_operator "" format_typ t2 Dcalc.Print.format_operator "" format_typ t2
@ -67,14 +68,14 @@ let rec format_expr
| EVar v -> Format.fprintf fmt "%a" format_var v | EVar v -> Format.fprintf fmt "%a" format_var v
| ELit l -> Format.fprintf fmt "%a" Dcalc.Print.format_lit l | ELit l -> Format.fprintf fmt "%a" Dcalc.Print.format_lit l
| EStruct (name, fields) -> | EStruct (name, fields) ->
Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" Ast.StructName.format_t name Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
Dcalc.Print.format_punctuation "{" Dcalc.Print.format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> ~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";") Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";")
(fun fmt (field_name, field_expr) -> (fun fmt (field_name, field_expr) ->
Format.fprintf fmt "%a%a%a%a@ %a" Dcalc.Print.format_punctuation "\"" Format.fprintf fmt "%a%a%a%a@ %a" Dcalc.Print.format_punctuation "\""
Ast.StructFieldName.format_t field_name StructFieldName.format_t field_name
Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation
"=" format_expr field_expr)) "=" format_expr field_expr))
(Ast.StructFieldMap.bindings fields) (Ast.StructFieldMap.bindings fields)
@ -82,9 +83,9 @@ let rec format_expr
| EStructAccess (e1, field, _) -> | EStructAccess (e1, field, _) ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Format.fprintf fmt "%a%a%a%a%a" format_expr e1
Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\""
Ast.StructFieldName.format_t field Dcalc.Print.format_punctuation "\"" StructFieldName.format_t field Dcalc.Print.format_punctuation "\""
| EEnumInj (e1, cons, _) -> | EEnumInj (e1, cons, _) ->
Format.fprintf fmt "%a@ %a" Ast.EnumConstructor.format_t cons format_expr e1 Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons format_expr e1
| EMatch (e1, _, cases) -> | EMatch (e1, _, cases) ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]"
Dcalc.Print.format_keyword "match" format_expr e1 Dcalc.Print.format_keyword "match" format_expr e1

View File

@ -29,7 +29,7 @@ val format_expr :
val format_scope : val format_scope :
?debug:bool (** [true] for debug printing *) -> ?debug:bool (** [true] for debug printing *) ->
Format.formatter -> Format.formatter ->
Ast.ScopeName.t * Ast.scope_decl -> Shared_ast.ScopeName.t * Ast.scope_decl ->
unit unit
val format_program : val format_program :

View File

@ -19,18 +19,18 @@ open Shared_ast
type scope_var_ctx = { type scope_var_ctx = {
scope_var_name : Ast.ScopeVar.t; scope_var_name : Ast.ScopeVar.t;
scope_var_typ : Dcalc.Ast.typ; scope_var_typ : typ;
scope_var_io : Ast.io; scope_var_io : Ast.io;
} }
type scope_sig_ctx = { type scope_sig_ctx = {
scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *) scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *)
scope_sig_scope_var : Dcalc.Ast.untyped Dcalc.Ast.var; scope_sig_scope_var : untyped Dcalc.Ast.var;
(** Var representing the scope *) (** Var representing the scope *)
scope_sig_input_var : Dcalc.Ast.untyped Dcalc.Ast.var; scope_sig_input_var : untyped Dcalc.Ast.var;
(** Var representing the scope input inside the scope func *) (** Var representing the scope input inside the scope func *)
scope_sig_input_struct : Ast.StructName.t; (** Scope input *) scope_sig_input_struct : StructName.t; (** Scope input *)
scope_sig_output_struct : Ast.StructName.t; (** Scope output *) scope_sig_output_struct : StructName.t; (** Scope output *)
} }
type scope_sigs_ctx = scope_sig_ctx Ast.ScopeMap.t type scope_sigs_ctx = scope_sig_ctx Ast.ScopeMap.t
@ -38,21 +38,21 @@ type scope_sigs_ctx = scope_sig_ctx Ast.ScopeMap.t
type ctx = { type ctx = {
structs : Ast.struct_ctx; structs : Ast.struct_ctx;
enums : Ast.enum_ctx; enums : Ast.enum_ctx;
scope_name : Ast.ScopeName.t; scope_name : ScopeName.t;
scopes_parameters : scope_sigs_ctx; scopes_parameters : scope_sigs_ctx;
scope_vars : scope_vars :
(Dcalc.Ast.untyped Dcalc.Ast.var * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t; (untyped Dcalc.Ast.var * typ * Ast.io) Ast.ScopeVarMap.t;
subscope_vars : subscope_vars :
(Dcalc.Ast.untyped Dcalc.Ast.var * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t (untyped Dcalc.Ast.var * typ * Ast.io) Ast.ScopeVarMap.t
Ast.SubScopeMap.t; Ast.SubScopeMap.t;
local_vars : Dcalc.Ast.untyped Dcalc.Ast.var Ast.VarMap.t; local_vars : untyped Dcalc.Ast.var Ast.VarMap.t;
} }
let empty_ctx let empty_ctx
(struct_ctx : Ast.struct_ctx) (struct_ctx : Ast.struct_ctx)
(enum_ctx : Ast.enum_ctx) (enum_ctx : Ast.enum_ctx)
(scopes_ctx : scope_sigs_ctx) (scopes_ctx : scope_sigs_ctx)
(scope_name : Ast.ScopeName.t) = (scope_name : ScopeName.t) =
{ {
structs = struct_ctx; structs = struct_ctx;
enums = enum_ctx; enums = enum_ctx;
@ -64,62 +64,62 @@ let empty_ctx
} }
let rec translate_typ (ctx : ctx) (t : Ast.typ Marked.pos) : let rec translate_typ (ctx : ctx) (t : Ast.typ Marked.pos) :
Dcalc.Ast.typ Marked.pos = typ Marked.pos =
Marked.same_mark_as Marked.same_mark_as
(match Marked.unmark t with (match Marked.unmark t with
| Ast.TLit l -> Dcalc.Ast.TLit l | Ast.TLit l -> TLit l
| Ast.TArrow (t1, t2) -> | Ast.TArrow (t1, t2) ->
Dcalc.Ast.TArrow (translate_typ ctx t1, translate_typ ctx t2) TArrow (translate_typ ctx t1, translate_typ ctx t2)
| Ast.TStruct s_uid -> | Ast.TStruct s_uid ->
let s_fields = Ast.StructMap.find s_uid ctx.structs in let s_fields = StructMap.find s_uid ctx.structs in
Dcalc.Ast.TTuple TTuple
(List.map (fun (_, t) -> translate_typ ctx t) s_fields, Some s_uid) (List.map (fun (_, t) -> translate_typ ctx t) s_fields, Some s_uid)
| Ast.TEnum e_uid -> | Ast.TEnum e_uid ->
let e_cases = Ast.EnumMap.find e_uid ctx.enums in let e_cases = EnumMap.find e_uid ctx.enums in
Dcalc.Ast.TEnum TEnum
(List.map (fun (_, t) -> translate_typ ctx t) e_cases, e_uid) (List.map (fun (_, t) -> translate_typ ctx t) e_cases, e_uid)
| Ast.TArray t1 -> | Ast.TArray t1 ->
Dcalc.Ast.TArray (translate_typ ctx (Marked.same_mark_as t1 t)) TArray (translate_typ ctx (Marked.same_mark_as t1 t))
| Ast.TAny -> Dcalc.Ast.TAny) | Ast.TAny -> TAny)
t t
let pos_mark (pos : Pos.t) : Dcalc.Ast.untyped Dcalc.Ast.mark = let pos_mark (pos : Pos.t) : untyped mark =
Dcalc.Ast.Untyped { pos } Untyped { pos }
let pos_mark_as e = pos_mark (Marked.get_mark e) let pos_mark_as e = pos_mark (Marked.get_mark e)
let merge_defaults let merge_defaults
(caller : Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box) (caller : untyped Dcalc.Ast.marked_expr Bindlib.box)
(callee : Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box) : (callee : untyped Dcalc.Ast.marked_expr Bindlib.box) :
Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box = untyped Dcalc.Ast.marked_expr Bindlib.box =
let caller = let caller =
let m = Marked.get_mark (Bindlib.unbox caller) in let m = Marked.get_mark (Bindlib.unbox caller) in
Dcalc.Ast.make_app caller Dcalc.Ast.make_app caller
[Bindlib.box (Dcalc.Ast.ELit Dcalc.Ast.LUnit, m)] [Bindlib.box (ELit LUnit, m)]
m m
in in
let body = let body =
Bindlib.box_apply2 Bindlib.box_apply2
(fun caller callee -> (fun caller callee ->
let m = Marked.get_mark callee in let m = Marked.get_mark callee in
( Dcalc.Ast.EDefault ( EDefault
([caller], (Dcalc.Ast.ELit (Dcalc.Ast.LBool true), m), callee), ([caller], (ELit (LBool true), m), callee),
m )) m ))
caller callee caller callee
in in
body body
let tag_with_log_entry let tag_with_log_entry
(e : Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box) (e : untyped Dcalc.Ast.marked_expr Bindlib.box)
(l : Dcalc.Ast.log_entry) (l : log_entry)
(markings : Utils.Uid.MarkedString.info list) : (markings : Utils.Uid.MarkedString.info list) :
Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box = untyped Dcalc.Ast.marked_expr Bindlib.box =
Bindlib.box_apply Bindlib.box_apply
(fun e -> (fun e ->
Marked.same_mark_as Marked.same_mark_as
(Dcalc.Ast.EApp (EApp
( Marked.same_mark_as ( Marked.same_mark_as
(Dcalc.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings)))) (EOp (Unop (Log (l, markings))))
e, e,
[e] )) [e] ))
e) e)
@ -165,15 +165,15 @@ let collapse_similar_outcomes (excepts : Ast.expr Marked.pos list) :
excepts excepts
let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box = untyped Dcalc.Ast.marked_expr Bindlib.box =
Bindlib.box_apply (fun (x : Dcalc.Ast.untyped Dcalc.Ast.expr) -> Bindlib.box_apply (fun (x : untyped Dcalc.Ast.expr) ->
Marked.mark (pos_mark_as e) x) Marked.mark (pos_mark_as e) x)
@@ @@
match Marked.unmark e with match Marked.unmark e with
| EVar v -> Bindlib.box_var (Ast.VarMap.find v ctx.local_vars) | EVar v -> Bindlib.box_var (Ast.VarMap.find v ctx.local_vars)
| ELit l -> Bindlib.box (Dcalc.Ast.ELit l) | ELit l -> Bindlib.box (ELit l)
| EStruct (struct_name, e_fields) -> | EStruct (struct_name, e_fields) ->
let struct_sig = Ast.StructMap.find struct_name ctx.structs in let struct_sig = StructMap.find struct_name ctx.structs in
let d_fields, remaining_e_fields = let d_fields, remaining_e_fields =
List.fold_right List.fold_right
(fun (field_name, _) (d_fields, e_fields) -> (fun (field_name, _) (d_fields, e_fields) ->
@ -185,58 +185,58 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
if Ast.StructFieldMap.cardinal remaining_e_fields > 0 then if Ast.StructFieldMap.cardinal remaining_e_fields > 0 then
Errors.raise_spanned_error (Marked.get_mark e) Errors.raise_spanned_error (Marked.get_mark e)
"The fields \"%a\" do not belong to the structure %a" "The fields \"%a\" do not belong to the structure %a"
Ast.StructName.format_t struct_name StructName.format_t struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (field_name, _) -> (fun fmt (field_name, _) ->
Format.fprintf fmt "%a" Ast.StructFieldName.format_t field_name)) Format.fprintf fmt "%a" StructFieldName.format_t field_name))
(Ast.StructFieldMap.bindings remaining_e_fields) (Ast.StructFieldMap.bindings remaining_e_fields)
else else
Bindlib.box_apply Bindlib.box_apply
(fun d_fields -> Dcalc.Ast.ETuple (d_fields, Some struct_name)) (fun d_fields -> ETuple (d_fields, Some struct_name))
(Bindlib.box_list d_fields) (Bindlib.box_list d_fields)
| EStructAccess (e1, field_name, struct_name) -> | EStructAccess (e1, field_name, struct_name) ->
let struct_sig = Ast.StructMap.find struct_name ctx.structs in let struct_sig = StructMap.find struct_name ctx.structs in
let _, field_index = let _, field_index =
try try
List.assoc field_name (List.mapi (fun i (x, y) -> x, (y, i)) struct_sig) List.assoc field_name (List.mapi (fun i (x, y) -> x, (y, i)) struct_sig)
with Not_found -> with Not_found ->
Errors.raise_spanned_error (Marked.get_mark e) Errors.raise_spanned_error (Marked.get_mark e)
"The field \"%a\" does not belong to the structure %a" "The field \"%a\" does not belong to the structure %a"
Ast.StructFieldName.format_t field_name Ast.StructName.format_t StructFieldName.format_t field_name StructName.format_t
struct_name struct_name
in in
let e1 = translate_expr ctx e1 in let e1 = translate_expr ctx e1 in
Bindlib.box_apply Bindlib.box_apply
(fun e1 -> (fun e1 ->
Dcalc.Ast.ETupleAccess ETupleAccess
( e1, ( e1,
field_index, field_index,
Some struct_name, Some struct_name,
List.map (fun (_, t) -> translate_typ ctx t) struct_sig )) List.map (fun (_, t) -> translate_typ ctx t) struct_sig ))
e1 e1
| EEnumInj (e1, constructor, enum_name) -> | EEnumInj (e1, constructor, enum_name) ->
let enum_sig = Ast.EnumMap.find enum_name ctx.enums in let enum_sig = EnumMap.find enum_name ctx.enums in
let _, constructor_index = let _, constructor_index =
try try
List.assoc constructor (List.mapi (fun i (x, y) -> x, (y, i)) enum_sig) List.assoc constructor (List.mapi (fun i (x, y) -> x, (y, i)) enum_sig)
with Not_found -> with Not_found ->
Errors.raise_spanned_error (Marked.get_mark e) Errors.raise_spanned_error (Marked.get_mark e)
"The constructor \"%a\" does not belong to the enum %a" "The constructor \"%a\" does not belong to the enum %a"
Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t EnumConstructor.format_t constructor EnumName.format_t
enum_name enum_name
in in
let e1 = translate_expr ctx e1 in let e1 = translate_expr ctx e1 in
Bindlib.box_apply Bindlib.box_apply
(fun e1 -> (fun e1 ->
Dcalc.Ast.EInj EInj
( e1, ( e1,
constructor_index, constructor_index,
enum_name, enum_name,
List.map (fun (_, t) -> translate_typ ctx t) enum_sig )) List.map (fun (_, t) -> translate_typ ctx t) enum_sig ))
e1 e1
| EMatch (e1, enum_name, cases) -> | EMatch (e1, enum_name, cases) ->
let enum_sig = Ast.EnumMap.find enum_name ctx.enums in let enum_sig = EnumMap.find enum_name ctx.enums in
let d_cases, remaining_e_cases = let d_cases, remaining_e_cases =
List.fold_right List.fold_right
(fun (constructor, _) (d_cases, e_cases) -> (fun (constructor, _) (d_cases, e_cases) ->
@ -246,7 +246,7 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
Errors.raise_spanned_error (Marked.get_mark e) Errors.raise_spanned_error (Marked.get_mark e)
"The constructor %a of enum %a is missing from this pattern \ "The constructor %a of enum %a is missing from this pattern \
matching" matching"
Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t EnumConstructor.format_t constructor EnumName.format_t
enum_name enum_name
in in
let case_d = translate_expr ctx case_e in let case_d = translate_expr ctx case_e in
@ -256,16 +256,16 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
if Ast.EnumConstructorMap.cardinal remaining_e_cases > 0 then if Ast.EnumConstructorMap.cardinal remaining_e_cases > 0 then
Errors.raise_spanned_error (Marked.get_mark e) Errors.raise_spanned_error (Marked.get_mark e)
"Patter matching is incomplete for enum %a: missing cases %a" "Patter matching is incomplete for enum %a: missing cases %a"
Ast.EnumName.format_t enum_name EnumName.format_t enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (case_name, _) -> (fun fmt (case_name, _) ->
Format.fprintf fmt "%a" Ast.EnumConstructor.format_t case_name)) Format.fprintf fmt "%a" EnumConstructor.format_t case_name))
(Ast.EnumConstructorMap.bindings remaining_e_cases) (Ast.EnumConstructorMap.bindings remaining_e_cases)
else else
let e1 = translate_expr ctx e1 in let e1 = translate_expr ctx e1 in
Bindlib.box_apply2 Bindlib.box_apply2
(fun d_fields e1 -> Dcalc.Ast.EMatch (e1, d_fields, enum_name)) (fun d_fields e1 -> EMatch (e1, d_fields, enum_name))
(Bindlib.box_list d_cases) e1 (Bindlib.box_list d_cases) e1
| EApp (e1, args) -> | EApp (e1, args) ->
(* We insert various log calls to record arguments and outputs of (* We insert various log calls to record arguments and outputs of
@ -274,14 +274,14 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
let markings l = let markings l =
match l with match l with
| Ast.ScopeVar (v, _) -> | Ast.ScopeVar (v, _) ->
[Ast.ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v] [ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v]
| Ast.SubScopeVar (s, _, (v, _)) -> | Ast.SubScopeVar (s, _, (v, _)) ->
[Ast.ScopeName.get_info s; Ast.ScopeVar.get_info v] [ScopeName.get_info s; Ast.ScopeVar.get_info v]
in in
let e1_func = let e1_func =
match Marked.unmark e1 with match Marked.unmark e1 with
| ELocation l -> | ELocation l ->
tag_with_log_entry e1_func Dcalc.Ast.BeginCall (markings l) tag_with_log_entry e1_func BeginCall (markings l)
| _ -> e1_func | _ -> e1_func
in in
let new_args = List.map (translate_expr ctx) args in let new_args = List.map (translate_expr ctx) args in
@ -293,9 +293,9 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
let retrieve_in_and_out_typ_or_any var vars = let retrieve_in_and_out_typ_or_any var vars =
let _, typ, _ = Ast.ScopeVarMap.find (Marked.unmark var) vars in let _, typ, _ = Ast.ScopeVarMap.find (Marked.unmark var) vars in
match typ with match typ with
| Dcalc.Ast.TArrow (marked_input_typ, marked_output_typ) -> | TArrow (marked_input_typ, marked_output_typ) ->
Marked.unmark marked_input_typ, Marked.unmark marked_output_typ Marked.unmark marked_input_typ, Marked.unmark marked_output_typ
| _ -> Dcalc.Ast.TAny, Dcalc.Ast.TAny | _ -> TAny, TAny
in in
match Marked.unmark e1 with match Marked.unmark e1 with
| ELocation (ScopeVar var) -> | ELocation (ScopeVar var) ->
@ -304,20 +304,20 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
ctx.subscope_vars ctx.subscope_vars
|> Ast.SubScopeMap.find (Marked.unmark sname) |> Ast.SubScopeMap.find (Marked.unmark sname)
|> retrieve_in_and_out_typ_or_any var |> retrieve_in_and_out_typ_or_any var
| _ -> Dcalc.Ast.TAny, Dcalc.Ast.TAny | _ -> TAny, TAny
in in
let new_args = let new_args =
match Marked.unmark e1, new_args with match Marked.unmark e1, new_args with
| ELocation l, [new_arg] -> | ELocation l, [new_arg] ->
[ [
tag_with_log_entry new_arg (Dcalc.Ast.VarDef input_typ) tag_with_log_entry new_arg (VarDef input_typ)
(markings l @ [Marked.same_mark_as "input" e]); (markings l @ [Marked.same_mark_as "input" e]);
] ]
| _ -> new_args | _ -> new_args
in in
let new_e = let new_e =
Bindlib.box_apply2 Bindlib.box_apply2
(fun e' u -> Dcalc.Ast.EApp (e', u), pos_mark_as e) (fun e' u -> EApp (e', u), pos_mark_as e)
e1_func e1_func
(Bindlib.box_list new_args) (Bindlib.box_list new_args)
in in
@ -325,9 +325,9 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
match Marked.unmark e1 with match Marked.unmark e1 with
| ELocation l -> | ELocation l ->
tag_with_log_entry tag_with_log_entry
(tag_with_log_entry new_e (Dcalc.Ast.VarDef output_typ) (tag_with_log_entry new_e (VarDef output_typ)
(markings l @ [Marked.same_mark_as "output" e])) (markings l @ [Marked.same_mark_as "output" e]))
Dcalc.Ast.EndCall (markings l) EndCall (markings l)
| _ -> new_e | _ -> new_e
in in
Bindlib.box_apply Marked.unmark new_e Bindlib.box_apply Marked.unmark new_e
@ -348,12 +348,12 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
in in
let binder = Bindlib.bind_mvar new_xs body in let binder = Bindlib.bind_mvar new_xs body in
Bindlib.box_apply Bindlib.box_apply
(fun b -> Dcalc.Ast.EAbs (b, List.map (translate_typ ctx) typ)) (fun b -> EAbs (b, List.map (translate_typ ctx) typ))
binder binder
| EDefault (excepts, just, cons) -> | EDefault (excepts, just, cons) ->
let excepts = collapse_similar_outcomes excepts in let excepts = collapse_similar_outcomes excepts in
Bindlib.box_apply3 Bindlib.box_apply3
(fun e j c -> Dcalc.Ast.EDefault (e, j, c)) (fun e j c -> EDefault (e, j, c))
(Bindlib.box_list (List.map (translate_expr ctx) excepts)) (Bindlib.box_list (List.map (translate_expr ctx) excepts))
(translate_expr ctx just) (translate_expr ctx cons) (translate_expr ctx just) (translate_expr ctx cons)
| ELocation (ScopeVar a) -> | ELocation (ScopeVar a) ->
@ -381,16 +381,16 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) :
(Marked.unmark a) Ast.SubScopeName.format_t (Marked.unmark s)) (Marked.unmark a) Ast.SubScopeName.format_t (Marked.unmark s))
| EIfThenElse (cond, et, ef) -> | EIfThenElse (cond, et, ef) ->
Bindlib.box_apply3 Bindlib.box_apply3
(fun c t f -> Dcalc.Ast.EIfThenElse (c, t, f)) (fun c t f -> EIfThenElse (c, t, f))
(translate_expr ctx cond) (translate_expr ctx et) (translate_expr ctx ef) (translate_expr ctx cond) (translate_expr ctx et) (translate_expr ctx ef)
| EOp op -> Bindlib.box (Dcalc.Ast.EOp op) | EOp op -> Bindlib.box (EOp op)
| ErrorOnEmpty e' -> | ErrorOnEmpty e' ->
Bindlib.box_apply Bindlib.box_apply
(fun e' -> Dcalc.Ast.ErrorOnEmpty e') (fun e' -> ErrorOnEmpty e')
(translate_expr ctx e') (translate_expr ctx e')
| EArray es -> | EArray es ->
Bindlib.box_apply Bindlib.box_apply
(fun es -> Dcalc.Ast.EArray es) (fun es -> EArray es)
(Bindlib.box_list (List.map (translate_expr ctx) es)) (Bindlib.box_list (List.map (translate_expr ctx) es))
(** The result of a rule translation is a list of assignment, with variables and (** The result of a rule translation is a list of assignment, with variables and
@ -402,13 +402,13 @@ let translate_rule
(ctx : ctx) (ctx : ctx)
(rule : Ast.rule) (rule : Ast.rule)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) : ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) :
(( Dcalc.Ast.untyped Dcalc.Ast.expr, (( untyped Dcalc.Ast.expr,
Dcalc.Ast.untyped ) untyped )
Dcalc.Ast.scope_body_expr scope_body_expr
Bindlib.box -> Bindlib.box ->
( Dcalc.Ast.untyped Dcalc.Ast.expr, ( untyped Dcalc.Ast.expr,
Dcalc.Ast.untyped ) untyped )
Dcalc.Ast.scope_body_expr scope_body_expr
Bindlib.box) Bindlib.box)
* ctx = * ctx =
match rule with match rule with
@ -421,7 +421,7 @@ let translate_rule
let merged_expr = let merged_expr =
Bindlib.box_apply Bindlib.box_apply
(fun merged_expr -> (fun merged_expr ->
Dcalc.Ast.ErrorOnEmpty merged_expr, pos_mark_as a_name) ErrorOnEmpty merged_expr, pos_mark_as a_name)
(match Marked.unmark a_io.io_input with (match Marked.unmark a_io.io_input with
| OnlyInput -> | OnlyInput ->
failwith "should not happen" failwith "should not happen"
@ -432,19 +432,19 @@ let translate_rule
in in
let merged_expr = let merged_expr =
tag_with_log_entry merged_expr tag_with_log_entry merged_expr
(Dcalc.Ast.VarDef (Marked.unmark tau)) (VarDef (Marked.unmark tau))
[sigma_name, pos_sigma; a_name] [sigma_name, pos_sigma; a_name]
in in
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun next merged_expr -> (fun next merged_expr ->
Dcalc.Ast.ScopeLet ScopeLet
{ {
Dcalc.Ast.scope_let_next = next; scope_let_next = next;
Dcalc.Ast.scope_let_typ = tau; scope_let_typ = tau;
Dcalc.Ast.scope_let_expr = merged_expr; scope_let_expr = merged_expr;
Dcalc.Ast.scope_let_kind = Dcalc.Ast.ScopeVarDefinition; scope_let_kind = ScopeVarDefinition;
Dcalc.Ast.scope_let_pos = Marked.get_mark a; scope_let_pos = Marked.get_mark a;
}) })
(Bindlib.bind_var a_var next) (Bindlib.bind_var a_var next)
merged_expr), merged_expr),
@ -472,7 +472,7 @@ let translate_rule
let tau = translate_typ ctx tau in let tau = translate_typ ctx tau in
let new_e = let new_e =
tag_with_log_entry (translate_expr ctx e) tag_with_log_entry (translate_expr ctx e)
(Dcalc.Ast.VarDef (Marked.unmark tau)) (VarDef (Marked.unmark tau))
[sigma_name, pos_sigma; a_name] [sigma_name, pos_sigma; a_name]
in in
let silent_var = Var.make "_" in let silent_var = Var.make "_" in
@ -481,31 +481,31 @@ let translate_rule
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
| OnlyInput -> | OnlyInput ->
Bindlib.box_apply Bindlib.box_apply
(fun new_e -> Dcalc.Ast.ErrorOnEmpty new_e, pos_mark_as subs_var) (fun new_e -> ErrorOnEmpty new_e, pos_mark_as subs_var)
new_e new_e
| Reentrant -> | Reentrant ->
Dcalc.Ast.make_abs Dcalc.Ast.make_abs
(Array.of_list [silent_var]) (Array.of_list [silent_var])
new_e new_e
[Dcalc.Ast.TLit TUnit, var_def_pos] [TLit TUnit, var_def_pos]
(pos_mark var_def_pos) (pos_mark var_def_pos)
in in
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun next thunked_or_nonempty_new_e -> (fun next thunked_or_nonempty_new_e ->
Dcalc.Ast.ScopeLet ScopeLet
{ {
Dcalc.Ast.scope_let_next = next; scope_let_next = next;
Dcalc.Ast.scope_let_pos = Marked.get_mark a_name; scope_let_pos = Marked.get_mark a_name;
Dcalc.Ast.scope_let_typ = scope_let_typ =
(match Marked.unmark a_io.io_input with (match Marked.unmark a_io.io_input with
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
| OnlyInput -> tau | OnlyInput -> tau
| Reentrant -> | Reentrant ->
( Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), ( TArrow ((TLit TUnit, var_def_pos), tau),
var_def_pos )); var_def_pos ));
Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e; scope_let_expr = thunked_or_nonempty_new_e;
Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
}) })
(Bindlib.bind_var a_var next) (Bindlib.bind_var a_var next)
thunked_or_nonempty_new_e), thunked_or_nonempty_new_e),
@ -573,7 +573,7 @@ let translate_rule
let subscope_struct_arg = let subscope_struct_arg =
Bindlib.box_apply Bindlib.box_apply
(fun subscope_args -> (fun subscope_args ->
( Dcalc.Ast.ETuple (subscope_args, Some called_scope_input_struct), ( ETuple (subscope_args, Some called_scope_input_struct),
pos_mark pos_call )) pos_mark pos_call ))
(Bindlib.box_list subscope_args) (Bindlib.box_list subscope_args)
in in
@ -593,28 +593,28 @@ let translate_rule
tag_with_log_entry tag_with_log_entry
(Dcalc.Ast.make_var (Dcalc.Ast.make_var
(scope_dcalc_var, pos_mark_as (Ast.SubScopeName.get_info subindex))) (scope_dcalc_var, pos_mark_as (Ast.SubScopeName.get_info subindex)))
Dcalc.Ast.BeginCall BeginCall
[ [
sigma_name, pos_sigma; sigma_name, pos_sigma;
Ast.SubScopeName.get_info subindex; Ast.SubScopeName.get_info subindex;
Ast.ScopeName.get_info subname; ScopeName.get_info subname;
] ]
in in
let call_expr = let call_expr =
tag_with_log_entry tag_with_log_entry
(Bindlib.box_apply2 (Bindlib.box_apply2
(fun e u -> Dcalc.Ast.EApp (e, [u]), pos_mark Pos.no_pos) (fun e u -> EApp (e, [u]), pos_mark Pos.no_pos)
subscope_func subscope_struct_arg) subscope_func subscope_struct_arg)
Dcalc.Ast.EndCall EndCall
[ [
sigma_name, pos_sigma; sigma_name, pos_sigma;
Ast.SubScopeName.get_info subindex; Ast.SubScopeName.get_info subindex;
Ast.ScopeName.get_info subname; ScopeName.get_info subname;
] ]
in in
let result_tuple_var = Var.make "result" in let result_tuple_var = Var.make "result" in
let result_tuple_typ = let result_tuple_typ =
( Dcalc.Ast.TTuple ( TTuple
( List.map ( List.map
(fun (subvar, _) -> subvar.scope_var_typ, pos_sigma) (fun (subvar, _) -> subvar.scope_var_typ, pos_sigma)
all_subscope_output_vars_dcalc, all_subscope_output_vars_dcalc,
@ -624,13 +624,13 @@ let translate_rule
let call_scope_let next = let call_scope_let next =
Bindlib.box_apply2 Bindlib.box_apply2
(fun next call_expr -> (fun next call_expr ->
Dcalc.Ast.ScopeLet ScopeLet
{ {
Dcalc.Ast.scope_let_next = next; scope_let_next = next;
Dcalc.Ast.scope_let_pos = pos_sigma; scope_let_pos = pos_sigma;
Dcalc.Ast.scope_let_kind = Dcalc.Ast.CallingSubScope; scope_let_kind = CallingSubScope;
Dcalc.Ast.scope_let_typ = result_tuple_typ; scope_let_typ = result_tuple_typ;
Dcalc.Ast.scope_let_expr = call_expr; scope_let_expr = call_expr;
}) })
(Bindlib.bind_var result_tuple_var next) (Bindlib.bind_var result_tuple_var next)
call_expr call_expr
@ -640,15 +640,15 @@ let translate_rule
(fun (var_ctx, v) (next, i) -> (fun (var_ctx, v) (next, i) ->
( Bindlib.box_apply2 ( Bindlib.box_apply2
(fun next r -> (fun next r ->
Dcalc.Ast.ScopeLet ScopeLet
{ {
Dcalc.Ast.scope_let_next = next; scope_let_next = next;
Dcalc.Ast.scope_let_pos = pos_sigma; scope_let_pos = pos_sigma;
Dcalc.Ast.scope_let_typ = var_ctx.scope_var_typ, pos_sigma; scope_let_typ = var_ctx.scope_var_typ, pos_sigma;
Dcalc.Ast.scope_let_kind = scope_let_kind =
Dcalc.Ast.DestructuringSubScopeResults; DestructuringSubScopeResults;
Dcalc.Ast.scope_let_expr = scope_let_expr =
( Dcalc.Ast.ETupleAccess ( ETupleAccess
( r, ( r,
i, i,
Some called_scope_return_struct, Some called_scope_return_struct,
@ -682,20 +682,20 @@ let translate_rule
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun next new_e -> (fun next new_e ->
Dcalc.Ast.ScopeLet ScopeLet
{ {
Dcalc.Ast.scope_let_next = next; scope_let_next = next;
Dcalc.Ast.scope_let_pos = Marked.get_mark e; scope_let_pos = Marked.get_mark e;
Dcalc.Ast.scope_let_typ = scope_let_typ =
Dcalc.Ast.TLit TUnit, Marked.get_mark e; TLit TUnit, Marked.get_mark e;
Dcalc.Ast.scope_let_expr = scope_let_expr =
(* To ensure that we throw an error if the value is not (* To ensure that we throw an error if the value is not
defined, we add an check "ErrorOnEmpty" here. *) defined, we add an check "ErrorOnEmpty" here. *)
Marked.same_mark_as Marked.same_mark_as
(Dcalc.Ast.EAssert (EAssert
(Dcalc.Ast.ErrorOnEmpty new_e, pos_mark_as e)) (ErrorOnEmpty new_e, pos_mark_as e))
new_e; new_e;
Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion; scope_let_kind = Assertion;
}) })
(Bindlib.bind_var (Var.make "_") next) (Bindlib.bind_var (Var.make "_") next)
new_e), new_e),
@ -705,10 +705,10 @@ let translate_rules
(ctx : ctx) (ctx : ctx)
(rules : Ast.rule list) (rules : Ast.rule list)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info)
(sigma_return_struct_name : Ast.StructName.t) : (sigma_return_struct_name : StructName.t) :
( Dcalc.Ast.untyped Dcalc.Ast.expr, ( untyped Dcalc.Ast.expr,
Dcalc.Ast.untyped ) untyped )
Dcalc.Ast.scope_body_expr scope_body_expr
Bindlib.box Bindlib.box
* ctx = * ctx =
let scope_lets, new_ctx = let scope_lets, new_ctx =
@ -730,7 +730,7 @@ let translate_rules
let return_exp = let return_exp =
Bindlib.box_apply Bindlib.box_apply
(fun args -> (fun args ->
( Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), ( ETuple (args, Some sigma_return_struct_name),
pos_mark pos_sigma )) pos_mark pos_sigma ))
(Bindlib.box_list (Bindlib.box_list
(List.map (List.map
@ -740,7 +740,7 @@ let translate_rules
in in
( scope_lets ( scope_lets
(Bindlib.box_apply (Bindlib.box_apply
(fun return_exp -> Dcalc.Ast.Result return_exp) (fun return_exp -> Result return_exp)
return_exp), return_exp),
new_ctx ) new_ctx )
@ -748,12 +748,12 @@ let translate_scope_decl
(struct_ctx : Ast.struct_ctx) (struct_ctx : Ast.struct_ctx)
(enum_ctx : Ast.enum_ctx) (enum_ctx : Ast.enum_ctx)
(sctx : scope_sigs_ctx) (sctx : scope_sigs_ctx)
(scope_name : Ast.ScopeName.t) (scope_name : ScopeName.t)
(sigma : Ast.scope_decl) : (sigma : Ast.scope_decl) :
(Dcalc.Ast.untyped Dcalc.Ast.expr, Dcalc.Ast.untyped) Dcalc.Ast.scope_body (untyped Dcalc.Ast.expr, untyped) scope_body
Bindlib.box Bindlib.box
* struct_ctx = * struct_ctx =
let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in let sigma_info = ScopeName.get_info sigma.scope_decl_name in
let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in
let scope_variables = scope_sig.scope_sig_local_vars in let scope_variables = scope_sig.scope_sig_local_vars in
let ctx = let ctx =
@ -813,8 +813,8 @@ let translate_scope_decl
match Marked.unmark var_ctx.scope_var_io.io_input with match Marked.unmark var_ctx.scope_var_io.io_input with
| OnlyInput -> var_ctx.scope_var_typ, pos_sigma | OnlyInput -> var_ctx.scope_var_typ, pos_sigma
| Reentrant -> | Reentrant ->
( Dcalc.Ast.TArrow ( TArrow
((Dcalc.Ast.TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma ) pos_sigma )
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
in in
@ -824,15 +824,15 @@ let translate_scope_decl
(fun (var_ctx, v) (next, i) -> (fun (var_ctx, v) (next, i) ->
( Bindlib.box_apply2 ( Bindlib.box_apply2
(fun next r -> (fun next r ->
Dcalc.Ast.ScopeLet ScopeLet
{ {
Dcalc.Ast.scope_let_kind = scope_let_kind =
Dcalc.Ast.DestructuringInputStruct; DestructuringInputStruct;
Dcalc.Ast.scope_let_next = next; scope_let_next = next;
Dcalc.Ast.scope_let_pos = pos_sigma; scope_let_pos = pos_sigma;
Dcalc.Ast.scope_let_typ = input_var_typ var_ctx; scope_let_typ = input_var_typ var_ctx;
Dcalc.Ast.scope_let_expr = scope_let_expr =
( Dcalc.Ast.ETupleAccess ( ETupleAccess
( r, ( r,
i, i,
Some scope_input_struct_name, Some scope_input_struct_name,
@ -851,7 +851,7 @@ let translate_scope_decl
List.map List.map
(fun (var_ctx, dvar) -> (fun (var_ctx, dvar) ->
let struct_field_name = let struct_field_name =
Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_out", pos_sigma) StructFieldName.fresh (Bindlib.name_of dvar ^ "_out", pos_sigma)
in in
struct_field_name, (var_ctx.scope_var_typ, pos_sigma)) struct_field_name, (var_ctx.scope_var_typ, pos_sigma))
scope_output_variables scope_output_variables
@ -860,29 +860,29 @@ let translate_scope_decl
List.map List.map
(fun (var_ctx, dvar) -> (fun (var_ctx, dvar) ->
let struct_field_name = let struct_field_name =
Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma) StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma)
in in
struct_field_name, input_var_typ var_ctx) struct_field_name, input_var_typ var_ctx)
scope_input_variables scope_input_variables
in in
let new_struct_ctx = let new_struct_ctx =
Ast.StructMap.add scope_input_struct_name scope_input_struct_fields StructMap.add scope_input_struct_name scope_input_struct_fields
(Ast.StructMap.singleton scope_return_struct_name (StructMap.singleton scope_return_struct_name
scope_return_struct_fields) scope_return_struct_fields)
in in
( Bindlib.box_apply ( Bindlib.box_apply
(fun scope_body_expr -> (fun scope_body_expr ->
{ {
Dcalc.Ast.scope_body_expr; scope_body_expr;
Dcalc.Ast.scope_body_input_struct = scope_input_struct_name; scope_body_input_struct = scope_input_struct_name;
Dcalc.Ast.scope_body_output_struct = scope_return_struct_name; scope_body_output_struct = scope_return_struct_name;
}) })
(Bindlib.bind_var scope_input_var (Bindlib.bind_var scope_input_var
(input_destructurings rules_with_return_expr)), (input_destructurings rules_with_return_expr)),
new_struct_ctx ) new_struct_ctx )
let translate_program (prgm : Ast.program) : let translate_program (prgm : Ast.program) :
Dcalc.Ast.untyped Dcalc.Ast.program * Dependency.TVertex.t list = untyped Dcalc.Ast.program * Dependency.TVertex.t list =
let scope_dependencies = Dependency.build_program_dep_graph prgm in let scope_dependencies = Dependency.build_program_dep_graph prgm in
Dependency.check_for_cycle_in_scope scope_dependencies; Dependency.check_for_cycle_in_scope scope_dependencies;
let types_ordering = let types_ordering =
@ -894,16 +894,16 @@ let translate_program (prgm : Ast.program) :
let ctx_for_typ_translation scope_name = let ctx_for_typ_translation scope_name =
empty_ctx struct_ctx enum_ctx Ast.ScopeMap.empty scope_name empty_ctx struct_ctx enum_ctx Ast.ScopeMap.empty scope_name
in in
let dummy_scope = Ast.ScopeName.fresh ("dummy", Pos.no_pos) in let dummy_scope = ScopeName.fresh ("dummy", Pos.no_pos) in
let decl_ctx = let decl_ctx =
{ {
Dcalc.Ast.ctx_structs = ctx_structs =
Ast.StructMap.map StructMap.map
(List.map (fun (x, y) -> (List.map (fun (x, y) ->
x, translate_typ (ctx_for_typ_translation dummy_scope) y)) x, translate_typ (ctx_for_typ_translation dummy_scope) y))
struct_ctx; struct_ctx;
Dcalc.Ast.ctx_enums = ctx_enums =
Ast.EnumMap.map EnumMap.map
(List.map (fun (x, y) -> (List.map (fun (x, y) ->
x, (translate_typ (ctx_for_typ_translation dummy_scope)) y)) x, (translate_typ (ctx_for_typ_translation dummy_scope)) y))
enum_ctx; enum_ctx;
@ -914,22 +914,22 @@ let translate_program (prgm : Ast.program) :
(fun scope_name scope -> (fun scope_name scope ->
let scope_dvar = let scope_dvar =
Var.make Var.make
(Marked.unmark (Ast.ScopeName.get_info scope.Ast.scope_decl_name)) (Marked.unmark (ScopeName.get_info scope.Ast.scope_decl_name))
in in
let scope_return_struct_name = let scope_return_struct_name =
Ast.StructName.fresh StructName.fresh
(Marked.map_under_mark (Marked.map_under_mark
(fun s -> s ^ "_out") (fun s -> s ^ "_out")
(Ast.ScopeName.get_info scope_name)) (ScopeName.get_info scope_name))
in in
let scope_input_var = let scope_input_var =
Var.make (Marked.unmark (Ast.ScopeName.get_info scope_name) ^ "_in") Var.make (Marked.unmark (ScopeName.get_info scope_name) ^ "_in")
in in
let scope_input_struct_name = let scope_input_struct_name =
Ast.StructName.fresh StructName.fresh
(Marked.map_under_mark (Marked.map_under_mark
(fun s -> s ^ "_in") (fun s -> s ^ "_in")
(Ast.ScopeName.get_info scope_name)) (ScopeName.get_info scope_name))
in in
{ {
scope_sig_local_vars = scope_sig_local_vars =
@ -954,7 +954,7 @@ let translate_program (prgm : Ast.program) :
(* the resulting expression is the list of definitions of all the scopes, (* the resulting expression is the list of definitions of all the scopes,
ending with the top-level scope. *) ending with the top-level scope. *)
let (scopes, decl_ctx) let (scopes, decl_ctx)
: (Dcalc.Ast.untyped Dcalc.Ast.expr, Dcalc.Ast.untyped) Dcalc.Ast.scopes : (untyped Dcalc.Ast.expr, untyped) scopes
Bindlib.box Bindlib.box
* _ = * _ =
List.fold_right List.fold_right
@ -967,21 +967,21 @@ let translate_program (prgm : Ast.program) :
let decl_ctx = let decl_ctx =
{ {
decl_ctx with decl_ctx with
Dcalc.Ast.ctx_structs = ctx_structs =
Ast.StructMap.union StructMap.union
(fun _ _ -> assert false (* should not happen *)) (fun _ _ -> assert false (* should not happen *))
decl_ctx.Dcalc.Ast.ctx_structs scope_out_struct; decl_ctx.ctx_structs scope_out_struct;
} }
in in
let scope_next = Bindlib.bind_var dvar scopes in let scope_next = Bindlib.bind_var dvar scopes in
let new_scopes = let new_scopes =
Bindlib.box_apply2 Bindlib.box_apply2
(fun scope_body scope_next -> (fun scope_body scope_next ->
Dcalc.Ast.ScopeDef { scope_name; scope_body; scope_next }) ScopeDef { scope_name; scope_body; scope_next })
scope_body scope_next scope_body scope_next
in in
new_scopes, decl_ctx) new_scopes, decl_ctx)
scope_ordering scope_ordering
(Bindlib.box Dcalc.Ast.Nil, decl_ctx) (Bindlib.box Nil, decl_ctx)
in in
{ scopes = Bindlib.unbox scopes; decl_ctx }, types_ordering { scopes = Bindlib.unbox scopes; decl_ctx }, types_ordering

View File

@ -17,7 +17,7 @@
(** Scope language to default calculus translator *) (** Scope language to default calculus translator *)
val translate_program : val translate_program :
Ast.program -> Dcalc.Ast.untyped Dcalc.Ast.program * Dependency.TVertex.t list Ast.program -> Shared_ast.untyped Dcalc.Ast.program * Dependency.TVertex.t list
(** Usage [translate_program p] returns a tuple [(new_program, types_list)] (** Usage [translate_program p] returns a tuple [(new_program, types_list)]
where [new_program] is the map of translated scopes. Finally, [types_list] where [new_program] is the map of translated scopes. Finally, [types_list]
is a list of all types (structs and enums) used in the program, correctly is a list of all types (structs and enums) used in the program, correctly

View File

@ -68,7 +68,67 @@ let eraise e1 pos = Bindlib.box (ERaise e1, pos)
let ecatch e1 exn e2 pos = let ecatch e1 exn e2 pos =
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2 Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2
let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v) (* - Manipulation of marks - *)
let no_mark (type m) : m mark -> m mark = function
| Untyped _ -> Untyped { pos = Pos.no_pos }
| Typed _ -> Typed { pos = Pos.no_pos; ty = Marked.mark Pos.no_pos TAny }
let mark_pos (type m) (m : m mark) : Pos.t =
match m with Untyped { pos } | Typed { pos; _ } -> pos
let pos (type m) (x : ('a, m) marked) : Pos.t = mark_pos (Marked.get_mark x)
let ty (_, m) : marked_typ = match m with Typed { ty; _ } -> ty
let with_ty (type m) (ty : marked_typ) (x : ('a, m) marked) : ('a, typed) marked
=
Marked.mark
(match Marked.get_mark x with
| Untyped { pos } -> Typed { pos; ty }
| Typed m -> Typed { m with ty })
(Marked.unmark x)
let map_mark
(type m)
(pos_f : Pos.t -> Pos.t)
(ty_f : marked_typ -> marked_typ)
(m : m mark) : m mark =
match m with
| Untyped { pos } -> Untyped { pos = pos_f pos }
| Typed { pos; ty } -> Typed { pos = pos_f pos; ty = ty_f ty }
let map_mark2
(type m)
(pos_f : Pos.t -> Pos.t -> Pos.t)
(ty_f : typed -> typed -> marked_typ)
(m1 : m mark)
(m2 : m mark) : m mark =
match m1, m2 with
| Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos }
| Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 }
let fold_marks
(type m)
(pos_f : Pos.t list -> Pos.t)
(ty_f : typed list -> marked_typ)
(ms : m mark list) : m mark =
match ms with
| [] -> invalid_arg "Dcalc.Ast.fold_mark"
| Untyped _ :: _ as ms ->
Untyped { pos = pos_f (List.map (function Untyped { pos } -> pos) ms) }
| Typed _ :: _ ->
Typed
{
pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms);
ty = ty_f (List.map (function Typed m -> m) ms);
}
let get_scope_body_mark scope_body =
match snd (Bindlib.unbind scope_body.scope_body_expr) with
| Result e | ScopeLet { scope_let_expr = e; _ } -> Marked.get_mark e
(* - Traversal functions - *)
let map let map
(type a) (type a)
@ -81,10 +141,10 @@ let map
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m | EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m
| EOp op -> Bindlib.box (EOp op, m) | EOp op -> Bindlib.box (EOp op, m)
| EArray args -> earray (List.map (f ctx) args) m | EArray args -> earray (List.map (f ctx) args) m
| EVar v -> evar (translate_var v) m | EVar v -> evar (Var.translate v) m
| EAbs (binder, typs) -> | EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m eabs (Bindlib.bind_mvar (Array.map Var.translate vars) (f ctx body)) typs m
| EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m | ETuple (args, s) -> etuple (List.map (f ctx) args) s m
@ -179,3 +239,22 @@ let map_exprs_in_scopes ~f ~varf scopes =
}) })
new_scope_body_expr new_next) new_scope_body_expr new_next)
~init:(Bindlib.box Nil) scopes ~init:(Bindlib.box Nil) scopes
(* - *)
(** See [Bindlib.box_term] documentation for why we are doing that. *)
let box e=
let rec id_t () e = map () ~f:id_t e in
id_t () e
let untype e = map_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e
let untype_program prg =
{
prg with
scopes =
Bindlib.unbox
(map_exprs_in_scopes
~f:(fun e -> untype e)
~varf:Var.translate prg.scopes);
}

View File

@ -105,13 +105,70 @@ val eerroronempty :
't -> 't ->
('a, 't) marked_gexpr Bindlib.box ('a, 't) marked_gexpr Bindlib.box
(** ---------- *) val ecatch :
(lcalc, 't) marked_gexpr Bindlib.box ->
except ->
(lcalc, 't) marked_gexpr Bindlib.box ->
't ->
(lcalc, 't) marked_gexpr Bindlib.box
val eraise : except -> 't -> (lcalc, 't) marked_gexpr Bindlib.box
(** Manipulation of marks *)
val no_mark : 'm mark -> 'm mark
val mark_pos : 'm mark -> Pos.t
val pos : ('a, 'm) marked -> Pos.t
val ty : ('a, typed) marked -> marked_typ
val with_ty : marked_typ -> ('a, 'm) marked -> ('a, typed) marked
val map_mark :
(Pos.t -> Pos.t) -> (marked_typ -> marked_typ) -> 'm mark -> 'm mark
val map_mark2 :
(Pos.t -> Pos.t -> Pos.t) ->
(typed -> typed -> marked_typ) ->
'm mark ->
'm mark ->
'm mark
val fold_marks :
(Pos.t list -> Pos.t) -> (typed list -> marked_typ) -> 'm mark list -> 'm mark
val get_scope_body_mark : ('expr, 'm) scope_body -> 'm mark
val untype : ('a, 'm mark) marked_gexpr -> ('a, untyped mark) marked_gexpr Bindlib.box
val untype_program : (('a, 'm mark) gexpr Var.expr, 'm) program_generic -> (('a, untyped mark) gexpr Var.expr, untyped) program_generic
(** {2 Handling of boxing} *)
val box : ('a, 't) marked_gexpr -> ('a, 't) marked_gexpr Bindlib.box
(** {2 Traversal functions} *)
val map: val map:
'ctx -> 'ctx ->
f:('ctx -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box) -> f:('ctx -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box) ->
(('a, 't1) gexpr, 't2) Marked.t -> (('a, 't1) gexpr, 't2) Marked.t ->
('a, 't2) marked_gexpr Bindlib.box ('a, 't2) marked_gexpr Bindlib.box
(** Flat (non-recursive) mapping on expressions.
If you want to apply a map transform to an expression, you can save up
writing a painful match over all the cases of the AST. For instance, if you
want to remove all errors on empty, you can write
{[
let remove_error_empty =
let rec f () e =
match Marked.unmark e with
| ErrorOnEmpty e1 -> Expr.map () f e1
| _ -> Expr.map () f e
in
f () e
]}
The first argument of map_expr is an optional context that you can carry
around during your map traversal. *)
val map_top_down : val map_top_down :
f:(('a, 't1) marked_gexpr -> (('a, 't1) gexpr, 't2) Marked.t) -> f:(('a, 't1) marked_gexpr -> (('a, 't1) gexpr, 't2) Marked.t) ->

View File

@ -16,6 +16,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
module Runtime = Runtime_ocaml.Runtime module Runtime = Runtime_ocaml.Runtime
(** Translation from {!module: Surface.Ast} to {!module: Desugaring.Ast}. (** Translation from {!module: Surface.Ast} to {!module: Desugaring.Ast}.
@ -25,7 +26,7 @@ module Runtime = Runtime_ocaml.Runtime
(** {1 Translating expressions} *) (** {1 Translating expressions} *)
let translate_op_kind (k : Ast.op_kind) : Dcalc.Ast.op_kind = let translate_op_kind (k : Ast.op_kind) : op_kind =
match k with match k with
| KInt -> KInt | KInt -> KInt
| KDec -> KRat | KDec -> KRat
@ -33,7 +34,7 @@ let translate_op_kind (k : Ast.op_kind) : Dcalc.Ast.op_kind =
| KDate -> KDate | KDate -> KDate
| KDuration -> KDuration | KDuration -> KDuration
let translate_binop (op : Ast.binop) : Dcalc.Ast.binop = let translate_binop (op : Ast.binop) : binop =
match op with match op with
| And -> And | And -> And
| Or -> Or | Or -> Or
@ -50,7 +51,7 @@ let translate_binop (op : Ast.binop) : Dcalc.Ast.binop =
| Neq -> Neq | Neq -> Neq
| Concat -> Concat | Concat -> Concat
let translate_unop (op : Ast.unop) : Dcalc.Ast.unop = let translate_unop (op : Ast.unop) : unop =
match op with Not -> Not | Minus l -> Minus (translate_op_kind l) match op with Not -> Not | Minus l -> Minus (translate_op_kind l)
(** The two modules below help performing operations on map with the {!type: (** The two modules below help performing operations on map with the {!type:
@ -65,7 +66,7 @@ module LiftEnumConstructorMap = Bindlib.Lift (Scopelang.Ast.EnumConstructorMap)
let disambiguate_constructor let disambiguate_constructor
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(constructor : (string Marked.pos option * string Marked.pos) list) (constructor : (string Marked.pos option * string Marked.pos) list)
(pos : Pos.t) : Scopelang.Ast.EnumName.t * Scopelang.Ast.EnumConstructor.t = (pos : Pos.t) : EnumName.t * EnumConstructor.t =
let enum, constructor = let enum, constructor =
match constructor with match constructor with
| [c] -> c | [c] -> c
@ -86,7 +87,7 @@ let disambiguate_constructor
in in
match enum with match enum with
| None -> | None ->
if Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 then if EnumMap.cardinal possible_c_uids > 1 then
Errors.raise_spanned_error Errors.raise_spanned_error
(Marked.get_mark constructor) (Marked.get_mark constructor)
"This constructor name is ambiguous, it can belong to %a. Disambiguate \ "This constructor name is ambiguous, it can belong to %a. Disambiguate \
@ -94,9 +95,9 @@ let disambiguate_constructor
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ")
(fun fmt (s_name, _) -> (fun fmt (s_name, _) ->
Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t s_name)) Format.fprintf fmt "%a" EnumName.format_t s_name))
(Scopelang.Ast.EnumMap.bindings possible_c_uids); (EnumMap.bindings possible_c_uids);
Scopelang.Ast.EnumMap.choose possible_c_uids EnumMap.choose possible_c_uids
| Some enum -> ( | Some enum -> (
try try
(* The path is fully qualified *) (* The path is fully qualified *)
@ -104,7 +105,7 @@ let disambiguate_constructor
Desugared.Ast.IdentMap.find (Marked.unmark enum) ctxt.enum_idmap Desugared.Ast.IdentMap.find (Marked.unmark enum) ctxt.enum_idmap
in in
try try
let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in let c_uid = EnumMap.find e_uid possible_c_uids in
e_uid, c_uid e_uid, c_uid
with Not_found -> with Not_found ->
Errors.raise_spanned_error pos "Enum %s does not contain case %s" Errors.raise_spanned_error pos "Enum %s does not contain case %s"
@ -119,7 +120,7 @@ let disambiguate_constructor
Translates [expr] into its desugared equivalent. [scope] is used to Translates [expr] into its desugared equivalent. [scope] is used to
disambiguate the scope and subscopes variables than occur in the expression *) disambiguate the scope and subscopes variables than occur in the expression *)
let rec translate_expr let rec translate_expr
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option) (inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
((expr, pos) : Ast.expression Marked.pos) : ((expr, pos) : Ast.expression Marked.pos) :
@ -140,11 +141,11 @@ let rec translate_expr
let cases = let cases =
Scopelang.Ast.EnumConstructorMap.mapi Scopelang.Ast.EnumConstructorMap.mapi
(fun c_uid' tau -> (fun c_uid' tau ->
if Scopelang.Ast.EnumConstructor.compare c_uid c_uid' <> 0 then if EnumConstructor.compare c_uid c_uid' <> 0 then
let nop_var = Desugared.Ast.Var.make "_" in let nop_var = Desugared.Ast.Var.make "_" in
Bindlib.unbox Bindlib.unbox
(Desugared.Ast.make_abs [| nop_var |] (Desugared.Ast.make_abs [| nop_var |]
(Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos)) (Bindlib.box (Desugared.Ast.ELit (LBool false), pos))
[tau] pos) [tau] pos)
else else
let ctxt, binding_var = let ctxt, binding_var =
@ -153,7 +154,7 @@ let rec translate_expr
let e2 = translate_expr scope inside_definition_of ctxt e2 in let e2 = translate_expr scope inside_definition_of ctxt e2 in
Bindlib.unbox Bindlib.unbox
(Desugared.Ast.make_abs [| binding_var |] e2 [tau] pos)) (Desugared.Ast.make_abs [| binding_var |] e2 [tau] pos))
(Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) (EnumMap.find enum_uid ctxt.enums)
in in
Bindlib.box_apply Bindlib.box_apply
(fun e1_sub -> Desugared.Ast.EMatch (e1_sub, enum_uid, cases), pos) (fun e1_sub -> Desugared.Ast.EMatch (e1_sub, enum_uid, cases), pos)
@ -167,7 +168,7 @@ let rec translate_expr
let op_term = let op_term =
Marked.same_mark_as Marked.same_mark_as
(Desugared.Ast.EOp (Desugared.Ast.EOp
(Dcalc.Ast.Binop (translate_binop (Marked.unmark op)))) (Binop (translate_binop (Marked.unmark op))))
op op
in in
Bindlib.box_apply2 Bindlib.box_apply2
@ -176,7 +177,7 @@ let rec translate_expr
| Unop (op, e) -> | Unop (op, e) ->
let op_term = let op_term =
Marked.same_mark_as Marked.same_mark_as
(Desugared.Ast.EOp (Dcalc.Ast.Unop (translate_unop (Marked.unmark op)))) (Desugared.Ast.EOp (Unop (translate_unop (Marked.unmark op))))
op op
in in
Bindlib.box_apply Bindlib.box_apply
@ -186,38 +187,38 @@ let rec translate_expr
let untyped_term = let untyped_term =
match l with match l with
| LNumber ((Int i, _), None) -> | LNumber ((Int i, _), None) ->
Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_string i)) Desugared.Ast.ELit (LInt (Runtime.integer_of_string i))
| LNumber ((Int i, _), Some (Percent, _)) -> | LNumber ((Int i, _), Some (Percent, _)) ->
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LRat (LRat
Runtime.(decimal_of_string i /& decimal_of_string "100")) Runtime.(decimal_of_string i /& decimal_of_string "100"))
| LNumber ((Dec (i, f), _), None) -> | LNumber ((Dec (i, f), _), None) ->
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LRat Runtime.(decimal_of_string (i ^ "." ^ f))) (LRat Runtime.(decimal_of_string (i ^ "." ^ f)))
| LNumber ((Dec (i, f), _), Some (Percent, _)) -> | LNumber ((Dec (i, f), _), Some (Percent, _)) ->
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LRat (LRat
Runtime.( Runtime.(
decimal_of_string (i ^ "." ^ f) /& decimal_of_string "100")) decimal_of_string (i ^ "." ^ f) /& decimal_of_string "100"))
| LBool b -> Desugared.Ast.ELit (Dcalc.Ast.LBool b) | LBool b -> Desugared.Ast.ELit (LBool b)
| LMoneyAmount i -> | LMoneyAmount i ->
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LMoney (LMoney
Runtime.( Runtime.(
money_of_cents_integer money_of_cents_integer
((integer_of_string i.money_amount_units *! integer_of_int 100) ((integer_of_string i.money_amount_units *! integer_of_int 100)
+! integer_of_string i.money_amount_cents))) +! integer_of_string i.money_amount_cents)))
| LNumber ((Int i, _), Some (Year, _)) -> | LNumber ((Int i, _), Some (Year, _)) ->
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LDuration (LDuration
(Runtime.duration_of_numbers (int_of_string i) 0 0)) (Runtime.duration_of_numbers (int_of_string i) 0 0))
| LNumber ((Int i, _), Some (Month, _)) -> | LNumber ((Int i, _), Some (Month, _)) ->
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LDuration (LDuration
(Runtime.duration_of_numbers 0 (int_of_string i) 0)) (Runtime.duration_of_numbers 0 (int_of_string i) 0))
| LNumber ((Int i, _), Some (Day, _)) -> | LNumber ((Int i, _), Some (Day, _)) ->
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LDuration (LDuration
(Runtime.duration_of_numbers 0 0 (int_of_string i))) (Runtime.duration_of_numbers 0 0 (int_of_string i)))
| LNumber ((Dec (_, _), _), Some ((Year | Month | Day), _)) -> | LNumber ((Dec (_, _), _), Some ((Year | Month | Day), _)) ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
@ -230,7 +231,7 @@ let rec translate_expr
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"There is an error in this date: the day number is bigger than 31"; "There is an error in this date: the day number is bigger than 31";
Desugared.Ast.ELit Desugared.Ast.ELit
(Dcalc.Ast.LDate (LDate
(try (try
Runtime.date_of_numbers date.literal_date_year Runtime.date_of_numbers date.literal_date_year
date.literal_date_month date.literal_date_day date.literal_date_month date.literal_date_day
@ -307,7 +308,7 @@ let rec translate_expr
let subscope_uid : Scopelang.Ast.SubScopeName.t = let subscope_uid : Scopelang.Ast.SubScopeName.t =
Name_resolution.get_subscope_uid scope ctxt (Marked.same_mark_as y e) Name_resolution.get_subscope_uid scope ctxt (Marked.same_mark_as y e)
in in
let subscope_real_uid : Scopelang.Ast.ScopeName.t = let subscope_real_uid : ScopeName.t =
Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes
in in
let subscope_var_uid = let subscope_var_uid =
@ -330,19 +331,19 @@ let rec translate_expr
match c with match c with
| None -> | None ->
(* No constructor name was specified *) (* No constructor name was specified *)
if Scopelang.Ast.StructMap.cardinal x_possible_structs > 1 then if StructMap.cardinal x_possible_structs > 1 then
Errors.raise_spanned_error (Marked.get_mark x) Errors.raise_spanned_error (Marked.get_mark x)
"This struct field name is ambiguous, it can belong to %a. \ "This struct field name is ambiguous, it can belong to %a. \
Disambiguate it by prefixing it with the struct name." Disambiguate it by prefixing it with the struct name."
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ")
(fun fmt (s_name, _) -> (fun fmt (s_name, _) ->
Format.fprintf fmt "%a" Scopelang.Ast.StructName.format_t Format.fprintf fmt "%a" StructName.format_t
s_name)) s_name))
(Scopelang.Ast.StructMap.bindings x_possible_structs) (StructMap.bindings x_possible_structs)
else else
let s_uid, f_uid = let s_uid, f_uid =
Scopelang.Ast.StructMap.choose x_possible_structs StructMap.choose x_possible_structs
in in
Bindlib.box_apply Bindlib.box_apply
(fun e -> Desugared.Ast.EStructAccess (e, f_uid, s_uid), pos) (fun e -> Desugared.Ast.EStructAccess (e, f_uid, s_uid), pos)
@ -353,7 +354,7 @@ let rec translate_expr
Desugared.Ast.IdentMap.find (Marked.unmark c_name) ctxt.struct_idmap Desugared.Ast.IdentMap.find (Marked.unmark c_name) ctxt.struct_idmap
in in
try try
let f_uid = Scopelang.Ast.StructMap.find c_uid x_possible_structs in let f_uid = StructMap.find c_uid x_possible_structs in
Bindlib.box_apply Bindlib.box_apply
(fun e -> Desugared.Ast.EStructAccess (e, f_uid, c_uid), pos) (fun e -> Desugared.Ast.EStructAccess (e, f_uid, c_uid), pos)
e e
@ -391,7 +392,7 @@ let rec translate_expr
(fun s_fields (f_name, f_e) -> (fun s_fields (f_name, f_e) ->
let f_uid = let f_uid =
try try
Scopelang.Ast.StructMap.find s_uid StructMap.find s_uid
(Desugared.Ast.IdentMap.find (Marked.unmark f_name) (Desugared.Ast.IdentMap.find (Marked.unmark f_name)
ctxt.field_idmap) ctxt.field_idmap)
with Not_found -> with Not_found ->
@ -408,19 +409,19 @@ let rec translate_expr
None, Marked.get_mark (Bindlib.unbox e_field); None, Marked.get_mark (Bindlib.unbox e_field);
] ]
"The field %a has been defined twice:" "The field %a has been defined twice:"
Scopelang.Ast.StructFieldName.format_t f_uid); StructFieldName.format_t f_uid);
let f_e = translate_expr scope inside_definition_of ctxt f_e in let f_e = translate_expr scope inside_definition_of ctxt f_e in
Scopelang.Ast.StructFieldMap.add f_uid f_e s_fields) Scopelang.Ast.StructFieldMap.add f_uid f_e s_fields)
Scopelang.Ast.StructFieldMap.empty fields Scopelang.Ast.StructFieldMap.empty fields
in in
let expected_s_fields = Scopelang.Ast.StructMap.find s_uid ctxt.structs in let expected_s_fields = StructMap.find s_uid ctxt.structs in
Scopelang.Ast.StructFieldMap.iter Scopelang.Ast.StructFieldMap.iter
(fun expected_f _ -> (fun expected_f _ ->
if not (Scopelang.Ast.StructFieldMap.mem expected_f s_fields) then if not (Scopelang.Ast.StructFieldMap.mem expected_f s_fields) then
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Missing field for structure %a: \"%a\"" "Missing field for structure %a: \"%a\""
Scopelang.Ast.StructName.format_t s_uid StructName.format_t s_uid
Scopelang.Ast.StructFieldName.format_t expected_f) StructFieldName.format_t expected_f)
expected_s_fields; expected_s_fields;
Bindlib.box_apply Bindlib.box_apply
@ -443,7 +444,7 @@ let rec translate_expr
| None -> | None ->
if if
(* No constructor name was specified *) (* No constructor name was specified *)
Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 EnumMap.cardinal possible_c_uids > 1
then then
Errors.raise_spanned_error Errors.raise_spanned_error
(Marked.get_mark constructor) (Marked.get_mark constructor)
@ -452,10 +453,10 @@ let rec translate_expr
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ")
(fun fmt (s_name, _) -> (fun fmt (s_name, _) ->
Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t s_name)) Format.fprintf fmt "%a" EnumName.format_t s_name))
(Scopelang.Ast.EnumMap.bindings possible_c_uids) (EnumMap.bindings possible_c_uids)
else else
let e_uid, c_uid = Scopelang.Ast.EnumMap.choose possible_c_uids in let e_uid, c_uid = EnumMap.choose possible_c_uids in
let payload = let payload =
Option.map (translate_expr scope inside_definition_of ctxt) payload Option.map (translate_expr scope inside_definition_of ctxt) payload
in in
@ -465,7 +466,7 @@ let rec translate_expr
( (match payload with ( (match payload with
| Some e' -> e' | Some e' -> e'
| None -> | None ->
( Desugared.Ast.ELit Dcalc.Ast.LUnit, ( Desugared.Ast.ELit LUnit,
Marked.get_mark constructor )), Marked.get_mark constructor )),
c_uid, c_uid,
e_uid ), e_uid ),
@ -478,7 +479,7 @@ let rec translate_expr
Desugared.Ast.IdentMap.find (Marked.unmark enum) ctxt.enum_idmap Desugared.Ast.IdentMap.find (Marked.unmark enum) ctxt.enum_idmap
in in
try try
let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in let c_uid = EnumMap.find e_uid possible_c_uids in
let payload = let payload =
Option.map (translate_expr scope inside_definition_of ctxt) payload Option.map (translate_expr scope inside_definition_of ctxt) payload
in in
@ -488,7 +489,7 @@ let rec translate_expr
( (match payload with ( (match payload with
| Some e' -> e' | Some e' -> e'
| None -> | None ->
( Desugared.Ast.ELit Dcalc.Ast.LUnit, ( Desugared.Ast.ELit LUnit,
Marked.get_mark constructor )), Marked.get_mark constructor )),
c_uid, c_uid,
e_uid ), e_uid ),
@ -530,11 +531,11 @@ let rec translate_expr
(Desugared.Ast.make_abs [| nop_var |] (Desugared.Ast.make_abs [| nop_var |]
(Bindlib.box (Bindlib.box
( Desugared.Ast.ELit ( Desugared.Ast.ELit
(Dcalc.Ast.LBool (LBool
(Scopelang.Ast.EnumConstructor.compare c_uid c_uid' = 0)), (EnumConstructor.compare c_uid c_uid' = 0)),
pos )) pos ))
[tau] pos)) [tau] pos))
(Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) (EnumMap.find enum_uid ctxt.enums)
in in
Bindlib.box_apply Bindlib.box_apply
(fun e -> Desugared.Ast.EMatch (e, enum_uid, cases), pos) (fun e -> Desugared.Ast.EMatch (e, enum_uid, cases), pos)
@ -563,8 +564,8 @@ let rec translate_expr
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( ( Desugared.Ast.EOp ( ( Desugared.Ast.EOp
(match op' with (match op' with
| Ast.Map -> Dcalc.Ast.Binop Dcalc.Ast.Map | Ast.Map -> Binop Map
| Ast.Filter -> Dcalc.Ast.Binop Dcalc.Ast.Filter | Ast.Filter -> Binop Filter
| _ -> assert false (* should not happen *)), | _ -> assert false (* should not happen *)),
pos ), pos ),
[f_pred; collection] ), [f_pred; collection] ),
@ -583,11 +584,11 @@ let rec translate_expr
in in
let op_kind = let op_kind =
match pred_typ with match pred_typ with
| Ast.Integer -> Dcalc.Ast.KInt | Ast.Integer -> KInt
| Ast.Decimal -> Dcalc.Ast.KRat | Ast.Decimal -> KRat
| Ast.Money -> Dcalc.Ast.KMoney | Ast.Money -> KMoney
| Ast.Duration -> Dcalc.Ast.KDuration | Ast.Duration -> KDuration
| Ast.Date -> Dcalc.Ast.KDate | Ast.Date -> KDate
| _ -> | _ ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"It is impossible to compute the arg-%s of two values of type %a" "It is impossible to compute the arg-%s of two values of type %a"
@ -595,7 +596,7 @@ let rec translate_expr
Print.format_primitive_typ pred_typ Print.format_primitive_typ pred_typ
in in
let cmp_op = let cmp_op =
if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind if max_or_min then Gt op_kind else Lt op_kind
in in
let f_pred = let f_pred =
Desugared.Ast.make_abs [| param |] Desugared.Ast.make_abs [| param |]
@ -619,7 +620,7 @@ let rec translate_expr
(fun acc_var_e item_var_e f_pred_var_e -> (fun acc_var_e item_var_e f_pred_var_e ->
( Desugared.Ast.EIfThenElse ( Desugared.Ast.EIfThenElse
( ( Desugared.Ast.EApp ( ( Desugared.Ast.EApp
( (Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), pos_op'), ( (Desugared.Ast.EOp (Binop cmp_op), pos_op'),
[ [
Desugared.Ast.EApp (f_pred_var_e, [acc_var_e]), pos; Desugared.Ast.EApp (f_pred_var_e, [acc_var_e]), pos;
Desugared.Ast.EApp (f_pred_var_e, [item_var_e]), pos; Desugared.Ast.EApp (f_pred_var_e, [item_var_e]), pos;
@ -639,7 +640,7 @@ let rec translate_expr
Bindlib.box_apply3 Bindlib.box_apply3
(fun fold_f collection init -> (fun fold_f collection init ->
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), ( (Desugared.Ast.EOp (Ternop Fold), pos),
[fold_f; init; collection] ), [fold_f; init; collection] ),
pos )) pos ))
fold_f collection init fold_f collection init
@ -656,28 +657,28 @@ let rec translate_expr
assert false (* should not happen *) assert false (* should not happen *)
| Ast.Exists -> | Ast.Exists ->
Bindlib.box Bindlib.box
(Desugared.Ast.ELit (Dcalc.Ast.LBool false), Marked.get_mark op') (Desugared.Ast.ELit (LBool false), Marked.get_mark op')
| Ast.Forall -> | Ast.Forall ->
Bindlib.box Bindlib.box
(Desugared.Ast.ELit (Dcalc.Ast.LBool true), Marked.get_mark op') (Desugared.Ast.ELit (LBool true), Marked.get_mark op')
| Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> | Ast.Aggregate (Ast.AggregateSum Ast.Integer) ->
Bindlib.box Bindlib.box
( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), ( Desugared.Ast.ELit (LInt (Runtime.integer_of_int 0)),
Marked.get_mark op' ) Marked.get_mark op' )
| Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) ->
Bindlib.box Bindlib.box
( Desugared.Ast.ELit (Dcalc.Ast.LRat (Runtime.decimal_of_string "0")), ( Desugared.Ast.ELit (LRat (Runtime.decimal_of_string "0")),
Marked.get_mark op' ) Marked.get_mark op' )
| Ast.Aggregate (Ast.AggregateSum Ast.Money) -> | Ast.Aggregate (Ast.AggregateSum Ast.Money) ->
Bindlib.box Bindlib.box
( Desugared.Ast.ELit ( Desugared.Ast.ELit
(Dcalc.Ast.LMoney (LMoney
(Runtime.money_of_cents_integer (Runtime.integer_of_int 0))), (Runtime.money_of_cents_integer (Runtime.integer_of_int 0))),
Marked.get_mark op' ) Marked.get_mark op' )
| Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> | Ast.Aggregate (Ast.AggregateSum Ast.Duration) ->
Bindlib.box Bindlib.box
( Desugared.Ast.ELit ( Desugared.Ast.ELit
(Dcalc.Ast.LDuration (Runtime.duration_of_numbers 0 0 0)), (LDuration (Runtime.duration_of_numbers 0 0 0)),
Marked.get_mark op' ) Marked.get_mark op' )
| Ast.Aggregate (Ast.AggregateSum t) -> | Ast.Aggregate (Ast.AggregateSum t) ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
@ -686,24 +687,24 @@ let rec translate_expr
| Ast.Aggregate (Ast.AggregateExtremum (_, _, init)) -> rec_helper init | Ast.Aggregate (Ast.AggregateExtremum (_, _, init)) -> rec_helper init
| Ast.Aggregate Ast.AggregateCount -> | Ast.Aggregate Ast.AggregateCount ->
Bindlib.box Bindlib.box
( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), ( Desugared.Ast.ELit (LInt (Runtime.integer_of_int 0)),
Marked.get_mark op' ) Marked.get_mark op' )
in in
let acc_var = Desugared.Ast.Var.make "acc" in let acc_var = Desugared.Ast.Var.make "acc" in
let acc = Desugared.Ast.make_var (acc_var, Marked.get_mark param') in let acc = Desugared.Ast.make_var (acc_var, Marked.get_mark param') in
let f_body = let f_body =
let make_body (op : Dcalc.Ast.binop) = let make_body (op : binop) =
Bindlib.box_apply2 Bindlib.box_apply2
(fun predicate acc -> (fun predicate acc ->
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( (Desugared.Ast.EOp (Dcalc.Ast.Binop op), Marked.get_mark op'), ( (Desugared.Ast.EOp (Binop op), Marked.get_mark op'),
[acc; predicate] ), [acc; predicate] ),
pos )) pos ))
(translate_expr scope inside_definition_of ctxt predicate) (translate_expr scope inside_definition_of ctxt predicate)
acc acc
in in
let make_extr_body let make_extr_body
(cmp_op : Dcalc.Ast.binop) (cmp_op : binop)
(t : Scopelang.Ast.typ Marked.pos) = (t : Scopelang.Ast.typ Marked.pos) =
let tmp_var = Desugared.Ast.Var.make "tmp" in let tmp_var = Desugared.Ast.Var.make "tmp" in
let tmp = Desugared.Ast.make_var (tmp_var, Marked.get_mark param') in let tmp = Desugared.Ast.make_var (tmp_var, Marked.get_mark param') in
@ -713,7 +714,7 @@ let rec translate_expr
(fun acc tmp -> (fun acc tmp ->
( Desugared.Ast.EIfThenElse ( Desugared.Ast.EIfThenElse
( ( Desugared.Ast.EApp ( ( Desugared.Ast.EApp
( ( Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), ( ( Desugared.Ast.EOp (Binop cmp_op),
Marked.get_mark op' ), Marked.get_mark op' ),
[acc; tmp] ), [acc; tmp] ),
pos ), pos ),
@ -725,35 +726,35 @@ let rec translate_expr
match Marked.unmark op' with match Marked.unmark op' with
| Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) ->
assert false (* should not happen *) assert false (* should not happen *)
| Ast.Exists -> make_body Dcalc.Ast.Or | Ast.Exists -> make_body Or
| Ast.Forall -> make_body Dcalc.Ast.And | Ast.Forall -> make_body And
| Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> | Ast.Aggregate (Ast.AggregateSum Ast.Integer) ->
make_body (Dcalc.Ast.Add Dcalc.Ast.KInt) make_body (Add KInt)
| Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) ->
make_body (Dcalc.Ast.Add Dcalc.Ast.KRat) make_body (Add KRat)
| Ast.Aggregate (Ast.AggregateSum Ast.Money) -> | Ast.Aggregate (Ast.AggregateSum Ast.Money) ->
make_body (Dcalc.Ast.Add Dcalc.Ast.KMoney) make_body (Add KMoney)
| Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> | Ast.Aggregate (Ast.AggregateSum Ast.Duration) ->
make_body (Dcalc.Ast.Add Dcalc.Ast.KDuration) make_body (Add KDuration)
| Ast.Aggregate (Ast.AggregateSum _) -> | Ast.Aggregate (Ast.AggregateSum _) ->
assert false (* should not happen *) assert false (* should not happen *)
| Ast.Aggregate (Ast.AggregateExtremum (max_or_min, t, _)) -> | Ast.Aggregate (Ast.AggregateExtremum (max_or_min, t, _)) ->
let op_kind, typ = let op_kind, typ =
match t with match t with
| Ast.Integer -> Dcalc.Ast.KInt, (Scopelang.Ast.TLit TInt, pos) | Ast.Integer -> KInt, (Scopelang.Ast.TLit TInt, pos)
| Ast.Decimal -> Dcalc.Ast.KRat, (Scopelang.Ast.TLit TRat, pos) | Ast.Decimal -> KRat, (Scopelang.Ast.TLit TRat, pos)
| Ast.Money -> Dcalc.Ast.KMoney, (Scopelang.Ast.TLit TMoney, pos) | Ast.Money -> KMoney, (Scopelang.Ast.TLit TMoney, pos)
| Ast.Duration -> | Ast.Duration ->
Dcalc.Ast.KDuration, (Scopelang.Ast.TLit TDuration, pos) KDuration, (Scopelang.Ast.TLit TDuration, pos)
| Ast.Date -> Dcalc.Ast.KDate, (Scopelang.Ast.TLit TDate, pos) | Ast.Date -> KDate, (Scopelang.Ast.TLit TDate, pos)
| _ -> | _ ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"It is impossible to compute the %s of two values of type %a" "ssible to compute the %s of two values of type %a"
(if max_or_min then "max" else "min") (if max_or_min then "max" else "min")
Print.format_primitive_typ t Print.format_primitive_typ t
in in
let cmp_op = let cmp_op =
if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind if max_or_min then Gt op_kind else Lt op_kind
in in
make_extr_body cmp_op typ make_extr_body cmp_op typ
| Ast.Aggregate Ast.AggregateCount -> | Ast.Aggregate Ast.AggregateCount ->
@ -763,12 +764,12 @@ let rec translate_expr
( predicate, ( predicate,
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( ( Desugared.Ast.EOp ( ( Desugared.Ast.EOp
(Dcalc.Ast.Binop (Dcalc.Ast.Add Dcalc.Ast.KInt)), (Binop (Add KInt)),
Marked.get_mark op' ), Marked.get_mark op' ),
[ [
acc; acc;
( Desugared.Ast.ELit ( Desugared.Ast.ELit
(Dcalc.Ast.LInt (Runtime.integer_of_int 1)), (LInt (Runtime.integer_of_int 1)),
Marked.get_mark predicate ); Marked.get_mark predicate );
] ), ] ),
pos ), pos ),
@ -778,7 +779,7 @@ let rec translate_expr
acc acc
in in
let f = let f =
let make_f (t : Dcalc.Ast.typ_lit) = let make_f (t : typ_lit) =
Bindlib.box_apply Bindlib.box_apply
(fun binder -> (fun binder ->
( Desugared.Ast.EAbs ( Desugared.Ast.EAbs
@ -796,29 +797,29 @@ let rec translate_expr
match Marked.unmark op' with match Marked.unmark op' with
| Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) ->
assert false (* should not happen *) assert false (* should not happen *)
| Ast.Exists -> make_f Dcalc.Ast.TBool | Ast.Exists -> make_f TBool
| Ast.Forall -> make_f Dcalc.Ast.TBool | Ast.Forall -> make_f TBool
| Ast.Aggregate (Ast.AggregateSum Ast.Integer) | Ast.Aggregate (Ast.AggregateSum Ast.Integer)
| Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Integer, _)) -> | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Integer, _)) ->
make_f Dcalc.Ast.TInt make_f TInt
| Ast.Aggregate (Ast.AggregateSum Ast.Decimal) | Ast.Aggregate (Ast.AggregateSum Ast.Decimal)
| Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Decimal, _)) -> | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Decimal, _)) ->
make_f Dcalc.Ast.TRat make_f TRat
| Ast.Aggregate (Ast.AggregateSum Ast.Money) | Ast.Aggregate (Ast.AggregateSum Ast.Money)
| Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Money, _)) -> | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Money, _)) ->
make_f Dcalc.Ast.TMoney make_f TMoney
| Ast.Aggregate (Ast.AggregateSum Ast.Duration) | Ast.Aggregate (Ast.AggregateSum Ast.Duration)
| Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Duration, _)) -> | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Duration, _)) ->
make_f Dcalc.Ast.TDuration make_f TDuration
| Ast.Aggregate (Ast.AggregateSum _) | Ast.Aggregate (Ast.AggregateSum _)
| Ast.Aggregate (Ast.AggregateExtremum _) -> | Ast.Aggregate (Ast.AggregateExtremum _) ->
assert false (* should not happen *) assert false (* should not happen *)
| Ast.Aggregate Ast.AggregateCount -> make_f Dcalc.Ast.TInt | Ast.Aggregate Ast.AggregateCount -> make_f TInt
in in
Bindlib.box_apply3 Bindlib.box_apply3
(fun f collection init -> (fun f collection init ->
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), ( (Desugared.Ast.EOp (Ternop Fold), pos),
[f; init; collection] ), [f; init; collection] ),
pos )) pos ))
f collection init f collection init
@ -826,17 +827,17 @@ let rec translate_expr
let param_var = Desugared.Ast.Var.make "collection_member" in let param_var = Desugared.Ast.Var.make "collection_member" in
let param = Desugared.Ast.make_var (param_var, pos) in let param = Desugared.Ast.make_var (param_var, pos) in
let collection = rec_helper collection in let collection = rec_helper collection in
let init = Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos) in let init = Bindlib.box (Desugared.Ast.ELit (LBool false), pos) in
let acc_var = Desugared.Ast.Var.make "acc" in let acc_var = Desugared.Ast.Var.make "acc" in
let acc = Desugared.Ast.make_var (acc_var, pos) in let acc = Desugared.Ast.make_var (acc_var, pos) in
let f_body = let f_body =
Bindlib.box_apply3 Bindlib.box_apply3
(fun member acc param -> (fun member acc param ->
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Or), pos), ( (Desugared.Ast.EOp (Binop Or), pos),
[ [
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Eq), pos), ( (Desugared.Ast.EOp (Binop Eq), pos),
[member; param] ), [member; param] ),
pos ); pos );
acc; acc;
@ -851,7 +852,7 @@ let rec translate_expr
( Desugared.Ast.EAbs ( Desugared.Ast.EAbs
( binder, ( binder,
[ [
Scopelang.Ast.TLit Dcalc.Ast.TBool, pos; Scopelang.Ast.TLit TBool, pos;
Scopelang.Ast.TAny, pos; Scopelang.Ast.TAny, pos;
] ), ] ),
pos )) pos ))
@ -860,42 +861,42 @@ let rec translate_expr
Bindlib.box_apply3 Bindlib.box_apply3
(fun f collection init -> (fun f collection init ->
( Desugared.Ast.EApp ( Desugared.Ast.EApp
( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), ( (Desugared.Ast.EOp (Ternop Fold), pos),
[f; init; collection] ), [f; init; collection] ),
pos )) pos ))
f collection init f collection init
| Builtin IntToDec -> | Builtin IntToDec ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.IntToRat), pos) Bindlib.box (Desugared.Ast.EOp (Unop IntToRat), pos)
| Builtin MoneyToDec -> | Builtin MoneyToDec ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.MoneyToRat), pos) Bindlib.box (Desugared.Ast.EOp (Unop MoneyToRat), pos)
| Builtin DecToMoney -> | Builtin DecToMoney ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RatToMoney), pos) Bindlib.box (Desugared.Ast.EOp (Unop RatToMoney), pos)
| Builtin Cardinal -> | Builtin Cardinal ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.Length), pos) Bindlib.box (Desugared.Ast.EOp (Unop Length), pos)
| Builtin GetDay -> | Builtin GetDay ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetDay), pos) Bindlib.box (Desugared.Ast.EOp (Unop GetDay), pos)
| Builtin GetMonth -> | Builtin GetMonth ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetMonth), pos) Bindlib.box (Desugared.Ast.EOp (Unop GetMonth), pos)
| Builtin GetYear -> | Builtin GetYear ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetYear), pos) Bindlib.box (Desugared.Ast.EOp (Unop GetYear), pos)
| Builtin FirstDayOfMonth -> | Builtin FirstDayOfMonth ->
Bindlib.box Bindlib.box
(Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.FirstDayOfMonth), pos) (Desugared.Ast.EOp (Unop FirstDayOfMonth), pos)
| Builtin LastDayOfMonth -> | Builtin LastDayOfMonth ->
Bindlib.box Bindlib.box
(Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.LastDayOfMonth), pos) (Desugared.Ast.EOp (Unop LastDayOfMonth), pos)
| Builtin RoundMoney -> | Builtin RoundMoney ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundMoney), pos) Bindlib.box (Desugared.Ast.EOp (Unop RoundMoney), pos)
| Builtin RoundDecimal -> | Builtin RoundDecimal ->
Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundDecimal), pos) Bindlib.box (Desugared.Ast.EOp (Unop RoundDecimal), pos)
and disambiguate_match_and_build_expression and disambiguate_match_and_build_expression
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option) (inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(cases : Ast.match_case Marked.pos list) : (cases : Ast.match_case Marked.pos list) :
Desugared.Ast.expr Marked.pos Bindlib.box Scopelang.Ast.EnumConstructorMap.t Desugared.Ast.expr Marked.pos Bindlib.box Scopelang.Ast.EnumConstructorMap.t
* Scopelang.Ast.EnumName.t = * EnumName.t =
let create_var = function let create_var = function
| None -> ctxt, Desugared.Ast.Var.make "_" | None -> ctxt, Desugared.Ast.Var.make "_"
| Some param -> | Some param ->
@ -903,8 +904,8 @@ and disambiguate_match_and_build_expression
ctxt, param_var ctxt, param_var
in in
let bind_case_body let bind_case_body
(c_uid : Dcalc.Ast.EnumConstructor.t) (c_uid : EnumConstructor.t)
(e_uid : Dcalc.Ast.EnumName.t) (e_uid : EnumName.t)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(case_body : ('a * Pos.t) Bindlib.box) (case_body : ('a * Pos.t) Bindlib.box)
(e_binder : (e_binder :
@ -917,7 +918,7 @@ and disambiguate_match_and_build_expression
( e_binder, ( e_binder,
[ [
Scopelang.Ast.EnumConstructorMap.find c_uid Scopelang.Ast.EnumConstructorMap.find c_uid
(Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums); (EnumMap.find e_uid ctxt.Name_resolution.enums);
] )) ] ))
case_body) case_body)
e_binder case_body e_binder case_body
@ -940,8 +941,8 @@ and disambiguate_match_and_build_expression
(Marked.get_mark case.Ast.match_case_pattern) (Marked.get_mark case.Ast.match_case_pattern)
"This case matches a constructor of enumeration %a but previous \ "This case matches a constructor of enumeration %a but previous \
case were matching constructors of enumeration %a" case were matching constructors of enumeration %a"
Scopelang.Ast.EnumName.format_t e_uid EnumName.format_t e_uid
Scopelang.Ast.EnumName.format_t e_uid' EnumName.format_t e_uid'
in in
(match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with (match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with
| None -> () | None -> ()
@ -952,7 +953,7 @@ and disambiguate_match_and_build_expression
None, Marked.get_mark (Bindlib.unbox e_case); None, Marked.get_mark (Bindlib.unbox e_case);
] ]
"The constructor %a has been matched twice:" "The constructor %a has been matched twice:"
Scopelang.Ast.EnumConstructor.format_t c_uid); EnumConstructor.format_t c_uid);
let ctxt, param_var = create_var (Option.map Marked.unmark binding) in let ctxt, param_var = create_var (Option.map Marked.unmark binding) in
let case_body = let case_body =
translate_expr scope inside_definition_of ctxt case.Ast.match_case_expr translate_expr scope inside_definition_of ctxt case.Ast.match_case_expr
@ -983,7 +984,7 @@ and disambiguate_match_and_build_expression
| Some e_uid -> | Some e_uid ->
if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err (); if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err ();
let missing_constructors = let missing_constructors =
Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums EnumMap.find e_uid ctxt.Name_resolution.enums
|> Scopelang.Ast.EnumConstructorMap.filter_map (fun c_uid _ -> |> Scopelang.Ast.EnumConstructorMap.filter_map (fun c_uid _ ->
match match
Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d
@ -995,7 +996,7 @@ and disambiguate_match_and_build_expression
Errors.format_spanned_warning case_pos Errors.format_spanned_warning case_pos
"Unreachable match case, all constructors of the enumeration %a \ "Unreachable match case, all constructors of the enumeration %a \
are already specified" are already specified"
Scopelang.Ast.EnumName.format_t e_uid; EnumName.format_t e_uid;
(* The current used strategy is to replace the wildcard branch: (* The current used strategy is to replace the wildcard branch:
match foo with match foo with
| Case1 x -> x | Case1 x -> x
@ -1048,7 +1049,7 @@ let merge_conditions
match precond, cond with match precond, cond with
| Some precond, Some cond -> | Some precond, Some cond ->
let op_term = let op_term =
( Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.And), ( Desugared.Ast.EOp (Binop And),
Marked.get_mark (Bindlib.unbox cond) ) Marked.get_mark (Bindlib.unbox cond) )
in in
Bindlib.box_apply2 Bindlib.box_apply2
@ -1061,13 +1062,13 @@ let merge_conditions
precond precond
| None, Some cond -> cond | None, Some cond -> cond
| None, None -> | None, None ->
Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool true), default_pos) Bindlib.box (Desugared.Ast.ELit (LBool true), default_pos)
(** Translates a surface definition into condition into a desugared {!type: (** Translates a surface definition into condition into a desugared {!type:
Desugared.Ast.rule} *) Desugared.Ast.rule} *)
let process_default let process_default
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(def_key : Desugared.Ast.ScopeDef.t Marked.pos) (def_key : Desugared.Ast.ScopeDef.t Marked.pos)
(rule_id : Desugared.Ast.RuleName.t) (rule_id : Desugared.Ast.RuleName.t)
(param_uid : Desugared.Ast.Var.t Marked.pos option) (param_uid : Desugared.Ast.Var.t Marked.pos option)
@ -1111,7 +1112,7 @@ let process_default
disambiguation *) disambiguation *)
let process_def let process_def
(precond : Desugared.Ast.expr Marked.pos Bindlib.box option) (precond : Desugared.Ast.expr Marked.pos Bindlib.box option)
(scope_uid : Scopelang.Ast.ScopeName.t) (scope_uid : ScopeName.t)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(prgm : Desugared.Ast.program) (prgm : Desugared.Ast.program)
(def : Ast.definition) : Desugared.Ast.program = (def : Ast.definition) : Desugared.Ast.program =
@ -1200,7 +1201,7 @@ let process_def
(** Translates a {!type: Surface.Ast.rule} from the surface language *) (** Translates a {!type: Surface.Ast.rule} from the surface language *)
let process_rule let process_rule
(precond : Desugared.Ast.expr Marked.pos Bindlib.box option) (precond : Desugared.Ast.expr Marked.pos Bindlib.box option)
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(prgm : Desugared.Ast.program) (prgm : Desugared.Ast.program)
(rule : Ast.rule) : Desugared.Ast.program = (rule : Ast.rule) : Desugared.Ast.program =
@ -1210,7 +1211,7 @@ let process_rule
(** Translates assertions *) (** Translates assertions *)
let process_assert let process_assert
(precond : Desugared.Ast.expr Marked.pos Bindlib.box option) (precond : Desugared.Ast.expr Marked.pos Bindlib.box option)
(scope_uid : Scopelang.Ast.ScopeName.t) (scope_uid : ScopeName.t)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(prgm : Desugared.Ast.program) (prgm : Desugared.Ast.program)
(ass : Ast.assertion) : Desugared.Ast.program = (ass : Ast.assertion) : Desugared.Ast.program =
@ -1236,7 +1237,7 @@ let process_assert
( Desugared.Ast.EIfThenElse ( Desugared.Ast.EIfThenElse
( precond, ( precond,
ass, ass,
Marked.same_mark_as (Desugared.Ast.ELit (Dcalc.Ast.LBool true)) Marked.same_mark_as (Desugared.Ast.ELit (LBool true))
precond ), precond ),
Marked.get_mark precond )) Marked.get_mark precond ))
precond ass precond ass
@ -1254,7 +1255,7 @@ let process_assert
(** Translates a surface definition, rule or assertion *) (** Translates a surface definition, rule or assertion *)
let process_scope_use_item let process_scope_use_item
(precond : Ast.expression Marked.pos option) (precond : Ast.expression Marked.pos option)
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(prgm : Desugared.Ast.program) (prgm : Desugared.Ast.program)
(item : Ast.scope_use_item Marked.pos) : Desugared.Ast.program = (item : Ast.scope_use_item Marked.pos) : Desugared.Ast.program =
@ -1270,7 +1271,7 @@ let process_scope_use_item
(* If this is an unlabeled exception, ensures that it has a unique default (* If this is an unlabeled exception, ensures that it has a unique default
definition *) definition *)
let check_unlabeled_exception let check_unlabeled_exception
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(item : Ast.scope_use_item Marked.pos) : unit = (item : Ast.scope_use_item Marked.pos) : unit =
let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in
@ -1353,10 +1354,10 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) :
let empty_prgm = let empty_prgm =
{ {
Desugared.Ast.program_structs = Desugared.Ast.program_structs =
Scopelang.Ast.StructMap.map Scopelang.Ast.StructFieldMap.bindings StructMap.map Scopelang.Ast.StructFieldMap.bindings
ctxt.Name_resolution.structs; ctxt.Name_resolution.structs;
Desugared.Ast.program_enums = Desugared.Ast.program_enums =
Scopelang.Ast.EnumMap.map Scopelang.Ast.EnumConstructorMap.bindings EnumMap.map Scopelang.Ast.EnumConstructorMap.bindings
ctxt.Name_resolution.enums; ctxt.Name_resolution.enums;
Desugared.Ast.program_scopes = Desugared.Ast.program_scopes =
Scopelang.Ast.ScopeMap.mapi Scopelang.Ast.ScopeMap.mapi

View File

@ -19,6 +19,7 @@
lexical scopes into account *) lexical scopes into account *)
open Utils open Utils
open Shared_ast
(** {1 Name resolution context} *) (** {1 Name resolution context} *)
@ -41,7 +42,7 @@ type scope_context = {
(** What is the default rule to refer to for unnamed exceptions, if any *) (** What is the default rule to refer to for unnamed exceptions, if any *)
sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t; sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t;
(** Sub-scopes variables *) (** Sub-scopes variables *)
sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t;
(** To what scope sub-scopes refer to? *) (** To what scope sub-scopes refer to? *)
} }
(** Inside a scope, we distinguish between the variables and the subscopes. *) (** Inside a scope, we distinguish between the variables and the subscopes. *)
@ -64,27 +65,27 @@ type context = {
local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t; local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t;
(** Inside a definition, local variables can be introduced by functions (** Inside a definition, local variables can be introduced by functions
arguments or pattern matching *) arguments or pattern matching *)
scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t; scope_idmap : ScopeName.t Desugared.Ast.IdentMap.t;
(** The names of the scopes *) (** The names of the scopes *)
struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t; struct_idmap : StructName.t Desugared.Ast.IdentMap.t;
(** The names of the structs *) (** The names of the structs *)
field_idmap : field_idmap :
Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t StructFieldName.t StructMap.t
Desugared.Ast.IdentMap.t; Desugared.Ast.IdentMap.t;
(** The names of the struct fields. Names of fields can be shared between (** The names of the struct fields. Names of fields can be shared between
different structs *) different structs *)
enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t; enum_idmap : EnumName.t Desugared.Ast.IdentMap.t;
(** The names of the enums *) (** The names of the enums *)
constructor_idmap : constructor_idmap :
Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t EnumConstructor.t EnumMap.t
Desugared.Ast.IdentMap.t; Desugared.Ast.IdentMap.t;
(** The names of the enum constructors. Constructor names can be shared (** The names of the enum constructors. Constructor names can be shared
between different enums *) between different enums *)
scopes : scope_context Scopelang.Ast.ScopeMap.t; scopes : scope_context Scopelang.Ast.ScopeMap.t;
(** For each scope, its context *) (** For each scope, its context *)
structs : struct_context Scopelang.Ast.StructMap.t; structs : struct_context StructMap.t;
(** For each struct, its context *) (** For each struct, its context *)
enums : enum_context Scopelang.Ast.EnumMap.t; enums : enum_context EnumMap.t;
(** For each enum, its context *) (** For each enum, its context *)
var_typs : var_sig Desugared.Ast.ScopeVarMap.t; var_typs : var_sig Desugared.Ast.ScopeVarMap.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
@ -120,7 +121,7 @@ let get_var_io (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) :
(** Get the variable uid inside the scope given in argument *) (** Get the variable uid inside the scope given in argument *)
let get_var_uid let get_var_uid
(scope_uid : Scopelang.Ast.ScopeName.t) (scope_uid : ScopeName.t)
(ctxt : context) (ctxt : context)
((x, pos) : ident Marked.pos) : Desugared.Ast.ScopeVar.t = ((x, pos) : ident Marked.pos) : Desugared.Ast.ScopeVar.t =
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
@ -128,13 +129,13 @@ let get_var_uid
| None -> | None ->
raise_unknown_identifier raise_unknown_identifier
(Format.asprintf "for a variable of scope %a" (Format.asprintf "for a variable of scope %a"
Scopelang.Ast.ScopeName.format_t scope_uid) ScopeName.format_t scope_uid)
(x, pos) (x, pos)
| Some uid -> uid | Some uid -> uid
(** Get the subscope uid inside the scope given in argument *) (** Get the subscope uid inside the scope given in argument *)
let get_subscope_uid let get_subscope_uid
(scope_uid : Scopelang.Ast.ScopeName.t) (scope_uid : ScopeName.t)
(ctxt : context) (ctxt : context)
((y, pos) : ident Marked.pos) : Scopelang.Ast.SubScopeName.t = ((y, pos) : ident Marked.pos) : Scopelang.Ast.SubScopeName.t =
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
@ -145,7 +146,7 @@ let get_subscope_uid
(** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the
subscopes of [scope_uid]. *) subscopes of [scope_uid]. *)
let is_subscope_uid let is_subscope_uid
(scope_uid : Scopelang.Ast.ScopeName.t) (scope_uid : ScopeName.t)
(ctxt : context) (ctxt : context)
(y : ident) : bool = (y : ident) : bool =
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
@ -155,7 +156,7 @@ let is_subscope_uid
let belongs_to let belongs_to
(ctxt : context) (ctxt : context)
(uid : Desugared.Ast.ScopeVar.t) (uid : Desugared.Ast.ScopeVar.t)
(scope_uid : Scopelang.Ast.ScopeName.t) : bool = (scope_uid : ScopeName.t) : bool =
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
Desugared.Ast.IdentMap.exists Desugared.Ast.IdentMap.exists
(fun _ var_uid -> Desugared.Ast.ScopeVar.compare uid var_uid = 0) (fun _ var_uid -> Desugared.Ast.ScopeVar.compare uid var_uid = 0)
@ -183,7 +184,7 @@ let is_def_cond (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : bool =
(** Process a subscope declaration *) (** Process a subscope declaration *)
let process_subscope_decl let process_subscope_decl
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(ctxt : context) (ctxt : context)
(decl : Ast.scope_decl_context_scope) : context = (decl : Ast.scope_decl_context_scope) : context =
let name, name_pos = decl.scope_decl_context_scope_name in let name, name_pos = decl.scope_decl_context_scope_name in
@ -277,7 +278,7 @@ let process_type (ctxt : context) ((typ, typ_pos) : Ast.typ Marked.pos) :
(** Process data declaration *) (** Process data declaration *)
let process_data_decl let process_data_decl
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(ctxt : context) (ctxt : context)
(decl : Ast.scope_decl_context_data) : context = (decl : Ast.scope_decl_context_data) : context =
(* First check the type of the context data *) (* First check the type of the context data *)
@ -330,7 +331,7 @@ let process_data_decl
(** Process an item declaration *) (** Process an item declaration *)
let process_item_decl let process_item_decl
(scope : Scopelang.Ast.ScopeName.t) (scope : ScopeName.t)
(ctxt : context) (ctxt : context)
(decl : Ast.scope_decl_context_item) : context = (decl : Ast.scope_decl_context_item) : context =
match decl with match decl with
@ -372,7 +373,7 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context =
List.fold_left List.fold_left
(fun ctxt (fdecl, _) -> (fun ctxt (fdecl, _) ->
let f_uid = let f_uid =
Scopelang.Ast.StructFieldName.fresh fdecl.Ast.struct_decl_field_name StructFieldName.fresh fdecl.Ast.struct_decl_field_name
in in
let ctxt = let ctxt =
{ {
@ -382,16 +383,16 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context =
(Marked.unmark fdecl.Ast.struct_decl_field_name) (Marked.unmark fdecl.Ast.struct_decl_field_name)
(fun uids -> (fun uids ->
match uids with match uids with
| None -> Some (Scopelang.Ast.StructMap.singleton s_uid f_uid) | None -> Some (StructMap.singleton s_uid f_uid)
| Some uids -> | Some uids ->
Some (Scopelang.Ast.StructMap.add s_uid f_uid uids)) Some (StructMap.add s_uid f_uid uids))
ctxt.field_idmap; ctxt.field_idmap;
} }
in in
{ {
ctxt with ctxt with
structs = structs =
Scopelang.Ast.StructMap.update s_uid StructMap.update s_uid
(fun fields -> (fun fields ->
match fields with match fields with
| None -> | None ->
@ -421,7 +422,7 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context =
List.fold_left List.fold_left
(fun ctxt (cdecl, cdecl_pos) -> (fun ctxt (cdecl, cdecl_pos) ->
let c_uid = let c_uid =
Scopelang.Ast.EnumConstructor.fresh cdecl.Ast.enum_decl_case_name EnumConstructor.fresh cdecl.Ast.enum_decl_case_name
in in
let ctxt = let ctxt =
{ {
@ -431,15 +432,15 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context =
(Marked.unmark cdecl.Ast.enum_decl_case_name) (Marked.unmark cdecl.Ast.enum_decl_case_name)
(fun uids -> (fun uids ->
match uids with match uids with
| None -> Some (Scopelang.Ast.EnumMap.singleton e_uid c_uid) | None -> Some (EnumMap.singleton e_uid c_uid)
| Some uids -> Some (Scopelang.Ast.EnumMap.add e_uid c_uid uids)) | Some uids -> Some (EnumMap.add e_uid c_uid uids))
ctxt.constructor_idmap; ctxt.constructor_idmap;
} }
in in
{ {
ctxt with ctxt with
enums = enums =
Scopelang.Ast.EnumMap.update e_uid EnumMap.update e_uid
(fun cases -> (fun cases ->
let typ = let typ =
match cdecl.Ast.enum_decl_case_typ with match cdecl.Ast.enum_decl_case_typ with
@ -475,10 +476,10 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) :
match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with
| Some use -> | Some use ->
raise_already_defined_error raise_already_defined_error
(Scopelang.Ast.ScopeName.get_info use) (ScopeName.get_info use)
name pos "scope" name pos "scope"
| None -> | None ->
let scope_uid = Scopelang.Ast.ScopeName.fresh (name, pos) in let scope_uid = ScopeName.fresh (name, pos) in
{ {
ctxt with ctxt with
scope_idmap = Desugared.Ast.IdentMap.add name scope_uid ctxt.scope_idmap; scope_idmap = Desugared.Ast.IdentMap.add name scope_uid ctxt.scope_idmap;
@ -497,10 +498,10 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) :
match Desugared.Ast.IdentMap.find_opt name ctxt.struct_idmap with match Desugared.Ast.IdentMap.find_opt name ctxt.struct_idmap with
| Some use -> | Some use ->
raise_already_defined_error raise_already_defined_error
(Scopelang.Ast.StructName.get_info use) (StructName.get_info use)
name pos "struct" name pos "struct"
| None -> | None ->
let s_uid = Scopelang.Ast.StructName.fresh sdecl.struct_decl_name in let s_uid = StructName.fresh sdecl.struct_decl_name in
{ {
ctxt with ctxt with
struct_idmap = struct_idmap =
@ -513,10 +514,10 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) :
match Desugared.Ast.IdentMap.find_opt name ctxt.enum_idmap with match Desugared.Ast.IdentMap.find_opt name ctxt.enum_idmap with
| Some use -> | Some use ->
raise_already_defined_error raise_already_defined_error
(Scopelang.Ast.EnumName.get_info use) (EnumName.get_info use)
name pos "enum" name pos "enum"
| None -> | None ->
let e_uid = Scopelang.Ast.EnumName.fresh edecl.enum_decl_name in let e_uid = EnumName.fresh edecl.enum_decl_name in
{ {
ctxt with ctxt with
@ -561,7 +562,7 @@ let rec process_law_structure
let get_def_key let get_def_key
(name : Ast.qident) (name : Ast.qident)
(state : Ast.ident Marked.pos option) (state : Ast.ident Marked.pos option)
(scope_uid : Scopelang.Ast.ScopeName.t) (scope_uid : ScopeName.t)
(ctxt : context) (ctxt : context)
(default_pos : Pos.t) : Desugared.Ast.ScopeDef.t = (default_pos : Pos.t) : Desugared.Ast.ScopeDef.t =
let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
@ -603,7 +604,7 @@ let get_def_key
let subscope_uid : Scopelang.Ast.SubScopeName.t = let subscope_uid : Scopelang.Ast.SubScopeName.t =
get_subscope_uid scope_uid ctxt y get_subscope_uid scope_uid ctxt y
in in
let subscope_real_uid : Scopelang.Ast.ScopeName.t = let subscope_real_uid : ScopeName.t =
Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes
in in
let x_uid = get_var_uid subscope_real_uid ctxt x in let x_uid = get_var_uid subscope_real_uid ctxt x in
@ -616,7 +617,7 @@ let get_def_key
let process_definition let process_definition
(ctxt : context) (ctxt : context)
(s_name : Scopelang.Ast.ScopeName.t) (s_name : ScopeName.t)
(d : Ast.definition) : context = (d : Ast.definition) : context =
(* We update the definition context inside the big context *) (* We update the definition context inside the big context *)
{ {
@ -725,7 +726,7 @@ let process_definition
} }
let process_scope_use_item let process_scope_use_item
(s_name : Scopelang.Ast.ScopeName.t) (s_name : ScopeName.t)
(ctxt : context) (ctxt : context)
(sitem : Ast.scope_use_item Marked.pos) : context = (sitem : Ast.scope_use_item Marked.pos) : context =
match Marked.unmark sitem with match Marked.unmark sitem with
@ -764,10 +765,10 @@ let form_context (prgm : Ast.program) : context =
scope_idmap = Desugared.Ast.IdentMap.empty; scope_idmap = Desugared.Ast.IdentMap.empty;
scopes = Scopelang.Ast.ScopeMap.empty; scopes = Scopelang.Ast.ScopeMap.empty;
var_typs = Desugared.Ast.ScopeVarMap.empty; var_typs = Desugared.Ast.ScopeVarMap.empty;
structs = Scopelang.Ast.StructMap.empty; structs = StructMap.empty;
struct_idmap = Desugared.Ast.IdentMap.empty; struct_idmap = Desugared.Ast.IdentMap.empty;
field_idmap = Desugared.Ast.IdentMap.empty; field_idmap = Desugared.Ast.IdentMap.empty;
enums = Scopelang.Ast.EnumMap.empty; enums = EnumMap.empty;
enum_idmap = Desugared.Ast.IdentMap.empty; enum_idmap = Desugared.Ast.IdentMap.empty;
constructor_idmap = Desugared.Ast.IdentMap.empty; constructor_idmap = Desugared.Ast.IdentMap.empty;
} }

View File

@ -19,6 +19,7 @@
lexical scopes into account *) lexical scopes into account *)
open Utils open Utils
open Shared_ast
(** {1 Name resolution context} *) (** {1 Name resolution context} *)
@ -41,7 +42,7 @@ type scope_context = {
(** What is the default rule to refer to for unnamed exceptions, if any *) (** What is the default rule to refer to for unnamed exceptions, if any *)
sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t; sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t;
(** Sub-scopes variables *) (** Sub-scopes variables *)
sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t;
(** To what scope sub-scopes refer to? *) (** To what scope sub-scopes refer to? *)
} }
(** Inside a scope, we distinguish between the variables and the subscopes. *) (** Inside a scope, we distinguish between the variables and the subscopes. *)
@ -64,27 +65,27 @@ type context = {
local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t; local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t;
(** Inside a definition, local variables can be introduced by functions (** Inside a definition, local variables can be introduced by functions
arguments or pattern matching *) arguments or pattern matching *)
scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t; scope_idmap : ScopeName.t Desugared.Ast.IdentMap.t;
(** The names of the scopes *) (** The names of the scopes *)
struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t; struct_idmap : StructName.t Desugared.Ast.IdentMap.t;
(** The names of the structs *) (** The names of the structs *)
field_idmap : field_idmap :
Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t StructFieldName.t StructMap.t
Desugared.Ast.IdentMap.t; Desugared.Ast.IdentMap.t;
(** The names of the struct fields. Names of fields can be shared between (** The names of the struct fields. Names of fields can be shared between
different structs *) different structs *)
enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t; enum_idmap : EnumName.t Desugared.Ast.IdentMap.t;
(** The names of the enums *) (** The names of the enums *)
constructor_idmap : constructor_idmap :
Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t EnumConstructor.t EnumMap.t
Desugared.Ast.IdentMap.t; Desugared.Ast.IdentMap.t;
(** The names of the enum constructors. Constructor names can be shared (** The names of the enum constructors. Constructor names can be shared
between different enums *) between different enums *)
scopes : scope_context Scopelang.Ast.ScopeMap.t; scopes : scope_context Scopelang.Ast.ScopeMap.t;
(** For each scope, its context *) (** For each scope, its context *)
structs : struct_context Scopelang.Ast.StructMap.t; structs : struct_context StructMap.t;
(** For each struct, its context *) (** For each struct, its context *)
enums : enum_context Scopelang.Ast.EnumMap.t; enums : enum_context EnumMap.t;
(** For each enum, its context *) (** For each enum, its context *)
var_typs : var_sig Desugared.Ast.ScopeVarMap.t; var_typs : var_sig Desugared.Ast.ScopeVarMap.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
@ -110,25 +111,25 @@ val get_var_io :
context -> Desugared.Ast.ScopeVar.t -> Ast.scope_decl_context_io context -> Desugared.Ast.ScopeVar.t -> Ast.scope_decl_context_io
val get_var_uid : val get_var_uid :
Scopelang.Ast.ScopeName.t -> ScopeName.t ->
context -> context ->
ident Marked.pos -> ident Marked.pos ->
Desugared.Ast.ScopeVar.t Desugared.Ast.ScopeVar.t
(** Get the variable uid inside the scope given in argument *) (** Get the variable uid inside the scope given in argument *)
val get_subscope_uid : val get_subscope_uid :
Scopelang.Ast.ScopeName.t -> ScopeName.t ->
context -> context ->
ident Marked.pos -> ident Marked.pos ->
Scopelang.Ast.SubScopeName.t Scopelang.Ast.SubScopeName.t
(** Get the subscope uid inside the scope given in argument *) (** Get the subscope uid inside the scope given in argument *)
val is_subscope_uid : Scopelang.Ast.ScopeName.t -> context -> ident -> bool val is_subscope_uid : ScopeName.t -> context -> ident -> bool
(** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the
subscopes of [scope_uid]. *) subscopes of [scope_uid]. *)
val belongs_to : val belongs_to :
context -> Desugared.Ast.ScopeVar.t -> Scopelang.Ast.ScopeName.t -> bool context -> Desugared.Ast.ScopeVar.t -> ScopeName.t -> bool
(** Checks if the var_uid belongs to the scope scope_uid *) (** Checks if the var_uid belongs to the scope scope_uid *)
val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ Marked.pos val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ Marked.pos
@ -143,7 +144,7 @@ val add_def_local_var : context -> ident -> context * Desugared.Ast.Var.t
val get_def_key : val get_def_key :
Ast.qident -> Ast.qident ->
Ast.ident Marked.pos option -> Ast.ident Marked.pos option ->
Scopelang.Ast.ScopeName.t -> ScopeName.t ->
context -> context ->
Pos.t -> Pos.t ->
Desugared.Ast.ScopeDef.t Desugared.Ast.ScopeDef.t

View File

@ -16,6 +16,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Dcalc open Dcalc
open Ast open Ast
@ -92,7 +93,7 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : typed marked_expr)
| ErrorOnEmpty d -> | ErrorOnEmpty d ->
d (* input subscope variables and non-input scope variable *) d (* input subscope variables and non-input scope variable *)
| _ -> | _ ->
Errors.raise_spanned_error (pos e) Errors.raise_spanned_error (Expr.pos e)
"Internal error: this expression does not have the structure expected by \ "Internal error: this expression does not have the structure expected by \
the VC generator:\n\ the VC generator:\n\
%a" %a"
@ -382,7 +383,7 @@ let rec generate_verification_conditions_scopes
| ScopeDef scope_def -> | ScopeDef scope_def ->
let is_selected_scope = let is_selected_scope =
match s with match s with
| Some s when Dcalc.Ast.ScopeName.compare s scope_def.scope_name = 0 -> | Some s when ScopeName.compare s scope_def.scope_name = 0 ->
true true
| None -> true | None -> true
| _ -> false | _ -> false
@ -416,7 +417,7 @@ let rec generate_verification_conditions_scopes
let generate_verification_conditions let generate_verification_conditions
(p : 'm program) (p : 'm program)
(s : Dcalc.Ast.ScopeName.t option) : verification_condition list = (s : ScopeName.t option) : verification_condition list =
let vcs = generate_verification_conditions_scopes p.decl_ctx p.scopes s in let vcs = generate_verification_conditions_scopes p.decl_ctx p.scopes s in
(* We sort this list by scope name and then variable name to ensure consistent (* We sort this list by scope name and then variable name to ensure consistent
output for testing*) output for testing*)

View File

@ -29,21 +29,21 @@ type verification_condition_kind =
a conflict error *) a conflict error *)
type verification_condition = { type verification_condition = {
vc_guard : Dcalc.Ast.typed Dcalc.Ast.marked_expr; vc_guard : typed Dcalc.Ast.marked_expr;
(** This expression should have type [bool]*) (** This expression should have type [bool]*)
vc_kind : verification_condition_kind; vc_kind : verification_condition_kind;
vc_scope : Dcalc.Ast.ScopeName.t; vc_scope : ScopeName.t;
vc_variable : typed Dcalc.Ast.var Marked.pos; vc_variable : typed Dcalc.Ast.var Marked.pos;
vc_free_vars_typ : vc_free_vars_typ :
(typed Dcalc.Ast.expr, Dcalc.Ast.typ Marked.pos) Var.Map.t; (typed Dcalc.Ast.expr, typ Marked.pos) Var.Map.t;
(** Types of the locally free variables in [vc_guard]. The types of other (** Types of the locally free variables in [vc_guard]. The types of other
free variables linked to scope variables can be obtained with free variables linked to scope variables can be obtained with
[Dcalc.Ast.variable_types]. *) [Dcalc.Ast.variable_types]. *)
} }
val generate_verification_conditions : val generate_verification_conditions :
Dcalc.Ast.typed Dcalc.Ast.program -> typed Dcalc.Ast.program ->
Dcalc.Ast.ScopeName.t option -> ScopeName.t option ->
verification_condition list verification_condition list
(** [generate_verification_conditions p None] will generate the verification (** [generate_verification_conditions p None] will generate the verification
conditions for all the variables of all the scopes of the program [p], while conditions for all the variables of all the scopes of the program [p], while

View File

@ -16,6 +16,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Dcalc.Ast open Dcalc.Ast
module type Backend = sig module type Backend = sig
@ -73,7 +74,7 @@ module type BackendIO = sig
string string
val encode_and_check_vc : val encode_and_check_vc :
Dcalc.Ast.decl_ctx -> decl_ctx ->
Conditions.verification_condition * vc_encoding_result -> Conditions.verification_condition * vc_encoding_result ->
unit unit
end end
@ -161,7 +162,7 @@ module MakeBackendIO (B : Backend) = struct
let vc, z3_vc = vc in let vc, z3_vc = vc in
Cli.debug_print "For this variable:\n%s\n" Cli.debug_print "For this variable:\n%s\n"
(Pos.retrieve_loc_text (pos vc.Conditions.vc_guard)); (Pos.retrieve_loc_text (Expr.pos vc.Conditions.vc_guard));
Cli.debug_format "This verification condition was generated for %a:@\n%a" Cli.debug_format "This verification condition was generated for %a:@\n%a"
(Cli.format_with_style [ANSITerminal.yellow]) (Cli.format_with_style [ANSITerminal.yellow])
(match vc.vc_kind with (match vc.vc_kind with

View File

@ -26,8 +26,8 @@ module type Backend = sig
type backend_context type backend_context
val make_context : val make_context :
Dcalc.Ast.decl_ctx -> decl_ctx ->
(typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t -> (typed Dcalc.Ast.expr, typ Utils.Marked.pos) Var.Map.t ->
backend_context backend_context
type vc_encoding type vc_encoding
@ -53,8 +53,8 @@ module type BackendIO = sig
type backend_context type backend_context
val make_context : val make_context :
Dcalc.Ast.decl_ctx -> decl_ctx ->
(typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t -> (typed Dcalc.Ast.expr, typ Utils.Marked.pos) Var.Map.t ->
backend_context backend_context
type vc_encoding type vc_encoding
@ -79,7 +79,7 @@ module type BackendIO = sig
string string
val encode_and_check_vc : val encode_and_check_vc :
Dcalc.Ast.decl_ctx -> decl_ctx ->
Conditions.verification_condition * vc_encoding_result -> Conditions.verification_condition * vc_encoding_result ->
unit unit
end end

View File

@ -20,7 +20,7 @@ open Dcalc.Ast
expressions [vcs] corresponding to verification conditions that must be expressions [vcs] corresponding to verification conditions that must be
discharged by Z3, and attempts to solve them **) discharged by Z3, and attempts to solve them **)
let solve_vc let solve_vc
(decl_ctx : decl_ctx) (decl_ctx : Shared_ast.decl_ctx)
(vcs : Conditions.verification_condition list) : unit = (vcs : Conditions.verification_condition list) : unit =
(* Right now we only use the Z3 backend but the functorial interface should (* Right now we only use the Z3 backend but the functorial interface should
make it easy to mix and match different proof backends. *) make it easy to mix and match different proof backends. *)

View File

@ -17,4 +17,4 @@
(** Solves verification conditions using various proof backends *) (** Solves verification conditions using various proof backends *)
val solve_vc : val solve_vc :
Dcalc.Ast.decl_ctx -> Conditions.verification_condition list -> unit Shared_ast.decl_ctx -> Conditions.verification_condition list -> unit

View File

@ -15,6 +15,7 @@
the License. *) the License. *)
open Utils open Utils
open Shared_ast
open Dcalc open Dcalc
open Ast open Ast
open Z3 open Z3
@ -428,7 +429,7 @@ let rec translate_op
(Print.format_expr ctx.ctx_decl) (Print.format_expr ctx.ctx_decl)
( EApp ( EApp
( (EOp op, Untyped { pos = Pos.no_pos }), ( (EOp op, Untyped { pos = Pos.no_pos }),
List.map (fun arg -> Bindlib.unbox (untype_expr arg)) args ), List.map (fun arg -> Bindlib.unbox (Shared_ast.Expr.untype arg)) args ),
Untyped { pos = Pos.no_pos } )) Untyped { pos = Pos.no_pos } ))
in in
@ -520,7 +521,7 @@ let rec translate_op
( EApp ( EApp
( (EOp op, Untyped { pos = Pos.no_pos }), ( (EOp op, Untyped { pos = Pos.no_pos }),
List.map List.map
(fun arg -> arg |> untype_expr |> Bindlib.unbox) (fun arg -> arg |> Shared_ast.Expr.untype |> Bindlib.unbox)
args ), args ),
Untyped { pos = Pos.no_pos } )) Untyped { pos = Pos.no_pos } ))
in in
@ -572,7 +573,7 @@ let rec translate_op
( EApp ( EApp
( (EOp op, Untyped { pos = Pos.no_pos }), ( (EOp op, Untyped { pos = Pos.no_pos }),
List.map List.map
(fun arg -> arg |> untype_expr |> Bindlib.unbox) (fun arg -> arg |> Shared_ast.Expr.untype |> Bindlib.unbox)
args ), args ),
Untyped { pos = Pos.no_pos } )) Untyped { pos = Pos.no_pos } ))
in in

View File

@ -1,4 +1,4 @@
(lang dune 2.8) (lang dune 3.0)
(name catala) (name catala)

View File

@ -115,7 +115,7 @@ let run_test () =
exit (-1) exit (-1)
| Runtime.AssertionFailed _ -> () | Runtime.AssertionFailed _ -> ()
let bench = let _bench =
Random.init (int_of_float (Unix.time ())); Random.init (int_of_float (Unix.time ()));
let num_iter = 10000 in let num_iter = 10000 in
let _ = let _ =

View File

@ -12,7 +12,7 @@
(preprocess (preprocess
(pps js_of_ocaml-ppx)) (pps js_of_ocaml-ppx))
(js_of_ocaml (js_of_ocaml
(flags --disable=shortvar --opt 3)) (flags :standard --disable=shortvar --opt 3))
; We need to disable shortvar because ; We need to disable shortvar because
; otherwise Webpack wrongly minifies ; otherwise Webpack wrongly minifies
; the library and it gives bugs. ; the library and it gives bugs.