diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index d03df46c..35aa454d 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -155,6 +155,116 @@ type expr = | EArray of expr Pos.marked list | ErrorOnEmpty of expr Pos.marked +module Expr = struct + type t = expr + + (** Syntactic comparison, up to locations and alpha-renaming *) + let rec compare e1 e2 = + let rec list_compare cmp l1 l2 = + (* List.compare is available from OCaml 4.12 on *) + match l1, l2 with + | [], [] -> 0 + | [], _ :: _ -> -1 + | _ :: _, [] -> 1 + | a1 :: l1, a2 :: l2 -> + let c = cmp a1 a2 in + if c <> 0 then c else list_compare cmp l1 l2 + in + match e1, e2 with + | ELocation _, ELocation _ -> 0 + | EVar (v1, _), EVar (v2, _) -> Bindlib.compare_vars v1 v2 + | EStruct (name1, field_map1), EStruct (name2, field_map2) -> ( + match Scopelang.Ast.StructName.compare name1 name2 with + | 0 -> + Scopelang.Ast.StructFieldMap.compare + (Pos.compare_marked compare) + field_map1 field_map2 + | n -> n) + | ( EStructAccess ((e1, _), field_name1, struct_name1), + EStructAccess ((e2, _), field_name2, struct_name2) ) -> ( + match compare e1 e2 with + | 0 -> ( + match Scopelang.Ast.StructFieldName.compare field_name1 field_name2 with + | 0 -> Scopelang.Ast.StructName.compare struct_name1 struct_name2 + | n -> n) + | n -> n) + | EEnumInj ((e1, _), cstr1, name1), EEnumInj ((e2, _), cstr2, name2) -> ( + match compare e1 e2 with + | 0 -> ( + match Scopelang.Ast.EnumName.compare name1 name2 with + | 0 -> Scopelang.Ast.EnumConstructor.compare cstr1 cstr2 + | n -> n) + | n -> n) + | EMatch ((e1, _), name1, emap1), EMatch ((e2, _), name2, emap2) -> ( + match compare e1 e2 with + | 0 -> ( + match Scopelang.Ast.EnumName.compare name1 name2 with + | 0 -> + Scopelang.Ast.EnumConstructorMap.compare + (Pos.compare_marked compare) + emap1 emap2 + | n -> n) + | n -> n) + | ELit l1, ELit l2 -> Stdlib.compare l1 l2 + | EAbs ((binder1, _), typs1), EAbs ((binder2, _), typs2) -> ( + match + list_compare (Pos.compare_marked Scopelang.Ast.Typ.compare) typs1 typs2 + with + | 0 -> + let _, (e1, _), (e2, _) = Bindlib.unmbind2 binder1 binder2 in + compare e1 e2 + | n -> n) + | EApp ((f1, _), args1), EApp ((f2, _), args2) -> ( + match compare f1 f2 with + | 0 -> list_compare (fun (x1, _) (x2, _) -> compare x1 x2) args1 args2 + | n -> n) + | EOp op1, EOp op2 -> Stdlib.compare op1 op2 + | ( EDefault (exs1, (just1, _), (cons1, _)), + EDefault (exs2, (just2, _), (cons2, _)) ) -> ( + match compare just1 just2 with + | 0 -> ( + match compare cons1 cons2 with + | 0 -> list_compare (Pos.compare_marked compare) exs1 exs2 + | n -> n) + | n -> n) + | ( EIfThenElse ((i1, _), (t1, _), (e1, _)), + EIfThenElse ((i2, _), (t2, _), (e2, _)) ) -> ( + match compare i1 i2 with + | 0 -> ( match compare t1 t2 with 0 -> compare e1 e2 | n -> n) + | n -> n) + | EArray a1, EArray a2 -> + list_compare (fun (e1, _) (e2, _) -> compare e1 e2) a1 a2 + | ErrorOnEmpty (e1, _), ErrorOnEmpty (e2, _) -> compare e1 e2 + | ELocation _, _ -> -1 + | _, ELocation _ -> 1 + | EVar _, _ -> -1 + | _, EVar _ -> 1 + | EStruct _, _ -> -1 + | _, EStruct _ -> 1 + | EStructAccess _, _ -> -1 + | _, EStructAccess _ -> 1 + | EEnumInj _, _ -> -1 + | _, EEnumInj _ -> 1 + | EMatch _, _ -> -1 + | _, EMatch _ -> 1 + | ELit _, _ -> -1 + | _, ELit _ -> 1 + | EAbs _, _ -> -1 + | _, EAbs _ -> 1 + | EApp _, _ -> -1 + | _, EApp _ -> 1 + | EOp _, _ -> -1 + | _, EOp _ -> 1 + | EDefault _, _ -> -1 + | _, EDefault _ -> 1 + | EIfThenElse _, _ -> -1 + | _, EIfThenElse _ -> 1 + | EArray _, _ -> -1 + | _, EArray _ -> 1 +end + +module ExprMap = Map.Make (Expr) + module Var = struct type t = expr Bindlib.var @@ -176,6 +286,41 @@ type rule = { rule_exception_to_rules : RuleSet.t Pos.marked; } +module Rule = struct + type t = rule + + (** Structural equality (otherwise, you should just compare the [rule_id] + fields) *) + let compare r1 r2 = + match r1.rule_parameter, r2.rule_parameter with + | None, None -> ( + let j1, _ = Bindlib.unbox r1.rule_just in + let j2, _ = Bindlib.unbox r2.rule_just in + match Expr.compare j1 j2 with + | 0 -> + let c1, _ = Bindlib.unbox r1.rule_cons in + let c2, _ = Bindlib.unbox r2.rule_cons in + Expr.compare c1 c2 + | n -> n) + | Some (v1, (t1, _)), Some (v2, (t2, _)) -> ( + match Scopelang.Ast.Typ.compare t1 t2 with + | 0 -> ( + let open Bindlib in + let b1 = unbox (bind_var v1 r1.rule_just) in + let b2 = unbox (bind_var v2 r2.rule_just) in + let _, (j1, _), (j2, _) = unbind2 b1 b2 in + match Expr.compare j1 j2 with + | 0 -> + let b1 = unbox (bind_var v1 r1.rule_cons) in + let b2 = unbox (bind_var v2 r2.rule_cons) in + let _, (c1, _), (c2, _) = unbind2 b1 b2 in + Expr.compare c1 c2 + | n -> n) + | n -> n) + | None, Some _ -> -1 + | Some _, None -> 1 +end + let empty_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule = diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index be85be80..f5f1503e 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -91,6 +91,8 @@ type expr = | EArray of expr Pos.marked list | ErrorOnEmpty of expr Pos.marked +module ExprMap : Map.S with type key = expr + (** {2 Variable helpers} *) module Var : sig @@ -137,6 +139,8 @@ type rule = { rule_exception_to_rules : RuleSet.t Pos.marked; } +module Rule : Set.OrderedType with type t = rule + val empty_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule val always_false_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index 13dce2d3..751aed86 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -78,6 +78,30 @@ type typ = | TArray of typ | TAny +module Typ = struct + type t = typ + + let rec compare ty1 ty2 = + match ty1, ty2 with + | TLit l1, TLit l2 -> Stdlib.compare l1 l2 + | TStruct n1, TStruct n2 -> StructName.compare n1 n2 + | TEnum en1, TEnum en2 -> EnumName.compare en1 en2 + | TArrow ((a1, _), (b1, _)), TArrow ((a2, _), (b2, _)) -> ( + match compare a1 a2 with 0 -> compare b1 b2 | n -> n) + | TArray t1, TArray t2 -> compare t1 t2 + | TAny, TAny -> 0 + | TLit _, _ -> -1 + | _, TLit _ -> 1 + | TStruct _, _ -> -1 + | _, TStruct _ -> 1 + | TEnum _, _ -> -1 + | _, TEnum _ -> 1 + | TArrow _, _ -> -1 + | _, TArrow _ -> 1 + | TArray _, _ -> -1 + | _, TArray _ -> 1 +end + type expr = | ELocation of location | EVar of expr Bindlib.var Pos.marked @@ -96,6 +120,111 @@ type expr = | EArray of expr Pos.marked list | ErrorOnEmpty of expr Pos.marked +module Expr = struct + type t = expr + + let rec compare e1 e2 = + let rec list_compare cmp l1 l2 = + (* List.compare is available from OCaml 4.12 on *) + match l1, l2 with + | [], [] -> 0 + | [], _ :: _ -> -1 + | _ :: _, [] -> 1 + | a1 :: l1, a2 :: l2 -> + let c = cmp a1 a2 in + if c <> 0 then c else list_compare cmp l1 l2 + in + match e1, e2 with + | ELocation _, ELocation _ -> 0 + | EVar (v1, _), EVar (v2, _) -> Bindlib.compare_vars v1 v2 + | EStruct (name1, field_map1), EStruct (name2, field_map2) -> ( + match StructName.compare name1 name2 with + | 0 -> + StructFieldMap.compare + (Pos.compare_marked compare) + field_map1 field_map2 + | n -> n) + | ( EStructAccess ((e1, _), field_name1, struct_name1), + EStructAccess ((e2, _), field_name2, struct_name2) ) -> ( + match compare e1 e2 with + | 0 -> ( + match StructFieldName.compare field_name1 field_name2 with + | 0 -> StructName.compare struct_name1 struct_name2 + | n -> n) + | n -> n) + | EEnumInj ((e1, _), cstr1, name1), EEnumInj ((e2, _), cstr2, name2) -> ( + match compare e1 e2 with + | 0 -> ( + match EnumName.compare name1 name2 with + | 0 -> EnumConstructor.compare cstr1 cstr2 + | n -> n) + | n -> n) + | EMatch ((e1, _), name1, emap1), EMatch ((e2, _), name2, emap2) -> ( + match compare e1 e2 with + | 0 -> ( + match EnumName.compare name1 name2 with + | 0 -> + EnumConstructorMap.compare (Pos.compare_marked compare) emap1 emap2 + | n -> n) + | n -> n) + | ELit l1, ELit l2 -> Stdlib.compare l1 l2 + | EAbs ((binder1, _), typs1), EAbs ((binder2, _), typs2) -> ( + match list_compare (Pos.compare_marked Typ.compare) typs1 typs2 with + | 0 -> + let _, (e1, _), (e2, _) = Bindlib.unmbind2 binder1 binder2 in + compare e1 e2 + | n -> n) + | EApp ((f1, _), args1), EApp ((f2, _), args2) -> ( + match compare f1 f2 with + | 0 -> list_compare (fun (x1, _) (x2, _) -> compare x1 x2) args1 args2 + | n -> n) + | EOp op1, EOp op2 -> Stdlib.compare op1 op2 + | ( EDefault (exs1, (just1, _), (cons1, _)), + EDefault (exs2, (just2, _), (cons2, _)) ) -> ( + match compare just1 just2 with + | 0 -> ( + match compare cons1 cons2 with + | 0 -> list_compare (Pos.compare_marked compare) exs1 exs2 + | n -> n) + | n -> n) + | ( EIfThenElse ((i1, _), (t1, _), (e1, _)), + EIfThenElse ((i2, _), (t2, _), (e2, _)) ) -> ( + match compare i1 i2 with + | 0 -> ( match compare t1 t2 with 0 -> compare e1 e2 | n -> n) + | n -> n) + | EArray a1, EArray a2 -> + list_compare (fun (e1, _) (e2, _) -> compare e1 e2) a1 a2 + | ErrorOnEmpty (e1, _), ErrorOnEmpty (e2, _) -> compare e1 e2 + | ELocation _, _ -> -1 + | _, ELocation _ -> 1 + | EVar _, _ -> -1 + | _, EVar _ -> 1 + | EStruct _, _ -> -1 + | _, EStruct _ -> 1 + | EStructAccess _, _ -> -1 + | _, EStructAccess _ -> 1 + | EEnumInj _, _ -> -1 + | _, EEnumInj _ -> 1 + | EMatch _, _ -> -1 + | _, EMatch _ -> 1 + | ELit _, _ -> -1 + | _, ELit _ -> 1 + | EAbs _, _ -> -1 + | _, EAbs _ -> 1 + | EApp _, _ -> -1 + | _, EApp _ -> 1 + | EOp _, _ -> -1 + | _, EOp _ -> 1 + | EDefault _, _ -> -1 + | _, EDefault _ -> 1 + | EIfThenElse _, _ -> -1 + | _, EIfThenElse _ -> 1 + | EArray _, _ -> -1 + | _, EArray _ -> 1 +end + +module ExprMap = Map.Make (Expr) + let rec locations_used (e : expr Pos.marked) : LocationSet.t = match Pos.unmark e with | ELocation l -> LocationSet.singleton (l, Pos.get_position e) diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index f75a9a6f..6add803a 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -66,6 +66,8 @@ type typ = | TArray of typ | TAny +module Typ : Set.OrderedType with type t = typ + (** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib} library, based on higher-order abstract syntax*) type expr = @@ -86,6 +88,9 @@ type expr = | EArray of expr Pos.marked list | ErrorOnEmpty of expr Pos.marked +module Expr : Set.OrderedType with type t = expr +module ExprMap : Map.S with type key = expr + val locations_used : expr Pos.marked -> LocationSet.t (** This type characterizes the three levels of visibility for a given scope diff --git a/compiler/utils/pos.ml b/compiler/utils/pos.ml index fcf86e98..9a682ea2 100644 --- a/compiler/utils/pos.ml +++ b/compiler/utils/pos.ml @@ -222,6 +222,12 @@ let get_position ((_, x) : 'a marked) : t = x let map_under_mark (f : 'a -> 'b) ((x, y) : 'a marked) : 'b marked = f x, y let same_pos_as (x : 'a) ((_, y) : 'b marked) : 'a marked = x, y +let compare_marked + (cmp : 'a -> 'a -> int) + ((x, _) : 'a marked) + ((y, _) : 'a marked) : int = + cmp x y + let unmark_option (x : 'a marked option) : 'a option = match x with Some x -> Some (unmark x) | None -> None diff --git a/compiler/utils/pos.mli b/compiler/utils/pos.mli index ade01d25..dad08092 100644 --- a/compiler/utils/pos.mli +++ b/compiler/utils/pos.mli @@ -68,6 +68,9 @@ val map_under_mark : ('a -> 'b) -> 'a marked -> 'b marked val same_pos_as : 'a -> 'b marked -> 'a marked val unmark_option : 'a marked option -> 'a option +val compare_marked : ('a -> 'a -> int) -> 'a marked -> 'a marked -> int +(** Compares two marked values {b ignoring positions} *) + (** Visitors *) class ['self] marked_map :