diff --git a/compiler/dcalc/ast.ml b/compiler/dcalc/ast.ml index 322b18a3..99e8b233 100644 --- a/compiler/dcalc/ast.ml +++ b/compiler/dcalc/ast.ml @@ -16,8 +16,7 @@ the License. *) open Utils -include Shared_ast -include Shared_ast.Expr +open Shared_ast type lit = dcalc glit @@ -26,53 +25,9 @@ and 'm marked_expr = (dcalc, 'm mark) marked_gexpr type 'm program = ('m expr, 'm) program_generic -let no_mark (type m) : m mark -> m mark = function - | Untyped _ -> Untyped { pos = Pos.no_pos } - | Typed _ -> Typed { pos = Pos.no_pos; ty = Marked.mark Pos.no_pos TAny } - -let mark_pos (type m) (m : m mark) : Pos.t = - match m with Untyped { pos } | Typed { pos; _ } -> pos - -let pos (type m) (x : ('a, m) marked) : Pos.t = mark_pos (Marked.get_mark x) -let ty (_, m) : marked_typ = match m with Typed { ty; _ } -> ty - -let with_ty (type m) (ty : marked_typ) (x : ('a, m) marked) : ('a, typed) marked - = - Marked.mark - (match Marked.get_mark x with - | Untyped { pos } -> Typed { pos; ty } - | Typed m -> Typed { m with ty }) - (Marked.unmark x) - -let map_expr ctx ~f e = Expr.map ctx ~f e - -let rec map_expr_top_down ~f e = - map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e) - -let map_expr_marks ~f e = - map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e - -let untype_expr e = map_expr_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e - type ('expr, 'm) box_expr_sig = ('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box -(** See [Bindlib.box_term] documentation for why we are doing that. *) -let box_expr : ('m expr, 'm) box_expr_sig = - fun e -> - let rec id_t () e = map_expr () ~f:id_t e in - id_t () e - -let untype_program prg = - { - prg with - scopes = - Bindlib.unbox - (map_exprs_in_scopes - ~f:(fun e -> untype_expr e) - ~varf:Var.translate prg.scopes); - } - type 'm var = 'm expr Var.t type 'm vars = 'm expr Var.vars @@ -158,49 +113,14 @@ type ('expr, 'm) make_let_in_sig = Pos.t -> ('expr, 'm) marked Bindlib.box -let map_mark - (type m) - (pos_f : Pos.t -> Pos.t) - (ty_f : marked_typ -> marked_typ) - (m : m mark) : m mark = - match m with - | Untyped { pos } -> Untyped { pos = pos_f pos } - | Typed { pos; ty } -> Typed { pos = pos_f pos; ty = ty_f ty } - -let map_mark2 - (type m) - (pos_f : Pos.t -> Pos.t -> Pos.t) - (ty_f : typed -> typed -> marked_typ) - (m1 : m mark) - (m2 : m mark) : m mark = - match m1, m2 with - | Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos } - | Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 } - -let fold_marks - (type m) - (pos_f : Pos.t list -> Pos.t) - (ty_f : typed list -> marked_typ) - (ms : m mark list) : m mark = - match ms with - | [] -> invalid_arg "Dcalc.Ast.fold_mark" - | Untyped _ :: _ as ms -> - Untyped { pos = pos_f (List.map (function Untyped { pos } -> pos) ms) } - | Typed _ :: _ -> - Typed - { - pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms); - ty = ty_f (List.map (function Typed m -> m) ms); - } - let empty_thunked_term mark : 'm marked_expr = let silent = Var.make "_" in - let pos = mark_pos mark in + let pos = Expr.mark_pos mark in Bindlib.unbox (make_abs [| silent |] (Bindlib.box (ELit LEmptyError, mark)) [TLit TUnit, pos] - (map_mark + (Expr.map_mark (fun pos -> pos) (fun ty -> Marked.mark pos (TArrow (Marked.mark pos (TLit TUnit), ty))) @@ -211,7 +131,7 @@ let (make_let_in : ('m expr, 'm) make_let_in_sig) = let m_e1 = Marked.get_mark (Bindlib.unbox e1) in let m_e2 = Marked.get_mark (Bindlib.unbox e2) in let m_abs = - map_mark2 + Expr.map_mark2 (fun _ _ -> pos) (fun m1 m2 -> Marked.mark pos (TArrow (m1.ty, m2.ty))) m_e1 m_e2 @@ -329,7 +249,7 @@ let build_whole_scope_expr ( List.map snd (StructMap.find body.scope_body_input_struct ctx.ctx_structs), Some body.scope_body_input_struct ), - mark_pos mark_scope ); + Expr.mark_pos mark_scope ); ] mark_scope @@ -354,10 +274,6 @@ type 'expr scope_name_or_var = | ScopeName of ScopeName.t | ScopeVar of 'expr Bindlib.var -let get_scope_body_mark scope_body = - match snd (Bindlib.unbind scope_body.scope_body_expr) with - | Result e | ScopeLet { scope_let_expr = e; _ } -> Marked.get_mark e - let rec unfold_scopes ~(box_expr : ('expr, 'm) box_expr_sig) ~(make_abs : ('expr, 'm) make_abs_sig) @@ -374,7 +290,7 @@ let rec unfold_scopes | ScopeDef { scope_name; scope_body; scope_next } -> let scope_var, scope_next = Bindlib.unbind scope_next in let scope_pos = Marked.get_mark (ScopeName.get_info scope_name) in - let scope_body_mark = get_scope_body_mark scope_body in + let scope_body_mark = Expr.get_scope_body_mark scope_body in let main_scope = match main_scope with | ScopeVar v -> ScopeVar v @@ -407,7 +323,7 @@ let build_whole_program_expr (main_scope : ScopeName.t) : ('expr, 'm) marked Bindlib.box = let _, main_scope_body = find_scope main_scope [] p.scopes in unfold_scopes ~box_expr ~make_abs ~make_let_in p.decl_ctx p.scopes - (get_scope_body_mark main_scope_body) + (Expr.get_scope_body_mark main_scope_body) (ScopeName main_scope) let rec expr_size (e : 'm marked_expr) : int = @@ -435,7 +351,7 @@ let rec expr_size (e : 'm marked_expr) : int = let remove_logging_calls (e : 'm marked_expr) : 'm marked_expr Bindlib.box = let rec f () e = match Marked.unmark e with - | EApp ((EOp (Unop (Log _)), _), [arg]) -> map_expr () ~f arg - | _ -> map_expr () ~f e + | EApp ((EOp (Unop (Log _)), _), [arg]) -> Expr.map () ~f arg + | _ -> Expr.map () ~f e in f () e diff --git a/compiler/dcalc/ast.mli b/compiler/dcalc/ast.mli index 86d07f9d..04ff839e 100644 --- a/compiler/dcalc/ast.mli +++ b/compiler/dcalc/ast.mli @@ -18,8 +18,7 @@ (** Abstract syntax tree of the default calculus intermediate representation *) open Utils -include module type of Shared_ast -include module type of Shared_ast.Expr +open Shared_ast type lit = dcalc glit @@ -44,149 +43,9 @@ val free_vars_scope_body : ('m expr, 'm) scope_body -> 'm expr Var.Set.t val free_vars_scopes : ('m expr, 'm) scopes -> 'm expr Var.Set.t val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box -(** {2 Manipulation of marks} *) - -val no_mark : 'm mark -> 'm mark -val mark_pos : 'm mark -> Pos.t -val pos : ('a, 'm) marked -> Pos.t -val ty : ('a, typed) marked -> marked_typ -val with_ty : marked_typ -> ('a, 'm) marked -> ('a, typed) marked - -(** All the following functions will resolve the types if called on an - [Inferring] type *) - -val map_mark : - (Pos.t -> Pos.t) -> (marked_typ -> marked_typ) -> 'm mark -> 'm mark - -val map_mark2 : - (Pos.t -> Pos.t -> Pos.t) -> - (typed -> typed -> marked_typ) -> - 'm mark -> - 'm mark -> - 'm mark - -val fold_marks : - (Pos.t list -> Pos.t) -> (typed list -> marked_typ) -> 'm mark list -> 'm mark - -val get_scope_body_mark : ('expr, 'm) scope_body -> 'm mark -val untype_expr : 'm marked_expr -> untyped marked_expr Bindlib.box -val untype_program : 'm program -> untyped program - -(** {2 Boxed constructors} *) - -val evar : 'm expr Bindlib.var -> 'm mark -> 'm marked_expr Bindlib.box - -val etuple : - 'm marked_expr Bindlib.box list -> - StructName.t option -> - 'm mark -> - 'm marked_expr Bindlib.box - -val etupleaccess : - 'm marked_expr Bindlib.box -> - int -> - StructName.t option -> - marked_typ list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val einj : - 'm marked_expr Bindlib.box -> - int -> - EnumName.t -> - marked_typ list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val ematch : - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box list -> - EnumName.t -> - 'm mark -> - 'm marked_expr Bindlib.box - -val earray : - 'm marked_expr Bindlib.box list -> 'm mark -> 'm marked_expr Bindlib.box - -val elit : lit -> 'm mark -> 'm marked_expr Bindlib.box - -val eabs : - ('m expr, 'm marked_expr) Bindlib.mbinder Bindlib.box -> - marked_typ list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val eapp : - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val eassert : - 'm marked_expr Bindlib.box -> 'm mark -> 'm marked_expr Bindlib.box - -val eop : operator -> 'm mark -> 'm marked_expr Bindlib.box - -val edefault : - 'm marked_expr Bindlib.box list -> - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box -> - 'm mark -> - 'm marked_expr Bindlib.box - -val eifthenelse : - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box -> - 'm mark -> - 'm marked_expr Bindlib.box - -val eerroronempty : - 'm marked_expr Bindlib.box -> 'm mark -> 'm marked_expr Bindlib.box - type ('expr, 'm) box_expr_sig = ('expr, 'm) marked -> ('expr, 'm) marked Bindlib.box -val box_expr : ('m expr, 'm) box_expr_sig - -(**{2 Program traversal}*) - -(** Be careful when using these traversal functions, as the bound variables they - open will be different at each traversal. *) - -val map_expr : - 'a -> - f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) -> - ('m1 expr, 'm2 mark) Marked.t -> - 'm2 marked_expr Bindlib.box -(** If you want to apply a map transform to an expression, you can save up - writing a painful match over all the cases of the AST. For instance, if you - want to remove all errors on empty, you can write - - {[ - let remove_error_empty = - let rec f () e = - match Marked.unmark e with - | ErrorOnEmpty e1 -> map_expr () f e1 - | _ -> map_expr () f e - in - f () e - ]} - - The first argument of map_expr is an optional context that you can carry - around during your map traversal. *) - -val map_expr_top_down : - f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) -> - 'm1 marked_expr -> - 'm2 marked_expr Bindlib.box -(** Recursively applies [f] to the nodes of the expression tree. The type - returned by [f] is hybrid since the mark at top-level has been rewritten, - but not yet the marks in the subtrees. *) - -val map_expr_marks : - f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box - (** {2 Boxed term constructors} *) type ('e, 'm) make_abs_sig = diff --git a/compiler/dcalc/interpreter.ml b/compiler/dcalc/interpreter.ml index 7e6a5556..4aaafd0e 100644 --- a/compiler/dcalc/interpreter.ml +++ b/compiler/dcalc/interpreter.ml @@ -17,12 +17,13 @@ (** Reference interpreter for the default calculus *) open Utils +open Shared_ast module A = Ast module Runtime = Runtime_ocaml.Runtime (** {1 Helpers} *) -let is_empty_error (e : 'm A.marked_expr) : bool = +let is_empty_error (e : 'm Ast.marked_expr) : bool = match Marked.unmark e with ELit LEmptyError -> true | _ -> false let log_indent = ref 0 @@ -30,25 +31,25 @@ let log_indent = ref 0 (** {1 Evaluation} *) let rec evaluate_operator - (ctx : Ast.decl_ctx) - (op : A.operator) + (ctx : decl_ctx) + (op : operator) (pos : Pos.t) - (args : 'm A.marked_expr list) : 'm A.expr = + (args : 'm Ast.marked_expr list) : 'm Ast.expr = (* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use [op] to raise multispanned errors. *) - let apply_div_or_raise_err (div : unit -> 'm A.expr) : 'm A.expr = + let apply_div_or_raise_err (div : unit -> 'm Ast.expr) : 'm Ast.expr = try div () with Division_by_zero -> Errors.raise_multispanned_error [ Some "The division operator:", pos; - Some "The null denominator:", Ast.pos (List.nth args 1); + Some "The null denominator:", Expr.pos (List.nth args 1); ] "division by zero at runtime" in let get_binop_args_pos = function | (arg0 :: arg1 :: _ : 'm A.marked_expr list) -> - [None, Ast.pos arg0; None, Ast.pos arg1] + [None, Expr.pos arg0; None, Expr.pos arg1] | _ -> assert false in (* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched, @@ -63,211 +64,211 @@ let rec evaluate_operator precise number of days" in match op, List.map Marked.unmark args with - | A.Ternop A.Fold, [_f; _init; EArray es] -> + | Ternop Fold, [_f; _init; EArray es] -> Marked.unmark (List.fold_left (fun acc e' -> evaluate_expr ctx - (Marked.same_mark_as (A.EApp (List.nth args 0, [acc; e'])) e')) + (Marked.same_mark_as (EApp (List.nth args 0, [acc; e'])) e')) (List.nth args 1) es) - | A.Binop A.And, [ELit (LBool b1); ELit (LBool b2)] -> - A.ELit (LBool (b1 && b2)) - | A.Binop A.Or, [ELit (LBool b1); ELit (LBool b2)] -> - A.ELit (LBool (b1 || b2)) - | A.Binop A.Xor, [ELit (LBool b1); ELit (LBool b2)] -> - A.ELit (LBool (b1 <> b2)) - | A.Binop (A.Add KInt), [ELit (LInt i1); ELit (LInt i2)] -> - A.ELit (LInt Runtime.(i1 +! i2)) - | A.Binop (A.Sub KInt), [ELit (LInt i1); ELit (LInt i2)] -> - A.ELit (LInt Runtime.(i1 -! i2)) - | A.Binop (A.Mult KInt), [ELit (LInt i1); ELit (LInt i2)] -> - A.ELit (LInt Runtime.(i1 *! i2)) - | A.Binop (A.Div KInt), [ELit (LInt i1); ELit (LInt i2)] -> - apply_div_or_raise_err (fun _ -> A.ELit (LInt Runtime.(i1 /! i2))) - | A.Binop (A.Add KRat), [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LRat Runtime.(i1 +& i2)) - | A.Binop (A.Sub KRat), [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LRat Runtime.(i1 -& i2)) - | A.Binop (A.Mult KRat), [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LRat Runtime.(i1 *& i2)) - | A.Binop (A.Div KRat), [ELit (LRat i1); ELit (LRat i2)] -> - apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(i1 /& i2))) - | A.Binop (A.Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> - A.ELit (LMoney Runtime.(m1 +$ m2)) - | A.Binop (A.Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> - A.ELit (LMoney Runtime.(m1 -$ m2)) - | A.Binop (A.Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] -> - A.ELit (LMoney Runtime.(m1 *$ m2)) - | A.Binop (A.Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> - apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(m1 /$ m2))) - | A.Binop (A.Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> - A.ELit (LDuration Runtime.(d1 +^ d2)) - | A.Binop (A.Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> - A.ELit (LDuration Runtime.(d1 -^ d2)) - | A.Binop (A.Sub KDate), [ELit (LDate d1); ELit (LDate d2)] -> - A.ELit (LDuration Runtime.(d1 -@ d2)) - | A.Binop (A.Add KDate), [ELit (LDate d1); ELit (LDuration d2)] -> - A.ELit (LDate Runtime.(d1 +@ d2)) - | A.Binop (A.Div KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + | Binop And, [ELit (LBool b1); ELit (LBool b2)] -> + ELit (LBool (b1 && b2)) + | Binop Or, [ELit (LBool b1); ELit (LBool b2)] -> + ELit (LBool (b1 || b2)) + | Binop Xor, [ELit (LBool b1); ELit (LBool b2)] -> + ELit (LBool (b1 <> b2)) + | Binop (Add KInt), [ELit (LInt i1); ELit (LInt i2)] -> + ELit (LInt Runtime.(i1 +! i2)) + | Binop (Sub KInt), [ELit (LInt i1); ELit (LInt i2)] -> + ELit (LInt Runtime.(i1 -! i2)) + | Binop (Mult KInt), [ELit (LInt i1); ELit (LInt i2)] -> + ELit (LInt Runtime.(i1 *! i2)) + | Binop (Div KInt), [ELit (LInt i1); ELit (LInt i2)] -> + apply_div_or_raise_err (fun _ -> ELit (LInt Runtime.(i1 /! i2))) + | Binop (Add KRat), [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LRat Runtime.(i1 +& i2)) + | Binop (Sub KRat), [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LRat Runtime.(i1 -& i2)) + | Binop (Mult KRat), [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LRat Runtime.(i1 *& i2)) + | Binop (Div KRat), [ELit (LRat i1); ELit (LRat i2)] -> + apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(i1 /& i2))) + | Binop (Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + ELit (LMoney Runtime.(m1 +$ m2)) + | Binop (Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + ELit (LMoney Runtime.(m1 -$ m2)) + | Binop (Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] -> + ELit (LMoney Runtime.(m1 *$ m2)) + | Binop (Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(m1 /$ m2))) + | Binop (Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + ELit (LDuration Runtime.(d1 +^ d2)) + | Binop (Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + ELit (LDuration Runtime.(d1 -^ d2)) + | Binop (Sub KDate), [ELit (LDate d1); ELit (LDate d2)] -> + ELit (LDuration Runtime.(d1 -@ d2)) + | Binop (Add KDate), [ELit (LDate d1); ELit (LDuration d2)] -> + ELit (LDate Runtime.(d1 +@ d2)) + | Binop (Div KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> apply_div_or_raise_err (fun _ -> - try A.ELit (LRat Runtime.(d1 /^ d2)) + try ELit (LRat Runtime.(d1 /^ d2)) with Runtime.IndivisableDurations -> Errors.raise_multispanned_error (get_binop_args_pos args) "Cannot divide durations that cannot be converted to a precise \ number of days") - | A.Binop (A.Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] -> - A.ELit (LDuration Runtime.(d1 *^ i1)) - | A.Binop (A.Lt KInt), [ELit (LInt i1); ELit (LInt i2)] -> - A.ELit (LBool Runtime.(i1 - A.ELit (LBool Runtime.(i1 <=! i2)) - | A.Binop (A.Gt KInt), [ELit (LInt i1); ELit (LInt i2)] -> - A.ELit (LBool Runtime.(i1 >! i2)) - | A.Binop (A.Gte KInt), [ELit (LInt i1); ELit (LInt i2)] -> - A.ELit (LBool Runtime.(i1 >=! i2)) - | A.Binop (A.Lt KRat), [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LBool Runtime.(i1 <& i2)) - | A.Binop (A.Lte KRat), [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LBool Runtime.(i1 <=& i2)) - | A.Binop (A.Gt KRat), [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LBool Runtime.(i1 >& i2)) - | A.Binop (A.Gte KRat), [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LBool Runtime.(i1 >=& i2)) - | A.Binop (A.Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> - A.ELit (LBool Runtime.(m1 <$ m2)) - | A.Binop (A.Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> - A.ELit (LBool Runtime.(m1 <=$ m2)) - | A.Binop (A.Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> - A.ELit (LBool Runtime.(m1 >$ m2)) - | A.Binop (A.Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> - A.ELit (LBool Runtime.(m1 >=$ m2)) - | A.Binop (A.Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> - apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <^ d2))) args - | A.Binop (A.Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> - apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <=^ d2))) args - | A.Binop (A.Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> - apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >^ d2))) args - | A.Binop (A.Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> - apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >=^ d2))) args - | A.Binop (A.Lt KDate), [ELit (LDate d1); ELit (LDate d2)] -> - A.ELit (LBool Runtime.(d1 <@ d2)) - | A.Binop (A.Lte KDate), [ELit (LDate d1); ELit (LDate d2)] -> - A.ELit (LBool Runtime.(d1 <=@ d2)) - | A.Binop (A.Gt KDate), [ELit (LDate d1); ELit (LDate d2)] -> - A.ELit (LBool Runtime.(d1 >@ d2)) - | A.Binop (A.Gte KDate), [ELit (LDate d1); ELit (LDate d2)] -> - A.ELit (LBool Runtime.(d1 >=@ d2)) - | A.Binop A.Eq, [ELit LUnit; ELit LUnit] -> A.ELit (LBool true) - | A.Binop A.Eq, [ELit (LDuration d1); ELit (LDuration d2)] -> - A.ELit (LBool Runtime.(d1 =^ d2)) - | A.Binop A.Eq, [ELit (LDate d1); ELit (LDate d2)] -> - A.ELit (LBool Runtime.(d1 =@ d2)) - | A.Binop A.Eq, [ELit (LMoney m1); ELit (LMoney m2)] -> - A.ELit (LBool Runtime.(m1 =$ m2)) - | A.Binop A.Eq, [ELit (LRat i1); ELit (LRat i2)] -> - A.ELit (LBool Runtime.(i1 =& i2)) - | A.Binop A.Eq, [ELit (LInt i1); ELit (LInt i2)] -> - A.ELit (LBool Runtime.(i1 =! i2)) - | A.Binop A.Eq, [ELit (LBool b1); ELit (LBool b2)] -> A.ELit (LBool (b1 = b2)) - | A.Binop A.Eq, [EArray es1; EArray es2] -> - A.ELit + | Binop (Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] -> + ELit (LDuration Runtime.(d1 *^ i1)) + | Binop (Lt KInt), [ELit (LInt i1); ELit (LInt i2)] -> + ELit (LBool Runtime.(i1 + ELit (LBool Runtime.(i1 <=! i2)) + | Binop (Gt KInt), [ELit (LInt i1); ELit (LInt i2)] -> + ELit (LBool Runtime.(i1 >! i2)) + | Binop (Gte KInt), [ELit (LInt i1); ELit (LInt i2)] -> + ELit (LBool Runtime.(i1 >=! i2)) + | Binop (Lt KRat), [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LBool Runtime.(i1 <& i2)) + | Binop (Lte KRat), [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LBool Runtime.(i1 <=& i2)) + | Binop (Gt KRat), [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LBool Runtime.(i1 >& i2)) + | Binop (Gte KRat), [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LBool Runtime.(i1 >=& i2)) + | Binop (Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + ELit (LBool Runtime.(m1 <$ m2)) + | Binop (Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + ELit (LBool Runtime.(m1 <=$ m2)) + | Binop (Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + ELit (LBool Runtime.(m1 >$ m2)) + | Binop (Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + ELit (LBool Runtime.(m1 >=$ m2)) + | Binop (Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <^ d2))) args + | Binop (Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <=^ d2))) args + | Binop (Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >^ d2))) args + | Binop (Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >=^ d2))) args + | Binop (Lt KDate), [ELit (LDate d1); ELit (LDate d2)] -> + ELit (LBool Runtime.(d1 <@ d2)) + | Binop (Lte KDate), [ELit (LDate d1); ELit (LDate d2)] -> + ELit (LBool Runtime.(d1 <=@ d2)) + | Binop (Gt KDate), [ELit (LDate d1); ELit (LDate d2)] -> + ELit (LBool Runtime.(d1 >@ d2)) + | Binop (Gte KDate), [ELit (LDate d1); ELit (LDate d2)] -> + ELit (LBool Runtime.(d1 >=@ d2)) + | Binop Eq, [ELit LUnit; ELit LUnit] -> ELit (LBool true) + | Binop Eq, [ELit (LDuration d1); ELit (LDuration d2)] -> + ELit (LBool Runtime.(d1 =^ d2)) + | Binop Eq, [ELit (LDate d1); ELit (LDate d2)] -> + ELit (LBool Runtime.(d1 =@ d2)) + | Binop Eq, [ELit (LMoney m1); ELit (LMoney m2)] -> + ELit (LBool Runtime.(m1 =$ m2)) + | Binop Eq, [ELit (LRat i1); ELit (LRat i2)] -> + ELit (LBool Runtime.(i1 =& i2)) + | Binop Eq, [ELit (LInt i1); ELit (LInt i2)] -> + ELit (LBool Runtime.(i1 =! i2)) + | Binop Eq, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 = b2)) + | Binop Eq, [EArray es1; EArray es2] -> + ELit (LBool (try List.for_all2 (fun e1 e2 -> match evaluate_operator ctx op pos [e1; e2] with - | A.ELit (LBool b) -> b + | ELit (LBool b) -> b | _ -> assert false (* should not happen *)) es1 es2 with Invalid_argument _ -> false)) - | A.Binop A.Eq, [ETuple (es1, s1); ETuple (es2, s2)] -> - A.ELit + | Binop Eq, [ETuple (es1, s1); ETuple (es2, s2)] -> + ELit (LBool (try s1 = s2 && List.for_all2 (fun e1 e2 -> match evaluate_operator ctx op pos [e1; e2] with - | A.ELit (LBool b) -> b + | ELit (LBool b) -> b | _ -> assert false (* should not happen *)) es1 es2 with Invalid_argument _ -> false)) - | A.Binop A.Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] -> - A.ELit + | Binop Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] -> + ELit (LBool (try en1 = en2 && i1 = i2 && match evaluate_operator ctx op pos [e1; e2] with - | A.ELit (LBool b) -> b + | ELit (LBool b) -> b | _ -> assert false (* should not happen *) with Invalid_argument _ -> false)) - | A.Binop A.Eq, [_; _] -> - A.ELit (LBool false) (* comparing anything else return false *) - | A.Binop A.Neq, [_; _] -> ( - match evaluate_operator ctx (A.Binop A.Eq) pos args with - | A.ELit (A.LBool b) -> A.ELit (A.LBool (not b)) + | Binop Eq, [_; _] -> + ELit (LBool false) (* comparing anything else return false *) + | Binop Neq, [_; _] -> ( + match evaluate_operator ctx (Binop Eq) pos args with + | ELit (LBool b) -> ELit (LBool (not b)) | _ -> assert false (*should not happen *)) - | A.Binop A.Concat, [A.EArray es1; A.EArray es2] -> A.EArray (es1 @ es2) - | A.Binop A.Map, [_; A.EArray es] -> - A.EArray + | Binop Concat, [EArray es1; EArray es2] -> EArray (es1 @ es2) + | Binop Map, [_; EArray es] -> + EArray (List.map (fun e' -> evaluate_expr ctx - (Marked.same_mark_as (A.EApp (List.nth args 0, [e'])) e')) + (Marked.same_mark_as (EApp (List.nth args 0, [e'])) e')) es) - | A.Binop A.Filter, [_; A.EArray es] -> - A.EArray + | Binop Filter, [_; EArray es] -> + EArray (List.filter (fun e' -> match evaluate_expr ctx - (Marked.same_mark_as (A.EApp (List.nth args 0, [e'])) e') + (Marked.same_mark_as (EApp (List.nth args 0, [e'])) e') with - | A.ELit (A.LBool b), _ -> b + | ELit (LBool b), _ -> b | _ -> Errors.raise_spanned_error - (A.pos (List.nth args 0)) + (Expr.pos (List.nth args 0)) "This predicate evaluated to something else than a boolean \ (should not happen if the term was well-typed)") es) - | A.Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> - A.ELit LEmptyError - | A.Unop (A.Minus KInt), [ELit (LInt i)] -> - A.ELit (LInt Runtime.(integer_of_int 0 -! i)) - | A.Unop (A.Minus KRat), [ELit (LRat i)] -> - A.ELit (LRat Runtime.(decimal_of_string "0" -& i)) - | A.Unop (A.Minus KMoney), [ELit (LMoney i)] -> - A.ELit (LMoney Runtime.(money_of_units_int 0 -$ i)) - | A.Unop (A.Minus KDuration), [ELit (LDuration i)] -> - A.ELit (LDuration Runtime.(~-^i)) - | A.Unop A.Not, [ELit (LBool b)] -> A.ELit (LBool (not b)) - | A.Unop A.Length, [EArray es] -> - A.ELit (LInt (Runtime.integer_of_int (List.length es))) - | A.Unop A.GetDay, [ELit (LDate d)] -> - A.ELit (LInt Runtime.(day_of_month_of_date d)) - | A.Unop A.GetMonth, [ELit (LDate d)] -> - A.ELit (LInt Runtime.(month_number_of_date d)) - | A.Unop A.GetYear, [ELit (LDate d)] -> A.ELit (LInt Runtime.(year_of_date d)) - | A.Unop A.FirstDayOfMonth, [ELit (LDate d)] -> - A.ELit (LDate Runtime.(first_day_of_month d)) - | A.Unop A.LastDayOfMonth, [ELit (LDate d)] -> - A.ELit (LDate Runtime.(first_day_of_month d)) - | A.Unop A.IntToRat, [ELit (LInt i)] -> - A.ELit (LRat Runtime.(decimal_of_integer i)) - | A.Unop A.MoneyToRat, [ELit (LMoney i)] -> - A.ELit (LRat Runtime.(decimal_of_money i)) - | A.Unop A.RatToMoney, [ELit (LRat i)] -> - A.ELit (LMoney Runtime.(money_of_decimal i)) - | A.Unop A.RoundMoney, [ELit (LMoney m)] -> - A.ELit (LMoney Runtime.(money_round m)) - | A.Unop A.RoundDecimal, [ELit (LRat m)] -> - A.ELit (LRat Runtime.(decimal_round m)) - | A.Unop (A.Log (entry, infos)), [e'] -> + | Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> + ELit LEmptyError + | Unop (Minus KInt), [ELit (LInt i)] -> + ELit (LInt Runtime.(integer_of_int 0 -! i)) + | Unop (Minus KRat), [ELit (LRat i)] -> + ELit (LRat Runtime.(decimal_of_string "0" -& i)) + | Unop (Minus KMoney), [ELit (LMoney i)] -> + ELit (LMoney Runtime.(money_of_units_int 0 -$ i)) + | Unop (Minus KDuration), [ELit (LDuration i)] -> + ELit (LDuration Runtime.(~-^i)) + | Unop Not, [ELit (LBool b)] -> ELit (LBool (not b)) + | Unop Length, [EArray es] -> + ELit (LInt (Runtime.integer_of_int (List.length es))) + | Unop GetDay, [ELit (LDate d)] -> + ELit (LInt Runtime.(day_of_month_of_date d)) + | Unop GetMonth, [ELit (LDate d)] -> + ELit (LInt Runtime.(month_number_of_date d)) + | Unop GetYear, [ELit (LDate d)] -> ELit (LInt Runtime.(year_of_date d)) + | Unop FirstDayOfMonth, [ELit (LDate d)] -> + ELit (LDate Runtime.(first_day_of_month d)) + | Unop LastDayOfMonth, [ELit (LDate d)] -> + ELit (LDate Runtime.(first_day_of_month d)) + | Unop IntToRat, [ELit (LInt i)] -> + ELit (LRat Runtime.(decimal_of_integer i)) + | Unop MoneyToRat, [ELit (LMoney i)] -> + ELit (LRat Runtime.(decimal_of_money i)) + | Unop RatToMoney, [ELit (LRat i)] -> + ELit (LMoney Runtime.(money_of_decimal i)) + | Unop RoundMoney, [ELit (LMoney m)] -> + ELit (LMoney Runtime.(money_round m)) + | Unop RoundDecimal, [ELit (LRat m)] -> + ELit (LRat Runtime.(decimal_round m)) + | Unop (Log (entry, infos)), [e'] -> if !Cli.trace_flag then ( match entry with | VarDef _ -> @@ -276,7 +277,7 @@ let rec evaluate_operator Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.format_log_entry entry Print.format_uid_list infos (match e' with - | Ast.EAbs _ -> Cli.with_style [ANSITerminal.green] "" + | EAbs _ -> Cli.with_style [ANSITerminal.green] "" | _ -> let expr_str = Format.asprintf "%a" @@ -308,7 +309,7 @@ let rec evaluate_operator entry Print.format_uid_list infos) else (); e' - | A.Unop _, [ELit LEmptyError] -> A.ELit LEmptyError + | Unop _, [ELit LEmptyError] -> ELit LEmptyError | _ -> Errors.raise_multispanned_error ([Some "Operator:", pos] @@ -318,16 +319,16 @@ let rec evaluate_operator (Format.asprintf "Argument n°%d, value %a" (i + 1) (Print.format_expr ctx ~debug:true) arg), - A.pos arg )) + Expr.pos arg )) args) "Operator applied to the wrong arguments\n\ (should not happen if the term was well-typed)" -and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr +and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.marked_expr) : 'm Ast.marked_expr = match Marked.unmark e with | EVar _ -> - Errors.raise_spanned_error (A.pos e) + Errors.raise_spanned_error (Expr.pos e) "free variable found at evaluation (should not happen if term was \ well-typed" | EApp (e1, args) -> ( @@ -339,22 +340,22 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr evaluate_expr ctx (Bindlib.msubst binder (Array.of_list (List.map Marked.unmark args))) else - Errors.raise_spanned_error (A.pos e) + Errors.raise_spanned_error (Expr.pos e) "wrong function call, expected %d arguments, got %d" (Bindlib.mbinder_arity binder) (List.length args) - | EOp op -> Marked.same_mark_as (evaluate_operator ctx op (A.pos e) args) e - | ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e + | EOp op -> Marked.same_mark_as (evaluate_operator ctx op (Expr.pos e) args) e + | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | _ -> - Errors.raise_spanned_error (A.pos e) + Errors.raise_spanned_error (Expr.pos e) "function has not been reduced to a lambda at evaluation (should not \ happen if the term was well-typed") | EAbs _ | ELit _ | EOp _ -> e (* these are values *) | ETuple (es, s) -> let new_es = List.map (evaluate_expr ctx) es in if List.exists is_empty_error new_es then - Marked.same_mark_as (A.ELit LEmptyError) e - else Marked.same_mark_as (A.ETuple (new_es, s)) e + Marked.same_mark_as (ELit LEmptyError) e + else Marked.same_mark_as (ETuple (new_es, s)) e | ETupleAccess (e1, n, s, _) -> ( let e1 = evaluate_expr ctx e1 in match Marked.unmark e1 with @@ -364,49 +365,49 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr | Some s, Some s' when s = s' -> () | _ -> Errors.raise_multispanned_error - [None, A.pos e; None, A.pos e1] + [None, Expr.pos e; None, Expr.pos e1] "Error during tuple access: not the same structs (should not happen \ if the term was well-typed)"); match List.nth_opt es n with | Some e' -> e' | None -> - Errors.raise_spanned_error (A.pos e1) + Errors.raise_spanned_error (Expr.pos e1) "The tuple has %d components but the %i-th element was requested \ (should not happen if the term was well-type)" (List.length es) n) - | ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e + | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | _ -> - Errors.raise_spanned_error (A.pos e1) + Errors.raise_spanned_error (Expr.pos e1) "The expression %a should be a tuple with %d components but is not \ (should not happen if the term was well-typed)" (Print.format_expr ctx ~debug:true) e n) | EInj (e1, n, en, ts) -> let e1' = evaluate_expr ctx e1 in - if is_empty_error e1' then Marked.same_mark_as (A.ELit LEmptyError) e - else Marked.same_mark_as (A.EInj (e1', n, en, ts)) e + if is_empty_error e1' then Marked.same_mark_as (ELit LEmptyError) e + else Marked.same_mark_as (EInj (e1', n, en, ts)) e | EMatch (e1, es, e_name) -> ( let e1 = evaluate_expr ctx e1 in match Marked.unmark e1 with - | A.EInj (e1, n, e_name', _) -> + | EInj (e1, n, e_name', _) -> if e_name <> e_name' then Errors.raise_multispanned_error - [None, A.pos e; None, A.pos e1] + [None, Expr.pos e; None, Expr.pos e1] "Error during match: two different enums found (should not happend \ if the term was well-typed)"; let es_n = match List.nth_opt es n with | Some es_n -> es_n | None -> - Errors.raise_spanned_error (A.pos e) + Errors.raise_spanned_error (Expr.pos e) "sum type index error (should not happend if the term was \ well-typed)" in - let new_e = Marked.same_mark_as (A.EApp (es_n, [e1])) e in + let new_e = Marked.same_mark_as (EApp (es_n, [e1])) e in evaluate_expr ctx new_e - | A.ELit A.LEmptyError -> Marked.same_mark_as (A.ELit A.LEmptyError) e + | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | _ -> - Errors.raise_spanned_error (A.pos e1) + Errors.raise_spanned_error (Expr.pos e1) "Expected a term having a sum type as an argument to a match (should \ not happend if the term was well-typed") | EDefault (exceptions, just, cons) -> ( @@ -416,11 +417,11 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr | 0 -> ( let just = evaluate_expr ctx just in match Marked.unmark just with - | ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e + | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | ELit (LBool true) -> evaluate_expr ctx cons - | ELit (LBool false) -> Marked.same_mark_as (A.ELit LEmptyError) e + | ELit (LBool false) -> Marked.same_mark_as (ELit LEmptyError) e | _ -> - Errors.raise_spanned_error (A.pos e) + Errors.raise_spanned_error (Expr.pos e) "Default justification has not been reduced to a boolean at \ evaluation (should not happen if the term was well-typed") | 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions @@ -428,7 +429,7 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr Errors.raise_multispanned_error (List.map (fun except -> - Some "This consequence has a valid justification:", A.pos except) + Some "This consequence has a valid justification:", Expr.pos except) (List.filter (fun sub -> not (is_empty_error sub)) exceptions)) "There is a conflict between multiple valid consequences for assigning \ the same variable.") @@ -436,55 +437,55 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr match Marked.unmark (evaluate_expr ctx cond) with | ELit (LBool true) -> evaluate_expr ctx et | ELit (LBool false) -> evaluate_expr ctx ef - | ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e + | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | _ -> - Errors.raise_spanned_error (A.pos cond) + Errors.raise_spanned_error (Expr.pos cond) "Expected a boolean literal for the result of this condition (should \ not happen if the term was well-typed)") | EArray es -> let new_es = List.map (evaluate_expr ctx) es in if List.exists is_empty_error new_es then - Marked.same_mark_as (A.ELit LEmptyError) e - else Marked.same_mark_as (A.EArray new_es) e + Marked.same_mark_as (ELit LEmptyError) e + else Marked.same_mark_as (EArray new_es) e | ErrorOnEmpty e' -> let e' = evaluate_expr ctx e' in - if Marked.unmark e' = A.ELit LEmptyError then - Errors.raise_spanned_error (A.pos e') + if Marked.unmark e' = ELit LEmptyError then + Errors.raise_spanned_error (Expr.pos e') "This variable evaluated to an empty term (no rule that defined it \ applied in this situation)" else e' | EAssert e' -> ( match Marked.unmark (evaluate_expr ctx e') with - | ELit (LBool true) -> Marked.same_mark_as (Ast.ELit LUnit) e' + | ELit (LBool true) -> Marked.same_mark_as (ELit LUnit) e' | ELit (LBool false) -> ( match Marked.unmark e' with - | Ast.ErrorOnEmpty + | ErrorOnEmpty ( EApp - ( (Ast.EOp (Binop op), _), + ( (EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)] ), _ ) | EApp - ( (Ast.EOp (Ast.Unop (Ast.Log _)), _), + ( (EOp (Unop (Log _)), _), [ - ( Ast.EApp - ( (Ast.EOp (Binop op), _), + ( EApp + ( (EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)] ), _ ); ] ) | EApp - ((Ast.EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)]) + ((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)]) -> - Errors.raise_spanned_error (A.pos e') "Assertion failed: %a %a %a" + Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a" (Print.format_expr ctx ~debug:false) e1 Print.format_binop op (Print.format_expr ctx ~debug:false) e2 | _ -> Cli.debug_format "%a" (Print.format_expr ctx) e'; - Errors.raise_spanned_error (A.pos e') "Assertion failed") - | ELit LEmptyError -> Marked.same_mark_as (A.ELit LEmptyError) e + Errors.raise_spanned_error (Expr.pos e') "Assertion failed") + | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | _ -> - Errors.raise_spanned_error (A.pos e') + Errors.raise_spanned_error (Expr.pos e') "Expected a boolean literal for the result of this assertion (should \ not happen if the term was well-typed)") @@ -492,13 +493,13 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : 'm A.marked_expr let interpret_program : 'm. - Ast.decl_ctx -> + decl_ctx -> 'm Ast.marked_expr -> (Uid.MarkedString.info * 'm Ast.marked_expr) list = - fun (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) : + fun (ctx : decl_ctx) (e : 'm Ast.marked_expr) : (Uid.MarkedString.info * 'm Ast.marked_expr) list -> match evaluate_expr ctx e with - | Ast.EAbs (_, [((Ast.TTuple (taus, Some s_in), _) as targs)]), mark_e -> + | EAbs (_, [((TTuple (taus, Some s_in), _) as targs)]), mark_e -> begin (* At this point, the interpreter seeks to execute the scope but does not have a way to retrieve input values from the command line. [taus] contain @@ -509,9 +510,9 @@ let interpret_program : List.map (fun ty -> match Marked.unmark ty with - | A.TArrow ((A.TLit A.TUnit, _), ty_in) -> + | TArrow ((TLit TUnit, _), ty_in) -> Ast.empty_thunked_term - (A.map_mark (fun pos -> pos) (fun _ -> ty_in) mark_e) + (Expr.map_mark (fun pos -> pos) (fun _ -> ty_in) mark_e) | _ -> Errors.raise_spanned_error (Marked.get_mark ty) "This scope needs input arguments to be executed. But the Catala \ @@ -522,23 +523,23 @@ let interpret_program : taus in let to_interpret = - ( Ast.EApp + ( EApp ( e, [ - ( Ast.ETuple (application_term, Some s_in), + ( ETuple (application_term, Some s_in), let pos = match application_term with - | a :: _ -> A.pos a + | a :: _ -> Expr.pos a | [] -> Pos.no_pos in - A.map_mark (fun _ -> pos) (fun _ -> targs) mark_e ); + Expr.map_mark (fun _ -> pos) (fun _ -> targs) mark_e ); ] ), - A.map_mark + Expr.map_mark (fun pos -> pos) (fun ty -> match application_term, ty with | [], t_out -> t_out - | _ :: _, (A.TArrow (_, t_out), _) -> t_out + | _ :: _, (TArrow (_, t_out), _) -> t_out | _ :: _, (_, bad_pos) -> Errors.raise_spanned_error bad_pos "@[(bug) Result of interpretation doesn't have the \ @@ -547,19 +548,19 @@ let interpret_program : mark_e ) in match Marked.unmark (evaluate_expr ctx to_interpret) with - | Ast.ETuple (args, Some s_out) -> + | ETuple (args, Some s_out) -> let s_out_fields = List.map - (fun (f, _) -> Ast.StructFieldName.get_info f) - (Ast.StructMap.find s_out ctx.ctx_structs) + (fun (f, _) -> StructFieldName.get_info f) + (StructMap.find s_out ctx.ctx_structs) in List.map2 (fun arg var -> var, arg) args s_out_fields | _ -> - Errors.raise_spanned_error (A.pos e) + Errors.raise_spanned_error (Expr.pos e) "The interpretation of a program should always yield a struct \ corresponding to the scope variables" end | _ -> - Errors.raise_spanned_error (A.pos e) + Errors.raise_spanned_error (Expr.pos e) "The interpreter can only interpret terms starting with functions having \ thunked arguments" diff --git a/compiler/dcalc/interpreter.mli b/compiler/dcalc/interpreter.mli index 4190c2f1..bd1125b0 100644 --- a/compiler/dcalc/interpreter.mli +++ b/compiler/dcalc/interpreter.mli @@ -17,12 +17,13 @@ (** Reference interpreter for the default calculus *) open Utils +open Shared_ast -val evaluate_expr : Ast.decl_ctx -> 'm Ast.marked_expr -> 'm Ast.marked_expr +val evaluate_expr : decl_ctx -> 'm Ast.marked_expr -> 'm Ast.marked_expr (** Evaluates an expression according to the semantics of the default calculus. *) val interpret_program : - Ast.decl_ctx -> + decl_ctx -> 'm Ast.marked_expr -> (Uid.MarkedString.info * 'm Ast.marked_expr) list (** Interprets a program. This function expects an expression typed as a diff --git a/compiler/dcalc/optimizations.ml b/compiler/dcalc/optimizations.ml index 5ecbe18a..4a00fbfa 100644 --- a/compiler/dcalc/optimizations.ml +++ b/compiler/dcalc/optimizations.ml @@ -15,6 +15,7 @@ License for the specific language governing permissions and limitations under the License. *) open Utils +open Shared_ast open Ast type partial_evaluation_ctx = { @@ -82,7 +83,7 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm marked_expr) : (fun arg arms -> match arg, arms with | (EInj (e1, i, e_name', _ts), _), _ - when Ast.EnumName.compare e_name e_name' = 0 -> + when EnumName.compare e_name e_name' = 0 -> (* iota reduction *) EApp (List.nth arms i, [e1]), pos | _ -> EMatch (arg, arms, e_name), pos) @@ -252,4 +253,4 @@ let optimize_program (p : 'm program) : untyped program = (program_map partial_evaluation { var_values = Var.Map.empty; decl_ctx = p.decl_ctx } p) - |> untype_program + |> Expr.untype_program diff --git a/compiler/dcalc/optimizations.mli b/compiler/dcalc/optimizations.mli index 53c7a600..e70ec2c2 100644 --- a/compiler/dcalc/optimizations.mli +++ b/compiler/dcalc/optimizations.mli @@ -17,6 +17,7 @@ (** Optimization passes for default calculus programs and expressions *) +open Shared_ast open Ast val optimize_expr : decl_ctx -> 'm marked_expr -> 'm marked_expr Bindlib.box diff --git a/compiler/dcalc/print.ml b/compiler/dcalc/print.ml index ac94560d..2f31f12e 100644 --- a/compiler/dcalc/print.ml +++ b/compiler/dcalc/print.ml @@ -15,6 +15,7 @@ the License. *) open Utils +open Shared_ast open Ast open String_common @@ -68,7 +69,7 @@ let format_enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : (Utils.Cli.format_with_style [ANSITerminal.magenta]) (Format.asprintf "%a" EnumConstructor.format_t c) -let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ) : +let rec format_typ (ctx : decl_ctx) (fmt : Format.formatter) (typ : typ) : unit = let format_typ = format_typ ctx in let format_typ_with_parens (fmt : Format.formatter) (t : typ) = @@ -84,7 +85,7 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ) : (fun fmt t -> Format.fprintf fmt "%a" format_typ t)) (List.map Marked.unmark ts) | TTuple (_args, Some s) -> - Format.fprintf fmt "@[%a%a%a%a@]" Ast.StructName.format_t s + Format.fprintf fmt "@[%a%a%a%a@]" StructName.format_t s format_punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> @@ -98,7 +99,7 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ) : (StructMap.find s ctx.ctx_structs)) format_punctuation "}" | TEnum (_, e) -> - Format.fprintf fmt "@[%a%a%a%a@]" Ast.EnumName.format_t e + Format.fprintf fmt "@[%a%a%a%a@]" EnumName.format_t e format_punctuation "[" (Format.pp_print_list ~pp_sep:(fun fmt () -> @@ -211,7 +212,7 @@ let format_var (fmt : Format.formatter) (v : 'm Ast.var) : unit = let rec format_expr ?(debug : bool = false) - (ctx : Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm marked_expr) : unit = let format_expr = format_expr ~debug ctx in @@ -231,15 +232,15 @@ let rec format_expr es format_punctuation ")" | ETuple (es, Some s) -> Format.fprintf fmt "@[%a@ @[%a%a%a@]@]" - Ast.StructName.format_t s format_punctuation "{" + StructName.format_t s format_punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " format_punctuation ";") (fun fmt (e, struct_field) -> Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" - Ast.StructFieldName.format_t struct_field format_punctuation "\"" + StructFieldName.format_t struct_field format_punctuation "\"" format_punctuation "=" format_expr e)) - (List.combine es (List.map fst (Ast.StructMap.find s ctx.ctx_structs))) + (List.combine es (List.map fst (StructMap.find s ctx.ctx_structs))) format_punctuation "}" | EArray es -> Format.fprintf fmt "@[%a%a%a@]" format_punctuation "[" @@ -253,12 +254,12 @@ let rec format_expr Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n | Some s -> Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_operator "." - format_punctuation "\"" Ast.StructFieldName.format_t - (fst (List.nth (Ast.StructMap.find s ctx.ctx_structs) n)) + format_punctuation "\"" StructFieldName.format_t + (fst (List.nth (StructMap.find s ctx.ctx_structs) n)) format_punctuation "\"") | EInj (e, n, en, _ts) -> Format.fprintf fmt "@[%a@ %a@]" format_enum_constructor - (fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n)) + (fst (List.nth (EnumMap.find en ctx.ctx_enums) n)) format_expr e | EMatch (e, es, e_name) -> Format.fprintf fmt "@[%a@ @[%a@]@ %a@ %a@]" format_keyword @@ -268,7 +269,7 @@ let rec format_expr (fun fmt (e, c) -> Format.fprintf fmt "@[%a %a%a@ %a@]" format_punctuation "|" format_enum_constructor c format_punctuation ":" format_expr e)) - (List.combine es (List.map fst (Ast.EnumMap.find e_name ctx.ctx_enums))) + (List.combine es (List.map fst (EnumMap.find e_name ctx.ctx_enums))) | ELit l -> format_lit fmt l | EApp ((EAbs (binder, taus), _), args) -> let xs, body = Bindlib.unmbind binder in @@ -298,7 +299,7 @@ let rec format_expr Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var x format_punctuation ":" (format_typ ctx) tau format_punctuation ")")) xs_tau format_punctuation "→" format_expr body - | EApp ((EOp (Binop ((Ast.Map | Ast.Filter) as op)), _), [arg1; arg2]) -> + | EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> Format.fprintf fmt "@[%a@ %a@ %a@]" format_binop op format_with_parens arg1 format_with_parens arg2 | EApp ((EOp (Binop op), _), [arg1; arg2]) -> @@ -347,13 +348,13 @@ let format_scope ?(debug : bool = false) (ctx : decl_ctx) (fmt : Format.formatter) - ((n, s) : Ast.ScopeName.t * ('m Ast.expr, 'm) scope_body) = + ((n, s) : ScopeName.t * ('m Ast.expr, 'm) scope_body) = Format.fprintf fmt "@[%a %a =@ %a@]" format_keyword "let" - Ast.ScopeName.format_t n (format_expr ctx ~debug) + ScopeName.format_t n (format_expr ctx ~debug) (Bindlib.unbox (Ast.build_whole_scope_expr ~make_abs:Ast.make_abs - ~make_let_in:Ast.make_let_in ~box_expr:Ast.box_expr ctx s - (Ast.map_mark - (fun _ -> Marked.get_mark (Ast.ScopeName.get_info n)) + ~make_let_in:Ast.make_let_in ~box_expr:Expr.box ctx s + (Expr.map_mark + (fun _ -> Marked.get_mark (ScopeName.get_info n)) (fun ty -> ty) - (Ast.get_scope_body_mark s)))) + (Expr.get_scope_body_mark s)))) diff --git a/compiler/dcalc/print.mli b/compiler/dcalc/print.mli index 8a3ae7a0..a9148e33 100644 --- a/compiler/dcalc/print.mli +++ b/compiler/dcalc/print.mli @@ -17,6 +17,7 @@ (** Printing functions for the default calculus AST *) open Utils +open Shared_ast (** {1 Common syntax highlighting helpers}*) @@ -29,27 +30,27 @@ val format_lit_style : Format.formatter -> string -> unit (** {1 Formatters} *) val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit -val format_enum_constructor : Format.formatter -> Ast.EnumConstructor.t -> unit -val format_tlit : Format.formatter -> Ast.typ_lit -> unit -val format_typ : Ast.decl_ctx -> Format.formatter -> Ast.typ -> unit +val format_enum_constructor : Format.formatter -> EnumConstructor.t -> unit +val format_tlit : Format.formatter -> typ_lit -> unit +val format_typ : decl_ctx -> Format.formatter -> typ -> unit val format_lit : Format.formatter -> Ast.lit -> unit -val format_op_kind : Format.formatter -> Ast.op_kind -> unit -val format_binop : Format.formatter -> Ast.binop -> unit -val format_ternop : Format.formatter -> Ast.ternop -> unit -val format_log_entry : Format.formatter -> Ast.log_entry -> unit -val format_unop : Format.formatter -> Ast.unop -> unit +val format_op_kind : Format.formatter -> op_kind -> unit +val format_binop : Format.formatter -> binop -> unit +val format_ternop : Format.formatter -> ternop -> unit +val format_log_entry : Format.formatter -> log_entry -> unit +val format_unop : Format.formatter -> unop -> unit val format_var : Format.formatter -> 'm Ast.var -> unit val format_expr : ?debug:bool (** [true] for debug printing *) -> - Ast.decl_ctx -> + decl_ctx -> Format.formatter -> 'm Ast.marked_expr -> unit val format_scope : ?debug:bool (** [true] for debug printing *) -> - Ast.decl_ctx -> + decl_ctx -> Format.formatter -> - Ast.ScopeName.t * ('m Ast.expr, 'm) Ast.scope_body -> + ScopeName.t * ('m Ast.expr, 'm) scope_body -> unit diff --git a/compiler/dcalc/typing.ml b/compiler/dcalc/typing.ml index 0184c54c..9945eece 100644 --- a/compiler/dcalc/typing.ml +++ b/compiler/dcalc/typing.ml @@ -71,7 +71,7 @@ let typ_needs_parens (t : typ Marked.pos UnionFind.elem) : bool = match Marked.unmark t with TArrow _ | TArray _ -> true | _ -> false let rec format_typ - (ctx : Ast.decl_ctx) + (ctx : A.decl_ctx) (fmt : Format.formatter) (typ : typ Marked.pos UnionFind.elem) : unit = let format_typ = format_typ ctx in @@ -90,8 +90,8 @@ let rec format_typ ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") (fun fmt t -> Format.fprintf fmt "%a" format_typ t)) ts - | TTuple (_ts, Some s) -> Format.fprintf fmt "%a" Ast.StructName.format_t s - | TEnum (_ts, e) -> Format.fprintf fmt "%a" Ast.EnumName.format_t e + | TTuple (_ts, Some s) -> Format.fprintf fmt "%a" A.StructName.format_t s + | TEnum (_ts, e) -> Format.fprintf fmt "%a" A.EnumName.format_t e | TArrow (t1, t2) -> Format.fprintf fmt "@[%a →@ %a@]" format_typ_with_parens t1 format_typ t2 @@ -108,8 +108,8 @@ type mark = { pos : Pos.t; uf : unionfind_typ } (** Raises an error if unification cannot be performed *) let rec unify - (ctx : Ast.decl_ctx) - (e : ('a, 'm A.mark) Ast.marked_gexpr) (* used for error context *) + (ctx : A.decl_ctx) + (e : ('a, 'm A.mark) A.marked_gexpr) (* used for error context *) (t1 : typ Marked.pos UnionFind.elem) (t2 : typ Marked.pos UnionFind.elem) : unit = let unify = unify ctx in @@ -263,7 +263,7 @@ let op_type (op : A.operator Marked.pos) : typ Marked.pos UnionFind.elem = type 'e env = ('e, typ Marked.pos UnionFind.elem) A.Var.Map.t -let add_pos e ty = Marked.mark (Ast.pos e) ty +let add_pos e ty = Marked.mark (A.Expr.pos e) ty let ty (_, { uf; _ }) = uf let ( let+ ) x f = Bindlib.box_apply f x let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2 @@ -290,12 +290,12 @@ let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e) (** Infers the most permissive type from an expression *) let rec typecheck_expr_bottom_up - (ctx : Ast.decl_ctx) + (ctx : A.decl_ctx) (env : 'm Ast.expr env) (e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box = (* Cli.debug_format "Looking for type of %a" (Print.format_expr ~debug:true ctx) e; *) - let pos_e = Ast.pos e in + let pos_e = A.Expr.pos e in let mark (e : (A.dcalc, mark) A.gexpr) uf = Marked.mark { uf; pos = pos_e } e in @@ -308,7 +308,7 @@ let rec typecheck_expr_bottom_up let+ v' = Bindlib.box_var (A.Var.translate v) in mark v' t | None -> - Errors.raise_spanned_error (Ast.pos e) + Errors.raise_spanned_error (A.Expr.pos e) "Variable %s not found in the current context." (Bindlib.name_of v) end | A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool) @@ -343,7 +343,7 @@ let rec typecheck_expr_bottom_up match List.nth_opt ts' n with | Some ts_n -> ts_n | None -> - Errors.raise_spanned_error (Ast.pos e) + Errors.raise_spanned_error (A.Expr.pos e) "Expression should have a sum type with at least %d cases but only \ has %d" n (List.length ts') @@ -368,7 +368,7 @@ let rec typecheck_expr_bottom_up mark (EMatch (e1', es', e_name)) t_ret | A.EAbs (binder, taus) -> if Bindlib.mbinder_arity binder <> List.length taus then - Errors.raise_spanned_error (Ast.pos e) + Errors.raise_spanned_error (A.Expr.pos e) "function has %d variables but was supplied %d types" (Bindlib.mbinder_arity binder) (List.length taus) @@ -446,13 +446,13 @@ let rec typecheck_expr_bottom_up (** Checks whether the expression can be typed with the provided type *) and typecheck_expr_top_down - (ctx : Ast.decl_ctx) + (ctx : A.decl_ctx) (env : 'm Ast.expr env) (tau : typ Marked.pos UnionFind.elem) (e : 'm Ast.marked_expr) : (A.dcalc, mark) A.marked_gexpr Bindlib.box = (* Cli.debug_format "Propagating type %a for expr %a" (format_typ ctx) tau (Print.format_expr ctx) e; *) - let pos_e = Ast.pos e in + let pos_e = A.Expr.pos e in let mark e = Marked.mark { uf = tau; pos = pos_e } e in let unify_and_mark (e' : (A.dcalc, mark) A.gexpr) tau' = (* This try...with was added because of @@ -502,7 +502,7 @@ and typecheck_expr_top_down match List.nth_opt typs' n with | Some t1n -> unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n | None -> - Errors.raise_spanned_error (Ast.pos e1) + Errors.raise_spanned_error (A.Expr.pos e1) "Expression should have a tuple type with at least %d elements but \ only has %d" n (List.length typs) @@ -513,7 +513,7 @@ and typecheck_expr_top_down match List.nth_opt ts' n with | Some ts_n -> ts_n | None -> - Errors.raise_spanned_error (Ast.pos e) + Errors.raise_spanned_error (A.Expr.pos e) "Expression should have a sum type with at least %d cases but only \ has %d" n (List.length ts) @@ -544,7 +544,7 @@ and typecheck_expr_top_down unify_and_mark (EMatch (e1', es', e_name)) t_ret | A.EAbs (binder, t_args) -> if Bindlib.mbinder_arity binder <> List.length t_args then - Errors.raise_spanned_error (Ast.pos e) + Errors.raise_spanned_error (A.Expr.pos e) "function has %d variables but was supplied %d types" (Bindlib.mbinder_arity binder) (List.length t_args) @@ -628,8 +628,8 @@ let wrap ctx f e = let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos } (* Infer the type of an expression *) -let infer_types (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) : - Ast.typed Ast.marked_expr Bindlib.box = +let infer_types (ctx : A.decl_ctx) (e : 'm Ast.marked_expr) : + A.typed Ast.marked_expr Bindlib.box = A.Expr.map_marks ~f:get_ty_mark @@ Bindlib.unbox @@ wrap ctx (typecheck_expr_bottom_up ctx A.Var.Map.empty) e @@ -637,11 +637,11 @@ let infer_types (ctx : Ast.decl_ctx) (e : 'm Ast.marked_expr) : let infer_type (type m) ctx (e : m Ast.marked_expr) = match Marked.get_mark e with | A.Typed { ty; _ } -> ty - | A.Untyped _ -> Ast.ty (Bindlib.unbox (infer_types ctx e)) + | A.Untyped _ -> A.Expr.ty (Bindlib.unbox (infer_types ctx e)) (** Typechecks an expression given an expected type *) let check_type - (ctx : Ast.decl_ctx) + (ctx : A.decl_ctx) (e : 'm Ast.marked_expr) (tau : A.typ Marked.pos) = (* todo: consider using the already inferred type if ['m] = [typed] *) diff --git a/compiler/dcalc/typing.mli b/compiler/dcalc/typing.mli index 5c66ca4f..91ee3df1 100644 --- a/compiler/dcalc/typing.mli +++ b/compiler/dcalc/typing.mli @@ -17,18 +17,20 @@ (** Typing for the default calculus. Because of the error terms, we perform type inference using the classical W algorithm with union-find unification. *) +open Shared_ast + val infer_types : - Ast.decl_ctx -> - Ast.untyped Ast.marked_expr -> - Ast.typed Ast.marked_expr Bindlib.box + decl_ctx -> + untyped Ast.marked_expr -> + typed Ast.marked_expr Bindlib.box (** Infers types everywhere on the given expression, and adds (or replaces) type annotations on each node *) -val infer_type : Ast.decl_ctx -> 'm Ast.marked_expr -> Ast.typ Utils.Marked.pos +val infer_type : decl_ctx -> 'm Ast.marked_expr -> typ Utils.Marked.pos (** Gets the outer type of the given expression, using either the existing annotations or inference *) val check_type : - Ast.decl_ctx -> 'm Ast.marked_expr -> Ast.typ Utils.Marked.pos -> unit + decl_ctx -> 'm Ast.marked_expr -> typ Utils.Marked.pos -> unit -val infer_types_program : Ast.untyped Ast.program -> Ast.typed Ast.program +val infer_types_program : untyped Ast.program -> typed Ast.program diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index a65c0828..7982d783 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -17,6 +17,7 @@ (** Abstract syntax tree of the desugared representation *) open Utils +open Shared_ast (** {1 Names, Maps and Keys} *) @@ -99,7 +100,7 @@ module ScopeDefSet : Set.S with type elt = ScopeDef.t = Set.Make (ScopeDef) type location = | ScopeVar of ScopeVar.t Marked.pos * StateName.t option | SubScopeVar of - Scopelang.Ast.ScopeName.t + ScopeName.t * Scopelang.Ast.SubScopeName.t Marked.pos * ScopeVar.t Marked.pos @@ -132,20 +133,20 @@ and expr = | ELocation of location | EVar of expr Bindlib.var | EStruct of - Scopelang.Ast.StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t + StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t | EStructAccess of - marked_expr * Scopelang.Ast.StructFieldName.t * Scopelang.Ast.StructName.t + marked_expr * StructFieldName.t * StructName.t | EEnumInj of - marked_expr * Scopelang.Ast.EnumConstructor.t * Scopelang.Ast.EnumName.t + marked_expr * EnumConstructor.t * EnumName.t | EMatch of marked_expr - * Scopelang.Ast.EnumName.t + * EnumName.t * marked_expr Scopelang.Ast.EnumConstructorMap.t | ELit of Dcalc.Ast.lit | EAbs of (expr, marked_expr) Bindlib.mbinder * Scopelang.Ast.typ Marked.pos list | EApp of marked_expr * marked_expr list - | EOp of Dcalc.Ast.operator + | EOp of operator | EDefault of marked_expr list * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr | EArray of marked_expr list @@ -170,7 +171,7 @@ module Expr = struct | ELocation _, ELocation _ -> 0 | EVar v1, EVar v2 -> Bindlib.compare_vars v1 v2 | EStruct (name1, field_map1), EStruct (name2, field_map2) -> ( - match Scopelang.Ast.StructName.compare name1 name2 with + match StructName.compare name1 name2 with | 0 -> Scopelang.Ast.StructFieldMap.compare (Marked.compare compare) field_map1 field_map2 @@ -179,21 +180,21 @@ module Expr = struct EStructAccess ((e2, _), field_name2, struct_name2) ) -> ( match compare e1 e2 with | 0 -> ( - match Scopelang.Ast.StructFieldName.compare field_name1 field_name2 with - | 0 -> Scopelang.Ast.StructName.compare struct_name1 struct_name2 + match StructFieldName.compare field_name1 field_name2 with + | 0 -> StructName.compare struct_name1 struct_name2 | n -> n) | n -> n) | EEnumInj ((e1, _), cstr1, name1), EEnumInj ((e2, _), cstr2, name2) -> ( match compare e1 e2 with | 0 -> ( - match Scopelang.Ast.EnumName.compare name1 name2 with - | 0 -> Scopelang.Ast.EnumConstructor.compare cstr1 cstr2 + match EnumName.compare name1 name2 with + | 0 -> EnumConstructor.compare cstr1 cstr2 | n -> n) | n -> n) | EMatch ((e1, _), name1, emap1), EMatch ((e2, _), name2, emap2) -> ( match compare e1 e2 with | 0 -> ( - match Scopelang.Ast.EnumName.compare name1 name2 with + match EnumName.compare name1 name2 with | 0 -> Scopelang.Ast.EnumConstructorMap.compare (Marked.compare compare) emap1 emap2 @@ -325,8 +326,8 @@ let empty_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Marked.pos option) : rule = { - rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); - rule_cons = Bindlib.box (ELit Dcalc.Ast.LEmptyError, pos); + rule_just = Bindlib.box (ELit (LBool false), pos); + rule_cons = Bindlib.box (ELit LEmptyError, pos); rule_parameter = (match have_parameter with | Some typ -> Some (Var.make "dummy", typ) @@ -340,8 +341,8 @@ let always_false_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Marked.pos option) : rule = { - rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool true), pos); - rule_cons = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); + rule_just = Bindlib.box (ELit (LBool true), pos); + rule_cons = Bindlib.box (ELit (LBool false), pos); rule_parameter = (match have_parameter with | Some typ -> Some (Var.make "dummy", typ) @@ -370,8 +371,8 @@ type var_or_states = WholeVar | States of StateName.t list type scope = { scope_vars : var_or_states ScopeVarMap.t; - scope_sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; - scope_uid : Scopelang.Ast.ScopeName.t; + scope_sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t; + scope_uid : ScopeName.t; scope_defs : scope_def ScopeDefMap.t; scope_assertions : assertion list; scope_meta_assertions : meta_assertion list; diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index 80702135..365aff2d 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -17,6 +17,7 @@ (** Abstract syntax tree of the desugared representation *) open Utils +open Shared_ast (** {1 Names, Maps and Keys} *) @@ -54,7 +55,7 @@ module ScopeDefSet : Set.S with type elt = ScopeDef.t type location = | ScopeVar of ScopeVar.t Marked.pos * StateName.t option | SubScopeVar of - Scopelang.Ast.ScopeName.t + ScopeName.t * Scopelang.Ast.SubScopeName.t Marked.pos * ScopeVar.t Marked.pos @@ -68,20 +69,20 @@ and expr = | ELocation of location | EVar of expr Bindlib.var | EStruct of - Scopelang.Ast.StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t + StructName.t * marked_expr Scopelang.Ast.StructFieldMap.t | EStructAccess of - marked_expr * Scopelang.Ast.StructFieldName.t * Scopelang.Ast.StructName.t + marked_expr * StructFieldName.t * StructName.t | EEnumInj of - marked_expr * Scopelang.Ast.EnumConstructor.t * Scopelang.Ast.EnumName.t + marked_expr * EnumConstructor.t * EnumName.t | EMatch of marked_expr - * Scopelang.Ast.EnumName.t + * EnumName.t * marked_expr Scopelang.Ast.EnumConstructorMap.t | ELit of Dcalc.Ast.lit | EAbs of (expr, marked_expr) Bindlib.mbinder * Scopelang.Ast.typ Marked.pos list | EApp of marked_expr * marked_expr list - | EOp of Dcalc.Ast.operator + | EOp of operator | EDefault of marked_expr list * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr | EArray of marked_expr list @@ -166,8 +167,8 @@ type var_or_states = WholeVar | States of StateName.t list type scope = { scope_vars : var_or_states ScopeVarMap.t; - scope_sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; - scope_uid : Scopelang.Ast.ScopeName.t; + scope_sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t; + scope_uid : ScopeName.t; scope_defs : scope_def ScopeDefMap.t; scope_assertions : assertion list; scope_meta_assertions : meta_assertion list; diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 9aad2351..7a5f7268 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -18,6 +18,7 @@ OCamlgraph} *) open Utils +open Shared_ast (** {1 Scope variables dependency graph} *) @@ -140,7 +141,7 @@ let check_for_cycle (scope : Ast.scope) (g : ScopeDependencies.t) : unit = in Errors.raise_multispanned_error spans "Cyclic dependency detected between variables of scope %a!" - Scopelang.Ast.ScopeName.format_t scope.scope_uid + ScopeName.format_t scope.scope_uid (** Builds the dependency graph of a particular scope *) let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = diff --git a/compiler/desugared/desugared_to_scope.ml b/compiler/desugared/desugared_to_scope.ml index 531d5d91..df9edbea 100644 --- a/compiler/desugared/desugared_to_scope.ml +++ b/compiler/desugared/desugared_to_scope.ml @@ -17,6 +17,7 @@ (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) open Utils +open Shared_ast (** {1 Expression translation}*) @@ -31,11 +32,11 @@ type ctx = { let tag_with_log_entry (e : Scopelang.Ast.expr Marked.pos) - (l : Dcalc.Ast.log_entry) + (l : log_entry) (markings : Utils.Uid.MarkedString.info list) : Scopelang.Ast.expr Marked.pos = ( Scopelang.Ast.EApp - ( ( Scopelang.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings))), + ( ( Scopelang.Ast.EOp (Unop (Log (l, markings))), Marked.get_mark e ), [e] ), Marked.get_mark e ) @@ -263,11 +264,11 @@ let rec rule_tree_to_expr Scopelang.Ast.make_default ~pos:def_pos [] (* Here we insert the logging command that records when a decision is taken for the value of a variable. *) - (tag_with_log_entry base_just Dcalc.Ast.PosRecordIfTrueBool []) + (tag_with_log_entry base_just PosRecordIfTrueBool []) base_cons) base_just_list base_cons_list) - (Scopelang.Ast.ELit (Dcalc.Ast.LBool false), def_pos) - (Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, def_pos)) + (Scopelang.Ast.ELit (LBool false), def_pos) + (Scopelang.Ast.ELit LEmptyError, def_pos)) (Bindlib.box_list (translate_and_unbox_list base_just_list)) (Bindlib.box_list (translate_and_unbox_list base_cons_list)) in @@ -281,7 +282,7 @@ let rec rule_tree_to_expr Bindlib.box_apply2 (fun exceptions default_containing_base_cases -> Scopelang.Ast.make_default exceptions - (Scopelang.Ast.ELit (Dcalc.Ast.LBool true), def_pos) + (Scopelang.Ast.ELit (LBool true), def_pos) default_containing_base_cases) exceptions default_containing_base_cases in diff --git a/compiler/driver.ml b/compiler/driver.ml index bf4c1e10..d357a48f 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -200,10 +200,10 @@ let driver source_file (options : Cli.options) : int = (Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) ( scope_uid, Option.get - (Dcalc.Ast.fold_left_scope_defs ~init:None + (Shared_ast.Expr.fold_left_scope_defs ~init:None ~f:(fun acc scope_def _ -> if - Dcalc.Ast.ScopeName.compare scope_def.scope_name + Shared_ast.ScopeName.compare scope_def.scope_name scope_uid = 0 then Some scope_def.scope_body @@ -212,7 +212,7 @@ let driver source_file (options : Cli.options) : int = else let prgrm_dcalc_expr = Bindlib.unbox - (Dcalc.Ast.build_whole_program_expr ~box_expr:Dcalc.Ast.box_expr + (Dcalc.Ast.build_whole_program_expr ~box_expr:Shared_ast.Expr.box ~make_abs:Dcalc.Ast.make_abs ~make_let_in:Dcalc.Ast.make_let_in prgm scope_uid) in @@ -242,7 +242,7 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Starting interpretation..."; let prgrm_dcalc_expr = Bindlib.unbox - (Dcalc.Ast.build_whole_program_expr ~box_expr:Dcalc.Ast.box_expr + (Dcalc.Ast.build_whole_program_expr ~box_expr:Shared_ast.Expr.box ~make_abs:Dcalc.Ast.make_abs ~make_let_in:Dcalc.Ast.make_let_in prgm scope_uid) in @@ -285,7 +285,7 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Optimizing lambda calculus..."; Lcalc.Optimizations.optimize_program prgm end - else Lcalc.Ast.untype_program prgm + else Shared_ast.Expr.untype_program prgm in let prgm = if options.closure_conversion then ( @@ -305,10 +305,10 @@ let driver source_file (options : Cli.options) : int = (Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) ( scope_uid, Option.get - (Dcalc.Ast.fold_left_scope_defs ~init:None + (Shared_ast.Expr.fold_left_scope_defs ~init:None ~f:(fun acc scope_def _ -> if - Dcalc.Ast.ScopeName.compare scope_def.scope_name + Shared_ast.ScopeName.compare scope_def.scope_name scope_uid = 0 then Some scope_def.scope_body @@ -318,7 +318,7 @@ let driver source_file (options : Cli.options) : int = let prgrm_lcalc_expr = Bindlib.unbox (Dcalc.Ast.build_whole_program_expr - ~box_expr:Lcalc.Ast.box_expr ~make_abs:Lcalc.Ast.make_abs + ~box_expr:Shared_ast.Expr.box ~make_abs:Lcalc.Ast.make_abs ~make_let_in:Lcalc.Ast.make_let_in prgm scope_uid) in Format.fprintf fmt "%a\n" diff --git a/compiler/lcalc/ast.ml b/compiler/lcalc/ast.ml index 04a09255..df1b80a8 100644 --- a/compiler/lcalc/ast.ml +++ b/compiler/lcalc/ast.ml @@ -23,80 +23,10 @@ type lit = lcalc glit type 'm expr = (lcalc, 'm mark) gexpr and 'm marked_expr = (lcalc, 'm mark) marked_gexpr -type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic +type 'm program = ('m expr, 'm) program_generic type 'm var = 'm expr Var.t type 'm vars = 'm expr Var.vars -(* *) - -let evar v mark = Bindlib.box_apply (Marked.mark mark) (Bindlib.box_var v) - -let etuple args s mark = - Bindlib.box_apply (fun args -> ETuple (args, s), mark) (Bindlib.box_list args) - -let etupleaccess e1 i s typs mark = - Bindlib.box_apply (fun e1 -> ETupleAccess (e1, i, s, typs), mark) e1 - -let einj e1 i e_name typs mark = - Bindlib.box_apply (fun e1 -> EInj (e1, i, e_name, typs), mark) e1 - -let ematch arg arms e_name mark = - Bindlib.box_apply2 - (fun arg arms -> EMatch (arg, arms, e_name), mark) - arg (Bindlib.box_list arms) - -let earray args mark = - Bindlib.box_apply (fun args -> EArray args, mark) (Bindlib.box_list args) - -let elit l mark = Bindlib.box (ELit l, mark) - -let eabs binder typs mark = - Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) binder - -let eapp e1 args mark = - Bindlib.box_apply2 - (fun e1 args -> EApp (e1, args), mark) - e1 (Bindlib.box_list args) - -let eassert e1 mark = Bindlib.box_apply (fun e1 -> EAssert e1, mark) e1 -let eop op mark = Bindlib.box (EOp op, mark) - -let eifthenelse e1 e2 e3 pos = - Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), pos) e1 e2 e3 - -(* *) - -let eraise e1 pos = Bindlib.box (ERaise e1, pos) - -let ecatch e1 exn e2 pos = - Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2 - -let map_expr ctx ~f e = Expr.map ctx ~f e - -let rec map_expr_top_down ~f e = - map_expr () ~f:(fun () -> map_expr_top_down ~f) (f e) - -let map_expr_marks ~f e = - map_expr_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e - -let untype_expr e = - map_expr_marks ~f:(fun m -> Untyped { pos = D.mark_pos m }) e - -let untype_program prg = - { - prg with - D.scopes = - Bindlib.unbox - (D.map_exprs_in_scopes - ~f:(fun e -> untype_expr e) - ~varf:Var.translate prg.D.scopes); - } - -(** See [Bindlib.box_term] documentation for why we are doing that. *) -let box_expr (e : 'm marked_expr) : 'm marked_expr Bindlib.box = - let rec id_t () e = map_expr () ~f:id_t e in - id_t () e - let make_var (x, mark) = Bindlib.box_apply (fun x -> x, mark) (Bindlib.box_var x) @@ -110,7 +40,7 @@ let make_let_in x tau e1 e2 pos = let m_e1 = Marked.get_mark (Bindlib.unbox e1) in let m_e2 = Marked.get_mark (Bindlib.unbox e2) in let m_abs = - D.map_mark2 + Expr.map_mark2 (fun _ _ -> pos) (fun m1 m2 -> TArrow (m1.ty, m2.ty), m1.pos) m_e1 m_e2 @@ -120,47 +50,47 @@ let make_let_in x tau e1 e2 pos = let make_multiple_let_in xs taus e1s e2 pos = (* let m_e1s = List.map (fun e -> Marked.get_mark (Bindlib.unbox e)) e1s in *) let m_e1s = - D.fold_marks List.hd + Expr.fold_marks List.hd (fun tys -> - D.TTuple (List.map (fun t -> t.D.ty) tys, None), (List.hd tys).D.pos) + TTuple (List.map (fun t -> t.ty) tys, None), (List.hd tys).pos) (List.map (fun e -> Marked.get_mark (Bindlib.unbox e)) e1s) in let m_e2 = Marked.get_mark (Bindlib.unbox e2) in let m_abs = - D.map_mark2 + Expr.map_mark2 (fun _ _ -> pos) - (fun m1 m2 -> Marked.mark pos (D.TArrow (m1.ty, m2.ty))) + (fun m1 m2 -> Marked.mark pos (TArrow (m1.ty, m2.ty))) m_e1s m_e2 in make_app (make_abs xs e2 taus m_abs) e1s m_e2 let ( let+ ) x f = Bindlib.box_apply f x let ( and+ ) x y = Bindlib.box_pair x y -let option_enum : D.EnumName.t = D.EnumName.fresh ("eoption", Pos.no_pos) +let option_enum : EnumName.t = EnumName.fresh ("eoption", Pos.no_pos) -let none_constr : D.EnumConstructor.t = - D.EnumConstructor.fresh ("ENone", Pos.no_pos) +let none_constr : EnumConstructor.t = + EnumConstructor.fresh ("ENone", Pos.no_pos) -let some_constr : D.EnumConstructor.t = - D.EnumConstructor.fresh ("ESome", Pos.no_pos) +let some_constr : EnumConstructor.t = + EnumConstructor.fresh ("ESome", Pos.no_pos) -let option_enum_config : (D.EnumConstructor.t * D.typ Marked.pos) list = - [none_constr, (D.TLit D.TUnit, Pos.no_pos); some_constr, (D.TAny, Pos.no_pos)] +let option_enum_config : (EnumConstructor.t * typ Marked.pos) list = + [none_constr, (TLit TUnit, Pos.no_pos); some_constr, (TAny, Pos.no_pos)] (* FIXME: proper typing in all the constructors below *) let make_none m = let mark = Marked.mark m in - let tunit = D.TLit D.TUnit, D.mark_pos m in + let tunit = TLit TUnit, Expr.mark_pos m in Bindlib.box @@ mark @@ EInj ( Marked.mark - (D.map_mark (fun pos -> pos) (fun _ -> tunit) m) + (Expr.map_mark (fun pos -> pos) (fun _ -> tunit) m) (ELit LUnit), 0, option_enum, - [D.TLit D.TUnit, Pos.no_pos; D.TAny, Pos.no_pos] ) + [TLit TUnit, Pos.no_pos; TAny, Pos.no_pos] ) let make_some e = let m = Marked.get_mark @@ Bindlib.unbox e in @@ -168,7 +98,7 @@ let make_some e = let+ e in mark @@ EInj - (e, 1, option_enum, [D.TLit D.TUnit, D.mark_pos m; D.TAny, D.mark_pos m]) + (e, 1, option_enum, [TLit TUnit, Expr.mark_pos m; TAny, Expr.mark_pos m]) (** [make_matchopt_with_abs_arms arg e_none e_some] build an expression [match arg with |None -> e_none | Some -> e_some] and requires e_some and @@ -187,7 +117,7 @@ let make_matchopt m v tau arg e_none e_some = let x = Var.make "_" in make_matchopt_with_abs_arms arg - (make_abs (Array.of_list [x]) e_none [D.TLit D.TUnit, D.mark_pos m] m) + (make_abs (Array.of_list [x]) e_none [TLit TUnit, Expr.mark_pos m] m) (make_abs (Array.of_list [v]) e_some [tau] m) let handle_default = Var.make "handle_default" diff --git a/compiler/lcalc/ast.mli b/compiler/lcalc/ast.mli index 51739da6..40ef7373 100644 --- a/compiler/lcalc/ast.mli +++ b/compiler/lcalc/ast.mli @@ -15,7 +15,7 @@ the License. *) open Utils -include module type of Shared_ast +open Shared_ast (** Abstract syntax tree for the lambda calculus *) @@ -26,114 +26,21 @@ type lit = lcalc glit type 'm expr = (lcalc, 'm mark) gexpr and 'm marked_expr = (lcalc, 'm mark) marked_gexpr -type 'm program = ('m expr, 'm) Dcalc.Ast.program_generic +type 'm program = ('m expr, 'm) program_generic (** {1 Variable helpers} *) type 'm var = 'm expr Var.t type 'm vars = 'm expr Var.vars -(** {2 Program traversal} *) - -val map_expr : - 'a -> - f:('a -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box) -> - ('m1 expr, 'm2 mark) Marked.t -> - 'm2 marked_expr Bindlib.box -(** See [Dcalc.Ast.map_expr] *) - -val map_expr_top_down : - f:('m1 marked_expr -> ('m1 expr, 'm2 mark) Marked.t) -> - 'm1 marked_expr -> - 'm2 marked_expr Bindlib.box -(** See [Dcalc.Ast.map_expr_top_down] *) - -val map_expr_marks : - f:('m1 mark -> 'm2 mark) -> 'm1 marked_expr -> 'm2 marked_expr Bindlib.box -(** See [Dcalc.Ast.map_expr_marks] *) - -val untype_expr : 'm marked_expr -> Dcalc.Ast.untyped marked_expr Bindlib.box -val untype_program : 'm program -> Dcalc.Ast.untyped program - -(** {1 Boxed constructors} *) - -val evar : 'm expr Bindlib.var -> 'm mark -> 'm marked_expr Bindlib.box - -val etuple : - 'm marked_expr Bindlib.box list -> - Dcalc.Ast.StructName.t option -> - 'm mark -> - 'm marked_expr Bindlib.box - -val etupleaccess : - 'm marked_expr Bindlib.box -> - int -> - Dcalc.Ast.StructName.t option -> - Dcalc.Ast.typ Marked.pos list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val einj : - 'm marked_expr Bindlib.box -> - int -> - Dcalc.Ast.EnumName.t -> - Dcalc.Ast.typ Marked.pos list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val ematch : - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box list -> - Dcalc.Ast.EnumName.t -> - 'm mark -> - 'm marked_expr Bindlib.box - -val earray : - 'm marked_expr Bindlib.box list -> 'm mark -> 'm marked_expr Bindlib.box - -val elit : lit -> 'm mark -> 'm marked_expr Bindlib.box - -val eabs : - ('m expr, 'm marked_expr) Bindlib.mbinder Bindlib.box -> - Dcalc.Ast.typ Marked.pos list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val eapp : - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box list -> - 'm mark -> - 'm marked_expr Bindlib.box - -val eassert : - 'm marked_expr Bindlib.box -> 'm mark -> 'm marked_expr Bindlib.box - -val eop : Dcalc.Ast.operator -> 'm mark -> 'm marked_expr Bindlib.box - -val eifthenelse : - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box -> - 'm marked_expr Bindlib.box -> - 'm mark -> - 'm marked_expr Bindlib.box - -val ecatch : - 'm marked_expr Bindlib.box -> - except -> - 'm marked_expr Bindlib.box -> - 'm mark -> - 'm marked_expr Bindlib.box - -val eraise : except -> 'm mark -> 'm marked_expr Bindlib.box - (** {1 Language terms construction}*) -val make_var : ('m var, 'm) Dcalc.Ast.marked -> 'm marked_expr Bindlib.box +val make_var : ('m var, 'm) marked -> 'm marked_expr Bindlib.box val make_abs : 'm vars -> 'm marked_expr Bindlib.box -> - Dcalc.Ast.typ Marked.pos list -> + typ Marked.pos list -> 'm mark -> 'm marked_expr Bindlib.box @@ -145,7 +52,7 @@ val make_app : val make_let_in : 'm var -> - Dcalc.Ast.typ Marked.pos -> + typ Marked.pos -> 'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box -> Pos.t -> @@ -153,18 +60,18 @@ val make_let_in : val make_multiple_let_in : 'm vars -> - Dcalc.Ast.typ Marked.pos list -> + typ Marked.pos list -> 'm marked_expr Bindlib.box list -> 'm marked_expr Bindlib.box -> Pos.t -> 'm marked_expr Bindlib.box -val option_enum : Dcalc.Ast.EnumName.t -val none_constr : Dcalc.Ast.EnumConstructor.t -val some_constr : Dcalc.Ast.EnumConstructor.t +val option_enum : EnumName.t +val none_constr : EnumConstructor.t +val some_constr : EnumConstructor.t val option_enum_config : - (Dcalc.Ast.EnumConstructor.t * Dcalc.Ast.typ Marked.pos) list + (EnumConstructor.t * typ Marked.pos) list val make_none : 'm mark -> 'm marked_expr Bindlib.box val make_some : 'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box @@ -178,7 +85,7 @@ val make_matchopt_with_abs_arms : val make_matchopt : 'm mark -> 'm var -> - Dcalc.Ast.typ Marked.pos -> + typ Marked.pos -> 'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box -> 'm marked_expr Bindlib.box -> @@ -186,8 +93,6 @@ val make_matchopt : (** [e' = make_matchopt'' pos v e e_none e_some] Builds the term corresponding to [match e with | None -> fun () -> e_none |Some -> fun v -> e_some]. *) -val box_expr : 'm marked_expr -> 'm marked_expr Bindlib.box - (** {1 Special symbols} *) val handle_default : untyped var diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index b878a05f..62a74bc7 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -14,8 +14,9 @@ License for the specific language governing permissions and limitations under the License. *) -open Ast open Utils +open Shared_ast +open Ast module D = Dcalc.Ast (** TODO: This version is not yet debugged and ought to be specialized when @@ -127,7 +128,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : | EAbs (binder, typs) -> (* λ x.t *) let binder_mark = Marked.get_mark e in - let binder_pos = D.mark_pos binder_mark in + let binder_pos = Expr.mark_pos binder_mark in (* Converting the closure. *) let vars, body = Bindlib.unmbind binder in (* t *) @@ -141,7 +142,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : let code_var = Var.make ctx.name_context in (* code *) let inner_c_var = Var.make "env" in - let any_ty = Dcalc.Ast.TAny, binder_pos in + let any_ty = TAny, binder_pos in let new_closure_body = make_multiple_let_in (Array.of_list extra_vars_list) @@ -158,17 +159,17 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : binder_mark )) (Bindlib.box_var inner_c_var)) extra_vars_list) - new_body (D.mark_pos binder_mark) + new_body (Expr.mark_pos binder_mark) in let new_closure = make_abs (Array.concat [Array.make 1 inner_c_var; vars]) new_closure_body - ((Dcalc.Ast.TAny, binder_pos) :: typs) + ((TAny, binder_pos) :: typs) (Marked.get_mark e) in ( make_let_in code_var - (Dcalc.Ast.TAny, D.pos e) + (TAny, Expr.pos e) new_closure (Bindlib.box_apply2 (fun code_var extra_vars -> @@ -184,7 +185,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : (List.map (fun extra_var -> Bindlib.box_var extra_var) extra_vars_list))) - (D.pos e), + (Expr.pos e), extra_vars ) | EApp ((EOp op, pos_op), args) -> (* This corresponds to an operator call, which we don't want to @@ -227,7 +228,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : in let call_expr = make_let_in code_var - (Dcalc.Ast.TAny, D.pos e) + (TAny, Expr.pos e) (Bindlib.box_apply (fun env_var -> ( ETupleAccess @@ -242,9 +243,9 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : Marked.get_mark e )) (Bindlib.box_var code_var) (Bindlib.box_var env_var) (Bindlib.box_list new_args)) - (D.pos e) + (Expr.pos e) in - ( make_let_in env_var (Dcalc.Ast.TAny, D.pos e) new_e1 call_expr (D.pos e), + ( make_let_in env_var (TAny, Expr.pos e) new_e1 call_expr (Expr.pos e), free_vars ) | EAssert e1 -> let new_e1, free_vars = aux e1 in @@ -278,7 +279,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m marked_expr) : let closure_conversion (p : 'm program) : 'm program Bindlib.box = let new_scopes, _ = - D.fold_left_scope_defs + Expr.fold_left_scope_defs ~f:(fun (acc_new_scopes, global_vars) scope scope_var -> (* [acc_new_scopes] represents what has been translated in the past, it needs a continuation to attach the rest of the translated scopes. *) @@ -289,12 +290,12 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box = let ctx = { name_context = - Marked.unmark (Dcalc.Ast.ScopeName.get_info scope.scope_name); + Marked.unmark (ScopeName.get_info scope.scope_name); globally_bound_vars = global_vars; } in let new_scope_lets = - D.map_exprs_in_scope_lets + Expr.map_exprs_in_scope_lets ~f:(closure_conversion_expr ctx) ~varf:(fun v -> v) scope_body_expr @@ -306,7 +307,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box = acc_new_scopes (Bindlib.box_apply2 (fun new_scope_body_expr next -> - D.ScopeDef + ScopeDef { scope with scope_body = @@ -327,4 +328,4 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box = in Bindlib.box_apply (fun new_scopes -> { p with scopes = new_scopes }) - (new_scopes (Bindlib.box D.Nil)) + (new_scopes (Bindlib.box Nil)) diff --git a/compiler/lcalc/compile_with_exceptions.ml b/compiler/lcalc/compile_with_exceptions.ml index 4850891d..2d231195 100644 --- a/compiler/lcalc/compile_with_exceptions.ml +++ b/compiler/lcalc/compile_with_exceptions.ml @@ -25,26 +25,26 @@ type 'm ctx = ('m D.expr, 'm A.expr Var.t) Var.Map.t let translate_lit (l : D.lit) : 'm A.expr = match l with - | D.LBool l -> A.ELit (A.LBool l) - | D.LInt i -> A.ELit (A.LInt i) - | D.LRat r -> A.ELit (A.LRat r) - | D.LMoney m -> A.ELit (A.LMoney m) - | D.LUnit -> A.ELit A.LUnit - | D.LDate d -> A.ELit (A.LDate d) - | D.LDuration d -> A.ELit (A.LDuration d) - | D.LEmptyError -> A.ERaise A.EmptyError + | LBool l -> ELit (LBool l) + | LInt i -> ELit (LInt i) + | LRat r -> ELit (LRat r) + | LMoney m -> ELit (LMoney m) + | LUnit -> ELit LUnit + | LDate d -> ELit (LDate d) + | LDuration d -> ELit (LDuration d) + | LEmptyError -> ERaise EmptyError -let thunk_expr (e : 'm A.marked_expr Bindlib.box) (mark : 'm A.mark) : +let thunk_expr (e : 'm A.marked_expr Bindlib.box) (mark : 'm mark) : 'm A.marked_expr Bindlib.box = let dummy_var = Var.make "_" in - A.make_abs [| dummy_var |] e [D.TAny, D.mark_pos mark] mark + A.make_abs [| dummy_var |] e [TAny, Expr.mark_pos mark] mark let rec translate_default (ctx : 'm ctx) (exceptions : 'm D.marked_expr list) (just : 'm D.marked_expr) (cons : 'm D.marked_expr) - (mark_default : 'm D.mark) : 'm A.marked_expr Bindlib.box = + (mark_default : 'm mark) : 'm A.marked_expr Bindlib.box = let exceptions = List.map (fun except -> thunk_expr (translate_expr ctx except) mark_default) @@ -54,7 +54,7 @@ let rec translate_default A.make_app (A.make_var (Var.translate A.handle_default, mark_default)) [ - A.earray exceptions mark_default; + Expr.earray exceptions mark_default; thunk_expr (translate_expr ctx just) mark_default; thunk_expr (translate_expr ctx cons) mark_default; ] @@ -65,34 +65,34 @@ let rec translate_default and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) : 'm A.marked_expr Bindlib.box = match Marked.unmark e with - | D.EVar v -> A.make_var (Var.Map.find v ctx, Marked.get_mark e) - | D.ETuple (args, s) -> - A.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e) - | D.ETupleAccess (e1, i, s, ts) -> - A.etupleaccess (translate_expr ctx e1) i s ts (Marked.get_mark e) - | D.EInj (e1, i, en, ts) -> - A.einj (translate_expr ctx e1) i en ts (Marked.get_mark e) - | D.EMatch (e1, cases, en) -> - A.ematch (translate_expr ctx e1) + | EVar v -> A.make_var (Var.Map.find v ctx, Marked.get_mark e) + | ETuple (args, s) -> + Expr.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e) + | ETupleAccess (e1, i, s, ts) -> + Expr.etupleaccess (translate_expr ctx e1) i s ts (Marked.get_mark e) + | EInj (e1, i, en, ts) -> + Expr.einj (translate_expr ctx e1) i en ts (Marked.get_mark e) + | EMatch (e1, cases, en) -> + Expr.ematch (translate_expr ctx e1) (List.map (translate_expr ctx) cases) en (Marked.get_mark e) - | D.EArray es -> - A.earray (List.map (translate_expr ctx) es) (Marked.get_mark e) - | D.ELit l -> Bindlib.box (Marked.same_mark_as (translate_lit l) e) - | D.EOp op -> A.eop op (Marked.get_mark e) - | D.EIfThenElse (e1, e2, e3) -> - A.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2) + | EArray es -> + Expr.earray (List.map (translate_expr ctx) es) (Marked.get_mark e) + | ELit l -> Bindlib.box (Marked.same_mark_as (translate_lit l) e) + | EOp op -> Expr.eop op (Marked.get_mark e) + | EIfThenElse (e1, e2, e3) -> + Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2) (translate_expr ctx e3) (Marked.get_mark e) - | D.EAssert e1 -> A.eassert (translate_expr ctx e1) (Marked.get_mark e) - | D.ErrorOnEmpty arg -> - A.ecatch (translate_expr ctx arg) A.EmptyError - (Bindlib.box (Marked.same_mark_as (A.ERaise A.NoValueProvided) e)) + | EAssert e1 -> Expr.eassert (translate_expr ctx e1) (Marked.get_mark e) + | ErrorOnEmpty arg -> + Expr.ecatch (translate_expr ctx arg) EmptyError + (Bindlib.box (Marked.same_mark_as (ERaise NoValueProvided) e)) (Marked.get_mark e) - | D.EApp (e1, args) -> - A.eapp (translate_expr ctx e1) + | EApp (e1, args) -> + Expr.eapp (translate_expr ctx e1) (List.map (translate_expr ctx) args) (Marked.get_mark e) - | D.EAbs (binder, ts) -> + | EAbs (binder, ts) -> let vars, body = Bindlib.unmbind binder in let ctx, lc_vars = Array.fold_right @@ -105,24 +105,24 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.marked_expr) : let new_body = translate_expr ctx body in let new_binder = Bindlib.bind_mvar lc_vars new_body in Bindlib.box_apply - (fun new_binder -> Marked.same_mark_as (A.EAbs (new_binder, ts)) e) + (fun new_binder -> Marked.same_mark_as (EAbs (new_binder, ts)) e) new_binder - | D.EDefault ([exn], just, cons) when !Cli.optimize_flag -> - A.ecatch (translate_expr ctx exn) A.EmptyError - (A.eifthenelse (translate_expr ctx just) (translate_expr ctx cons) - (Bindlib.box (Marked.same_mark_as (A.ERaise A.EmptyError) e)) + | EDefault ([exn], just, cons) when !Cli.optimize_flag -> + Expr.ecatch (translate_expr ctx exn) EmptyError + (Expr.eifthenelse (translate_expr ctx just) (translate_expr ctx cons) + (Bindlib.box (Marked.same_mark_as (ERaise EmptyError) e)) (Marked.get_mark e)) (Marked.get_mark e) - | D.EDefault (exceptions, just, cons) -> + | EDefault (exceptions, just, cons) -> translate_default ctx exceptions just cons (Marked.get_mark e) let rec translate_scope_lets - (decl_ctx : D.decl_ctx) + (decl_ctx : decl_ctx) (ctx : 'm ctx) - (scope_lets : ('m D.expr, 'm) D.scope_body_expr) : - ('m A.expr, 'm) D.scope_body_expr Bindlib.box = + (scope_lets : ('m D.expr, 'm) scope_body_expr) : + ('m A.expr, 'm) scope_body_expr Bindlib.box = match scope_lets with - | Result e -> Bindlib.box_apply (fun e -> D.Result e) (translate_expr ctx e) + | Result e -> Bindlib.box_apply (fun e -> Result e) (translate_expr ctx e) | ScopeLet scope_let -> let old_scope_let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next @@ -134,26 +134,26 @@ let rec translate_scope_lets let new_scope_next = Bindlib.bind_var new_scope_let_var new_scope_next in Bindlib.box_apply2 (fun new_scope_next new_scope_let_expr -> - D.ScopeLet + ScopeLet { - scope_let_typ = scope_let.D.scope_let_typ; - scope_let_kind = scope_let.D.scope_let_kind; - scope_let_pos = scope_let.D.scope_let_pos; + scope_let_typ = scope_let.scope_let_typ; + scope_let_kind = scope_let.scope_let_kind; + scope_let_pos = scope_let.scope_let_pos; scope_let_next = new_scope_next; scope_let_expr = new_scope_let_expr; }) new_scope_next new_scope_let_expr let rec translate_scopes - (decl_ctx : D.decl_ctx) + (decl_ctx : decl_ctx) (ctx : 'm ctx) - (scopes : ('m D.expr, 'm) D.scopes) : ('m A.expr, 'm) D.scopes Bindlib.box = + (scopes : ('m D.expr, 'm) scopes) : ('m A.expr, 'm) scopes Bindlib.box = match scopes with - | Nil -> Bindlib.box D.Nil + | Nil -> Bindlib.box Nil | ScopeDef scope_def -> let old_scope_var, scope_next = Bindlib.unbind scope_def.scope_next in let new_scope_var = - Var.make (Marked.unmark (D.ScopeName.get_info scope_def.scope_name)) + Var.make (Marked.unmark (ScopeName.get_info scope_def.scope_name)) in let old_scope_input_var, scope_body_expr = Bindlib.unbind scope_def.scope_body.scope_body_expr @@ -166,11 +166,11 @@ let rec translate_scopes let new_scope_body_expr = Bindlib.bind_var new_scope_input_var new_scope_body_expr in - let new_scope : ('m A.expr, 'm) D.scope_body Bindlib.box = + let new_scope : ('m A.expr, 'm) scope_body Bindlib.box = Bindlib.box_apply (fun new_scope_body_expr -> { - D.scope_body_input_struct = + scope_body_input_struct = scope_def.scope_body.scope_body_input_struct; scope_body_output_struct = scope_def.scope_body.scope_body_output_struct; @@ -185,7 +185,7 @@ let rec translate_scopes in Bindlib.box_apply2 (fun new_scope scope_next -> - D.ScopeDef + ScopeDef { scope_name = scope_def.scope_name; scope_body = new_scope; diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 63ade836..991d7e34 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -61,7 +61,7 @@ let pp_info (fmt : Format.formatter) (info : 'm info) = info.is_pure type 'm ctx = { - decl_ctx : D.decl_ctx; + decl_ctx : decl_ctx; vars : ('m D.expr, 'm info) Var.Map.t; (** information context about variables in the current scope *) } @@ -95,7 +95,7 @@ let find ?(info : string = "none") (n : 'm D.var) (ctx : 'm ctx) : 'm info = var, creating a unique corresponding variable in Lcalc, with the corresponding expression, and the boolean is_pure. It is usefull for debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *) -let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx) +let add_var (mark : 'm mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx) : 'm ctx = let new_var = Var.make (Bindlib.name_of var) in let expr = A.make_var (new_var, mark) in @@ -115,33 +115,33 @@ let add_var (mark : 'm D.mark) (var : 'm D.var) (is_pure : bool) (ctx : 'm ctx) Since positions where there is thunked expressions is exactly where we will put option expressions. Hence, the transformation simply reduce [unit -> 'a] into ['a option] recursivly. There is no polymorphism inside catala. *) -let rec translate_typ (tau : D.typ Marked.pos) : D.typ Marked.pos = +let rec translate_typ (tau : typ Marked.pos) : typ Marked.pos = (Fun.flip Marked.same_mark_as) tau begin match Marked.unmark tau with - | D.TLit l -> D.TLit l - | D.TTuple (ts, s) -> D.TTuple (List.map translate_typ ts, s) - | D.TEnum (ts, en) -> D.TEnum (List.map translate_typ ts, en) - | D.TAny -> D.TAny - | D.TArray ts -> D.TArray (translate_typ ts) + | TLit l -> TLit l + | TTuple (ts, s) -> TTuple (List.map translate_typ ts, s) + | TEnum (ts, en) -> TEnum (List.map translate_typ ts, en) + | TAny -> TAny + | TArray ts -> TArray (translate_typ ts) (* catala is not polymorphic *) - | D.TArrow ((D.TLit D.TUnit, pos_unit), t2) -> - D.TEnum ([D.TLit D.TUnit, pos_unit; translate_typ t2], A.option_enum) - (* D.TAny *) - | D.TArrow (t1, t2) -> D.TArrow (translate_typ t1, translate_typ t2) + | TArrow ((TLit TUnit, pos_unit), t2) -> + TEnum ([TLit TUnit, pos_unit; translate_typ t2], A.option_enum) + (* TAny *) + | TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2) end let translate_lit (l : D.lit) (pos : Pos.t) : A.lit = match l with - | D.LBool l -> A.LBool l - | D.LInt i -> A.LInt i - | D.LRat r -> A.LRat r - | D.LMoney m -> A.LMoney m - | D.LUnit -> A.LUnit - | D.LDate d -> A.LDate d - | D.LDuration d -> A.LDuration d - | D.LEmptyError -> + | LBool l -> LBool l + | LInt i -> LInt i + | LRat r -> LRat r + | LMoney m -> LMoney m + | LUnit -> LUnit + | LDate d -> LDate d + | LDuration d -> LDuration d + | LEmptyError -> Errors.raise_spanned_error pos "Internal Error: An empty error was found in a place that shouldn't be \ possible." @@ -171,7 +171,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) : (* empty-producing/using terms. We hoist those. (D.EVar in some cases, EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure about assert. *) - | D.EVar v -> + | EVar v -> (* todo: for now, every unpure (such that [is_pure] is [false] in the current context) is thunked, hence matched in the next case. This assumption can change in the future, and this case is here for this @@ -183,20 +183,20 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) : Print.format_var v'; *) A.make_var (v', pos), Var.Map.singleton v' e else (find ~info:"should never happend" v ctx).expr, Var.Map.empty - | D.EApp ((D.EVar v, p), [(D.ELit D.LUnit, _)]) -> + | EApp ((EVar v, p), [(ELit LUnit, _)]) -> if not (find ~info:"search for a variable" v ctx).is_pure then let v' = Var.make (Bindlib.name_of v) in (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to replace it" Dcalc.Print.format_var v Print.format_var v'; *) - A.make_var (v', pos), Var.Map.singleton v' (D.EVar v, p) + A.make_var (v', pos), Var.Map.singleton v' (EVar v, p) else - Errors.raise_spanned_error (D.pos e) + Errors.raise_spanned_error (Expr.pos e) "Internal error: an pure variable was found in an unpure environment." - | D.EDefault (_exceptions, _just, _cons) -> + | EDefault (_exceptions, _just, _cons) -> let v' = Var.make "default_term" in A.make_var (v', pos), Var.Map.singleton v' e - | D.ELit D.LEmptyError -> + | ELit LEmptyError -> let v' = Var.make "empty_litteral" in A.make_var (v', pos), Var.Map.singleton v' e (* This one is a very special case. It transform an unpure expression @@ -210,29 +210,29 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) : ( A.make_matchopt_with_abs_arms arg' (A.make_abs [| silent_var |] - (Bindlib.box (A.ERaise A.NoValueProvided, pos)) - [D.TAny, D.pos e] + (Bindlib.box (ERaise NoValueProvided, pos)) + [TAny, Expr.pos e] pos) - (A.make_abs [| x |] (A.make_var (x, pos)) [D.TAny, D.pos e] pos), + (A.make_abs [| x |] (A.make_var (x, pos)) [TAny, Expr.pos e] pos), Var.Map.empty ) (* pure terms *) - | D.ELit l -> A.elit (translate_lit l (D.pos e)) pos, Var.Map.empty - | D.EIfThenElse (e1, e2, e3) -> + | ELit l -> Expr.elit (translate_lit l (Expr.pos e)) pos, Var.Map.empty + | EIfThenElse (e1, e2, e3) -> let e1', h1 = translate_and_hoist ctx e1 in let e2', h2 = translate_and_hoist ctx e2 in let e3', h3 = translate_and_hoist ctx e3 in - let e' = A.eifthenelse e1' e2' e3' pos in + let e' = Expr.eifthenelse e1' e2' e3' pos in (*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' = e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *) - e', disjoint_union_maps (D.pos e) [h1; h2; h3] - | D.EAssert e1 -> + e', disjoint_union_maps (Expr.pos e) [h1; h2; h3] + | EAssert e1 -> (* same behavior as in the ICFP paper: if e1 is empty, then no error is raised. *) let e1', h1 = translate_and_hoist ctx e1 in - A.eassert e1' pos, h1 - | D.EAbs (binder, ts) -> + Expr.eassert e1' pos, h1 + | EAbs (binder, ts) -> let vars, body = Bindlib.unmbind binder in let ctx, lc_vars = ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) -> @@ -254,7 +254,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) : let new_binder = Bindlib.bind_mvar lc_vars new_body in ( Bindlib.box_apply - (fun new_binder -> A.EAbs (new_binder, List.map translate_typ ts), pos) + (fun new_binder -> EAbs (new_binder, List.map translate_typ ts), pos) new_binder, hoists ) | EApp (e1, args) -> @@ -263,23 +263,23 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) : args |> List.map (translate_and_hoist ctx) |> List.split in - let hoists = disjoint_union_maps (D.pos e) (h1 :: h_args) in - let e' = A.eapp e1' args' pos in + let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in + let e' = Expr.eapp e1' args' pos in e', hoists | ETuple (args, s) -> let args', h_args = args |> List.map (translate_and_hoist ctx) |> List.split in - let hoists = disjoint_union_maps (D.pos e) h_args in - A.etuple args' s pos, hoists + let hoists = disjoint_union_maps (Expr.pos e) h_args in + Expr.etuple args' s pos, hoists | ETupleAccess (e1, i, s, ts) -> let e1', hoists = translate_and_hoist ctx e1 in - let e1' = A.etupleaccess e1' i s ts pos in + let e1' = Expr.etupleaccess e1' i s ts pos in e1', hoists | EInj (e1, i, en, ts) -> let e1', hoists = translate_and_hoist ctx e1 in - let e1' = A.einj e1' i en ts pos in + let e1' = Expr.einj e1' i en ts pos in e1', hoists | EMatch (e1, cases, en) -> let e1', h1 = translate_and_hoist ctx e1 in @@ -287,14 +287,14 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.marked_expr) : cases |> List.map (translate_and_hoist ctx) |> List.split in - let hoists = disjoint_union_maps (D.pos e) (h1 :: h_cases) in - let e' = A.ematch e1' cases' en pos in + let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_cases) in + let e' = Expr.ematch e1' cases' en pos in e', hoists | EArray es -> let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in - A.earray es' pos, disjoint_union_maps (D.pos e) hoists - | EOp op -> Bindlib.box (A.EOp op, pos), Var.Map.empty + Expr.earray es' pos, disjoint_union_maps (Expr.pos e) hoists + | EOp op -> Bindlib.box (EOp op, pos), Var.Map.empty and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr) : 'm A.marked_expr Bindlib.box = @@ -315,8 +315,8 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr) match hoist with (* Here we have to handle only the cases appearing in hoists, as defined the [translate_and_hoist] function. *) - | D.EVar v -> (find ~info:"should never happend" v ctx).expr - | D.EDefault (excep, just, cons) -> + | EVar v -> (find ~info:"should never happend" v ctx).expr + | EDefault (excep, just, cons) -> let excep' = List.map (translate_expr ctx) excep in let just' = translate_expr ctx just in let cons' = translate_expr ctx cons in @@ -325,14 +325,14 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr) (A.make_var (Var.translate A.handle_default_opt, mark_hoist)) [ Bindlib.box_apply - (fun excep' -> A.EArray excep', mark_hoist) + (fun excep' -> EArray excep', mark_hoist) (Bindlib.box_list excep'); just'; cons'; ] mark_hoist - | D.ELit D.LEmptyError -> A.make_none mark_hoist - | D.EAssert arg -> + | ELit LEmptyError -> A.make_none mark_hoist + | EAssert arg -> let arg' = translate_expr ctx arg in (* [ match arg with | None -> raise NoValueProvided | Some v -> assert @@ -342,17 +342,17 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr) A.make_matchopt_with_abs_arms arg' (A.make_abs [| silent_var |] - (Bindlib.box (A.ERaise A.NoValueProvided, mark_hoist)) - [D.TAny, D.mark_pos mark_hoist] + (Bindlib.box (ERaise NoValueProvided, mark_hoist)) + [TAny, Expr.mark_pos mark_hoist] mark_hoist) (A.make_abs [| x |] (Bindlib.box_apply - (fun arg -> A.EAssert arg, mark_hoist) + (fun arg -> EAssert arg, mark_hoist) (A.make_var (x, mark_hoist))) - [D.TAny, D.mark_pos mark_hoist] + [TAny, Expr.mark_pos mark_hoist] mark_hoist) | _ -> - Errors.raise_spanned_error (D.mark_pos mark_hoist) + Errors.raise_spanned_error (Expr.mark_pos mark_hoist) "Internal Error: An term was found in a position where it should \ not be" in @@ -362,23 +362,23 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.marked_expr) (* Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.format_var v; *) A.make_matchopt mark_hoist v - (D.TAny, D.mark_pos mark_hoist) + (TAny, Expr.mark_pos mark_hoist) c' (A.make_none mark_hoist) acc) let rec translate_scope_let (ctx : 'm ctx) - (lets : ('m D.expr, 'm) D.scope_body_expr) : - ('m A.expr, 'm) D.scope_body_expr Bindlib.box = + (lets : ('m D.expr, 'm) scope_body_expr) : + ('m A.expr, 'm) scope_body_expr Bindlib.box = match lets with | Result e -> Bindlib.box_apply - (fun e -> D.Result e) + (fun e -> Result e) (translate_expr ~append_esome:false ctx e) | ScopeLet { scope_let_kind = SubScopeVarDefinition; scope_let_typ = typ; - scope_let_expr = D.EAbs (binder, _), emark; + scope_let_expr = EAbs (binder, _), emark; scope_let_next = next; scope_let_pos = pos; } -> @@ -390,13 +390,13 @@ let rec translate_scope_let let var, next = Bindlib.unbind next in (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *) - let vmark = D.map_mark (fun _ -> pos) (fun _ -> typ) emark in + let vmark = Expr.map_mark (fun _ -> pos) (fun _ -> typ) emark in let ctx' = add_var vmark var var_is_pure ctx in let new_var = (find ~info:"variable that was just created" var ctx').var in let new_next = translate_scope_let ctx' next in Bindlib.box_apply2 (fun new_expr new_next -> - D.ScopeLet + ScopeLet { scope_let_kind = SubScopeVarDefinition; scope_let_typ = translate_typ typ; @@ -410,7 +410,7 @@ let rec translate_scope_let { scope_let_kind = SubScopeVarDefinition; scope_let_typ = typ; - scope_let_expr = (D.ErrorOnEmpty _, emark) as expr; + scope_let_expr = (ErrorOnEmpty _, emark) as expr; scope_let_next = next; scope_let_pos = pos; } -> @@ -419,12 +419,12 @@ let rec translate_scope_let let var, next = Bindlib.unbind next in (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *) - let vmark = D.map_mark (fun _ -> pos) (fun _ -> typ) emark in + let vmark = Expr.map_mark (fun _ -> pos) (fun _ -> typ) emark in let ctx' = add_var vmark var var_is_pure ctx in let new_var = (find ~info:"variable that was just created" var ctx').var in Bindlib.box_apply2 (fun new_expr new_next -> - D.ScopeLet + ScopeLet { scope_let_kind = SubScopeVarDefinition; scope_let_typ = translate_typ typ; @@ -463,7 +463,7 @@ let rec translate_scope_let thunked, then the variable is context. If it's not thunked, it's a regular input. *) match Marked.unmark typ with - | D.TArrow ((D.TLit D.TUnit, _), _) -> false + | TArrow ((TLit TUnit, _), _) -> false | _ -> true) | ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope | DestructuringSubScopeResults | Assertion -> @@ -473,13 +473,13 @@ let rec translate_scope_let (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *) let vmark = - D.map_mark (fun _ -> pos) (fun _ -> typ) (Marked.get_mark expr) + Expr.map_mark (fun _ -> pos) (fun _ -> typ) (Marked.get_mark expr) in let ctx' = add_var vmark var var_is_pure ctx in let new_var = (find ~info:"variable that was just created" var ctx').var in Bindlib.box_apply2 (fun new_expr new_next -> - D.ScopeLet + ScopeLet { scope_let_kind = kind; scope_let_typ = translate_typ typ; @@ -493,8 +493,8 @@ let rec translate_scope_let let translate_scope_body (scope_pos : Pos.t) (ctx : 'm ctx) - (body : ('m D.expr, 'm) D.scope_body) : - ('m A.expr, 'm) D.scope_body Bindlib.box = + (body : ('m D.expr, 'm) scope_body) : + ('m A.expr, 'm) scope_body Bindlib.box = match body with | { scope_body_expr = result; @@ -507,23 +507,23 @@ let translate_scope_body match lets with | Result e | ScopeLet { scope_let_expr = e; _ } -> Marked.get_mark e in - D.map_mark (fun _ -> scope_pos) (fun ty -> ty) m + Expr.map_mark (fun _ -> scope_pos) (fun ty -> ty) m in let ctx' = add_var vmark v true ctx in let v' = (find ~info:"variable that was just created" v ctx').var in Bindlib.box_apply (fun new_expr -> { - D.scope_body_expr = new_expr; + scope_body_expr = new_expr; scope_body_input_struct = input_struct; scope_body_output_struct = output_struct; }) (Bindlib.bind_var v' (translate_scope_let ctx' lets)) -let rec translate_scopes (ctx : 'm ctx) (scopes : ('m D.expr, 'm) D.scopes) : - ('m A.expr, 'm) D.scopes Bindlib.box = +let rec translate_scopes (ctx : 'm ctx) (scopes : ('m D.expr, 'm) scopes) : + ('m A.expr, 'm) scopes Bindlib.box = match scopes with - | Nil -> Bindlib.box D.Nil + | Nil -> Bindlib.box Nil | ScopeDef { scope_name; scope_body; scope_next } -> let scope_var, next = Bindlib.unbind scope_next in let vmark = @@ -536,21 +536,21 @@ let rec translate_scopes (ctx : 'm ctx) (scopes : ('m D.expr, 'm) D.scopes) : (find ~info:"variable that was just created" scope_var new_ctx).var in - let scope_pos = Marked.get_mark (D.ScopeName.get_info scope_name) in + let scope_pos = Marked.get_mark (ScopeName.get_info scope_name) in let new_body = translate_scope_body scope_pos ctx scope_body in let tail = translate_scopes new_ctx next in Bindlib.box_apply2 (fun body tail -> - D.ScopeDef { scope_name; scope_body = body; scope_next = tail }) + ScopeDef { scope_name; scope_body = body; scope_next = tail }) new_body (Bindlib.bind_var new_scope_name tail) let translate_program (prgm : 'm D.program) : 'm A.program = let inputs_structs = - D.fold_left_scope_defs prgm.scopes ~init:[] ~f:(fun acc scope_def _ -> - scope_def.D.scope_body.scope_body_input_struct :: acc) + Expr.fold_left_scope_defs prgm.scopes ~init:[] ~f:(fun acc scope_def _ -> + scope_def.scope_body.scope_body_input_struct :: acc) in (* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]" @@ -558,17 +558,17 @@ let translate_program (prgm : 'm D.program) : 'm A.program = let decl_ctx = { prgm.decl_ctx with - D.ctx_enums = + ctx_enums = prgm.decl_ctx.ctx_enums - |> D.EnumMap.add A.option_enum A.option_enum_config; + |> EnumMap.add A.option_enum A.option_enum_config; } in let decl_ctx = { decl_ctx with - D.ctx_structs = + ctx_structs = prgm.decl_ctx.ctx_structs - |> D.StructMap.mapi (fun n l -> + |> StructMap.mapi (fun n l -> if List.mem n inputs_structs then ListLabels.map l ~f:(fun (n, tau) -> (* Cli.debug_print @@ Format.asprintf "Input type: %a" diff --git a/compiler/lcalc/optimizations.ml b/compiler/lcalc/optimizations.ml index c31f3c5e..0dc4c6da 100644 --- a/compiler/lcalc/optimizations.ml +++ b/compiler/lcalc/optimizations.ml @@ -14,6 +14,7 @@ License for the specific language governing permissions and limitations under the License. *) open Utils +open Shared_ast open Ast module D = Dcalc.Ast @@ -71,7 +72,7 @@ let rec iota_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box = let default_mark e' = Marked.mark (Marked.get_mark e) e' in match Marked.unmark e with | EMatch ((EInj (e1, i, n', _ts), _), cases, n) - when Dcalc.Ast.EnumName.compare n n' = 0 -> + when EnumName.compare n n' = 0 -> let+ e1 = visitor_map iota_expr () e1 and+ case = visitor_map iota_expr () (List.nth cases i) in default_mark @@ EApp (case, [e1]) @@ -80,7 +81,7 @@ let rec iota_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box = |> List.mapi (fun i (case, _pos) -> match case with | EInj (_ei, i', n', _ts') -> - i = i' && (* n = n' *) Dcalc.Ast.EnumName.compare n n' = 0 + i = i' && (* n = n' *) EnumName.compare n n' = 0 | _ -> false) |> List.for_all Fun.id -> visitor_map iota_expr () e' @@ -101,9 +102,9 @@ let rec beta_expr (_ : unit) (e : 'm marked_expr) : 'm marked_expr Bindlib.box = let iota_optimizations (p : 'm program) : 'm program = let new_scopes = - Dcalc.Ast.map_exprs_in_scopes ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes + Expr.map_exprs_in_scopes ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes in - { p with D.scopes = Bindlib.unbox new_scopes } + { p with scopes = Bindlib.unbox new_scopes } (* TODO: beta optimizations apply inlining of the program. We left the inclusion of beta-optimization as future work since its produce code that is harder to @@ -111,7 +112,7 @@ let iota_optimizations (p : 'm program) : 'm program = program. *) let _beta_optimizations (p : 'm program) : 'm program = let new_scopes = - Dcalc.Ast.map_exprs_in_scopes ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes + Expr.map_exprs_in_scopes ~f:(beta_expr ()) ~varf:(fun v -> v) p.scopes in { p with scopes = Bindlib.unbox new_scopes } @@ -145,11 +146,11 @@ let rec peephole_expr (_ : unit) (e : 'm marked_expr) : let peephole_optimizations (p : 'm program) : 'm program = let new_scopes = - Dcalc.Ast.map_exprs_in_scopes ~f:(peephole_expr ()) + Expr.map_exprs_in_scopes ~f:(peephole_expr ()) ~varf:(fun v -> v) p.scopes in { p with scopes = Bindlib.unbox new_scopes } -let optimize_program (p : 'm program) : Dcalc.Ast.untyped program = - p |> iota_optimizations |> peephole_optimizations |> untype_program +let optimize_program (p : 'm program) : untyped program = + p |> iota_optimizations |> peephole_optimizations |> Expr.untype_program diff --git a/compiler/lcalc/optimizations.mli b/compiler/lcalc/optimizations.mli index da3af2c5..8c0c0b03 100644 --- a/compiler/lcalc/optimizations.mli +++ b/compiler/lcalc/optimizations.mli @@ -16,6 +16,6 @@ open Ast -val optimize_program : 'm program -> Dcalc.Ast.untyped program +val optimize_program : 'm program -> Shared_ast.untyped program (** Warning/todo: no effort was yet made to ensure correct propagation of type annotations in the typed case *) diff --git a/compiler/lcalc/print.ml b/compiler/lcalc/print.ml index 661ceec7..f7caa5b9 100644 --- a/compiler/lcalc/print.ml +++ b/compiler/lcalc/print.ml @@ -15,6 +15,7 @@ the License. *) open Utils +open Shared_ast open Ast (** {b Note:} (EmileRolley) seems to be factorizable with @@ -64,7 +65,7 @@ let format_var (fmt : Format.formatter) (v : 'm Ast.var) : unit = let rec format_expr ?(debug : bool = false) - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm marked_expr) : unit = let format_expr = format_expr ctx ~debug in @@ -83,16 +84,16 @@ let rec format_expr (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) es format_punctuation ")" | ETuple (es, Some s) -> - Format.fprintf fmt "@[%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s + Format.fprintf fmt "@[%a@ %a%a%a@]" StructName.format_t s format_punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, struct_field) -> Format.fprintf fmt "%a%a%a%a %a" format_punctuation "\"" - Dcalc.Ast.StructFieldName.format_t struct_field format_punctuation + StructFieldName.format_t struct_field format_punctuation "\"" format_punctuation ":" format_expr e)) (List.combine es - (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) + (List.map fst (StructMap.find s ctx.ctx_structs))) format_punctuation "}" | EArray es -> Format.fprintf fmt "@[%a%a%a@]" format_punctuation "[" @@ -106,12 +107,12 @@ let rec format_expr Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n | Some s -> Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "." - format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t - (fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n)) + format_punctuation "\"" StructFieldName.format_t + (fst (List.nth (StructMap.find s ctx.ctx_structs) n)) format_punctuation "\"") | EInj (e, n, en, _ts) -> Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_enum_constructor - (fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n)) + (fst (List.nth (EnumMap.find en ctx.ctx_enums) n)) format_expr e | EMatch (e, es, e_name) -> Format.fprintf fmt "@[%a@ %a@ %a@ %a@]" format_keyword "match" @@ -123,9 +124,9 @@ let rec format_expr Dcalc.Print.format_enum_constructor c format_punctuation ":" format_expr e)) (List.combine es - (List.map fst (Dcalc.Ast.EnumMap.find e_name ctx.ctx_enums))) + (List.map fst (EnumMap.find e_name ctx.ctx_enums))) | ELit l -> - Format.fprintf fmt "%a" format_lit (Marked.mark (Dcalc.Ast.pos e) l) + Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l) | EApp ((EAbs (binder, taus), _), args) -> let xs, body = Bindlib.unmbind binder in Format.fprintf fmt "%a%a" @@ -152,7 +153,7 @@ let rec format_expr (List.combine (Array.to_list xs) taus) format_punctuation "→" format_expr body | EApp - ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> Format.fprintf fmt "@[%a@ %a@ %a@]" Dcalc.Print.format_binop op format_with_parens arg1 format_with_parens arg2 @@ -190,11 +191,11 @@ let rec format_expr let format_scope ?(debug = false) ctx fmt (n, s) = Format.fprintf fmt "@[%a %a =@ %a@]" format_keyword "let" - Dcalc.Ast.ScopeName.format_t n (format_expr ctx ~debug) + ScopeName.format_t n (format_expr ctx ~debug) (Bindlib.unbox (Dcalc.Ast.build_whole_scope_expr ~make_abs:Ast.make_abs - ~make_let_in:Ast.make_let_in ~box_expr:Ast.box_expr ctx s - (Dcalc.Ast.map_mark - (fun _ -> Marked.get_mark (Dcalc.Ast.ScopeName.get_info n)) + ~make_let_in:Ast.make_let_in ~box_expr:Expr.box ctx s + (Expr.map_mark + (fun _ -> Marked.get_mark (ScopeName.get_info n)) (fun ty -> ty) - (Dcalc.Ast.get_scope_body_mark s)))) + (Expr.get_scope_body_mark s)))) diff --git a/compiler/lcalc/print.mli b/compiler/lcalc/print.mli index 938aa3d5..46e7b71b 100644 --- a/compiler/lcalc/print.mli +++ b/compiler/lcalc/print.mli @@ -15,23 +15,24 @@ the License. *) open Utils +open Shared_ast (** {1 Formatters} *) val format_lit : Format.formatter -> Ast.lit Marked.pos -> unit val format_var : Format.formatter -> 'm Ast.var -> unit -val format_exception : Format.formatter -> Ast.except -> unit +val format_exception : Format.formatter -> except -> unit val format_expr : ?debug:bool -> - Dcalc.Ast.decl_ctx -> + decl_ctx -> Format.formatter -> 'm Ast.marked_expr -> unit val format_scope : ?debug:bool -> - Dcalc.Ast.decl_ctx -> + decl_ctx -> Format.formatter -> - Dcalc.Ast.ScopeName.t * ('m Ast.expr, 'm) Dcalc.Ast.scope_body -> + ScopeName.t * ('m Ast.expr, 'm) scope_body -> unit diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 2de4ea73..1c53f24b 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -15,37 +15,38 @@ the License. *) open Utils +open Shared_ast open Ast open String_common module D = Dcalc.Ast -let find_struct (s : D.StructName.t) (ctx : D.decl_ctx) : - (D.StructFieldName.t * D.typ Marked.pos) list = - try D.StructMap.find s ctx.D.ctx_structs +let find_struct (s : StructName.t) (ctx : decl_ctx) : + (StructFieldName.t * typ Marked.pos) list = + try StructMap.find s ctx.ctx_structs with Not_found -> - let s_name, pos = D.StructName.get_info s in + let s_name, pos = StructName.get_info s in Errors.raise_spanned_error pos "Internal Error: Structure %s was not found in the current environment." s_name -let find_enum (en : D.EnumName.t) (ctx : D.decl_ctx) : - (D.EnumConstructor.t * D.typ Marked.pos) list = - try D.EnumMap.find en ctx.D.ctx_enums +let find_enum (en : EnumName.t) (ctx : decl_ctx) : + (EnumConstructor.t * typ Marked.pos) list = + try EnumMap.find en ctx.ctx_enums with Not_found -> - let en_name, pos = D.EnumName.get_info en in + let en_name, pos = EnumName.get_info en in Errors.raise_spanned_error pos "Internal Error: Enumeration %s was not found in the current environment." en_name let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit = match Marked.unmark l with - | LBool b -> Dcalc.Print.format_lit fmt (Dcalc.Ast.LBool b) + | LBool b -> Dcalc.Print.format_lit fmt (LBool b) | LInt i -> Format.fprintf fmt "integer_of_string@ \"%s\"" (Runtime.integer_to_string i) - | LUnit -> Dcalc.Print.format_lit fmt Dcalc.Ast.LUnit + | LUnit -> Dcalc.Print.format_lit fmt LUnit | LRat i -> Format.fprintf fmt "decimal_of_string \"%a\"" Dcalc.Print.format_lit - (Dcalc.Ast.LRat i) + (LRat i) | LMoney e -> Format.fprintf fmt "money_of_cents_string@ \"%s\"" (Runtime.integer_to_string (Runtime.money_to_cents e)) @@ -58,7 +59,7 @@ let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit = let years, months, days = Runtime.duration_to_years_months_days d in Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days -let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) = +let format_op_kind (fmt : Format.formatter) (k : op_kind) = Format.fprintf fmt "%s" (match k with | KInt -> "!" @@ -67,7 +68,7 @@ let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) = | KDate -> "@" | KDuration -> "^") -let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) : +let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit = match Marked.unmark op with | Add k -> Format.fprintf fmt "+%a" format_op_kind k @@ -86,7 +87,7 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) : | Map -> Format.fprintf fmt "Array.map" | Filter -> Format.fprintf fmt "array_filter" -let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Marked.pos) : +let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit = match Marked.unmark op with Fold -> Format.fprintf fmt "Array.fold_left" @@ -109,7 +110,7 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit = (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) uids -let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Marked.pos) : unit +let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit = match Marked.unmark op with | Minus k -> Format.fprintf fmt "~-%a" format_op_kind k @@ -145,9 +146,9 @@ let avoid_keywords (s : string) : string = s ^ "_user" | _ -> s -let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : +let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = - Format.asprintf "%a" Dcalc.Ast.StructName.format_t v + Format.asprintf "%a" StructName.format_t v |> to_ascii |> to_snake_case |> avoid_keywords @@ -155,10 +156,10 @@ let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : let format_to_module_name (fmt : Format.formatter) - (name : [< `Ename of D.EnumName.t | `Sname of D.StructName.t ]) = + (name : [< `Ename of EnumName.t | `Sname of StructName.t ]) = (match name with - | `Ename v -> Format.asprintf "%a" D.EnumName.format_t v - | `Sname v -> Format.asprintf "%a" D.StructName.format_t v) + | `Ename v -> Format.asprintf "%a" EnumName.format_t v + | `Sname v -> Format.asprintf "%a" StructName.format_t v) |> to_ascii |> to_snake_case |> avoid_keywords @@ -170,52 +171,52 @@ let format_to_module_name let format_struct_field_name (fmt : Format.formatter) ((sname_opt, v) : - Dcalc.Ast.StructName.t option * Dcalc.Ast.StructFieldName.t) : unit = + StructName.t option * StructFieldName.t) : unit = (match sname_opt with | Some sname -> Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname) | None -> Format.fprintf fmt "%s") (avoid_keywords - (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) + (to_ascii (Format.asprintf "%a" StructFieldName.format_t v))) -let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit +let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = Format.fprintf fmt "%s" (avoid_keywords (to_snake_case - (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) + (to_ascii (Format.asprintf "%a" EnumName.format_t v)))) let format_enum_cons_name (fmt : Format.formatter) - (v : Dcalc.Ast.EnumConstructor.t) : unit = + (v : EnumConstructor.t) : unit = Format.fprintf fmt "%s" (avoid_keywords - (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) + (to_ascii (Format.asprintf "%a" EnumConstructor.format_t v))) -let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Marked.pos) : +let rec typ_embedding_name (fmt : Format.formatter) (ty : typ Marked.pos) : unit = match Marked.unmark ty with - | D.TLit D.TUnit -> Format.fprintf fmt "embed_unit" - | D.TLit D.TBool -> Format.fprintf fmt "embed_bool" - | D.TLit D.TInt -> Format.fprintf fmt "embed_integer" - | D.TLit D.TRat -> Format.fprintf fmt "embed_decimal" - | D.TLit D.TMoney -> Format.fprintf fmt "embed_money" - | D.TLit D.TDate -> Format.fprintf fmt "embed_date" - | D.TLit D.TDuration -> Format.fprintf fmt "embed_duration" - | D.TTuple (_, Some s_name) -> + | TLit TUnit -> Format.fprintf fmt "embed_unit" + | TLit TBool -> Format.fprintf fmt "embed_bool" + | TLit TInt -> Format.fprintf fmt "embed_integer" + | TLit TRat -> Format.fprintf fmt "embed_decimal" + | TLit TMoney -> Format.fprintf fmt "embed_money" + | TLit TDate -> Format.fprintf fmt "embed_date" + | TLit TDuration -> Format.fprintf fmt "embed_duration" + | TTuple (_, Some s_name) -> Format.fprintf fmt "embed_%a" format_struct_name s_name - | D.TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name - | D.TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty + | TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name + | TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty | _ -> Format.fprintf fmt "unembeddable" -let typ_needs_parens (e : Dcalc.Ast.typ Marked.pos) : bool = +let typ_needs_parens (e : typ Marked.pos) : bool = match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false -let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : +let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) : unit = let format_typ_with_parens (fmt : Format.formatter) - (t : Dcalc.Ast.typ Marked.pos) = + (t : typ Marked.pos) = if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t else Format.fprintf fmt "%a" format_typ t in @@ -229,10 +230,10 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : ts | TTuple (_, Some s) -> Format.fprintf fmt "%a.t" format_to_module_name (`Sname s) - | TEnum ([t], e) when D.EnumName.compare e Ast.option_enum = 0 -> + | TEnum ([t], e) when EnumName.compare e Ast.option_enum = 0 -> Format.fprintf fmt "@[(%a)@] %a" format_typ_with_parens t format_enum_name e - | TEnum (_, e) when D.EnumName.compare e Ast.option_enum = 0 -> + | TEnum (_, e) when EnumName.compare e Ast.option_enum = 0 -> Errors.raise_spanned_error (Marked.get_mark typ) "Internal Error: found an typing parameter for an eoption type of the \ wrong length." @@ -290,7 +291,7 @@ let format_exception (fmt : Format.formatter) (exc : except Marked.pos) : unit = (Pos.get_law_info pos) let rec format_expr - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm marked_expr) : unit = let format_expr = format_expr ctx in @@ -360,7 +361,7 @@ let rec format_expr (* should not happen *)) e)) (List.combine es (List.map fst (find_enum e_name ctx))) - | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (D.pos e) l) + | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l) | EApp ((EAbs (binder, taus), _), args) -> let xs, body = Bindlib.unmbind binder in let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in @@ -382,35 +383,35 @@ let rec format_expr Format.fprintf fmt "@[(%a:@ %a)@]" format_var x format_typ tau)) xs_tau format_expr body | EApp - ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> Format.fprintf fmt "@[%a@ %a@ %a@]" format_binop (op, Pos.no_pos) format_with_parens arg1 format_with_parens arg2 | EApp ((EOp (Binop op), _), [arg1; arg2]) -> Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 format_binop (op, Pos.no_pos) format_with_parens arg2 - | EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [f]), _), [arg]) + | EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg]) when !Cli.trace_flag -> Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info format_with_parens f format_with_parens arg - | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [arg1]) + | EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag -> Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1 - | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), m), [arg1]) + | EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), m), [arg1]) when !Cli.trace_flag -> - let pos = D.mark_pos m in + let pos = Expr.mark_pos m in Format.fprintf fmt "(log_decision_taken@ @[{filename = \"%s\";@ start_line=%d;@ \ start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)" (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) format_with_parens arg1 - | EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [arg1]) + | EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag -> Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info format_with_parens arg1 - | EApp ((EOp (Unop (D.Log _)), _), [arg1]) -> + | EApp ((EOp (Unop (Log _)), _), [arg1]) -> Format.fprintf fmt "%a" format_with_parens arg1 | EApp ((EOp (Unop op), _), [arg1]) -> Format.fprintf fmt "@[%a@ %a@]" format_unop (op, Pos.no_pos) @@ -422,13 +423,13 @@ let rec format_expr "@[%a@ @[{filename = \"%s\";@ start_line=%d;@ \ start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a@]" format_var x - (Pos.get_file (D.mark_pos pos)) - (Pos.get_start_line (D.mark_pos pos)) - (Pos.get_start_column (D.mark_pos pos)) - (Pos.get_end_line (D.mark_pos pos)) - (Pos.get_end_column (D.mark_pos pos)) + (Pos.get_file (Expr.mark_pos pos)) + (Pos.get_start_line (Expr.mark_pos pos)) + (Pos.get_start_column (Expr.mark_pos pos)) + (Pos.get_end_line (Expr.mark_pos pos)) + (Pos.get_end_column (Expr.mark_pos pos)) format_string_list - (Pos.get_law_info (D.mark_pos pos)) + (Pos.get_law_info (Expr.mark_pos pos)) (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) @@ -452,25 +453,25 @@ let rec format_expr 2>{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ end_line=%d; \ end_column=%d;@ law_headings=%a}@])@]" format_with_parens e' - (Pos.get_file (D.pos e')) - (Pos.get_start_line (D.pos e')) - (Pos.get_start_column (D.pos e')) - (Pos.get_end_line (D.pos e')) - (Pos.get_end_column (D.pos e')) + (Pos.get_file (Expr.pos e')) + (Pos.get_start_line (Expr.pos e')) + (Pos.get_start_column (Expr.pos e')) + (Pos.get_end_line (Expr.pos e')) + (Pos.get_end_column (Expr.pos e')) format_string_list - (Pos.get_law_info (D.pos e')) - | ERaise exc -> Format.fprintf fmt "raise@ %a" format_exception (exc, D.pos e) + (Pos.get_law_info (Expr.pos e')) + | ERaise exc -> Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e) | ECatch (e1, exc, e2) -> Format.fprintf fmt "@,@[@[try@ %a@]@ with@]@ @[%a@ ->@ %a@]" format_with_parens e1 format_exception - (exc, D.pos e) + (exc, Expr.pos e) format_with_parens e2 let format_struct_embedding (fmt : Format.formatter) ((struct_name, struct_fields) : - D.StructName.t * (D.StructFieldName.t * D.typ Marked.pos) list) = + StructName.t * (StructFieldName.t * typ Marked.pos) list) = if List.length struct_fields = 0 then Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" format_struct_name struct_name format_to_module_name (`Sname struct_name) @@ -480,11 +481,11 @@ let format_struct_embedding @[[%a]@])@]@\n\ @\n" format_struct_name struct_name format_to_module_name (`Sname struct_name) - D.StructName.format_t struct_name + StructName.format_t struct_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") (fun _fmt (struct_field, struct_field_type) -> - Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" D.StructFieldName.format_t + Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructFieldName.format_t struct_field typ_embedding_name struct_field_type format_struct_field_name (Some struct_name, struct_field))) @@ -493,7 +494,7 @@ let format_struct_embedding let format_enum_embedding (fmt : Format.formatter) ((enum_name, enum_cases) : - D.EnumName.t * (D.EnumConstructor.t * D.typ Marked.pos) list) = + EnumName.t * (EnumConstructor.t * typ Marked.pos) list) = if List.length enum_cases = 0 then Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" format_to_module_name (`Ename enum_name) format_enum_name enum_name @@ -503,19 +504,19 @@ let format_enum_embedding =@]@ Enum([\"%a\"],@ @[match x with@ %a@])@]@\n\ @\n" format_enum_name enum_name format_to_module_name (`Ename enum_name) - D.EnumName.format_t enum_name + EnumName.format_t enum_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun _fmt (enum_cons, enum_cons_type) -> Format.fprintf fmt "@[| %a x ->@ (\"%a\", %a x)@]" - format_enum_cons_name enum_cons D.EnumConstructor.format_t + format_enum_cons_name enum_cons EnumConstructor.format_t enum_cons typ_embedding_name enum_cons_type)) enum_cases let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) - (ctx : D.decl_ctx) : unit = + (ctx : decl_ctx) : unit = let format_struct_decl fmt (struct_name, struct_fields) = if List.length struct_fields = 0 then Format.fprintf fmt @@ -559,8 +560,8 @@ let format_ctx let scope_structs = List.map (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) - (Dcalc.Ast.StructMap.bindings - (Dcalc.Ast.StructMap.filter + (StructMap.bindings + (StructMap.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) in @@ -574,12 +575,12 @@ let format_ctx (type_ordering @ scope_structs) let rec format_scope_body_expr - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) - (scope_lets : ('m Ast.expr, 'm) Dcalc.Ast.scope_body_expr) : unit = + (scope_lets : ('m Ast.expr, 'm) scope_body_expr) : unit = match scope_lets with - | Dcalc.Ast.Result e -> format_expr ctx fmt e - | Dcalc.Ast.ScopeLet scope_let -> + | Result e -> format_expr ctx fmt e + | ScopeLet scope_let -> let scope_let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in @@ -590,12 +591,12 @@ let rec format_scope_body_expr scope_let_next let rec format_scopes - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) - (scopes : ('m Ast.expr, 'm) Dcalc.Ast.scopes) : unit = + (scopes : ('m Ast.expr, 'm) scopes) : unit = match scopes with - | Dcalc.Ast.Nil -> () - | Dcalc.Ast.ScopeDef scope_def -> + | Nil -> () + | ScopeDef scope_def -> let scope_input_var, scope_body_expr = Bindlib.unbind scope_def.scope_body.scope_body_expr in diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index f64ca373..b5d75e3d 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -15,6 +15,7 @@ the License. *) open Utils +open Shared_ast open Ast (** Formats a lambda calculus program into a valid OCaml program *) @@ -22,32 +23,32 @@ open Ast val avoid_keywords : string -> string val find_struct : - Dcalc.Ast.StructName.t -> - Dcalc.Ast.decl_ctx -> - (Dcalc.Ast.StructFieldName.t * Dcalc.Ast.typ Marked.pos) list + StructName.t -> + decl_ctx -> + (StructFieldName.t * typ Marked.pos) list val find_enum : - Dcalc.Ast.EnumName.t -> - Dcalc.Ast.decl_ctx -> - (Dcalc.Ast.EnumConstructor.t * Dcalc.Ast.typ Marked.pos) list + EnumName.t -> + decl_ctx -> + (EnumConstructor.t * typ Marked.pos) list -val typ_needs_parens : Dcalc.Ast.typ Marked.pos -> bool +val typ_needs_parens : typ Marked.pos -> bool val needs_parens : 'm marked_expr -> bool -val format_enum_name : Format.formatter -> Dcalc.Ast.EnumName.t -> unit +val format_enum_name : Format.formatter -> EnumName.t -> unit val format_enum_cons_name : - Format.formatter -> Dcalc.Ast.EnumConstructor.t -> unit + Format.formatter -> EnumConstructor.t -> unit -val format_struct_name : Format.formatter -> Dcalc.Ast.StructName.t -> unit +val format_struct_name : Format.formatter -> StructName.t -> unit val format_struct_field_name : Format.formatter -> - Dcalc.Ast.StructName.t option * Dcalc.Ast.StructFieldName.t -> + StructName.t option * StructFieldName.t -> unit val format_to_module_name : Format.formatter -> - [< `Ename of Dcalc.Ast.EnumName.t | `Sname of Dcalc.Ast.StructName.t ] -> + [< `Ename of EnumName.t | `Sname of StructName.t ] -> unit val format_lit : Format.formatter -> lit Marked.pos -> unit diff --git a/compiler/plugin.ml b/compiler/plugin.ml index 9dfaf6f5..1378c2da 100644 --- a/compiler/plugin.ml +++ b/compiler/plugin.ml @@ -29,7 +29,7 @@ type 'ast gen = { } type t = - | Lcalc of Dcalc.Ast.untyped Lcalc.Ast.program gen + | Lcalc of Shared_ast.untyped Lcalc.Ast.program gen | Scalc of Scalc.Ast.program gen let name = function Lcalc { name; _ } | Scalc { name; _ } -> name diff --git a/compiler/plugin.mli b/compiler/plugin.mli index 0eb30c17..b6d678f9 100644 --- a/compiler/plugin.mli +++ b/compiler/plugin.mli @@ -31,7 +31,7 @@ type 'ast gen = { } type t = - | Lcalc of Dcalc.Ast.untyped Lcalc.Ast.program gen + | Lcalc of Shared_ast.untyped Lcalc.Ast.program gen | Scalc of Scalc.Ast.program gen val find : string -> t @@ -49,7 +49,7 @@ module PluginAPI : sig val register_lcalc : name:string -> extension:string -> - Dcalc.Ast.untyped Lcalc.Ast.program plugin_apply_fun_typ -> + Shared_ast.untyped Lcalc.Ast.program plugin_apply_fun_typ -> unit val register_scalc : diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index ee9c784e..e58edc63 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -19,6 +19,7 @@ the associated [js_of_ocaml] wrapper. *) open Utils +open Shared_ast open String_common open Lcalc open Lcalc.Ast @@ -39,9 +40,9 @@ module To_jsoo = struct let format_struct_field_name_camel_case (fmt : Format.formatter) - (v : Dcalc.Ast.StructFieldName.t) : unit = + (v : StructFieldName.t) : unit = let s = - Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v + Format.asprintf "%a" StructFieldName.format_t v |> to_ascii |> to_snake_case |> avoid_keywords @@ -49,7 +50,7 @@ module To_jsoo = struct in Format.fprintf fmt "%s" s - let format_tlit (fmt : Format.formatter) (l : Dcalc.Ast.typ_lit) : unit = + let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = Dcalc.Print.format_base_type fmt (match l with | TUnit -> "unit" @@ -59,11 +60,11 @@ module To_jsoo = struct | TBool -> "bool Js.t" | TDate -> "Js.js_string Js.t") - let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : + let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) : unit = let format_typ_with_parens (fmt : Format.formatter) - (t : Dcalc.Ast.typ Marked.pos) = + (t : typ Marked.pos) = if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t else Format.fprintf fmt "%a" format_typ t in @@ -73,10 +74,10 @@ module To_jsoo = struct | TTuple (_, None) -> (* Tuples are encoded as an javascript polymorphic array. *) Format.fprintf fmt "Js.Unsafe.any_js_array Js.t " - | TEnum ([t], e) when D.EnumName.compare e option_enum = 0 -> + | TEnum ([t], e) when EnumName.compare e option_enum = 0 -> Format.fprintf fmt "@[(%a)@] %a" format_typ_with_parens t format_enum_name e - | TEnum (_, e) when D.EnumName.compare e option_enum = 0 -> + | TEnum (_, e) when EnumName.compare e option_enum = 0 -> Errors.raise_spanned_error (Marked.get_mark typ) "Internal Error: found an typing parameter for an eoption type of the \ wrong length." @@ -90,41 +91,41 @@ module To_jsoo = struct let rec format_typ_to_jsoo fmt typ = match Marked.unmark typ with - | Dcalc.Ast.TLit TBool -> Format.fprintf fmt "Js.bool" - | Dcalc.Ast.TLit TInt -> Format.fprintf fmt "integer_to_int" - | Dcalc.Ast.TLit TRat -> + | TLit TBool -> Format.fprintf fmt "Js.bool" + | TLit TInt -> Format.fprintf fmt "integer_to_int" + | TLit TRat -> Format.fprintf fmt "Js.number_of_float %@%@ decimal_to_float" - | Dcalc.Ast.TLit TMoney -> + | TLit TMoney -> Format.fprintf fmt "Js.number_of_float %@%@ money_to_float" - | Dcalc.Ast.TLit TDuration -> Format.fprintf fmt "duration_to_jsoo" - | Dcalc.Ast.TLit TDate -> Format.fprintf fmt "date_to_jsoo" - | Dcalc.Ast.TEnum (_, ename) -> + | TLit TDuration -> Format.fprintf fmt "duration_to_jsoo" + | TLit TDate -> Format.fprintf fmt "date_to_jsoo" + | TEnum (_, ename) -> Format.fprintf fmt "%a_to_jsoo" format_enum_name ename - | Dcalc.Ast.TTuple (_, Some sname) -> + | TTuple (_, Some sname) -> Format.fprintf fmt "%a_to_jsoo" format_struct_name sname - | Dcalc.Ast.TArray t -> + | TArray t -> Format.fprintf fmt "Js.array %@%@ Array.map (fun x -> %a x)" format_typ_to_jsoo t - | Dcalc.Ast.TAny | Dcalc.Ast.TTuple (_, None) -> + | TAny | TTuple (_, None) -> Format.fprintf fmt "Js.Unsafe.inject" | _ -> Format.fprintf fmt "" let rec format_typ_of_jsoo fmt typ = match Marked.unmark typ with - | Dcalc.Ast.TLit TBool -> Format.fprintf fmt "Js.to_bool" - | Dcalc.Ast.TLit TInt -> Format.fprintf fmt "integer_of_int" - | Dcalc.Ast.TLit TRat -> + | TLit TBool -> Format.fprintf fmt "Js.to_bool" + | TLit TInt -> Format.fprintf fmt "integer_of_int" + | TLit TRat -> Format.fprintf fmt "decimal_of_float %@%@ Js.float_of_number" - | Dcalc.Ast.TLit TMoney -> + | TLit TMoney -> Format.fprintf fmt "money_of_decimal %@%@ decimal_of_float %@%@ Js.float_of_number" - | Dcalc.Ast.TLit TDuration -> Format.fprintf fmt "duration_of_jsoo" - | Dcalc.Ast.TLit TDate -> Format.fprintf fmt "date_of_jsoo" - | Dcalc.Ast.TEnum (_, ename) -> + | TLit TDuration -> Format.fprintf fmt "duration_of_jsoo" + | TLit TDate -> Format.fprintf fmt "date_of_jsoo" + | TEnum (_, ename) -> Format.fprintf fmt "%a_of_jsoo" format_enum_name ename - | Dcalc.Ast.TTuple (_, Some sname) -> + | TTuple (_, Some sname) -> Format.fprintf fmt "%a_of_jsoo" format_struct_name sname - | Dcalc.Ast.TArray t -> + | TArray t -> Format.fprintf fmt "Array.map (fun x -> %a x) %@%@ Js.to_array" format_typ_of_jsoo t | _ -> Format.fprintf fmt "" @@ -150,10 +151,10 @@ module To_jsoo = struct let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) - (ctx : D.decl_ctx) : unit = - let format_prop_or_meth fmt (struct_field_type : D.typ Marked.pos) = + (ctx : decl_ctx) : unit = + let format_prop_or_meth fmt (struct_field_type : typ Marked.pos) = match Marked.unmark struct_field_type with - | Dcalc.Ast.TArrow _ -> Format.fprintf fmt "Js.meth" + | TArrow _ -> Format.fprintf fmt "Js.meth" | _ -> Format.fprintf fmt "Js.readonly_prop" in let format_struct_decl fmt (struct_name, struct_fields) = @@ -167,7 +168,7 @@ module To_jsoo = struct ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (struct_field, struct_field_type) -> match Marked.unmark struct_field_type with - | Dcalc.Ast.TArrow (t1, t2) -> + | TArrow (t1, t2) -> Format.fprintf fmt "@[method %a =@ Js.wrap_meth_callback@ @[(@,\ fun input ->@ %a (%a.%a (%a input)))@]@]" @@ -188,7 +189,7 @@ module To_jsoo = struct ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") (fun fmt (struct_field, struct_field_type) -> match Marked.unmark struct_field_type with - | Dcalc.Ast.TArrow _ -> + | TArrow _ -> Format.fprintf fmt "%a = failwith \"The function '%a' translation isn't yet \ supported...\"" @@ -238,7 +239,7 @@ module To_jsoo = struct in let format_enum_decl fmt - (enum_name, (enum_cons : (D.EnumConstructor.t * D.typ Marked.pos) list)) + (enum_name, (enum_cons : (EnumConstructor.t * typ Marked.pos) list)) = let fmt_enum_name fmt _ = format_enum_name fmt enum_name in let fmt_module_enum_name fmt _ = @@ -250,7 +251,7 @@ module To_jsoo = struct ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (cname, typ) -> match Marked.unmark typ with - | Dcalc.Ast.TTuple (_, None) -> + | TTuple (_, None) -> Cli.error_print "Tuples aren't supported yet in the conversion to JS" | _ -> @@ -275,10 +276,10 @@ module To_jsoo = struct ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (cname, typ) -> match Marked.unmark typ with - | Dcalc.Ast.TTuple (_, None) -> + | TTuple (_, None) -> Cli.error_print "Tuples aren't yet supported in the conversion to JS..." - | Dcalc.Ast.TLit TUnit -> + | TLit TUnit -> Format.fprintf fmt "@[| \"%a\" ->@ %a.%a ()@]" format_enum_cons_name cname fmt_module_enum_name () format_enum_cons_name cname @@ -329,8 +330,8 @@ module To_jsoo = struct let scope_structs = List.map (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) - (Dcalc.Ast.StructMap.bindings - (Dcalc.Ast.StructMap.filter + (StructMap.bindings + (StructMap.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) in @@ -343,19 +344,19 @@ module To_jsoo = struct Format.fprintf fmt "%a@\n" format_enum_decl (e, find_enum e ctx)) (type_ordering @ scope_structs) - let fmt_input_struct_name fmt (scope_def : ('a expr, 'm) D.scope_def) = + let fmt_input_struct_name fmt (scope_def : ('a expr, 'm) scope_def) = format_struct_name fmt scope_def.scope_body.scope_body_input_struct - let fmt_output_struct_name fmt (scope_def : ('a expr, 'm) D.scope_def) = + let fmt_output_struct_name fmt (scope_def : ('a expr, 'm) scope_def) = format_struct_name fmt scope_def.scope_body.scope_body_output_struct let rec format_scopes_to_fun - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) - (scopes : ('expr, 'm) Dcalc.Ast.scopes) = + (scopes : ('expr, 'm) scopes) = match scopes with - | Dcalc.Ast.Nil -> () - | Dcalc.Ast.ScopeDef scope_def -> + | Nil -> () + | ScopeDef scope_def -> let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in let fmt_fun_call fmt _ = Format.fprintf fmt "@[%a@ |> %a_of_jsoo@ |> %a@ |> %a_to_jsoo@]" @@ -369,12 +370,12 @@ module To_jsoo = struct fmt_fun_call () (format_scopes_to_fun ctx) scope_next let rec format_scopes_to_callbacks - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) - (scopes : ('expr, 'm) Dcalc.Ast.scopes) : unit = + (scopes : ('expr, 'm) scopes) : unit = match scopes with - | Dcalc.Ast.Nil -> () - | Dcalc.Ast.ScopeDef scope_def -> + | Nil -> () + | ScopeDef scope_def -> let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in let fmt_meth_name fmt _ = Format.fprintf fmt "method %a : (%a Js.t -> %a Js.t) Js.callback" diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index 63624cac..240fbc41 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -22,6 +22,7 @@ let extension = "_schema.json" open Utils open String_common +open Shared_ast open Lcalc.Ast open Lcalc.To_ocaml module D = Dcalc.Ast @@ -37,9 +38,9 @@ module To_json = struct let format_struct_field_name_camel_case (fmt : Format.formatter) - (v : Dcalc.Ast.StructFieldName.t) : unit = + (v : StructFieldName.t) : unit = let s = - Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v + Format.asprintf "%a" StructFieldName.format_t v |> to_ascii |> to_snake_case |> avoid_keywords @@ -48,18 +49,18 @@ module To_json = struct Format.fprintf fmt "%s" s let rec find_scope_def (target_name : string) : - ('m expr, 'm) D.scopes -> ('m expr, 'm) D.scope_def option = function - | D.Nil -> None - | D.ScopeDef scope_def -> + ('m expr, 'm) scopes -> ('m expr, 'm) scope_def option = function + | Nil -> None + | ScopeDef scope_def -> let name = - Format.asprintf "%a" D.ScopeName.format_t scope_def.scope_name + Format.asprintf "%a" ScopeName.format_t scope_def.scope_name in if name = target_name then Some scope_def else let _, next_scope = Bindlib.unbind scope_def.scope_next in find_scope_def target_name next_scope - let fmt_tlit fmt (tlit : D.typ_lit) = + let fmt_tlit fmt (tlit : typ_lit) = match tlit with | TUnit -> Format.fprintf fmt "\"type\": \"null\",@\n\"default\": null" | TInt | TRat -> Format.fprintf fmt "\"type\": \"number\",@\n\"default\": 0" @@ -70,15 +71,15 @@ module To_json = struct | TDate -> Format.fprintf fmt "\"type\": \"string\",@\n\"format\": \"date\"" | TDuration -> failwith "TODO: tlit duration" - let rec fmt_type fmt (typ : D.marked_typ) = + let rec fmt_type fmt (typ : marked_typ) = match Marked.unmark typ with - | D.TLit tlit -> fmt_tlit fmt tlit - | D.TTuple (_, Some sname) -> + | TLit tlit -> fmt_tlit fmt tlit + | TTuple (_, Some sname) -> Format.fprintf fmt "\"$ref\": \"#/definitions/%a\"" format_struct_name sname - | D.TEnum (_, ename) -> + | TEnum (_, ename) -> Format.fprintf fmt "\"$ref\": \"#/definitions/%a\"" format_enum_name ename - | D.TArray t -> + | TArray t -> Format.fprintf fmt "\"type\": \"array\",@\n\ \"default\": [],@\n\ @@ -89,9 +90,9 @@ module To_json = struct | _ -> () let fmt_struct_properties - (ctx : D.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) - (sname : D.StructName.t) = + (sname : StructName.t) = Format.fprintf fmt "%a" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") @@ -101,26 +102,26 @@ module To_json = struct (find_struct sname ctx) let fmt_definitions - (ctx : D.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) - (scope_def : ('m expr, 'm) D.scope_def) = + (scope_def : ('m expr, 'm) scope_def) = let get_name t = match Marked.unmark t with - | D.TTuple (_, Some sname) -> + | TTuple (_, Some sname) -> Format.asprintf "%a" format_struct_name sname - | D.TEnum (_, ename) -> Format.asprintf "%a" format_enum_name ename + | TEnum (_, ename) -> Format.asprintf "%a" format_enum_name ename | _ -> failwith "unreachable: only structs and enums are collected." in let rec collect_required_type_defs_from_scope_input - (input_struct : D.StructName.t) : D.marked_typ list = - let rec collect (acc : D.marked_typ list) (t : D.marked_typ) : - D.marked_typ list = + (input_struct : StructName.t) : marked_typ list = + let rec collect (acc : marked_typ list) (t : marked_typ) : + marked_typ list = match Marked.unmark t with - | D.TTuple (_, Some s) -> + | TTuple (_, Some s) -> (* Scope's input is a struct. *) (t :: acc) @ collect_required_type_defs_from_scope_input s - | D.TEnum (ts, _) -> List.fold_left collect (t :: acc) ts - | D.TArray t -> collect acc t + | TEnum (ts, _) -> List.fold_left collect (t :: acc) ts + | TArray t -> collect acc t | _ -> acc in find_struct input_struct ctx @@ -177,7 +178,7 @@ module To_json = struct ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") (fun fmt typ -> match Marked.unmark typ with - | D.TTuple (_, Some sname) -> + | TTuple (_, Some sname) -> Format.fprintf fmt "@[\"%a\": {@\n\ \"type\": \"object\",@\n\ @@ -188,7 +189,7 @@ module To_json = struct format_struct_name sname (fmt_struct_properties ctx) sname - | D.TEnum (_, ename) -> + | TEnum (_, ename) -> Format.fprintf fmt "@[\"%a\": {@\n\ \"type\": \"object\",@\n\ diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index e4f36a34..ead57cd4 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -15,6 +15,7 @@ the License. *) open Utils +open Shared_ast module D = Dcalc.Ast module L = Lcalc.Ast module TopLevelName = Uid.Make (Uid.MarkedString) () @@ -27,24 +28,24 @@ let handle_default_opt = TopLevelName.fresh ("handle_default_opt", Pos.no_pos) type expr = | EVar of LocalName.t | EFunc of TopLevelName.t - | EStruct of expr Marked.pos list * D.StructName.t - | EStructFieldAccess of expr Marked.pos * D.StructFieldName.t * D.StructName.t - | EInj of expr Marked.pos * D.EnumConstructor.t * D.EnumName.t + | EStruct of expr Marked.pos list * StructName.t + | EStructFieldAccess of expr Marked.pos * StructFieldName.t * StructName.t + | EInj of expr Marked.pos * EnumConstructor.t * EnumName.t | EArray of expr Marked.pos list | ELit of L.lit | EApp of expr Marked.pos * expr Marked.pos list - | EOp of Dcalc.Ast.operator + | EOp of operator type stmt = | SInnerFuncDef of LocalName.t Marked.pos * func - | SLocalDecl of LocalName.t Marked.pos * D.typ Marked.pos + | SLocalDecl of LocalName.t Marked.pos * typ Marked.pos | SLocalDef of LocalName.t Marked.pos * expr Marked.pos - | STryExcept of block * L.except * block - | SRaise of L.except + | STryExcept of block * except * block + | SRaise of except | SIfThenElse of expr Marked.pos * block * block | SSwitch of expr Marked.pos - * D.EnumName.t + * EnumName.t * (block (* Statements corresponding to arm closure body*) * (* Variable instantiated with enum payload *) LocalName.t) list (** Each block corresponds to one case of the enum *) @@ -54,14 +55,14 @@ type stmt = and block = stmt Marked.pos list and func = { - func_params : (LocalName.t Marked.pos * D.typ Marked.pos) list; + func_params : (LocalName.t Marked.pos * typ Marked.pos) list; func_body : block; } type scope_body = { - scope_body_name : Dcalc.Ast.ScopeName.t; + scope_body_name : ScopeName.t; scope_body_var : TopLevelName.t; scope_body_func : func; } -type program = { decl_ctx : D.decl_ctx; scopes : scope_body list } +type program = { decl_ctx : decl_ctx; scopes : scope_body list } diff --git a/compiler/scalc/compile_from_lambda.ml b/compiler/scalc/compile_from_lambda.ml index 51c4d73a..b17821be 100644 --- a/compiler/scalc/compile_from_lambda.ml +++ b/compiler/scalc/compile_from_lambda.ml @@ -22,7 +22,7 @@ module D = Dcalc.Ast type 'm ctxt = { func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t; - decl_ctx : D.decl_ctx; + decl_ctx : decl_ctx; var_dict : ('m L.expr, A.LocalName.t) Var.Map.t; inside_definition_of : A.LocalName.t option; context_name : string; @@ -33,13 +33,13 @@ type 'm ctxt = { let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) : A.block * A.expr Marked.pos = match Marked.unmark expr with - | L.EVar v -> + | EVar v -> let local_var = try A.EVar (Var.Map.find v ctxt.var_dict) with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict) in - [], (local_var, D.pos expr) - | L.ETuple (args, Some s_name) -> + [], (local_var, Expr.pos expr) + | ETuple (args, Some s_name) -> let args_stmts, new_args = List.fold_left (fun (args_stmts, new_args) arg -> @@ -49,25 +49,25 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) : in let new_args = List.rev new_args in let args_stmts = List.rev args_stmts in - args_stmts, (A.EStruct (new_args, s_name), D.pos expr) - | L.ETuple (_, None) -> + args_stmts, (A.EStruct (new_args, s_name), Expr.pos expr) + | ETuple (_, None) -> failwith "Non-struct tuples cannot be compiled to scalc" - | L.ETupleAccess (e1, num_field, Some s_name, _) -> + | ETupleAccess (e1, num_field, Some s_name, _) -> let e1_stmts, new_e1 = translate_expr ctxt e1 in let field_name = fst - (List.nth (D.StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field) + (List.nth (StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field) in - e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), D.pos expr) - | L.ETupleAccess (_, _, None, _) -> + e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), Expr.pos expr) + | ETupleAccess (_, _, None, _) -> failwith "Non-struct tuples cannot be compiled to scalc" - | L.EInj (e1, num_cons, e_name, _) -> + | EInj (e1, num_cons, e_name, _) -> let e1_stmts, new_e1 = translate_expr ctxt e1 in let cons_name = - fst (List.nth (D.EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons) + fst (List.nth (EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons) in - e1_stmts, (A.EInj (new_e1, cons_name, e_name), D.pos expr) - | L.EApp (f, args) -> + e1_stmts, (A.EInj (new_e1, cons_name, e_name), Expr.pos expr) + | EApp (f, args) -> let f_stmts, new_f = translate_expr ctxt f in let args_stmts, new_args = List.fold_left @@ -77,8 +77,8 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) : ([], []) args in let new_args = List.rev new_args in - f_stmts @ args_stmts, (A.EApp (new_f, new_args), D.pos expr) - | L.EArray args -> + f_stmts @ args_stmts, (A.EApp (new_f, new_args), Expr.pos expr) + | EArray args -> let args_stmts, new_args = List.fold_left (fun (args_stmts, new_args) arg -> @@ -87,9 +87,9 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) : ([], []) args in let new_args = List.rev new_args in - args_stmts, (A.EArray new_args, D.pos expr) - | L.EOp op -> [], (A.EOp op, D.pos expr) - | L.ELit l -> [], (A.ELit l, D.pos expr) + args_stmts, (A.EArray new_args, Expr.pos expr) + | EOp op -> [], (A.EOp op, Expr.pos expr) + | ELit l -> [], (A.ELit l, Expr.pos expr) | _ -> let tmp_var = A.LocalName.fresh @@ -102,7 +102,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) : let v = Marked.unmark (A.LocalName.get_info v) in let tmp_rex = Re.Pcre.regexp "^temp_" in if Re.Pcre.pmatch ~rex:tmp_rex v then v else "temp_" ^ v), - D.pos expr ) + Expr.pos expr ) in let ctxt = { @@ -112,20 +112,20 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.marked_expr) : } in let tmp_stmts = translate_statements ctxt expr in - ( (A.SLocalDecl ((tmp_var, D.pos expr), (D.TAny, D.pos expr)), D.pos expr) + ( (A.SLocalDecl ((tmp_var, Expr.pos expr), (TAny, Expr.pos expr)), Expr.pos expr) :: tmp_stmts, - (A.EVar tmp_var, D.pos expr) ) + (A.EVar tmp_var, Expr.pos expr) ) and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) : A.block = match Marked.unmark block_expr with - | L.EAssert e -> + | EAssert e -> (* Assertions are always encapsulated in a unit-typed let binding *) let e_stmts, new_e = translate_expr ctxt e in - e_stmts @ [A.SAssert (Marked.unmark new_e), D.pos block_expr] - | L.EApp ((L.EAbs (binder, taus), binder_mark), args) -> + e_stmts @ [A.SAssert (Marked.unmark new_e), Expr.pos block_expr] + | EApp ((EAbs (binder, taus), binder_mark), args) -> (* This defines multiple local variables at the time *) - let binder_pos = D.mark_pos binder_mark in + let binder_pos = Expr.mark_pos binder_mark in let vars, body = Bindlib.unmbind binder in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in let ctxt = @@ -170,13 +170,13 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) : in let rest_of_block = translate_statements ctxt body in local_decls @ List.flatten def_blocks @ rest_of_block - | L.EAbs (binder, taus) -> + | EAbs (binder, taus) -> let vars, body = Bindlib.unmbind binder in - let binder_pos = D.pos block_expr in + let binder_pos = Expr.pos block_expr in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in let closure_name = match ctxt.inside_definition_of with - | None -> A.LocalName.fresh (ctxt.context_name, D.pos block_expr) + | None -> A.LocalName.fresh (ctxt.context_name, Expr.pos block_expr) | Some x -> x in let ctxt = @@ -206,18 +206,18 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) : } ), binder_pos ); ] - | L.EMatch (e1, args, e_name) -> + | EMatch (e1, args, e_name) -> let e1_stmts, new_e1 = translate_expr ctxt e1 in let new_args = List.fold_left (fun new_args arg -> match Marked.unmark arg with - | L.EAbs (binder, _) -> + | EAbs (binder, _) -> let vars, body = Bindlib.unmbind binder in assert (Array.length vars = 1); let var = vars.(0) in let scalc_var = - A.LocalName.fresh (Bindlib.name_of var, D.pos arg) + A.LocalName.fresh (Bindlib.name_of var, Expr.pos arg) in let ctxt = { ctxt with var_dict = Var.Map.add var scalc_var ctxt.var_dict } @@ -229,17 +229,17 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) : [] args in let new_args = List.rev new_args in - e1_stmts @ [A.SSwitch (new_e1, e_name, new_args), D.pos block_expr] - | L.EIfThenElse (cond, e_true, e_false) -> + e1_stmts @ [A.SSwitch (new_e1, e_name, new_args), Expr.pos block_expr] + | EIfThenElse (cond, e_true, e_false) -> let cond_stmts, s_cond = translate_expr ctxt cond in let s_e_true = translate_statements ctxt e_true in let s_e_false = translate_statements ctxt e_false in - cond_stmts @ [A.SIfThenElse (s_cond, s_e_true, s_e_false), D.pos block_expr] - | L.ECatch (e_try, except, e_catch) -> + cond_stmts @ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr] + | ECatch (e_try, except, e_catch) -> let s_e_try = translate_statements ctxt e_try in let s_e_catch = translate_statements ctxt e_catch in - [A.STryExcept (s_e_try, except, s_e_catch), D.pos block_expr] - | L.ERaise except -> + [A.STryExcept (s_e_try, except, s_e_catch), Expr.pos block_expr] + | ERaise except -> (* Before raising the exception, we still give a dummy definition to the current variable so that tools like mypy don't complain. *) (match ctxt.inside_definition_of with @@ -247,10 +247,10 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) : | Some x -> [ ( A.SLocalDef - ((x, D.pos block_expr), (Ast.EVar Ast.dead_value, D.pos block_expr)), - D.pos block_expr ); + ((x, Expr.pos block_expr), (Ast.EVar Ast.dead_value, Expr.pos block_expr)), + Expr.pos block_expr ); ]) - @ [A.SRaise except, D.pos block_expr] + @ [A.SRaise except, Expr.pos block_expr] | _ -> ( let e_stmts, new_e = translate_expr ctxt block_expr in e_stmts @@ -266,15 +266,15 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.marked_expr) : ( (match ctxt.inside_definition_of with | None -> A.SReturn (Marked.unmark new_e) | Some x -> A.SLocalDef (Marked.same_mark_as x new_e, new_e)), - D.pos block_expr ); + Expr.pos block_expr ); ]) let rec translate_scope_body_expr - (scope_name : D.ScopeName.t) - (decl_ctx : D.decl_ctx) + (scope_name : ScopeName.t) + (decl_ctx : decl_ctx) (var_dict : ('m L.expr, A.LocalName.t) Var.Map.t) (func_dict : ('m L.expr, A.TopLevelName.t) Var.Map.t) - (scope_expr : ('m L.expr, 'm) D.scope_body_expr) : A.block = + (scope_expr : ('m L.expr, 'm) scope_body_expr) : A.block = match scope_expr with | Result e -> let block, new_e = @@ -284,7 +284,7 @@ let rec translate_scope_body_expr func_dict; var_dict; inside_definition_of = None; - context_name = Marked.unmark (D.ScopeName.get_info scope_name); + context_name = Marked.unmark (ScopeName.get_info scope_name); } e in @@ -296,14 +296,14 @@ let rec translate_scope_body_expr in let new_var_dict = Var.Map.add let_var let_var_id var_dict in (match scope_let.scope_let_kind with - | D.Assertion -> + | Assertion -> translate_statements { decl_ctx; func_dict; var_dict; inside_definition_of = Some let_var_id; - context_name = Marked.unmark (D.ScopeName.get_info scope_name); + context_name = Marked.unmark (ScopeName.get_info scope_name); } scope_let.scope_let_expr | _ -> @@ -314,7 +314,7 @@ let rec translate_scope_body_expr func_dict; var_dict; inside_definition_of = Some let_var_id; - context_name = Marked.unmark (D.ScopeName.get_info scope_name); + context_name = Marked.unmark (ScopeName.get_info scope_name); } scope_let.scope_let_expr in @@ -331,16 +331,16 @@ let rec translate_scope_body_expr let translate_program (p : 'm L.program) : A.program = { - decl_ctx = p.D.decl_ctx; + decl_ctx = p.decl_ctx; scopes = (let _, new_scopes = - D.fold_left_scope_defs + Expr.fold_left_scope_defs ~f:(fun (func_dict, new_scopes) scope_def scope_var -> let scope_input_var, scope_body_expr = Bindlib.unbind scope_def.scope_body.scope_body_expr in let input_pos = - Marked.get_mark (D.ScopeName.get_info scope_def.scope_name) + Marked.get_mark (ScopeName.get_info scope_def.scope_name) in let scope_input_var_id = A.LocalName.fresh (Bindlib.name_of scope_input_var, input_pos) @@ -349,7 +349,7 @@ let translate_program (p : 'm L.program) : A.program = Var.Map.singleton scope_input_var scope_input_var_id in let new_scope_body = - translate_scope_body_expr scope_def.D.scope_name p.decl_ctx + translate_scope_body_expr scope_def.scope_name p.decl_ctx var_dict func_dict scope_body_expr in let func_id = @@ -358,22 +358,22 @@ let translate_program (p : 'm L.program) : A.program = let func_dict = Var.Map.add scope_var func_id func_dict in ( func_dict, { - Ast.scope_body_name = scope_def.D.scope_name; + Ast.scope_body_name = scope_def.scope_name; Ast.scope_body_var = func_id; scope_body_func = { A.func_params = [ ( (scope_input_var_id, input_pos), - ( D.TTuple + ( TTuple ( List.map snd - (D.StructMap.find - scope_def.D.scope_body - .D.scope_body_input_struct - p.D.decl_ctx.ctx_structs), + (StructMap.find + scope_def.scope_body + .scope_body_input_struct + p.decl_ctx.ctx_structs), Some - scope_def.D.scope_body - .D.scope_body_input_struct ), + scope_def.scope_body + .scope_body_input_struct ), input_pos ) ); ]; A.func_body = new_scope_body; @@ -385,7 +385,7 @@ let translate_program (p : 'm L.program) : A.program = Var.Map.singleton L.handle_default_opt A.handle_default_opt else Var.Map.singleton L.handle_default A.handle_default), [] ) - p.D.scopes + p.scopes in List.rev new_scopes); } diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 44ff6487..052e8315 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -15,6 +15,7 @@ the License. *) open Utils +open Shared_ast open Ast let needs_parens (_e : expr Marked.pos) : bool = false @@ -24,7 +25,7 @@ let format_local_name (fmt : Format.formatter) (v : LocalName.t) : unit = (string_of_int (LocalName.hash v)) let rec format_expr - (decl_ctx : Dcalc.Ast.decl_ctx) + (decl_ctx : decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) (e : expr Marked.pos) : unit = @@ -39,17 +40,17 @@ let rec format_expr | EVar v -> Format.fprintf fmt "%a" format_local_name v | EFunc v -> Format.fprintf fmt "%a" TopLevelName.format_t v | EStruct (es, s) -> - Format.fprintf fmt "@[%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s + Format.fprintf fmt "@[%a@ %a%a%a@]" StructName.format_t s Dcalc.Print.format_punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, struct_field) -> Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation "\"" - Dcalc.Ast.StructFieldName.format_t struct_field + StructFieldName.format_t struct_field Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation ":" format_expr e)) (List.combine es - (List.map fst (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) + (List.map fst (StructMap.find s decl_ctx.ctx_structs))) Dcalc.Print.format_punctuation "}" | EArray es -> Format.fprintf fmt "@[%a%a%a@]" Dcalc.Print.format_punctuation "[" @@ -60,24 +61,24 @@ let rec format_expr | EStructFieldAccess (e1, field, s) -> Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" - Dcalc.Ast.StructFieldName.format_t + StructFieldName.format_t (fst (List.find (fun (field', _) -> - Dcalc.Ast.StructFieldName.compare field' field = 0) - (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) + StructFieldName.compare field' field = 0) + (StructMap.find s decl_ctx.ctx_structs))) Dcalc.Print.format_punctuation "\"" | EInj (e, case, enum) -> Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_enum_constructor (fst (List.find - (fun (case', _) -> Dcalc.Ast.EnumConstructor.compare case' case = 0) - (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums))) + (fun (case', _) -> EnumConstructor.compare case' case = 0) + (EnumMap.find enum decl_ctx.ctx_enums))) format_expr e | ELit l -> Format.fprintf fmt "%a" Lcalc.Print.format_lit (Marked.same_mark_as l e) | EApp - ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> Format.fprintf fmt "@[%a@ %a@ %a@]" Dcalc.Print.format_binop op format_with_parens arg1 format_with_parens arg2 @@ -100,7 +101,7 @@ let rec format_expr | EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop op let rec format_statement - (decl_ctx : Dcalc.Ast.decl_ctx) + (decl_ctx : decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) (stmt : stmt Marked.pos) : unit = @@ -174,10 +175,10 @@ let rec format_statement Dcalc.Print.format_punctuation "→" (format_block decl_ctx ~debug) arm_block)) - (List.combine (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums) arms) + (List.combine (EnumMap.find enum decl_ctx.ctx_enums) arms) and format_block - (decl_ctx : Dcalc.Ast.decl_ctx) + (decl_ctx : decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) (block : block) : unit = @@ -188,7 +189,7 @@ and format_block fmt block let format_scope - (decl_ctx : Dcalc.Ast.decl_ctx) + (decl_ctx : decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) (body : scope_body) : unit = diff --git a/compiler/scalc/print.mli b/compiler/scalc/print.mli index b2a5a0fe..512694bb 100644 --- a/compiler/scalc/print.mli +++ b/compiler/scalc/print.mli @@ -15,7 +15,7 @@ the License. *) val format_scope : - Dcalc.Ast.decl_ctx -> + Shared_ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.scope_body -> diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 38ae02f3..a55e34a1 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -16,6 +16,7 @@ [@@@warning "-32-27"] open Utils +open Shared_ast open Ast open String_common module Runtime = Runtime_ocaml.Runtime @@ -31,7 +32,7 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit = | LUnit -> Format.fprintf fmt "Unit()" | LRat i -> Format.fprintf fmt "decimal_of_string(\"%a\")" Dcalc.Print.format_lit - (Dcalc.Ast.LRat i) + (LRat i) | LMoney e -> Format.fprintf fmt "money_of_cents_string(\"%s\")" (Runtime.integer_to_string (Runtime.money_to_cents e)) @@ -44,7 +45,7 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit = let years, months, days = Runtime.duration_to_years_months_days d in Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days -let format_log_entry (fmt : Format.formatter) (entry : Dcalc.Ast.log_entry) : +let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit = match entry with | VarDef _ -> Format.fprintf fmt ":=" @@ -52,13 +53,13 @@ let format_log_entry (fmt : Format.formatter) (entry : Dcalc.Ast.log_entry) : | EndCall -> Format.fprintf fmt "%s" "← " | PosRecordIfTrueBool -> Format.fprintf fmt "☛ " -let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) : +let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit = match Marked.unmark op with | Add _ | Concat -> Format.fprintf fmt "+" | Sub _ -> Format.fprintf fmt "-" | Mult _ -> Format.fprintf fmt "*" - | Div D.KInt -> Format.fprintf fmt "//" + | Div KInt -> Format.fprintf fmt "//" | Div _ -> Format.fprintf fmt "/" | And -> Format.fprintf fmt "and" | Or -> Format.fprintf fmt "or" @@ -71,7 +72,7 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Marked.pos) : | Map -> Format.fprintf fmt "list_map" | Filter -> Format.fprintf fmt "list_filter" -let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Marked.pos) : +let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit = match Marked.unmark op with Fold -> Format.fprintf fmt "list_fold_left" @@ -94,7 +95,7 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit = (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) uids -let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Marked.pos) : unit +let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit = match Marked.unmark op with | Minus _ -> Format.fprintf fmt "-" @@ -127,43 +128,43 @@ let avoid_keywords (s : string) : string = then s ^ "_" else s -let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : +let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = Format.fprintf fmt "%s" (avoid_keywords (to_camel_case - (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v)))) + (to_ascii (Format.asprintf "%a" StructName.format_t v)))) let format_struct_field_name (fmt : Format.formatter) - (v : Dcalc.Ast.StructFieldName.t) : unit = + (v : StructFieldName.t) : unit = Format.fprintf fmt "%s" (avoid_keywords - (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) + (to_ascii (Format.asprintf "%a" StructFieldName.format_t v))) -let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit +let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = Format.fprintf fmt "%s" (avoid_keywords (to_camel_case - (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) + (to_ascii (Format.asprintf "%a" EnumName.format_t v)))) let format_enum_cons_name (fmt : Format.formatter) - (v : Dcalc.Ast.EnumConstructor.t) : unit = + (v : EnumConstructor.t) : unit = Format.fprintf fmt "%s" (avoid_keywords - (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) + (to_ascii (Format.asprintf "%a" EnumConstructor.format_t v))) -let typ_needs_parens (e : Dcalc.Ast.typ Marked.pos) : bool = +let typ_needs_parens (e : typ Marked.pos) : bool = match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false -let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : +let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) : unit = let format_typ = format_typ in let format_typ_with_parens (fmt : Format.formatter) - (t : Dcalc.Ast.typ Marked.pos) = + (t : typ Marked.pos) = if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t else Format.fprintf fmt "%a" format_typ t in @@ -182,7 +183,7 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Marked.pos) : (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) ts | TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s - | TEnum ([_; some_typ], e) when D.EnumName.compare e L.option_enum = 0 -> + | TEnum ([_; some_typ], e) when EnumName.compare e L.option_enum = 0 -> (* We translate the option type with an overloading by Python's [None] *) Format.fprintf fmt "Optional[%a]" format_typ some_typ | TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e @@ -251,7 +252,7 @@ let needs_parens (e : expr Marked.pos) : bool = | ELit (LBool _ | LUnit) | EVar _ | EOp _ -> false | _ -> true -let format_exception (fmt : Format.formatter) (exc : L.except Marked.pos) : unit +let format_exception (fmt : Format.formatter) (exc : except Marked.pos) : unit = let pos = Marked.get_mark exc in match Marked.unmark exc with @@ -275,7 +276,7 @@ let format_exception (fmt : Format.formatter) (exc : L.except Marked.pos) : unit (Pos.get_law_info pos) let rec format_expression - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) (e : expr Marked.pos) : unit = match Marked.unmark e with @@ -289,18 +290,18 @@ let rec format_expression Format.fprintf fmt "%a = %a" format_struct_field_name struct_field (format_expression ctx) e)) (List.combine es - (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) + (List.map fst (StructMap.find s ctx.ctx_structs))) | EStructFieldAccess (e1, field, _) -> Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field | EInj (_, cons, e_name) - when D.EnumName.compare e_name L.option_enum = 0 - && D.EnumConstructor.compare cons L.none_constr = 0 -> + when EnumName.compare e_name L.option_enum = 0 + && EnumConstructor.compare cons L.none_constr = 0 -> (* We translate the option type with an overloading by Python's [None] *) Format.fprintf fmt "None" | EInj (e, cons, e_name) - when D.EnumName.compare e_name L.option_enum = 0 - && D.EnumConstructor.compare cons L.some_constr = 0 -> + when EnumName.compare e_name L.option_enum = 0 + && EnumConstructor.compare cons L.some_constr = 0 -> (* We translate the option type with an overloading by Python's [None] *) format_expression ctx fmt e | EInj (e, cons, enum_name) -> @@ -315,22 +316,22 @@ let rec format_expression es | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e) | EApp - ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) (format_expression ctx) arg1 (format_expression ctx) arg2 | EApp ((EOp (Binop op), _), [arg1; arg2]) -> Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop (op, Pos.no_pos) (format_expression ctx) arg2 - | EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [f]), _), [arg]) + | EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg]) when !Cli.trace_flag -> Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info (format_expression ctx) f (format_expression ctx) arg - | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [arg1]) + | EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag -> Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info (format_expression ctx) arg1 - | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [arg1]) + | EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), pos), [arg1]) when !Cli.trace_flag -> Format.fprintf fmt "log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \ @@ -338,11 +339,11 @@ let rec format_expression (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) (format_expression ctx) arg1 - | EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [arg1]) + | EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag -> Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info (format_expression ctx) arg1 - | EApp ((EOp (Unop (D.Log _)), _), [arg1]) -> + | EApp ((EOp (Unop (Log _)), _), [arg1]) -> Format.fprintf fmt "%a" (format_expression ctx) arg1 | EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) -> Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos) @@ -374,7 +375,7 @@ let rec format_expression | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) let rec format_statement - (ctx : Dcalc.Ast.decl_ctx) + (ctx : decl_ctx) (fmt : Format.formatter) (s : stmt Marked.pos) : unit = match Marked.unmark s with @@ -403,7 +404,7 @@ let rec format_statement Format.fprintf fmt "@[if %a:@\n%a@]@\n@[else:@\n%a@]" (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 | SSwitch (e1, e_name, [(case_none, _); (case_some, case_some_var)]) - when D.EnumName.compare e_name L.option_enum = 0 -> + when EnumName.compare e_name L.option_enum = 0 -> (* We translate the option type with an overloading by Python's [None] *) let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in Format.fprintf fmt @@ -421,7 +422,7 @@ let rec format_statement List.map2 (fun (x, y) (cons, _) -> x, y, cons) cases - (D.EnumMap.find e_name ctx.ctx_enums) + (EnumMap.find e_name ctx.ctx_enums) in let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in Format.fprintf fmt "%a = %a@\n@[if %a@]" format_var tmp_var @@ -450,7 +451,7 @@ let rec format_statement (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) -and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block) +and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit = Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") @@ -462,7 +463,7 @@ and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block) let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) - (ctx : D.decl_ctx) : unit = + (ctx : decl_ctx) : unit = let format_struct_decl fmt (struct_name, struct_fields) = Format.fprintf fmt "class %a:@\n\ @@ -562,8 +563,8 @@ let format_ctx let scope_structs = List.map (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) - (Dcalc.Ast.StructMap.bindings - (Dcalc.Ast.StructMap.filter + (StructMap.bindings + (StructMap.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) in @@ -572,10 +573,10 @@ let format_ctx match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> Format.fprintf fmt "%a@\n@\n" format_struct_decl - (s, Dcalc.Ast.StructMap.find s ctx.Dcalc.Ast.ctx_structs) + (s, StructMap.find s ctx.ctx_structs) | Scopelang.Dependency.TVertex.Enum e -> Format.fprintf fmt "%a@\n@\n" format_enum_decl - (e, Dcalc.Ast.EnumMap.find e ctx.Dcalc.Ast.ctx_enums)) + (e, EnumMap.find e ctx.ctx_enums)) (type_ordering @ scope_structs) let format_program diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index c3b663bf..5197425a 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -15,8 +15,8 @@ the License. *) open Utils -module ScopeName = Dcalc.Ast.ScopeName -module ScopeNameSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName) +open Shared_ast + module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName) module SubScopeName : Uid.Id with type info = Uid.MarkedString.info = @@ -33,17 +33,11 @@ module ScopeVar : Uid.Id with type info = Uid.MarkedString.info = module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar) module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar) -module StructName = Dcalc.Ast.StructName -module StructMap = Dcalc.Ast.StructMap -module StructFieldName = Dcalc.Ast.StructFieldName module StructFieldMap : Map.S with type key = StructFieldName.t = Map.Make (StructFieldName) module StructFieldMapLift = Bindlib.Lift (StructFieldMap) -module EnumName = Dcalc.Ast.EnumName -module EnumMap = Dcalc.Ast.EnumMap -module EnumConstructor = Dcalc.Ast.EnumConstructor module EnumConstructorMap : Map.S with type key = EnumConstructor.t = Map.Make (EnumConstructor) @@ -71,7 +65,7 @@ Set.Make (struct end) type typ = - | TLit of Dcalc.Ast.typ_lit + | TLit of typ_lit | TStruct of StructName.t | TEnum of EnumName.t | TArrow of typ Marked.pos * typ Marked.pos @@ -114,7 +108,7 @@ and expr = | ELit of Dcalc.Ast.lit | EAbs of (expr, marked_expr) Bindlib.mbinder * typ Marked.pos list | EApp of marked_expr * marked_expr list - | EOp of Dcalc.Ast.operator + | EOp of operator | EDefault of marked_expr list * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr | EArray of marked_expr list @@ -319,9 +313,9 @@ let make_let_in let make_default ?(pos = Pos.no_pos) exceptions just cons = let rec bool_value = function - | ELit (Dcalc.Ast.LBool b), _ -> Some b + | ELit (LBool b), _ -> Some b | EApp ((EOp (Unop (Log (l, _))), _), [e]), _ - when l <> Dcalc.Ast.PosRecordIfTrueBool + when l <> PosRecordIfTrueBool (* we don't remove the log calls corresponding to source code definitions !*) -> bool_value e diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index 764aef36..4a95098f 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -17,11 +17,10 @@ (** Abstract syntax tree of the scope language *) open Utils +open Shared_ast (** {1 Identifiers} *) -module ScopeName = Dcalc.Ast.ScopeName -module ScopeNameSet : Set.S with type elt = ScopeName.t module ScopeMap : Map.S with type key = ScopeName.t module SubScopeName : Uid.Id with type info = Uid.MarkedString.info module SubScopeNameSet : Set.S with type elt = SubScopeName.t @@ -29,9 +28,6 @@ module SubScopeMap : Map.S with type key = SubScopeName.t module ScopeVar : Uid.Id with type info = Uid.MarkedString.info module ScopeVarSet : Set.S with type elt = ScopeVar.t module ScopeVarMap : Map.S with type key = ScopeVar.t -module StructName = Dcalc.Ast.StructName -module StructMap = Dcalc.Ast.StructMap -module StructFieldName = Dcalc.Ast.StructFieldName module StructFieldMap : Map.S with type key = StructFieldName.t module StructFieldMapLift : sig @@ -39,9 +35,6 @@ module StructFieldMapLift : sig 'a Bindlib.box StructFieldMap.t -> 'a StructFieldMap.t Bindlib.box end -module EnumName = Dcalc.Ast.EnumName -module EnumMap = Dcalc.Ast.EnumMap -module EnumConstructor = Dcalc.Ast.EnumConstructor module EnumConstructorMap : Map.S with type key = EnumConstructor.t module EnumConstructorMapLift : sig @@ -59,7 +52,7 @@ module LocationSet : Set.S with type elt = location Marked.pos (** {1 Abstract syntax tree} *) type typ = - | TLit of Dcalc.Ast.typ_lit + | TLit of typ_lit | TStruct of StructName.t | TEnum of EnumName.t | TArrow of typ Marked.pos * typ Marked.pos @@ -82,7 +75,7 @@ and expr = | ELit of Dcalc.Ast.lit | EAbs of (expr, marked_expr) Bindlib.mbinder * typ Marked.pos list | EApp of marked_expr * marked_expr list - | EOp of Dcalc.Ast.operator + | EOp of operator | EDefault of marked_expr list * marked_expr * marked_expr | EIfThenElse of marked_expr * marked_expr * marked_expr | EArray of marked_expr list diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index 05248304..85b15dfc 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -18,13 +18,14 @@ program. Vertices are functions, x -> y if x is used in the definition of y. *) open Utils +open Shared_ast module SVertex = struct - type t = Ast.ScopeName.t + type t = ScopeName.t - let hash x = Ast.ScopeName.hash x - let compare = Ast.ScopeName.compare - let equal x y = Ast.ScopeName.compare x y = 0 + let hash x = ScopeName.hash x + let compare = ScopeName.compare + let equal x y = ScopeName.compare x y = 0 end (** On the edges, the label is the expression responsible for the use of the @@ -62,10 +63,10 @@ let build_program_dep_graph (prgm : Ast.program) : SDependencies.t = if subscope = scope_name then Errors.raise_spanned_error (Marked.get_mark - (Ast.ScopeName.get_info scope.Ast.scope_decl_name)) + (ScopeName.get_info scope.Ast.scope_decl_name)) "The scope %a is calling into itself as a subscope, which is \ forbidden since Catala does not provide recursion" - Ast.ScopeName.format_t scope.Ast.scope_decl_name + ScopeName.format_t scope.Ast.scope_decl_name else Ast.ScopeMap.add subscope (Marked.get_mark (Ast.SubScopeName.get_info subindex)) @@ -90,14 +91,14 @@ let check_for_cycle_in_scope (g : SDependencies.t) : unit = (List.map (fun v -> let var_str, var_info = - ( Format.asprintf "%a" Ast.ScopeName.format_t v, - Ast.ScopeName.get_info v ) + ( Format.asprintf "%a" ScopeName.format_t v, + ScopeName.get_info v ) in let succs = SDependencies.succ_e g v in let _, edge_pos, succ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in - let succ_str = Format.asprintf "%a" Ast.ScopeName.format_t succ in + let succ_str = Format.asprintf "%a" ScopeName.format_t succ in [ ( Some ("Cycle variable " ^ var_str ^ ", declared:"), Marked.get_mark var_info ); @@ -112,39 +113,39 @@ let check_for_cycle_in_scope (g : SDependencies.t) : unit = Errors.raise_multispanned_error spans "Cyclic dependency detected between scopes!" -let get_scope_ordering (g : SDependencies.t) : Ast.ScopeName.t list = +let get_scope_ordering (g : SDependencies.t) : ScopeName.t list = List.rev (STopologicalTraversal.fold (fun sd acc -> sd :: acc) g []) module TVertex = struct - type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t + type t = Struct of StructName.t | Enum of EnumName.t let hash x = match x with - | Struct x -> Ast.StructName.hash x - | Enum x -> Ast.EnumName.hash x + | Struct x -> StructName.hash x + | Enum x -> EnumName.hash x let compare x y = match x, y with - | Struct x, Struct y -> Ast.StructName.compare x y - | Enum x, Enum y -> Ast.EnumName.compare x y + | Struct x, Struct y -> StructName.compare x y + | Enum x, Enum y -> EnumName.compare x y | Struct _, Enum _ -> 1 | Enum _, Struct _ -> -1 let equal x y = match x, y with - | Struct x, Struct y -> Ast.StructName.compare x y = 0 - | Enum x, Enum y -> Ast.EnumName.compare x y = 0 + | Struct x, Struct y -> StructName.compare x y = 0 + | Enum x, Enum y -> EnumName.compare x y = 0 | _ -> false let format_t (fmt : Format.formatter) (x : t) : unit = match x with - | Struct x -> Ast.StructName.format_t fmt x - | Enum x -> Ast.EnumName.format_t fmt x + | Struct x -> StructName.format_t fmt x + | Enum x -> EnumName.format_t fmt x let get_info (x : t) = match x with - | Struct x -> Ast.StructName.get_info x - | Enum x -> Ast.EnumName.get_info x + | Struct x -> StructName.get_info x + | Enum x -> EnumName.get_info x end module TVertexSet = Set.Make (TVertex) @@ -181,7 +182,7 @@ let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : TDependencies.t = let g = TDependencies.empty in let g = - Ast.StructMap.fold + StructMap.fold (fun s fields g -> List.fold_left (fun g (_, typ) -> @@ -205,7 +206,7 @@ let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : structs g in let g = - Ast.EnumMap.fold + EnumMap.fold (fun e cases g -> List.fold_left (fun g (_, typ) -> diff --git a/compiler/scopelang/dependency.mli b/compiler/scopelang/dependency.mli index 01c147c7..0414b1ae 100644 --- a/compiler/scopelang/dependency.mli +++ b/compiler/scopelang/dependency.mli @@ -18,25 +18,26 @@ program. Vertices are functions, x -> y if x is used in the definition of y. *) open Utils +open Shared_ast (** {1 Scope dependencies} *) (** On the edges, the label is the expression responsible for the use of the function *) module SDependencies : - Graph.Sig.P with type V.t = Ast.ScopeName.t and type E.label = Pos.t + Graph.Sig.P with type V.t = ScopeName.t and type E.label = Pos.t val build_program_dep_graph : Ast.program -> SDependencies.t val check_for_cycle_in_scope : SDependencies.t -> unit -val get_scope_ordering : SDependencies.t -> Ast.ScopeName.t list +val get_scope_ordering : SDependencies.t -> ScopeName.t list (** {1 Type dependencies} *) module TVertex : sig - type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t + type t = Struct of StructName.t | Enum of EnumName.t val format_t : Format.formatter -> t -> unit - val get_info : t -> Ast.StructName.info + val get_info : t -> StructName.info include Graph.Sig.COMPARABLE with type t := t end diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index 68fbb815..0b13425e 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -15,6 +15,7 @@ the License. *) open Utils +open Shared_ast open Ast let needs_parens (e : expr Marked.pos) : bool = @@ -42,8 +43,8 @@ let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) : unit = in match Marked.unmark typ with | TLit l -> Dcalc.Print.format_tlit fmt l - | TStruct s -> Format.fprintf fmt "%a" Ast.StructName.format_t s - | TEnum e -> Format.fprintf fmt "%a" Ast.EnumName.format_t e + | TStruct s -> Format.fprintf fmt "%a" StructName.format_t s + | TEnum e -> Format.fprintf fmt "%a" EnumName.format_t e | TArrow (t1, t2) -> Format.fprintf fmt "@[%a %a@ %a@]" format_typ_with_parens t1 Dcalc.Print.format_operator "→" format_typ t2 @@ -67,14 +68,14 @@ let rec format_expr | EVar v -> Format.fprintf fmt "%a" format_var v | ELit l -> Format.fprintf fmt "%a" Dcalc.Print.format_lit l | EStruct (name, fields) -> - Format.fprintf fmt " @[%a@ %a@ %a@ %a@]" Ast.StructName.format_t name + Format.fprintf fmt " @[%a@ %a@ %a@ %a@]" StructName.format_t name Dcalc.Print.format_punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";") (fun fmt (field_name, field_expr) -> Format.fprintf fmt "%a%a%a%a@ %a" Dcalc.Print.format_punctuation "\"" - Ast.StructFieldName.format_t field_name + StructFieldName.format_t field_name Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation "=" format_expr field_expr)) (Ast.StructFieldMap.bindings fields) @@ -82,9 +83,9 @@ let rec format_expr | EStructAccess (e1, field, _) -> Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" - Ast.StructFieldName.format_t field Dcalc.Print.format_punctuation "\"" + StructFieldName.format_t field Dcalc.Print.format_punctuation "\"" | EEnumInj (e1, cons, _) -> - Format.fprintf fmt "%a@ %a" Ast.EnumConstructor.format_t cons format_expr e1 + Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons format_expr e1 | EMatch (e1, _, cases) -> Format.fprintf fmt "@[%a@ @[%a@]@ %a@ %a@]" Dcalc.Print.format_keyword "match" format_expr e1 diff --git a/compiler/scopelang/print.mli b/compiler/scopelang/print.mli index ca6b8945..3188ccce 100644 --- a/compiler/scopelang/print.mli +++ b/compiler/scopelang/print.mli @@ -29,7 +29,7 @@ val format_expr : val format_scope : ?debug:bool (** [true] for debug printing *) -> Format.formatter -> - Ast.ScopeName.t * Ast.scope_decl -> + Shared_ast.ScopeName.t * Ast.scope_decl -> unit val format_program : diff --git a/compiler/scopelang/scope_to_dcalc.ml b/compiler/scopelang/scope_to_dcalc.ml index dd336439..3aeb8957 100644 --- a/compiler/scopelang/scope_to_dcalc.ml +++ b/compiler/scopelang/scope_to_dcalc.ml @@ -19,18 +19,18 @@ open Shared_ast type scope_var_ctx = { scope_var_name : Ast.ScopeVar.t; - scope_var_typ : Dcalc.Ast.typ; + scope_var_typ : typ; scope_var_io : Ast.io; } type scope_sig_ctx = { scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *) - scope_sig_scope_var : Dcalc.Ast.untyped Dcalc.Ast.var; + scope_sig_scope_var : untyped Dcalc.Ast.var; (** Var representing the scope *) - scope_sig_input_var : Dcalc.Ast.untyped Dcalc.Ast.var; + scope_sig_input_var : untyped Dcalc.Ast.var; (** Var representing the scope input inside the scope func *) - scope_sig_input_struct : Ast.StructName.t; (** Scope input *) - scope_sig_output_struct : Ast.StructName.t; (** Scope output *) + scope_sig_input_struct : StructName.t; (** Scope input *) + scope_sig_output_struct : StructName.t; (** Scope output *) } type scope_sigs_ctx = scope_sig_ctx Ast.ScopeMap.t @@ -38,21 +38,21 @@ type scope_sigs_ctx = scope_sig_ctx Ast.ScopeMap.t type ctx = { structs : Ast.struct_ctx; enums : Ast.enum_ctx; - scope_name : Ast.ScopeName.t; + scope_name : ScopeName.t; scopes_parameters : scope_sigs_ctx; scope_vars : - (Dcalc.Ast.untyped Dcalc.Ast.var * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t; + (untyped Dcalc.Ast.var * typ * Ast.io) Ast.ScopeVarMap.t; subscope_vars : - (Dcalc.Ast.untyped Dcalc.Ast.var * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t + (untyped Dcalc.Ast.var * typ * Ast.io) Ast.ScopeVarMap.t Ast.SubScopeMap.t; - local_vars : Dcalc.Ast.untyped Dcalc.Ast.var Ast.VarMap.t; + local_vars : untyped Dcalc.Ast.var Ast.VarMap.t; } let empty_ctx (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) (scopes_ctx : scope_sigs_ctx) - (scope_name : Ast.ScopeName.t) = + (scope_name : ScopeName.t) = { structs = struct_ctx; enums = enum_ctx; @@ -64,62 +64,62 @@ let empty_ctx } let rec translate_typ (ctx : ctx) (t : Ast.typ Marked.pos) : - Dcalc.Ast.typ Marked.pos = + typ Marked.pos = Marked.same_mark_as (match Marked.unmark t with - | Ast.TLit l -> Dcalc.Ast.TLit l + | Ast.TLit l -> TLit l | Ast.TArrow (t1, t2) -> - Dcalc.Ast.TArrow (translate_typ ctx t1, translate_typ ctx t2) + TArrow (translate_typ ctx t1, translate_typ ctx t2) | Ast.TStruct s_uid -> - let s_fields = Ast.StructMap.find s_uid ctx.structs in - Dcalc.Ast.TTuple + let s_fields = StructMap.find s_uid ctx.structs in + TTuple (List.map (fun (_, t) -> translate_typ ctx t) s_fields, Some s_uid) | Ast.TEnum e_uid -> - let e_cases = Ast.EnumMap.find e_uid ctx.enums in - Dcalc.Ast.TEnum + let e_cases = EnumMap.find e_uid ctx.enums in + TEnum (List.map (fun (_, t) -> translate_typ ctx t) e_cases, e_uid) | Ast.TArray t1 -> - Dcalc.Ast.TArray (translate_typ ctx (Marked.same_mark_as t1 t)) - | Ast.TAny -> Dcalc.Ast.TAny) + TArray (translate_typ ctx (Marked.same_mark_as t1 t)) + | Ast.TAny -> TAny) t -let pos_mark (pos : Pos.t) : Dcalc.Ast.untyped Dcalc.Ast.mark = - Dcalc.Ast.Untyped { pos } +let pos_mark (pos : Pos.t) : untyped mark = + Untyped { pos } let pos_mark_as e = pos_mark (Marked.get_mark e) let merge_defaults - (caller : Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box) - (callee : Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box) : - Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box = + (caller : untyped Dcalc.Ast.marked_expr Bindlib.box) + (callee : untyped Dcalc.Ast.marked_expr Bindlib.box) : + untyped Dcalc.Ast.marked_expr Bindlib.box = let caller = let m = Marked.get_mark (Bindlib.unbox caller) in Dcalc.Ast.make_app caller - [Bindlib.box (Dcalc.Ast.ELit Dcalc.Ast.LUnit, m)] + [Bindlib.box (ELit LUnit, m)] m in let body = Bindlib.box_apply2 (fun caller callee -> let m = Marked.get_mark callee in - ( Dcalc.Ast.EDefault - ([caller], (Dcalc.Ast.ELit (Dcalc.Ast.LBool true), m), callee), + ( EDefault + ([caller], (ELit (LBool true), m), callee), m )) caller callee in body let tag_with_log_entry - (e : Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box) - (l : Dcalc.Ast.log_entry) + (e : untyped Dcalc.Ast.marked_expr Bindlib.box) + (l : log_entry) (markings : Utils.Uid.MarkedString.info list) : - Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box = + untyped Dcalc.Ast.marked_expr Bindlib.box = Bindlib.box_apply (fun e -> Marked.same_mark_as - (Dcalc.Ast.EApp + (EApp ( Marked.same_mark_as - (Dcalc.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings)))) + (EOp (Unop (Log (l, markings)))) e, [e] )) e) @@ -165,15 +165,15 @@ let collapse_similar_outcomes (excepts : Ast.expr Marked.pos list) : excepts let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : - Dcalc.Ast.untyped Dcalc.Ast.marked_expr Bindlib.box = - Bindlib.box_apply (fun (x : Dcalc.Ast.untyped Dcalc.Ast.expr) -> + untyped Dcalc.Ast.marked_expr Bindlib.box = + Bindlib.box_apply (fun (x : untyped Dcalc.Ast.expr) -> Marked.mark (pos_mark_as e) x) @@ match Marked.unmark e with | EVar v -> Bindlib.box_var (Ast.VarMap.find v ctx.local_vars) - | ELit l -> Bindlib.box (Dcalc.Ast.ELit l) + | ELit l -> Bindlib.box (ELit l) | EStruct (struct_name, e_fields) -> - let struct_sig = Ast.StructMap.find struct_name ctx.structs in + let struct_sig = StructMap.find struct_name ctx.structs in let d_fields, remaining_e_fields = List.fold_right (fun (field_name, _) (d_fields, e_fields) -> @@ -185,58 +185,58 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : if Ast.StructFieldMap.cardinal remaining_e_fields > 0 then Errors.raise_spanned_error (Marked.get_mark e) "The fields \"%a\" do not belong to the structure %a" - Ast.StructName.format_t struct_name + StructName.format_t struct_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt (field_name, _) -> - Format.fprintf fmt "%a" Ast.StructFieldName.format_t field_name)) + Format.fprintf fmt "%a" StructFieldName.format_t field_name)) (Ast.StructFieldMap.bindings remaining_e_fields) else Bindlib.box_apply - (fun d_fields -> Dcalc.Ast.ETuple (d_fields, Some struct_name)) + (fun d_fields -> ETuple (d_fields, Some struct_name)) (Bindlib.box_list d_fields) | EStructAccess (e1, field_name, struct_name) -> - let struct_sig = Ast.StructMap.find struct_name ctx.structs in + let struct_sig = StructMap.find struct_name ctx.structs in let _, field_index = try List.assoc field_name (List.mapi (fun i (x, y) -> x, (y, i)) struct_sig) with Not_found -> Errors.raise_spanned_error (Marked.get_mark e) "The field \"%a\" does not belong to the structure %a" - Ast.StructFieldName.format_t field_name Ast.StructName.format_t + StructFieldName.format_t field_name StructName.format_t struct_name in let e1 = translate_expr ctx e1 in Bindlib.box_apply (fun e1 -> - Dcalc.Ast.ETupleAccess + ETupleAccess ( e1, field_index, Some struct_name, List.map (fun (_, t) -> translate_typ ctx t) struct_sig )) e1 | EEnumInj (e1, constructor, enum_name) -> - let enum_sig = Ast.EnumMap.find enum_name ctx.enums in + let enum_sig = EnumMap.find enum_name ctx.enums in let _, constructor_index = try List.assoc constructor (List.mapi (fun i (x, y) -> x, (y, i)) enum_sig) with Not_found -> Errors.raise_spanned_error (Marked.get_mark e) "The constructor \"%a\" does not belong to the enum %a" - Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t + EnumConstructor.format_t constructor EnumName.format_t enum_name in let e1 = translate_expr ctx e1 in Bindlib.box_apply (fun e1 -> - Dcalc.Ast.EInj + EInj ( e1, constructor_index, enum_name, List.map (fun (_, t) -> translate_typ ctx t) enum_sig )) e1 | EMatch (e1, enum_name, cases) -> - let enum_sig = Ast.EnumMap.find enum_name ctx.enums in + let enum_sig = EnumMap.find enum_name ctx.enums in let d_cases, remaining_e_cases = List.fold_right (fun (constructor, _) (d_cases, e_cases) -> @@ -246,7 +246,7 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : Errors.raise_spanned_error (Marked.get_mark e) "The constructor %a of enum %a is missing from this pattern \ matching" - Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t + EnumConstructor.format_t constructor EnumName.format_t enum_name in let case_d = translate_expr ctx case_e in @@ -256,16 +256,16 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : if Ast.EnumConstructorMap.cardinal remaining_e_cases > 0 then Errors.raise_spanned_error (Marked.get_mark e) "Patter matching is incomplete for enum %a: missing cases %a" - Ast.EnumName.format_t enum_name + EnumName.format_t enum_name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt (case_name, _) -> - Format.fprintf fmt "%a" Ast.EnumConstructor.format_t case_name)) + Format.fprintf fmt "%a" EnumConstructor.format_t case_name)) (Ast.EnumConstructorMap.bindings remaining_e_cases) else let e1 = translate_expr ctx e1 in Bindlib.box_apply2 - (fun d_fields e1 -> Dcalc.Ast.EMatch (e1, d_fields, enum_name)) + (fun d_fields e1 -> EMatch (e1, d_fields, enum_name)) (Bindlib.box_list d_cases) e1 | EApp (e1, args) -> (* We insert various log calls to record arguments and outputs of @@ -274,14 +274,14 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : let markings l = match l with | Ast.ScopeVar (v, _) -> - [Ast.ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v] + [ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v] | Ast.SubScopeVar (s, _, (v, _)) -> - [Ast.ScopeName.get_info s; Ast.ScopeVar.get_info v] + [ScopeName.get_info s; Ast.ScopeVar.get_info v] in let e1_func = match Marked.unmark e1 with | ELocation l -> - tag_with_log_entry e1_func Dcalc.Ast.BeginCall (markings l) + tag_with_log_entry e1_func BeginCall (markings l) | _ -> e1_func in let new_args = List.map (translate_expr ctx) args in @@ -293,9 +293,9 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : let retrieve_in_and_out_typ_or_any var vars = let _, typ, _ = Ast.ScopeVarMap.find (Marked.unmark var) vars in match typ with - | Dcalc.Ast.TArrow (marked_input_typ, marked_output_typ) -> + | TArrow (marked_input_typ, marked_output_typ) -> Marked.unmark marked_input_typ, Marked.unmark marked_output_typ - | _ -> Dcalc.Ast.TAny, Dcalc.Ast.TAny + | _ -> TAny, TAny in match Marked.unmark e1 with | ELocation (ScopeVar var) -> @@ -304,20 +304,20 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : ctx.subscope_vars |> Ast.SubScopeMap.find (Marked.unmark sname) |> retrieve_in_and_out_typ_or_any var - | _ -> Dcalc.Ast.TAny, Dcalc.Ast.TAny + | _ -> TAny, TAny in let new_args = match Marked.unmark e1, new_args with | ELocation l, [new_arg] -> [ - tag_with_log_entry new_arg (Dcalc.Ast.VarDef input_typ) + tag_with_log_entry new_arg (VarDef input_typ) (markings l @ [Marked.same_mark_as "input" e]); ] | _ -> new_args in let new_e = Bindlib.box_apply2 - (fun e' u -> Dcalc.Ast.EApp (e', u), pos_mark_as e) + (fun e' u -> EApp (e', u), pos_mark_as e) e1_func (Bindlib.box_list new_args) in @@ -325,9 +325,9 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : match Marked.unmark e1 with | ELocation l -> tag_with_log_entry - (tag_with_log_entry new_e (Dcalc.Ast.VarDef output_typ) + (tag_with_log_entry new_e (VarDef output_typ) (markings l @ [Marked.same_mark_as "output" e])) - Dcalc.Ast.EndCall (markings l) + EndCall (markings l) | _ -> new_e in Bindlib.box_apply Marked.unmark new_e @@ -348,12 +348,12 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : in let binder = Bindlib.bind_mvar new_xs body in Bindlib.box_apply - (fun b -> Dcalc.Ast.EAbs (b, List.map (translate_typ ctx) typ)) + (fun b -> EAbs (b, List.map (translate_typ ctx) typ)) binder | EDefault (excepts, just, cons) -> let excepts = collapse_similar_outcomes excepts in Bindlib.box_apply3 - (fun e j c -> Dcalc.Ast.EDefault (e, j, c)) + (fun e j c -> EDefault (e, j, c)) (Bindlib.box_list (List.map (translate_expr ctx) excepts)) (translate_expr ctx just) (translate_expr ctx cons) | ELocation (ScopeVar a) -> @@ -381,16 +381,16 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Marked.pos) : (Marked.unmark a) Ast.SubScopeName.format_t (Marked.unmark s)) | EIfThenElse (cond, et, ef) -> Bindlib.box_apply3 - (fun c t f -> Dcalc.Ast.EIfThenElse (c, t, f)) + (fun c t f -> EIfThenElse (c, t, f)) (translate_expr ctx cond) (translate_expr ctx et) (translate_expr ctx ef) - | EOp op -> Bindlib.box (Dcalc.Ast.EOp op) + | EOp op -> Bindlib.box (EOp op) | ErrorOnEmpty e' -> Bindlib.box_apply - (fun e' -> Dcalc.Ast.ErrorOnEmpty e') + (fun e' -> ErrorOnEmpty e') (translate_expr ctx e') | EArray es -> Bindlib.box_apply - (fun es -> Dcalc.Ast.EArray es) + (fun es -> EArray es) (Bindlib.box_list (List.map (translate_expr ctx) es)) (** The result of a rule translation is a list of assignment, with variables and @@ -402,13 +402,13 @@ let translate_rule (ctx : ctx) (rule : Ast.rule) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) : - (( Dcalc.Ast.untyped Dcalc.Ast.expr, - Dcalc.Ast.untyped ) - Dcalc.Ast.scope_body_expr + (( untyped Dcalc.Ast.expr, + untyped ) + scope_body_expr Bindlib.box -> - ( Dcalc.Ast.untyped Dcalc.Ast.expr, - Dcalc.Ast.untyped ) - Dcalc.Ast.scope_body_expr + ( untyped Dcalc.Ast.expr, + untyped ) + scope_body_expr Bindlib.box) * ctx = match rule with @@ -421,7 +421,7 @@ let translate_rule let merged_expr = Bindlib.box_apply (fun merged_expr -> - Dcalc.Ast.ErrorOnEmpty merged_expr, pos_mark_as a_name) + ErrorOnEmpty merged_expr, pos_mark_as a_name) (match Marked.unmark a_io.io_input with | OnlyInput -> failwith "should not happen" @@ -432,19 +432,19 @@ let translate_rule in let merged_expr = tag_with_log_entry merged_expr - (Dcalc.Ast.VarDef (Marked.unmark tau)) + (VarDef (Marked.unmark tau)) [sigma_name, pos_sigma; a_name] in ( (fun next -> Bindlib.box_apply2 (fun next merged_expr -> - Dcalc.Ast.ScopeLet + ScopeLet { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_typ = tau; - Dcalc.Ast.scope_let_expr = merged_expr; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.ScopeVarDefinition; - Dcalc.Ast.scope_let_pos = Marked.get_mark a; + scope_let_next = next; + scope_let_typ = tau; + scope_let_expr = merged_expr; + scope_let_kind = ScopeVarDefinition; + scope_let_pos = Marked.get_mark a; }) (Bindlib.bind_var a_var next) merged_expr), @@ -472,7 +472,7 @@ let translate_rule let tau = translate_typ ctx tau in let new_e = tag_with_log_entry (translate_expr ctx e) - (Dcalc.Ast.VarDef (Marked.unmark tau)) + (VarDef (Marked.unmark tau)) [sigma_name, pos_sigma; a_name] in let silent_var = Var.make "_" in @@ -481,31 +481,31 @@ let translate_rule | NoInput -> failwith "should not happen" | OnlyInput -> Bindlib.box_apply - (fun new_e -> Dcalc.Ast.ErrorOnEmpty new_e, pos_mark_as subs_var) + (fun new_e -> ErrorOnEmpty new_e, pos_mark_as subs_var) new_e | Reentrant -> Dcalc.Ast.make_abs (Array.of_list [silent_var]) new_e - [Dcalc.Ast.TLit TUnit, var_def_pos] + [TLit TUnit, var_def_pos] (pos_mark var_def_pos) in ( (fun next -> Bindlib.box_apply2 (fun next thunked_or_nonempty_new_e -> - Dcalc.Ast.ScopeLet + ScopeLet { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = Marked.get_mark a_name; - Dcalc.Ast.scope_let_typ = + scope_let_next = next; + scope_let_pos = Marked.get_mark a_name; + scope_let_typ = (match Marked.unmark a_io.io_input with | NoInput -> failwith "should not happen" | OnlyInput -> tau | Reentrant -> - ( Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), + ( TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos )); - Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition; + scope_let_expr = thunked_or_nonempty_new_e; + scope_let_kind = SubScopeVarDefinition; }) (Bindlib.bind_var a_var next) thunked_or_nonempty_new_e), @@ -573,7 +573,7 @@ let translate_rule let subscope_struct_arg = Bindlib.box_apply (fun subscope_args -> - ( Dcalc.Ast.ETuple (subscope_args, Some called_scope_input_struct), + ( ETuple (subscope_args, Some called_scope_input_struct), pos_mark pos_call )) (Bindlib.box_list subscope_args) in @@ -593,28 +593,28 @@ let translate_rule tag_with_log_entry (Dcalc.Ast.make_var (scope_dcalc_var, pos_mark_as (Ast.SubScopeName.get_info subindex))) - Dcalc.Ast.BeginCall + BeginCall [ sigma_name, pos_sigma; Ast.SubScopeName.get_info subindex; - Ast.ScopeName.get_info subname; + ScopeName.get_info subname; ] in let call_expr = tag_with_log_entry (Bindlib.box_apply2 - (fun e u -> Dcalc.Ast.EApp (e, [u]), pos_mark Pos.no_pos) + (fun e u -> EApp (e, [u]), pos_mark Pos.no_pos) subscope_func subscope_struct_arg) - Dcalc.Ast.EndCall + EndCall [ sigma_name, pos_sigma; Ast.SubScopeName.get_info subindex; - Ast.ScopeName.get_info subname; + ScopeName.get_info subname; ] in let result_tuple_var = Var.make "result" in let result_tuple_typ = - ( Dcalc.Ast.TTuple + ( TTuple ( List.map (fun (subvar, _) -> subvar.scope_var_typ, pos_sigma) all_subscope_output_vars_dcalc, @@ -624,13 +624,13 @@ let translate_rule let call_scope_let next = Bindlib.box_apply2 (fun next call_expr -> - Dcalc.Ast.ScopeLet + ScopeLet { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = pos_sigma; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.CallingSubScope; - Dcalc.Ast.scope_let_typ = result_tuple_typ; - Dcalc.Ast.scope_let_expr = call_expr; + scope_let_next = next; + scope_let_pos = pos_sigma; + scope_let_kind = CallingSubScope; + scope_let_typ = result_tuple_typ; + scope_let_expr = call_expr; }) (Bindlib.bind_var result_tuple_var next) call_expr @@ -640,15 +640,15 @@ let translate_rule (fun (var_ctx, v) (next, i) -> ( Bindlib.box_apply2 (fun next r -> - Dcalc.Ast.ScopeLet + ScopeLet { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = pos_sigma; - Dcalc.Ast.scope_let_typ = var_ctx.scope_var_typ, pos_sigma; - Dcalc.Ast.scope_let_kind = - Dcalc.Ast.DestructuringSubScopeResults; - Dcalc.Ast.scope_let_expr = - ( Dcalc.Ast.ETupleAccess + scope_let_next = next; + scope_let_pos = pos_sigma; + scope_let_typ = var_ctx.scope_var_typ, pos_sigma; + scope_let_kind = + DestructuringSubScopeResults; + scope_let_expr = + ( ETupleAccess ( r, i, Some called_scope_return_struct, @@ -682,20 +682,20 @@ let translate_rule ( (fun next -> Bindlib.box_apply2 (fun next new_e -> - Dcalc.Ast.ScopeLet + ScopeLet { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = Marked.get_mark e; - Dcalc.Ast.scope_let_typ = - Dcalc.Ast.TLit TUnit, Marked.get_mark e; - Dcalc.Ast.scope_let_expr = + scope_let_next = next; + scope_let_pos = Marked.get_mark e; + scope_let_typ = + TLit TUnit, Marked.get_mark e; + scope_let_expr = (* To ensure that we throw an error if the value is not defined, we add an check "ErrorOnEmpty" here. *) Marked.same_mark_as - (Dcalc.Ast.EAssert - (Dcalc.Ast.ErrorOnEmpty new_e, pos_mark_as e)) + (EAssert + (ErrorOnEmpty new_e, pos_mark_as e)) new_e; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion; + scope_let_kind = Assertion; }) (Bindlib.bind_var (Var.make "_") next) new_e), @@ -705,10 +705,10 @@ let translate_rules (ctx : ctx) (rules : Ast.rule list) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) - (sigma_return_struct_name : Ast.StructName.t) : - ( Dcalc.Ast.untyped Dcalc.Ast.expr, - Dcalc.Ast.untyped ) - Dcalc.Ast.scope_body_expr + (sigma_return_struct_name : StructName.t) : + ( untyped Dcalc.Ast.expr, + untyped ) + scope_body_expr Bindlib.box * ctx = let scope_lets, new_ctx = @@ -730,7 +730,7 @@ let translate_rules let return_exp = Bindlib.box_apply (fun args -> - ( Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), + ( ETuple (args, Some sigma_return_struct_name), pos_mark pos_sigma )) (Bindlib.box_list (List.map @@ -740,7 +740,7 @@ let translate_rules in ( scope_lets (Bindlib.box_apply - (fun return_exp -> Dcalc.Ast.Result return_exp) + (fun return_exp -> Result return_exp) return_exp), new_ctx ) @@ -748,12 +748,12 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) (sctx : scope_sigs_ctx) - (scope_name : Ast.ScopeName.t) + (scope_name : ScopeName.t) (sigma : Ast.scope_decl) : - (Dcalc.Ast.untyped Dcalc.Ast.expr, Dcalc.Ast.untyped) Dcalc.Ast.scope_body + (untyped Dcalc.Ast.expr, untyped) scope_body Bindlib.box * struct_ctx = - let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in + let sigma_info = ScopeName.get_info sigma.scope_decl_name in let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in let scope_variables = scope_sig.scope_sig_local_vars in let ctx = @@ -813,8 +813,8 @@ let translate_scope_decl match Marked.unmark var_ctx.scope_var_io.io_input with | OnlyInput -> var_ctx.scope_var_typ, pos_sigma | Reentrant -> - ( Dcalc.Ast.TArrow - ((Dcalc.Ast.TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), + ( TArrow + ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), pos_sigma ) | NoInput -> failwith "should not happen" in @@ -824,15 +824,15 @@ let translate_scope_decl (fun (var_ctx, v) (next, i) -> ( Bindlib.box_apply2 (fun next r -> - Dcalc.Ast.ScopeLet + ScopeLet { - Dcalc.Ast.scope_let_kind = - Dcalc.Ast.DestructuringInputStruct; - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = pos_sigma; - Dcalc.Ast.scope_let_typ = input_var_typ var_ctx; - Dcalc.Ast.scope_let_expr = - ( Dcalc.Ast.ETupleAccess + scope_let_kind = + DestructuringInputStruct; + scope_let_next = next; + scope_let_pos = pos_sigma; + scope_let_typ = input_var_typ var_ctx; + scope_let_expr = + ( ETupleAccess ( r, i, Some scope_input_struct_name, @@ -851,7 +851,7 @@ let translate_scope_decl List.map (fun (var_ctx, dvar) -> let struct_field_name = - Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_out", pos_sigma) + StructFieldName.fresh (Bindlib.name_of dvar ^ "_out", pos_sigma) in struct_field_name, (var_ctx.scope_var_typ, pos_sigma)) scope_output_variables @@ -860,29 +860,29 @@ let translate_scope_decl List.map (fun (var_ctx, dvar) -> let struct_field_name = - Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma) + StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma) in struct_field_name, input_var_typ var_ctx) scope_input_variables in let new_struct_ctx = - Ast.StructMap.add scope_input_struct_name scope_input_struct_fields - (Ast.StructMap.singleton scope_return_struct_name + StructMap.add scope_input_struct_name scope_input_struct_fields + (StructMap.singleton scope_return_struct_name scope_return_struct_fields) in ( Bindlib.box_apply (fun scope_body_expr -> { - Dcalc.Ast.scope_body_expr; - Dcalc.Ast.scope_body_input_struct = scope_input_struct_name; - Dcalc.Ast.scope_body_output_struct = scope_return_struct_name; + scope_body_expr; + scope_body_input_struct = scope_input_struct_name; + scope_body_output_struct = scope_return_struct_name; }) (Bindlib.bind_var scope_input_var (input_destructurings rules_with_return_expr)), new_struct_ctx ) let translate_program (prgm : Ast.program) : - Dcalc.Ast.untyped Dcalc.Ast.program * Dependency.TVertex.t list = + untyped Dcalc.Ast.program * Dependency.TVertex.t list = let scope_dependencies = Dependency.build_program_dep_graph prgm in Dependency.check_for_cycle_in_scope scope_dependencies; let types_ordering = @@ -894,16 +894,16 @@ let translate_program (prgm : Ast.program) : let ctx_for_typ_translation scope_name = empty_ctx struct_ctx enum_ctx Ast.ScopeMap.empty scope_name in - let dummy_scope = Ast.ScopeName.fresh ("dummy", Pos.no_pos) in + let dummy_scope = ScopeName.fresh ("dummy", Pos.no_pos) in let decl_ctx = { - Dcalc.Ast.ctx_structs = - Ast.StructMap.map + ctx_structs = + StructMap.map (List.map (fun (x, y) -> x, translate_typ (ctx_for_typ_translation dummy_scope) y)) struct_ctx; - Dcalc.Ast.ctx_enums = - Ast.EnumMap.map + ctx_enums = + EnumMap.map (List.map (fun (x, y) -> x, (translate_typ (ctx_for_typ_translation dummy_scope)) y)) enum_ctx; @@ -914,22 +914,22 @@ let translate_program (prgm : Ast.program) : (fun scope_name scope -> let scope_dvar = Var.make - (Marked.unmark (Ast.ScopeName.get_info scope.Ast.scope_decl_name)) + (Marked.unmark (ScopeName.get_info scope.Ast.scope_decl_name)) in let scope_return_struct_name = - Ast.StructName.fresh + StructName.fresh (Marked.map_under_mark (fun s -> s ^ "_out") - (Ast.ScopeName.get_info scope_name)) + (ScopeName.get_info scope_name)) in let scope_input_var = - Var.make (Marked.unmark (Ast.ScopeName.get_info scope_name) ^ "_in") + Var.make (Marked.unmark (ScopeName.get_info scope_name) ^ "_in") in let scope_input_struct_name = - Ast.StructName.fresh + StructName.fresh (Marked.map_under_mark (fun s -> s ^ "_in") - (Ast.ScopeName.get_info scope_name)) + (ScopeName.get_info scope_name)) in { scope_sig_local_vars = @@ -954,7 +954,7 @@ let translate_program (prgm : Ast.program) : (* the resulting expression is the list of definitions of all the scopes, ending with the top-level scope. *) let (scopes, decl_ctx) - : (Dcalc.Ast.untyped Dcalc.Ast.expr, Dcalc.Ast.untyped) Dcalc.Ast.scopes + : (untyped Dcalc.Ast.expr, untyped) scopes Bindlib.box * _ = List.fold_right @@ -967,21 +967,21 @@ let translate_program (prgm : Ast.program) : let decl_ctx = { decl_ctx with - Dcalc.Ast.ctx_structs = - Ast.StructMap.union + ctx_structs = + StructMap.union (fun _ _ -> assert false (* should not happen *)) - decl_ctx.Dcalc.Ast.ctx_structs scope_out_struct; + decl_ctx.ctx_structs scope_out_struct; } in let scope_next = Bindlib.bind_var dvar scopes in let new_scopes = Bindlib.box_apply2 (fun scope_body scope_next -> - Dcalc.Ast.ScopeDef { scope_name; scope_body; scope_next }) + ScopeDef { scope_name; scope_body; scope_next }) scope_body scope_next in new_scopes, decl_ctx) scope_ordering - (Bindlib.box Dcalc.Ast.Nil, decl_ctx) + (Bindlib.box Nil, decl_ctx) in { scopes = Bindlib.unbox scopes; decl_ctx }, types_ordering diff --git a/compiler/scopelang/scope_to_dcalc.mli b/compiler/scopelang/scope_to_dcalc.mli index 38f6228a..510fb5bf 100644 --- a/compiler/scopelang/scope_to_dcalc.mli +++ b/compiler/scopelang/scope_to_dcalc.mli @@ -17,7 +17,7 @@ (** Scope language to default calculus translator *) val translate_program : - Ast.program -> Dcalc.Ast.untyped Dcalc.Ast.program * Dependency.TVertex.t list + Ast.program -> Shared_ast.untyped Dcalc.Ast.program * Dependency.TVertex.t list (** Usage [translate_program p] returns a tuple [(new_program, types_list)] where [new_program] is the map of translated scopes. Finally, [types_list] is a list of all types (structs and enums) used in the program, correctly diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 34cc1160..b0699cbb 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -68,7 +68,67 @@ let eraise e1 pos = Bindlib.box (ERaise e1, pos) let ecatch e1 exn e2 pos = Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2 -let translate_var v = Bindlib.copy_var v (fun x -> EVar x) (Bindlib.name_of v) +(* - Manipulation of marks - *) + +let no_mark (type m) : m mark -> m mark = function + | Untyped _ -> Untyped { pos = Pos.no_pos } + | Typed _ -> Typed { pos = Pos.no_pos; ty = Marked.mark Pos.no_pos TAny } + +let mark_pos (type m) (m : m mark) : Pos.t = + match m with Untyped { pos } | Typed { pos; _ } -> pos + +let pos (type m) (x : ('a, m) marked) : Pos.t = mark_pos (Marked.get_mark x) + +let ty (_, m) : marked_typ = match m with Typed { ty; _ } -> ty + +let with_ty (type m) (ty : marked_typ) (x : ('a, m) marked) : ('a, typed) marked + = + Marked.mark + (match Marked.get_mark x with + | Untyped { pos } -> Typed { pos; ty } + | Typed m -> Typed { m with ty }) + (Marked.unmark x) + +let map_mark + (type m) + (pos_f : Pos.t -> Pos.t) + (ty_f : marked_typ -> marked_typ) + (m : m mark) : m mark = + match m with + | Untyped { pos } -> Untyped { pos = pos_f pos } + | Typed { pos; ty } -> Typed { pos = pos_f pos; ty = ty_f ty } + +let map_mark2 + (type m) + (pos_f : Pos.t -> Pos.t -> Pos.t) + (ty_f : typed -> typed -> marked_typ) + (m1 : m mark) + (m2 : m mark) : m mark = + match m1, m2 with + | Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos } + | Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 } + +let fold_marks + (type m) + (pos_f : Pos.t list -> Pos.t) + (ty_f : typed list -> marked_typ) + (ms : m mark list) : m mark = + match ms with + | [] -> invalid_arg "Dcalc.Ast.fold_mark" + | Untyped _ :: _ as ms -> + Untyped { pos = pos_f (List.map (function Untyped { pos } -> pos) ms) } + | Typed _ :: _ -> + Typed + { + pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms); + ty = ty_f (List.map (function Typed m -> m) ms); + } + +let get_scope_body_mark scope_body = + match snd (Bindlib.unbind scope_body.scope_body_expr) with + | Result e | ScopeLet { scope_let_expr = e; _ } -> Marked.get_mark e + +(* - Traversal functions - *) let map (type a) @@ -81,10 +141,10 @@ let map | EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m | EOp op -> Bindlib.box (EOp op, m) | EArray args -> earray (List.map (f ctx) args) m - | EVar v -> evar (translate_var v) m + | EVar v -> evar (Var.translate v) m | EAbs (binder, typs) -> let vars, body = Bindlib.unmbind binder in - eabs (Bindlib.bind_mvar (Array.map translate_var vars) (f ctx body)) typs m + eabs (Bindlib.bind_mvar (Array.map Var.translate vars) (f ctx body)) typs m | EIfThenElse (e1, e2, e3) -> eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m | ETuple (args, s) -> etuple (List.map (f ctx) args) s m @@ -179,3 +239,22 @@ let map_exprs_in_scopes ~f ~varf scopes = }) new_scope_body_expr new_next) ~init:(Bindlib.box Nil) scopes + +(* - *) + +(** See [Bindlib.box_term] documentation for why we are doing that. *) +let box e= + let rec id_t () e = map () ~f:id_t e in + id_t () e + +let untype e = map_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e + +let untype_program prg = + { + prg with + scopes = + Bindlib.unbox + (map_exprs_in_scopes + ~f:(fun e -> untype e) + ~varf:Var.translate prg.scopes); + } diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 21d2f973..1734505a 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -105,13 +105,70 @@ val eerroronempty : 't -> ('a, 't) marked_gexpr Bindlib.box -(** ---------- *) +val ecatch : + (lcalc, 't) marked_gexpr Bindlib.box -> + except -> + (lcalc, 't) marked_gexpr Bindlib.box -> + 't -> + (lcalc, 't) marked_gexpr Bindlib.box -val map : +val eraise : except -> 't -> (lcalc, 't) marked_gexpr Bindlib.box + +(** Manipulation of marks *) + +val no_mark : 'm mark -> 'm mark +val mark_pos : 'm mark -> Pos.t +val pos : ('a, 'm) marked -> Pos.t +val ty : ('a, typed) marked -> marked_typ +val with_ty : marked_typ -> ('a, 'm) marked -> ('a, typed) marked + +val map_mark : + (Pos.t -> Pos.t) -> (marked_typ -> marked_typ) -> 'm mark -> 'm mark + +val map_mark2 : + (Pos.t -> Pos.t -> Pos.t) -> + (typed -> typed -> marked_typ) -> + 'm mark -> + 'm mark -> + 'm mark + +val fold_marks : + (Pos.t list -> Pos.t) -> (typed list -> marked_typ) -> 'm mark list -> 'm mark + +val get_scope_body_mark : ('expr, 'm) scope_body -> 'm mark +val untype : ('a, 'm mark) marked_gexpr -> ('a, untyped mark) marked_gexpr Bindlib.box +val untype_program : (('a, 'm mark) gexpr Var.expr, 'm) program_generic -> (('a, untyped mark) gexpr Var.expr, untyped) program_generic + +(** {2 Handling of boxing} *) + +val box : ('a, 't) marked_gexpr -> ('a, 't) marked_gexpr Bindlib.box + + +(** {2 Traversal functions} *) + +val map: 'ctx -> f:('ctx -> ('a, 't1) marked_gexpr -> ('a, 't2) marked_gexpr Bindlib.box) -> (('a, 't1) gexpr, 't2) Marked.t -> ('a, 't2) marked_gexpr Bindlib.box +(** Flat (non-recursive) mapping on expressions. + + If you want to apply a map transform to an expression, you can save up + writing a painful match over all the cases of the AST. For instance, if you + want to remove all errors on empty, you can write + + {[ + let remove_error_empty = + let rec f () e = + match Marked.unmark e with + | ErrorOnEmpty e1 -> Expr.map () f e1 + | _ -> Expr.map () f e + in + f () e + ]} + + The first argument of map_expr is an optional context that you can carry + around during your map traversal. *) val map_top_down : f:(('a, 't1) marked_gexpr -> (('a, 't1) gexpr, 't2) Marked.t) -> diff --git a/compiler/surface/desugaring.ml b/compiler/surface/desugaring.ml index 108dad12..b3352fbe 100644 --- a/compiler/surface/desugaring.ml +++ b/compiler/surface/desugaring.ml @@ -16,6 +16,7 @@ the License. *) open Utils +open Shared_ast module Runtime = Runtime_ocaml.Runtime (** Translation from {!module: Surface.Ast} to {!module: Desugaring.Ast}. @@ -25,7 +26,7 @@ module Runtime = Runtime_ocaml.Runtime (** {1 Translating expressions} *) -let translate_op_kind (k : Ast.op_kind) : Dcalc.Ast.op_kind = +let translate_op_kind (k : Ast.op_kind) : op_kind = match k with | KInt -> KInt | KDec -> KRat @@ -33,7 +34,7 @@ let translate_op_kind (k : Ast.op_kind) : Dcalc.Ast.op_kind = | KDate -> KDate | KDuration -> KDuration -let translate_binop (op : Ast.binop) : Dcalc.Ast.binop = +let translate_binop (op : Ast.binop) : binop = match op with | And -> And | Or -> Or @@ -50,7 +51,7 @@ let translate_binop (op : Ast.binop) : Dcalc.Ast.binop = | Neq -> Neq | Concat -> Concat -let translate_unop (op : Ast.unop) : Dcalc.Ast.unop = +let translate_unop (op : Ast.unop) : unop = match op with Not -> Not | Minus l -> Minus (translate_op_kind l) (** The two modules below help performing operations on map with the {!type: @@ -65,7 +66,7 @@ module LiftEnumConstructorMap = Bindlib.Lift (Scopelang.Ast.EnumConstructorMap) let disambiguate_constructor (ctxt : Name_resolution.context) (constructor : (string Marked.pos option * string Marked.pos) list) - (pos : Pos.t) : Scopelang.Ast.EnumName.t * Scopelang.Ast.EnumConstructor.t = + (pos : Pos.t) : EnumName.t * EnumConstructor.t = let enum, constructor = match constructor with | [c] -> c @@ -86,7 +87,7 @@ let disambiguate_constructor in match enum with | None -> - if Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 then + if EnumMap.cardinal possible_c_uids > 1 then Errors.raise_spanned_error (Marked.get_mark constructor) "This constructor name is ambiguous, it can belong to %a. Disambiguate \ @@ -94,9 +95,9 @@ let disambiguate_constructor (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") (fun fmt (s_name, _) -> - Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t s_name)) - (Scopelang.Ast.EnumMap.bindings possible_c_uids); - Scopelang.Ast.EnumMap.choose possible_c_uids + Format.fprintf fmt "%a" EnumName.format_t s_name)) + (EnumMap.bindings possible_c_uids); + EnumMap.choose possible_c_uids | Some enum -> ( try (* The path is fully qualified *) @@ -104,7 +105,7 @@ let disambiguate_constructor Desugared.Ast.IdentMap.find (Marked.unmark enum) ctxt.enum_idmap in try - let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in + let c_uid = EnumMap.find e_uid possible_c_uids in e_uid, c_uid with Not_found -> Errors.raise_spanned_error pos "Enum %s does not contain case %s" @@ -119,7 +120,7 @@ let disambiguate_constructor Translates [expr] into its desugared equivalent. [scope] is used to disambiguate the scope and subscopes variables than occur in the expression *) let rec translate_expr - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option) (ctxt : Name_resolution.context) ((expr, pos) : Ast.expression Marked.pos) : @@ -140,11 +141,11 @@ let rec translate_expr let cases = Scopelang.Ast.EnumConstructorMap.mapi (fun c_uid' tau -> - if Scopelang.Ast.EnumConstructor.compare c_uid c_uid' <> 0 then + if EnumConstructor.compare c_uid c_uid' <> 0 then let nop_var = Desugared.Ast.Var.make "_" in Bindlib.unbox (Desugared.Ast.make_abs [| nop_var |] - (Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos)) + (Bindlib.box (Desugared.Ast.ELit (LBool false), pos)) [tau] pos) else let ctxt, binding_var = @@ -153,7 +154,7 @@ let rec translate_expr let e2 = translate_expr scope inside_definition_of ctxt e2 in Bindlib.unbox (Desugared.Ast.make_abs [| binding_var |] e2 [tau] pos)) - (Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) + (EnumMap.find enum_uid ctxt.enums) in Bindlib.box_apply (fun e1_sub -> Desugared.Ast.EMatch (e1_sub, enum_uid, cases), pos) @@ -167,7 +168,7 @@ let rec translate_expr let op_term = Marked.same_mark_as (Desugared.Ast.EOp - (Dcalc.Ast.Binop (translate_binop (Marked.unmark op)))) + (Binop (translate_binop (Marked.unmark op)))) op in Bindlib.box_apply2 @@ -176,7 +177,7 @@ let rec translate_expr | Unop (op, e) -> let op_term = Marked.same_mark_as - (Desugared.Ast.EOp (Dcalc.Ast.Unop (translate_unop (Marked.unmark op)))) + (Desugared.Ast.EOp (Unop (translate_unop (Marked.unmark op)))) op in Bindlib.box_apply @@ -186,38 +187,38 @@ let rec translate_expr let untyped_term = match l with | LNumber ((Int i, _), None) -> - Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_string i)) + Desugared.Ast.ELit (LInt (Runtime.integer_of_string i)) | LNumber ((Int i, _), Some (Percent, _)) -> Desugared.Ast.ELit - (Dcalc.Ast.LRat + (LRat Runtime.(decimal_of_string i /& decimal_of_string "100")) | LNumber ((Dec (i, f), _), None) -> Desugared.Ast.ELit - (Dcalc.Ast.LRat Runtime.(decimal_of_string (i ^ "." ^ f))) + (LRat Runtime.(decimal_of_string (i ^ "." ^ f))) | LNumber ((Dec (i, f), _), Some (Percent, _)) -> Desugared.Ast.ELit - (Dcalc.Ast.LRat + (LRat Runtime.( decimal_of_string (i ^ "." ^ f) /& decimal_of_string "100")) - | LBool b -> Desugared.Ast.ELit (Dcalc.Ast.LBool b) + | LBool b -> Desugared.Ast.ELit (LBool b) | LMoneyAmount i -> Desugared.Ast.ELit - (Dcalc.Ast.LMoney + (LMoney Runtime.( money_of_cents_integer ((integer_of_string i.money_amount_units *! integer_of_int 100) +! integer_of_string i.money_amount_cents))) | LNumber ((Int i, _), Some (Year, _)) -> Desugared.Ast.ELit - (Dcalc.Ast.LDuration + (LDuration (Runtime.duration_of_numbers (int_of_string i) 0 0)) | LNumber ((Int i, _), Some (Month, _)) -> Desugared.Ast.ELit - (Dcalc.Ast.LDuration + (LDuration (Runtime.duration_of_numbers 0 (int_of_string i) 0)) | LNumber ((Int i, _), Some (Day, _)) -> Desugared.Ast.ELit - (Dcalc.Ast.LDuration + (LDuration (Runtime.duration_of_numbers 0 0 (int_of_string i))) | LNumber ((Dec (_, _), _), Some ((Year | Month | Day), _)) -> Errors.raise_spanned_error pos @@ -230,7 +231,7 @@ let rec translate_expr Errors.raise_spanned_error pos "There is an error in this date: the day number is bigger than 31"; Desugared.Ast.ELit - (Dcalc.Ast.LDate + (LDate (try Runtime.date_of_numbers date.literal_date_year date.literal_date_month date.literal_date_day @@ -307,7 +308,7 @@ let rec translate_expr let subscope_uid : Scopelang.Ast.SubScopeName.t = Name_resolution.get_subscope_uid scope ctxt (Marked.same_mark_as y e) in - let subscope_real_uid : Scopelang.Ast.ScopeName.t = + let subscope_real_uid : ScopeName.t = Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes in let subscope_var_uid = @@ -330,19 +331,19 @@ let rec translate_expr match c with | None -> (* No constructor name was specified *) - if Scopelang.Ast.StructMap.cardinal x_possible_structs > 1 then + if StructMap.cardinal x_possible_structs > 1 then Errors.raise_spanned_error (Marked.get_mark x) "This struct field name is ambiguous, it can belong to %a. \ Disambiguate it by prefixing it with the struct name." (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") (fun fmt (s_name, _) -> - Format.fprintf fmt "%a" Scopelang.Ast.StructName.format_t + Format.fprintf fmt "%a" StructName.format_t s_name)) - (Scopelang.Ast.StructMap.bindings x_possible_structs) + (StructMap.bindings x_possible_structs) else let s_uid, f_uid = - Scopelang.Ast.StructMap.choose x_possible_structs + StructMap.choose x_possible_structs in Bindlib.box_apply (fun e -> Desugared.Ast.EStructAccess (e, f_uid, s_uid), pos) @@ -353,7 +354,7 @@ let rec translate_expr Desugared.Ast.IdentMap.find (Marked.unmark c_name) ctxt.struct_idmap in try - let f_uid = Scopelang.Ast.StructMap.find c_uid x_possible_structs in + let f_uid = StructMap.find c_uid x_possible_structs in Bindlib.box_apply (fun e -> Desugared.Ast.EStructAccess (e, f_uid, c_uid), pos) e @@ -391,7 +392,7 @@ let rec translate_expr (fun s_fields (f_name, f_e) -> let f_uid = try - Scopelang.Ast.StructMap.find s_uid + StructMap.find s_uid (Desugared.Ast.IdentMap.find (Marked.unmark f_name) ctxt.field_idmap) with Not_found -> @@ -408,19 +409,19 @@ let rec translate_expr None, Marked.get_mark (Bindlib.unbox e_field); ] "The field %a has been defined twice:" - Scopelang.Ast.StructFieldName.format_t f_uid); + StructFieldName.format_t f_uid); let f_e = translate_expr scope inside_definition_of ctxt f_e in Scopelang.Ast.StructFieldMap.add f_uid f_e s_fields) Scopelang.Ast.StructFieldMap.empty fields in - let expected_s_fields = Scopelang.Ast.StructMap.find s_uid ctxt.structs in + let expected_s_fields = StructMap.find s_uid ctxt.structs in Scopelang.Ast.StructFieldMap.iter (fun expected_f _ -> if not (Scopelang.Ast.StructFieldMap.mem expected_f s_fields) then Errors.raise_spanned_error pos "Missing field for structure %a: \"%a\"" - Scopelang.Ast.StructName.format_t s_uid - Scopelang.Ast.StructFieldName.format_t expected_f) + StructName.format_t s_uid + StructFieldName.format_t expected_f) expected_s_fields; Bindlib.box_apply @@ -443,7 +444,7 @@ let rec translate_expr | None -> if (* No constructor name was specified *) - Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 + EnumMap.cardinal possible_c_uids > 1 then Errors.raise_spanned_error (Marked.get_mark constructor) @@ -452,10 +453,10 @@ let rec translate_expr (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") (fun fmt (s_name, _) -> - Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t s_name)) - (Scopelang.Ast.EnumMap.bindings possible_c_uids) + Format.fprintf fmt "%a" EnumName.format_t s_name)) + (EnumMap.bindings possible_c_uids) else - let e_uid, c_uid = Scopelang.Ast.EnumMap.choose possible_c_uids in + let e_uid, c_uid = EnumMap.choose possible_c_uids in let payload = Option.map (translate_expr scope inside_definition_of ctxt) payload in @@ -465,7 +466,7 @@ let rec translate_expr ( (match payload with | Some e' -> e' | None -> - ( Desugared.Ast.ELit Dcalc.Ast.LUnit, + ( Desugared.Ast.ELit LUnit, Marked.get_mark constructor )), c_uid, e_uid ), @@ -478,7 +479,7 @@ let rec translate_expr Desugared.Ast.IdentMap.find (Marked.unmark enum) ctxt.enum_idmap in try - let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in + let c_uid = EnumMap.find e_uid possible_c_uids in let payload = Option.map (translate_expr scope inside_definition_of ctxt) payload in @@ -488,7 +489,7 @@ let rec translate_expr ( (match payload with | Some e' -> e' | None -> - ( Desugared.Ast.ELit Dcalc.Ast.LUnit, + ( Desugared.Ast.ELit LUnit, Marked.get_mark constructor )), c_uid, e_uid ), @@ -530,11 +531,11 @@ let rec translate_expr (Desugared.Ast.make_abs [| nop_var |] (Bindlib.box ( Desugared.Ast.ELit - (Dcalc.Ast.LBool - (Scopelang.Ast.EnumConstructor.compare c_uid c_uid' = 0)), + (LBool + (EnumConstructor.compare c_uid c_uid' = 0)), pos )) [tau] pos)) - (Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) + (EnumMap.find enum_uid ctxt.enums) in Bindlib.box_apply (fun e -> Desugared.Ast.EMatch (e, enum_uid, cases), pos) @@ -563,8 +564,8 @@ let rec translate_expr ( Desugared.Ast.EApp ( ( Desugared.Ast.EOp (match op' with - | Ast.Map -> Dcalc.Ast.Binop Dcalc.Ast.Map - | Ast.Filter -> Dcalc.Ast.Binop Dcalc.Ast.Filter + | Ast.Map -> Binop Map + | Ast.Filter -> Binop Filter | _ -> assert false (* should not happen *)), pos ), [f_pred; collection] ), @@ -583,11 +584,11 @@ let rec translate_expr in let op_kind = match pred_typ with - | Ast.Integer -> Dcalc.Ast.KInt - | Ast.Decimal -> Dcalc.Ast.KRat - | Ast.Money -> Dcalc.Ast.KMoney - | Ast.Duration -> Dcalc.Ast.KDuration - | Ast.Date -> Dcalc.Ast.KDate + | Ast.Integer -> KInt + | Ast.Decimal -> KRat + | Ast.Money -> KMoney + | Ast.Duration -> KDuration + | Ast.Date -> KDate | _ -> Errors.raise_spanned_error pos "It is impossible to compute the arg-%s of two values of type %a" @@ -595,7 +596,7 @@ let rec translate_expr Print.format_primitive_typ pred_typ in let cmp_op = - if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind + if max_or_min then Gt op_kind else Lt op_kind in let f_pred = Desugared.Ast.make_abs [| param |] @@ -619,7 +620,7 @@ let rec translate_expr (fun acc_var_e item_var_e f_pred_var_e -> ( Desugared.Ast.EIfThenElse ( ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), pos_op'), + ( (Desugared.Ast.EOp (Binop cmp_op), pos_op'), [ Desugared.Ast.EApp (f_pred_var_e, [acc_var_e]), pos; Desugared.Ast.EApp (f_pred_var_e, [item_var_e]), pos; @@ -639,7 +640,7 @@ let rec translate_expr Bindlib.box_apply3 (fun fold_f collection init -> ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), + ( (Desugared.Ast.EOp (Ternop Fold), pos), [fold_f; init; collection] ), pos )) fold_f collection init @@ -656,28 +657,28 @@ let rec translate_expr assert false (* should not happen *) | Ast.Exists -> Bindlib.box - (Desugared.Ast.ELit (Dcalc.Ast.LBool false), Marked.get_mark op') + (Desugared.Ast.ELit (LBool false), Marked.get_mark op') | Ast.Forall -> Bindlib.box - (Desugared.Ast.ELit (Dcalc.Ast.LBool true), Marked.get_mark op') + (Desugared.Ast.ELit (LBool true), Marked.get_mark op') | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> Bindlib.box - ( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), + ( Desugared.Ast.ELit (LInt (Runtime.integer_of_int 0)), Marked.get_mark op' ) | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> Bindlib.box - ( Desugared.Ast.ELit (Dcalc.Ast.LRat (Runtime.decimal_of_string "0")), + ( Desugared.Ast.ELit (LRat (Runtime.decimal_of_string "0")), Marked.get_mark op' ) | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> Bindlib.box ( Desugared.Ast.ELit - (Dcalc.Ast.LMoney + (LMoney (Runtime.money_of_cents_integer (Runtime.integer_of_int 0))), Marked.get_mark op' ) | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> Bindlib.box ( Desugared.Ast.ELit - (Dcalc.Ast.LDuration (Runtime.duration_of_numbers 0 0 0)), + (LDuration (Runtime.duration_of_numbers 0 0 0)), Marked.get_mark op' ) | Ast.Aggregate (Ast.AggregateSum t) -> Errors.raise_spanned_error pos @@ -686,24 +687,24 @@ let rec translate_expr | Ast.Aggregate (Ast.AggregateExtremum (_, _, init)) -> rec_helper init | Ast.Aggregate Ast.AggregateCount -> Bindlib.box - ( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), + ( Desugared.Ast.ELit (LInt (Runtime.integer_of_int 0)), Marked.get_mark op' ) in let acc_var = Desugared.Ast.Var.make "acc" in let acc = Desugared.Ast.make_var (acc_var, Marked.get_mark param') in let f_body = - let make_body (op : Dcalc.Ast.binop) = + let make_body (op : binop) = Bindlib.box_apply2 (fun predicate acc -> ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Binop op), Marked.get_mark op'), + ( (Desugared.Ast.EOp (Binop op), Marked.get_mark op'), [acc; predicate] ), pos )) (translate_expr scope inside_definition_of ctxt predicate) acc in let make_extr_body - (cmp_op : Dcalc.Ast.binop) + (cmp_op : binop) (t : Scopelang.Ast.typ Marked.pos) = let tmp_var = Desugared.Ast.Var.make "tmp" in let tmp = Desugared.Ast.make_var (tmp_var, Marked.get_mark param') in @@ -713,7 +714,7 @@ let rec translate_expr (fun acc tmp -> ( Desugared.Ast.EIfThenElse ( ( Desugared.Ast.EApp - ( ( Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), + ( ( Desugared.Ast.EOp (Binop cmp_op), Marked.get_mark op' ), [acc; tmp] ), pos ), @@ -725,35 +726,35 @@ let rec translate_expr match Marked.unmark op' with | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> assert false (* should not happen *) - | Ast.Exists -> make_body Dcalc.Ast.Or - | Ast.Forall -> make_body Dcalc.Ast.And + | Ast.Exists -> make_body Or + | Ast.Forall -> make_body And | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KInt) + make_body (Add KInt) | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KRat) + make_body (Add KRat) | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KMoney) + make_body (Add KMoney) | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KDuration) + make_body (Add KDuration) | Ast.Aggregate (Ast.AggregateSum _) -> assert false (* should not happen *) | Ast.Aggregate (Ast.AggregateExtremum (max_or_min, t, _)) -> let op_kind, typ = match t with - | Ast.Integer -> Dcalc.Ast.KInt, (Scopelang.Ast.TLit TInt, pos) - | Ast.Decimal -> Dcalc.Ast.KRat, (Scopelang.Ast.TLit TRat, pos) - | Ast.Money -> Dcalc.Ast.KMoney, (Scopelang.Ast.TLit TMoney, pos) + | Ast.Integer -> KInt, (Scopelang.Ast.TLit TInt, pos) + | Ast.Decimal -> KRat, (Scopelang.Ast.TLit TRat, pos) + | Ast.Money -> KMoney, (Scopelang.Ast.TLit TMoney, pos) | Ast.Duration -> - Dcalc.Ast.KDuration, (Scopelang.Ast.TLit TDuration, pos) - | Ast.Date -> Dcalc.Ast.KDate, (Scopelang.Ast.TLit TDate, pos) + KDuration, (Scopelang.Ast.TLit TDuration, pos) + | Ast.Date -> KDate, (Scopelang.Ast.TLit TDate, pos) | _ -> Errors.raise_spanned_error pos - "It is impossible to compute the %s of two values of type %a" + "ssible to compute the %s of two values of type %a" (if max_or_min then "max" else "min") Print.format_primitive_typ t in let cmp_op = - if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind + if max_or_min then Gt op_kind else Lt op_kind in make_extr_body cmp_op typ | Ast.Aggregate Ast.AggregateCount -> @@ -763,12 +764,12 @@ let rec translate_expr ( predicate, ( Desugared.Ast.EApp ( ( Desugared.Ast.EOp - (Dcalc.Ast.Binop (Dcalc.Ast.Add Dcalc.Ast.KInt)), + (Binop (Add KInt)), Marked.get_mark op' ), [ acc; ( Desugared.Ast.ELit - (Dcalc.Ast.LInt (Runtime.integer_of_int 1)), + (LInt (Runtime.integer_of_int 1)), Marked.get_mark predicate ); ] ), pos ), @@ -778,7 +779,7 @@ let rec translate_expr acc in let f = - let make_f (t : Dcalc.Ast.typ_lit) = + let make_f (t : typ_lit) = Bindlib.box_apply (fun binder -> ( Desugared.Ast.EAbs @@ -796,29 +797,29 @@ let rec translate_expr match Marked.unmark op' with | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> assert false (* should not happen *) - | Ast.Exists -> make_f Dcalc.Ast.TBool - | Ast.Forall -> make_f Dcalc.Ast.TBool + | Ast.Exists -> make_f TBool + | Ast.Forall -> make_f TBool | Ast.Aggregate (Ast.AggregateSum Ast.Integer) | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Integer, _)) -> - make_f Dcalc.Ast.TInt + make_f TInt | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Decimal, _)) -> - make_f Dcalc.Ast.TRat + make_f TRat | Ast.Aggregate (Ast.AggregateSum Ast.Money) | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Money, _)) -> - make_f Dcalc.Ast.TMoney + make_f TMoney | Ast.Aggregate (Ast.AggregateSum Ast.Duration) | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Duration, _)) -> - make_f Dcalc.Ast.TDuration + make_f TDuration | Ast.Aggregate (Ast.AggregateSum _) | Ast.Aggregate (Ast.AggregateExtremum _) -> assert false (* should not happen *) - | Ast.Aggregate Ast.AggregateCount -> make_f Dcalc.Ast.TInt + | Ast.Aggregate Ast.AggregateCount -> make_f TInt in Bindlib.box_apply3 (fun f collection init -> ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), + ( (Desugared.Ast.EOp (Ternop Fold), pos), [f; init; collection] ), pos )) f collection init @@ -826,17 +827,17 @@ let rec translate_expr let param_var = Desugared.Ast.Var.make "collection_member" in let param = Desugared.Ast.make_var (param_var, pos) in let collection = rec_helper collection in - let init = Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos) in + let init = Bindlib.box (Desugared.Ast.ELit (LBool false), pos) in let acc_var = Desugared.Ast.Var.make "acc" in let acc = Desugared.Ast.make_var (acc_var, pos) in let f_body = Bindlib.box_apply3 (fun member acc param -> ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Or), pos), + ( (Desugared.Ast.EOp (Binop Or), pos), [ ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Eq), pos), + ( (Desugared.Ast.EOp (Binop Eq), pos), [member; param] ), pos ); acc; @@ -851,7 +852,7 @@ let rec translate_expr ( Desugared.Ast.EAbs ( binder, [ - Scopelang.Ast.TLit Dcalc.Ast.TBool, pos; + Scopelang.Ast.TLit TBool, pos; Scopelang.Ast.TAny, pos; ] ), pos )) @@ -860,42 +861,42 @@ let rec translate_expr Bindlib.box_apply3 (fun f collection init -> ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), + ( (Desugared.Ast.EOp (Ternop Fold), pos), [f; init; collection] ), pos )) f collection init | Builtin IntToDec -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.IntToRat), pos) + Bindlib.box (Desugared.Ast.EOp (Unop IntToRat), pos) | Builtin MoneyToDec -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.MoneyToRat), pos) + Bindlib.box (Desugared.Ast.EOp (Unop MoneyToRat), pos) | Builtin DecToMoney -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RatToMoney), pos) + Bindlib.box (Desugared.Ast.EOp (Unop RatToMoney), pos) | Builtin Cardinal -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.Length), pos) + Bindlib.box (Desugared.Ast.EOp (Unop Length), pos) | Builtin GetDay -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetDay), pos) + Bindlib.box (Desugared.Ast.EOp (Unop GetDay), pos) | Builtin GetMonth -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetMonth), pos) + Bindlib.box (Desugared.Ast.EOp (Unop GetMonth), pos) | Builtin GetYear -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetYear), pos) + Bindlib.box (Desugared.Ast.EOp (Unop GetYear), pos) | Builtin FirstDayOfMonth -> Bindlib.box - (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.FirstDayOfMonth), pos) + (Desugared.Ast.EOp (Unop FirstDayOfMonth), pos) | Builtin LastDayOfMonth -> Bindlib.box - (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.LastDayOfMonth), pos) + (Desugared.Ast.EOp (Unop LastDayOfMonth), pos) | Builtin RoundMoney -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundMoney), pos) + Bindlib.box (Desugared.Ast.EOp (Unop RoundMoney), pos) | Builtin RoundDecimal -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundDecimal), pos) + Bindlib.box (Desugared.Ast.EOp (Unop RoundDecimal), pos) and disambiguate_match_and_build_expression - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option) (ctxt : Name_resolution.context) (cases : Ast.match_case Marked.pos list) : Desugared.Ast.expr Marked.pos Bindlib.box Scopelang.Ast.EnumConstructorMap.t - * Scopelang.Ast.EnumName.t = + * EnumName.t = let create_var = function | None -> ctxt, Desugared.Ast.Var.make "_" | Some param -> @@ -903,8 +904,8 @@ and disambiguate_match_and_build_expression ctxt, param_var in let bind_case_body - (c_uid : Dcalc.Ast.EnumConstructor.t) - (e_uid : Dcalc.Ast.EnumName.t) + (c_uid : EnumConstructor.t) + (e_uid : EnumName.t) (ctxt : Name_resolution.context) (case_body : ('a * Pos.t) Bindlib.box) (e_binder : @@ -917,7 +918,7 @@ and disambiguate_match_and_build_expression ( e_binder, [ Scopelang.Ast.EnumConstructorMap.find c_uid - (Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums); + (EnumMap.find e_uid ctxt.Name_resolution.enums); ] )) case_body) e_binder case_body @@ -940,8 +941,8 @@ and disambiguate_match_and_build_expression (Marked.get_mark case.Ast.match_case_pattern) "This case matches a constructor of enumeration %a but previous \ case were matching constructors of enumeration %a" - Scopelang.Ast.EnumName.format_t e_uid - Scopelang.Ast.EnumName.format_t e_uid' + EnumName.format_t e_uid + EnumName.format_t e_uid' in (match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with | None -> () @@ -952,7 +953,7 @@ and disambiguate_match_and_build_expression None, Marked.get_mark (Bindlib.unbox e_case); ] "The constructor %a has been matched twice:" - Scopelang.Ast.EnumConstructor.format_t c_uid); + EnumConstructor.format_t c_uid); let ctxt, param_var = create_var (Option.map Marked.unmark binding) in let case_body = translate_expr scope inside_definition_of ctxt case.Ast.match_case_expr @@ -983,7 +984,7 @@ and disambiguate_match_and_build_expression | Some e_uid -> if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err (); let missing_constructors = - Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums + EnumMap.find e_uid ctxt.Name_resolution.enums |> Scopelang.Ast.EnumConstructorMap.filter_map (fun c_uid _ -> match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d @@ -995,7 +996,7 @@ and disambiguate_match_and_build_expression Errors.format_spanned_warning case_pos "Unreachable match case, all constructors of the enumeration %a \ are already specified" - Scopelang.Ast.EnumName.format_t e_uid; + EnumName.format_t e_uid; (* The current used strategy is to replace the wildcard branch: match foo with | Case1 x -> x @@ -1048,7 +1049,7 @@ let merge_conditions match precond, cond with | Some precond, Some cond -> let op_term = - ( Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.And), + ( Desugared.Ast.EOp (Binop And), Marked.get_mark (Bindlib.unbox cond) ) in Bindlib.box_apply2 @@ -1061,13 +1062,13 @@ let merge_conditions precond | None, Some cond -> cond | None, None -> - Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool true), default_pos) + Bindlib.box (Desugared.Ast.ELit (LBool true), default_pos) (** Translates a surface definition into condition into a desugared {!type: Desugared.Ast.rule} *) let process_default (ctxt : Name_resolution.context) - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (def_key : Desugared.Ast.ScopeDef.t Marked.pos) (rule_id : Desugared.Ast.RuleName.t) (param_uid : Desugared.Ast.Var.t Marked.pos option) @@ -1111,7 +1112,7 @@ let process_default disambiguation *) let process_def (precond : Desugared.Ast.expr Marked.pos Bindlib.box option) - (scope_uid : Scopelang.Ast.ScopeName.t) + (scope_uid : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Desugared.Ast.program) (def : Ast.definition) : Desugared.Ast.program = @@ -1200,7 +1201,7 @@ let process_def (** Translates a {!type: Surface.Ast.rule} from the surface language *) let process_rule (precond : Desugared.Ast.expr Marked.pos Bindlib.box option) - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Desugared.Ast.program) (rule : Ast.rule) : Desugared.Ast.program = @@ -1210,7 +1211,7 @@ let process_rule (** Translates assertions *) let process_assert (precond : Desugared.Ast.expr Marked.pos Bindlib.box option) - (scope_uid : Scopelang.Ast.ScopeName.t) + (scope_uid : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Desugared.Ast.program) (ass : Ast.assertion) : Desugared.Ast.program = @@ -1236,7 +1237,7 @@ let process_assert ( Desugared.Ast.EIfThenElse ( precond, ass, - Marked.same_mark_as (Desugared.Ast.ELit (Dcalc.Ast.LBool true)) + Marked.same_mark_as (Desugared.Ast.ELit (LBool true)) precond ), Marked.get_mark precond )) precond ass @@ -1254,7 +1255,7 @@ let process_assert (** Translates a surface definition, rule or assertion *) let process_scope_use_item (precond : Ast.expression Marked.pos option) - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Desugared.Ast.program) (item : Ast.scope_use_item Marked.pos) : Desugared.Ast.program = @@ -1270,7 +1271,7 @@ let process_scope_use_item (* If this is an unlabeled exception, ensures that it has a unique default definition *) let check_unlabeled_exception - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (ctxt : Name_resolution.context) (item : Ast.scope_use_item Marked.pos) : unit = let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in @@ -1353,10 +1354,10 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : let empty_prgm = { Desugared.Ast.program_structs = - Scopelang.Ast.StructMap.map Scopelang.Ast.StructFieldMap.bindings + StructMap.map Scopelang.Ast.StructFieldMap.bindings ctxt.Name_resolution.structs; Desugared.Ast.program_enums = - Scopelang.Ast.EnumMap.map Scopelang.Ast.EnumConstructorMap.bindings + EnumMap.map Scopelang.Ast.EnumConstructorMap.bindings ctxt.Name_resolution.enums; Desugared.Ast.program_scopes = Scopelang.Ast.ScopeMap.mapi diff --git a/compiler/surface/name_resolution.ml b/compiler/surface/name_resolution.ml index 579fc97a..c5497ce2 100644 --- a/compiler/surface/name_resolution.ml +++ b/compiler/surface/name_resolution.ml @@ -19,6 +19,7 @@ lexical scopes into account *) open Utils +open Shared_ast (** {1 Name resolution context} *) @@ -41,7 +42,7 @@ type scope_context = { (** What is the default rule to refer to for unnamed exceptions, if any *) sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t; (** Sub-scopes variables *) - sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; + sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t; (** To what scope sub-scopes refer to? *) } (** Inside a scope, we distinguish between the variables and the subscopes. *) @@ -64,27 +65,27 @@ type context = { local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t; (** Inside a definition, local variables can be introduced by functions arguments or pattern matching *) - scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t; + scope_idmap : ScopeName.t Desugared.Ast.IdentMap.t; (** The names of the scopes *) - struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t; + struct_idmap : StructName.t Desugared.Ast.IdentMap.t; (** The names of the structs *) field_idmap : - Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t + StructFieldName.t StructMap.t Desugared.Ast.IdentMap.t; (** The names of the struct fields. Names of fields can be shared between different structs *) - enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t; + enum_idmap : EnumName.t Desugared.Ast.IdentMap.t; (** The names of the enums *) constructor_idmap : - Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t + EnumConstructor.t EnumMap.t Desugared.Ast.IdentMap.t; (** The names of the enum constructors. Constructor names can be shared between different enums *) scopes : scope_context Scopelang.Ast.ScopeMap.t; (** For each scope, its context *) - structs : struct_context Scopelang.Ast.StructMap.t; + structs : struct_context StructMap.t; (** For each struct, its context *) - enums : enum_context Scopelang.Ast.EnumMap.t; + enums : enum_context EnumMap.t; (** For each enum, its context *) var_typs : var_sig Desugared.Ast.ScopeVarMap.t; (** The signatures of each scope variable declared *) @@ -120,7 +121,7 @@ let get_var_io (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) : (** Get the variable uid inside the scope given in argument *) let get_var_uid - (scope_uid : Scopelang.Ast.ScopeName.t) + (scope_uid : ScopeName.t) (ctxt : context) ((x, pos) : ident Marked.pos) : Desugared.Ast.ScopeVar.t = let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in @@ -128,13 +129,13 @@ let get_var_uid | None -> raise_unknown_identifier (Format.asprintf "for a variable of scope %a" - Scopelang.Ast.ScopeName.format_t scope_uid) + ScopeName.format_t scope_uid) (x, pos) | Some uid -> uid (** Get the subscope uid inside the scope given in argument *) let get_subscope_uid - (scope_uid : Scopelang.Ast.ScopeName.t) + (scope_uid : ScopeName.t) (ctxt : context) ((y, pos) : ident Marked.pos) : Scopelang.Ast.SubScopeName.t = let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in @@ -145,7 +146,7 @@ let get_subscope_uid (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the subscopes of [scope_uid]. *) let is_subscope_uid - (scope_uid : Scopelang.Ast.ScopeName.t) + (scope_uid : ScopeName.t) (ctxt : context) (y : ident) : bool = let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in @@ -155,7 +156,7 @@ let is_subscope_uid let belongs_to (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) - (scope_uid : Scopelang.Ast.ScopeName.t) : bool = + (scope_uid : ScopeName.t) : bool = let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in Desugared.Ast.IdentMap.exists (fun _ var_uid -> Desugared.Ast.ScopeVar.compare uid var_uid = 0) @@ -183,7 +184,7 @@ let is_def_cond (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : bool = (** Process a subscope declaration *) let process_subscope_decl - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (ctxt : context) (decl : Ast.scope_decl_context_scope) : context = let name, name_pos = decl.scope_decl_context_scope_name in @@ -277,7 +278,7 @@ let process_type (ctxt : context) ((typ, typ_pos) : Ast.typ Marked.pos) : (** Process data declaration *) let process_data_decl - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (ctxt : context) (decl : Ast.scope_decl_context_data) : context = (* First check the type of the context data *) @@ -330,7 +331,7 @@ let process_data_decl (** Process an item declaration *) let process_item_decl - (scope : Scopelang.Ast.ScopeName.t) + (scope : ScopeName.t) (ctxt : context) (decl : Ast.scope_decl_context_item) : context = match decl with @@ -372,7 +373,7 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = List.fold_left (fun ctxt (fdecl, _) -> let f_uid = - Scopelang.Ast.StructFieldName.fresh fdecl.Ast.struct_decl_field_name + StructFieldName.fresh fdecl.Ast.struct_decl_field_name in let ctxt = { @@ -382,16 +383,16 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = (Marked.unmark fdecl.Ast.struct_decl_field_name) (fun uids -> match uids with - | None -> Some (Scopelang.Ast.StructMap.singleton s_uid f_uid) + | None -> Some (StructMap.singleton s_uid f_uid) | Some uids -> - Some (Scopelang.Ast.StructMap.add s_uid f_uid uids)) + Some (StructMap.add s_uid f_uid uids)) ctxt.field_idmap; } in { ctxt with structs = - Scopelang.Ast.StructMap.update s_uid + StructMap.update s_uid (fun fields -> match fields with | None -> @@ -421,7 +422,7 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = List.fold_left (fun ctxt (cdecl, cdecl_pos) -> let c_uid = - Scopelang.Ast.EnumConstructor.fresh cdecl.Ast.enum_decl_case_name + EnumConstructor.fresh cdecl.Ast.enum_decl_case_name in let ctxt = { @@ -431,15 +432,15 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = (Marked.unmark cdecl.Ast.enum_decl_case_name) (fun uids -> match uids with - | None -> Some (Scopelang.Ast.EnumMap.singleton e_uid c_uid) - | Some uids -> Some (Scopelang.Ast.EnumMap.add e_uid c_uid uids)) + | None -> Some (EnumMap.singleton e_uid c_uid) + | Some uids -> Some (EnumMap.add e_uid c_uid uids)) ctxt.constructor_idmap; } in { ctxt with enums = - Scopelang.Ast.EnumMap.update e_uid + EnumMap.update e_uid (fun cases -> let typ = match cdecl.Ast.enum_decl_case_typ with @@ -475,10 +476,10 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with | Some use -> raise_already_defined_error - (Scopelang.Ast.ScopeName.get_info use) + (ScopeName.get_info use) name pos "scope" | None -> - let scope_uid = Scopelang.Ast.ScopeName.fresh (name, pos) in + let scope_uid = ScopeName.fresh (name, pos) in { ctxt with scope_idmap = Desugared.Ast.IdentMap.add name scope_uid ctxt.scope_idmap; @@ -497,10 +498,10 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : match Desugared.Ast.IdentMap.find_opt name ctxt.struct_idmap with | Some use -> raise_already_defined_error - (Scopelang.Ast.StructName.get_info use) + (StructName.get_info use) name pos "struct" | None -> - let s_uid = Scopelang.Ast.StructName.fresh sdecl.struct_decl_name in + let s_uid = StructName.fresh sdecl.struct_decl_name in { ctxt with struct_idmap = @@ -513,10 +514,10 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : match Desugared.Ast.IdentMap.find_opt name ctxt.enum_idmap with | Some use -> raise_already_defined_error - (Scopelang.Ast.EnumName.get_info use) + (EnumName.get_info use) name pos "enum" | None -> - let e_uid = Scopelang.Ast.EnumName.fresh edecl.enum_decl_name in + let e_uid = EnumName.fresh edecl.enum_decl_name in { ctxt with @@ -561,7 +562,7 @@ let rec process_law_structure let get_def_key (name : Ast.qident) (state : Ast.ident Marked.pos option) - (scope_uid : Scopelang.Ast.ScopeName.t) + (scope_uid : ScopeName.t) (ctxt : context) (default_pos : Pos.t) : Desugared.Ast.ScopeDef.t = let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in @@ -603,7 +604,7 @@ let get_def_key let subscope_uid : Scopelang.Ast.SubScopeName.t = get_subscope_uid scope_uid ctxt y in - let subscope_real_uid : Scopelang.Ast.ScopeName.t = + let subscope_real_uid : ScopeName.t = Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes in let x_uid = get_var_uid subscope_real_uid ctxt x in @@ -616,7 +617,7 @@ let get_def_key let process_definition (ctxt : context) - (s_name : Scopelang.Ast.ScopeName.t) + (s_name : ScopeName.t) (d : Ast.definition) : context = (* We update the definition context inside the big context *) { @@ -725,7 +726,7 @@ let process_definition } let process_scope_use_item - (s_name : Scopelang.Ast.ScopeName.t) + (s_name : ScopeName.t) (ctxt : context) (sitem : Ast.scope_use_item Marked.pos) : context = match Marked.unmark sitem with @@ -764,10 +765,10 @@ let form_context (prgm : Ast.program) : context = scope_idmap = Desugared.Ast.IdentMap.empty; scopes = Scopelang.Ast.ScopeMap.empty; var_typs = Desugared.Ast.ScopeVarMap.empty; - structs = Scopelang.Ast.StructMap.empty; + structs = StructMap.empty; struct_idmap = Desugared.Ast.IdentMap.empty; field_idmap = Desugared.Ast.IdentMap.empty; - enums = Scopelang.Ast.EnumMap.empty; + enums = EnumMap.empty; enum_idmap = Desugared.Ast.IdentMap.empty; constructor_idmap = Desugared.Ast.IdentMap.empty; } diff --git a/compiler/surface/name_resolution.mli b/compiler/surface/name_resolution.mli index 21520a35..b68d5ceb 100644 --- a/compiler/surface/name_resolution.mli +++ b/compiler/surface/name_resolution.mli @@ -19,6 +19,7 @@ lexical scopes into account *) open Utils +open Shared_ast (** {1 Name resolution context} *) @@ -41,7 +42,7 @@ type scope_context = { (** What is the default rule to refer to for unnamed exceptions, if any *) sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t; (** Sub-scopes variables *) - sub_scopes : Scopelang.Ast.ScopeName.t Scopelang.Ast.SubScopeMap.t; + sub_scopes : ScopeName.t Scopelang.Ast.SubScopeMap.t; (** To what scope sub-scopes refer to? *) } (** Inside a scope, we distinguish between the variables and the subscopes. *) @@ -64,27 +65,27 @@ type context = { local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t; (** Inside a definition, local variables can be introduced by functions arguments or pattern matching *) - scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t; + scope_idmap : ScopeName.t Desugared.Ast.IdentMap.t; (** The names of the scopes *) - struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t; + struct_idmap : StructName.t Desugared.Ast.IdentMap.t; (** The names of the structs *) field_idmap : - Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t + StructFieldName.t StructMap.t Desugared.Ast.IdentMap.t; (** The names of the struct fields. Names of fields can be shared between different structs *) - enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t; + enum_idmap : EnumName.t Desugared.Ast.IdentMap.t; (** The names of the enums *) constructor_idmap : - Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t + EnumConstructor.t EnumMap.t Desugared.Ast.IdentMap.t; (** The names of the enum constructors. Constructor names can be shared between different enums *) scopes : scope_context Scopelang.Ast.ScopeMap.t; (** For each scope, its context *) - structs : struct_context Scopelang.Ast.StructMap.t; + structs : struct_context StructMap.t; (** For each struct, its context *) - enums : enum_context Scopelang.Ast.EnumMap.t; + enums : enum_context EnumMap.t; (** For each enum, its context *) var_typs : var_sig Desugared.Ast.ScopeVarMap.t; (** The signatures of each scope variable declared *) @@ -110,25 +111,25 @@ val get_var_io : context -> Desugared.Ast.ScopeVar.t -> Ast.scope_decl_context_io val get_var_uid : - Scopelang.Ast.ScopeName.t -> + ScopeName.t -> context -> ident Marked.pos -> Desugared.Ast.ScopeVar.t (** Get the variable uid inside the scope given in argument *) val get_subscope_uid : - Scopelang.Ast.ScopeName.t -> + ScopeName.t -> context -> ident Marked.pos -> Scopelang.Ast.SubScopeName.t (** Get the subscope uid inside the scope given in argument *) -val is_subscope_uid : Scopelang.Ast.ScopeName.t -> context -> ident -> bool +val is_subscope_uid : ScopeName.t -> context -> ident -> bool (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the subscopes of [scope_uid]. *) val belongs_to : - context -> Desugared.Ast.ScopeVar.t -> Scopelang.Ast.ScopeName.t -> bool + context -> Desugared.Ast.ScopeVar.t -> ScopeName.t -> bool (** Checks if the var_uid belongs to the scope scope_uid *) val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ Marked.pos @@ -143,7 +144,7 @@ val add_def_local_var : context -> ident -> context * Desugared.Ast.Var.t val get_def_key : Ast.qident -> Ast.ident Marked.pos option -> - Scopelang.Ast.ScopeName.t -> + ScopeName.t -> context -> Pos.t -> Desugared.Ast.ScopeDef.t diff --git a/compiler/verification/conditions.ml b/compiler/verification/conditions.ml index 67299376..09e30d94 100644 --- a/compiler/verification/conditions.ml +++ b/compiler/verification/conditions.ml @@ -16,6 +16,7 @@ the License. *) open Utils +open Shared_ast open Dcalc open Ast @@ -92,7 +93,7 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : typed marked_expr) | ErrorOnEmpty d -> d (* input subscope variables and non-input scope variable *) | _ -> - Errors.raise_spanned_error (pos e) + Errors.raise_spanned_error (Expr.pos e) "Internal error: this expression does not have the structure expected by \ the VC generator:\n\ %a" @@ -382,7 +383,7 @@ let rec generate_verification_conditions_scopes | ScopeDef scope_def -> let is_selected_scope = match s with - | Some s when Dcalc.Ast.ScopeName.compare s scope_def.scope_name = 0 -> + | Some s when ScopeName.compare s scope_def.scope_name = 0 -> true | None -> true | _ -> false @@ -416,7 +417,7 @@ let rec generate_verification_conditions_scopes let generate_verification_conditions (p : 'm program) - (s : Dcalc.Ast.ScopeName.t option) : verification_condition list = + (s : ScopeName.t option) : verification_condition list = let vcs = generate_verification_conditions_scopes p.decl_ctx p.scopes s in (* We sort this list by scope name and then variable name to ensure consistent output for testing*) diff --git a/compiler/verification/conditions.mli b/compiler/verification/conditions.mli index 924e0676..33ada9ca 100644 --- a/compiler/verification/conditions.mli +++ b/compiler/verification/conditions.mli @@ -29,21 +29,21 @@ type verification_condition_kind = a conflict error *) type verification_condition = { - vc_guard : Dcalc.Ast.typed Dcalc.Ast.marked_expr; + vc_guard : typed Dcalc.Ast.marked_expr; (** This expression should have type [bool]*) vc_kind : verification_condition_kind; - vc_scope : Dcalc.Ast.ScopeName.t; + vc_scope : ScopeName.t; vc_variable : typed Dcalc.Ast.var Marked.pos; vc_free_vars_typ : - (typed Dcalc.Ast.expr, Dcalc.Ast.typ Marked.pos) Var.Map.t; + (typed Dcalc.Ast.expr, typ Marked.pos) Var.Map.t; (** Types of the locally free variables in [vc_guard]. The types of other free variables linked to scope variables can be obtained with [Dcalc.Ast.variable_types]. *) } val generate_verification_conditions : - Dcalc.Ast.typed Dcalc.Ast.program -> - Dcalc.Ast.ScopeName.t option -> + typed Dcalc.Ast.program -> + ScopeName.t option -> verification_condition list (** [generate_verification_conditions p None] will generate the verification conditions for all the variables of all the scopes of the program [p], while diff --git a/compiler/verification/io.ml b/compiler/verification/io.ml index e94bd5d6..f6b20ea7 100644 --- a/compiler/verification/io.ml +++ b/compiler/verification/io.ml @@ -16,6 +16,7 @@ the License. *) open Utils +open Shared_ast open Dcalc.Ast module type Backend = sig @@ -73,7 +74,7 @@ module type BackendIO = sig string val encode_and_check_vc : - Dcalc.Ast.decl_ctx -> + decl_ctx -> Conditions.verification_condition * vc_encoding_result -> unit end @@ -161,7 +162,7 @@ module MakeBackendIO (B : Backend) = struct let vc, z3_vc = vc in Cli.debug_print "For this variable:\n%s\n" - (Pos.retrieve_loc_text (pos vc.Conditions.vc_guard)); + (Pos.retrieve_loc_text (Expr.pos vc.Conditions.vc_guard)); Cli.debug_format "This verification condition was generated for %a:@\n%a" (Cli.format_with_style [ANSITerminal.yellow]) (match vc.vc_kind with diff --git a/compiler/verification/io.mli b/compiler/verification/io.mli index 6a37e07f..7283b2ad 100644 --- a/compiler/verification/io.mli +++ b/compiler/verification/io.mli @@ -26,8 +26,8 @@ module type Backend = sig type backend_context val make_context : - Dcalc.Ast.decl_ctx -> - (typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t -> + decl_ctx -> + (typed Dcalc.Ast.expr, typ Utils.Marked.pos) Var.Map.t -> backend_context type vc_encoding @@ -53,8 +53,8 @@ module type BackendIO = sig type backend_context val make_context : - Dcalc.Ast.decl_ctx -> - (typed Dcalc.Ast.expr, Dcalc.Ast.typ Utils.Marked.pos) Var.Map.t -> + decl_ctx -> + (typed Dcalc.Ast.expr, typ Utils.Marked.pos) Var.Map.t -> backend_context type vc_encoding @@ -79,7 +79,7 @@ module type BackendIO = sig string val encode_and_check_vc : - Dcalc.Ast.decl_ctx -> + decl_ctx -> Conditions.verification_condition * vc_encoding_result -> unit end diff --git a/compiler/verification/solver.ml b/compiler/verification/solver.ml index c47bf8d8..6056ab1e 100644 --- a/compiler/verification/solver.ml +++ b/compiler/verification/solver.ml @@ -20,7 +20,7 @@ open Dcalc.Ast expressions [vcs] corresponding to verification conditions that must be discharged by Z3, and attempts to solve them **) let solve_vc - (decl_ctx : decl_ctx) + (decl_ctx : Shared_ast.decl_ctx) (vcs : Conditions.verification_condition list) : unit = (* Right now we only use the Z3 backend but the functorial interface should make it easy to mix and match different proof backends. *) diff --git a/compiler/verification/solver.mli b/compiler/verification/solver.mli index 8c972cb1..5ea951cb 100644 --- a/compiler/verification/solver.mli +++ b/compiler/verification/solver.mli @@ -17,4 +17,4 @@ (** Solves verification conditions using various proof backends *) val solve_vc : - Dcalc.Ast.decl_ctx -> Conditions.verification_condition list -> unit + Shared_ast.decl_ctx -> Conditions.verification_condition list -> unit diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index cb3b2a68..1c79a6e0 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -15,6 +15,7 @@ the License. *) open Utils +open Shared_ast open Dcalc open Ast open Z3 @@ -428,7 +429,7 @@ let rec translate_op (Print.format_expr ctx.ctx_decl) ( EApp ( (EOp op, Untyped { pos = Pos.no_pos }), - List.map (fun arg -> Bindlib.unbox (untype_expr arg)) args ), + List.map (fun arg -> Bindlib.unbox (Shared_ast.Expr.untype arg)) args ), Untyped { pos = Pos.no_pos } )) in @@ -520,7 +521,7 @@ let rec translate_op ( EApp ( (EOp op, Untyped { pos = Pos.no_pos }), List.map - (fun arg -> arg |> untype_expr |> Bindlib.unbox) + (fun arg -> arg |> Shared_ast.Expr.untype |> Bindlib.unbox) args ), Untyped { pos = Pos.no_pos } )) in @@ -572,7 +573,7 @@ let rec translate_op ( EApp ( (EOp op, Untyped { pos = Pos.no_pos }), List.map - (fun arg -> arg |> untype_expr |> Bindlib.unbox) + (fun arg -> arg |> Shared_ast.Expr.untype |> Bindlib.unbox) args ), Untyped { pos = Pos.no_pos } )) in diff --git a/dune-project b/dune-project index 1b7867ee..45993555 100644 --- a/dune-project +++ b/dune-project @@ -1,4 +1,4 @@ -(lang dune 2.8) +(lang dune 3.0) (name catala) diff --git a/french_law/ocaml/bench.ml b/french_law/ocaml/bench.ml index b2734d96..44e73e97 100644 --- a/french_law/ocaml/bench.ml +++ b/french_law/ocaml/bench.ml @@ -115,7 +115,7 @@ let run_test () = exit (-1) | Runtime.AssertionFailed _ -> () -let bench = +let _bench = Random.init (int_of_float (Unix.time ())); let num_iter = 10000 in let _ = diff --git a/french_law/ocaml/dune b/french_law/ocaml/dune index 734bae77..10350009 100644 --- a/french_law/ocaml/dune +++ b/french_law/ocaml/dune @@ -12,7 +12,7 @@ (preprocess (pps js_of_ocaml-ppx)) (js_of_ocaml - (flags --disable=shortvar --opt 3)) + (flags :standard --disable=shortvar --opt 3)) ; We need to disable shortvar because ; otherwise Webpack wrongly minifies ; the library and it gives bugs.