Compiler: add a few helper comparison functions

Adds syntactic comparison for some expressions, etc., allowing in
particular to detect syntactically equal expressions. Positions are,
obviously, ignored.
This commit is contained in:
Louis Gesbert 2022-05-25 14:41:04 +02:00
parent 30297b27b8
commit 2d41f53300
6 changed files with 292 additions and 0 deletions

View File

@ -155,6 +155,116 @@ type expr =
| EArray of expr Pos.marked list
| ErrorOnEmpty of expr Pos.marked
module Expr = struct
type t = expr
(** Syntactic comparison, up to locations and alpha-renaming *)
let rec compare e1 e2 =
let rec list_compare cmp l1 l2 =
(* List.compare is available from OCaml 4.12 on *)
match l1, l2 with
| [], [] -> 0
| [], _ :: _ -> -1
| _ :: _, [] -> 1
| a1 :: l1, a2 :: l2 ->
let c = cmp a1 a2 in
if c <> 0 then c else list_compare cmp l1 l2
in
match e1, e2 with
| ELocation _, ELocation _ -> 0
| EVar (v1, _), EVar (v2, _) -> Bindlib.compare_vars v1 v2
| EStruct (name1, field_map1), EStruct (name2, field_map2) -> (
match Scopelang.Ast.StructName.compare name1 name2 with
| 0 ->
Scopelang.Ast.StructFieldMap.compare
(Pos.compare_marked compare)
field_map1 field_map2
| n -> n)
| ( EStructAccess ((e1, _), field_name1, struct_name1),
EStructAccess ((e2, _), field_name2, struct_name2) ) -> (
match compare e1 e2 with
| 0 -> (
match Scopelang.Ast.StructFieldName.compare field_name1 field_name2 with
| 0 -> Scopelang.Ast.StructName.compare struct_name1 struct_name2
| n -> n)
| n -> n)
| EEnumInj ((e1, _), cstr1, name1), EEnumInj ((e2, _), cstr2, name2) -> (
match compare e1 e2 with
| 0 -> (
match Scopelang.Ast.EnumName.compare name1 name2 with
| 0 -> Scopelang.Ast.EnumConstructor.compare cstr1 cstr2
| n -> n)
| n -> n)
| EMatch ((e1, _), name1, emap1), EMatch ((e2, _), name2, emap2) -> (
match compare e1 e2 with
| 0 -> (
match Scopelang.Ast.EnumName.compare name1 name2 with
| 0 ->
Scopelang.Ast.EnumConstructorMap.compare
(Pos.compare_marked compare)
emap1 emap2
| n -> n)
| n -> n)
| ELit l1, ELit l2 -> Stdlib.compare l1 l2
| EAbs ((binder1, _), typs1), EAbs ((binder2, _), typs2) -> (
match
list_compare (Pos.compare_marked Scopelang.Ast.Typ.compare) typs1 typs2
with
| 0 ->
let _, (e1, _), (e2, _) = Bindlib.unmbind2 binder1 binder2 in
compare e1 e2
| n -> n)
| EApp ((f1, _), args1), EApp ((f2, _), args2) -> (
match compare f1 f2 with
| 0 -> list_compare (fun (x1, _) (x2, _) -> compare x1 x2) args1 args2
| n -> n)
| EOp op1, EOp op2 -> Stdlib.compare op1 op2
| ( EDefault (exs1, (just1, _), (cons1, _)),
EDefault (exs2, (just2, _), (cons2, _)) ) -> (
match compare just1 just2 with
| 0 -> (
match compare cons1 cons2 with
| 0 -> list_compare (Pos.compare_marked compare) exs1 exs2
| n -> n)
| n -> n)
| ( EIfThenElse ((i1, _), (t1, _), (e1, _)),
EIfThenElse ((i2, _), (t2, _), (e2, _)) ) -> (
match compare i1 i2 with
| 0 -> ( match compare t1 t2 with 0 -> compare e1 e2 | n -> n)
| n -> n)
| EArray a1, EArray a2 ->
list_compare (fun (e1, _) (e2, _) -> compare e1 e2) a1 a2
| ErrorOnEmpty (e1, _), ErrorOnEmpty (e2, _) -> compare e1 e2
| ELocation _, _ -> -1
| _, ELocation _ -> 1
| EVar _, _ -> -1
| _, EVar _ -> 1
| EStruct _, _ -> -1
| _, EStruct _ -> 1
| EStructAccess _, _ -> -1
| _, EStructAccess _ -> 1
| EEnumInj _, _ -> -1
| _, EEnumInj _ -> 1
| EMatch _, _ -> -1
| _, EMatch _ -> 1
| ELit _, _ -> -1
| _, ELit _ -> 1
| EAbs _, _ -> -1
| _, EAbs _ -> 1
| EApp _, _ -> -1
| _, EApp _ -> 1
| EOp _, _ -> -1
| _, EOp _ -> 1
| EDefault _, _ -> -1
| _, EDefault _ -> 1
| EIfThenElse _, _ -> -1
| _, EIfThenElse _ -> 1
| EArray _, _ -> -1
| _, EArray _ -> 1
end
module ExprMap = Map.Make (Expr)
module Var = struct
type t = expr Bindlib.var
@ -176,6 +286,41 @@ type rule = {
rule_exception_to_rules : RuleSet.t Pos.marked;
}
module Rule = struct
type t = rule
(** Structural equality (otherwise, you should just compare the [rule_id]
fields) *)
let compare r1 r2 =
match r1.rule_parameter, r2.rule_parameter with
| None, None -> (
let j1, _ = Bindlib.unbox r1.rule_just in
let j2, _ = Bindlib.unbox r2.rule_just in
match Expr.compare j1 j2 with
| 0 ->
let c1, _ = Bindlib.unbox r1.rule_cons in
let c2, _ = Bindlib.unbox r2.rule_cons in
Expr.compare c1 c2
| n -> n)
| Some (v1, (t1, _)), Some (v2, (t2, _)) -> (
match Scopelang.Ast.Typ.compare t1 t2 with
| 0 -> (
let open Bindlib in
let b1 = unbox (bind_var v1 r1.rule_just) in
let b2 = unbox (bind_var v2 r2.rule_just) in
let _, (j1, _), (j2, _) = unbind2 b1 b2 in
match Expr.compare j1 j2 with
| 0 ->
let b1 = unbox (bind_var v1 r1.rule_cons) in
let b2 = unbox (bind_var v2 r2.rule_cons) in
let _, (c1, _), (c2, _) = unbind2 b1 b2 in
Expr.compare c1 c2
| n -> n)
| n -> n)
| None, Some _ -> -1
| Some _, None -> 1
end
let empty_rule
(pos : Pos.t)
(have_parameter : Scopelang.Ast.typ Pos.marked option) : rule =

View File

@ -91,6 +91,8 @@ type expr =
| EArray of expr Pos.marked list
| ErrorOnEmpty of expr Pos.marked
module ExprMap : Map.S with type key = expr
(** {2 Variable helpers} *)
module Var : sig
@ -137,6 +139,8 @@ type rule = {
rule_exception_to_rules : RuleSet.t Pos.marked;
}
module Rule : Set.OrderedType with type t = rule
val empty_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule
val always_false_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule

View File

@ -78,6 +78,30 @@ type typ =
| TArray of typ
| TAny
module Typ = struct
type t = typ
let rec compare ty1 ty2 =
match ty1, ty2 with
| TLit l1, TLit l2 -> Stdlib.compare l1 l2
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
| TArrow ((a1, _), (b1, _)), TArrow ((a2, _), (b2, _)) -> (
match compare a1 a2 with 0 -> compare b1 b2 | n -> n)
| TArray t1, TArray t2 -> compare t1 t2
| TAny, TAny -> 0
| TLit _, _ -> -1
| _, TLit _ -> 1
| TStruct _, _ -> -1
| _, TStruct _ -> 1
| TEnum _, _ -> -1
| _, TEnum _ -> 1
| TArrow _, _ -> -1
| _, TArrow _ -> 1
| TArray _, _ -> -1
| _, TArray _ -> 1
end
type expr =
| ELocation of location
| EVar of expr Bindlib.var Pos.marked
@ -96,6 +120,111 @@ type expr =
| EArray of expr Pos.marked list
| ErrorOnEmpty of expr Pos.marked
module Expr = struct
type t = expr
let rec compare e1 e2 =
let rec list_compare cmp l1 l2 =
(* List.compare is available from OCaml 4.12 on *)
match l1, l2 with
| [], [] -> 0
| [], _ :: _ -> -1
| _ :: _, [] -> 1
| a1 :: l1, a2 :: l2 ->
let c = cmp a1 a2 in
if c <> 0 then c else list_compare cmp l1 l2
in
match e1, e2 with
| ELocation _, ELocation _ -> 0
| EVar (v1, _), EVar (v2, _) -> Bindlib.compare_vars v1 v2
| EStruct (name1, field_map1), EStruct (name2, field_map2) -> (
match StructName.compare name1 name2 with
| 0 ->
StructFieldMap.compare
(Pos.compare_marked compare)
field_map1 field_map2
| n -> n)
| ( EStructAccess ((e1, _), field_name1, struct_name1),
EStructAccess ((e2, _), field_name2, struct_name2) ) -> (
match compare e1 e2 with
| 0 -> (
match StructFieldName.compare field_name1 field_name2 with
| 0 -> StructName.compare struct_name1 struct_name2
| n -> n)
| n -> n)
| EEnumInj ((e1, _), cstr1, name1), EEnumInj ((e2, _), cstr2, name2) -> (
match compare e1 e2 with
| 0 -> (
match EnumName.compare name1 name2 with
| 0 -> EnumConstructor.compare cstr1 cstr2
| n -> n)
| n -> n)
| EMatch ((e1, _), name1, emap1), EMatch ((e2, _), name2, emap2) -> (
match compare e1 e2 with
| 0 -> (
match EnumName.compare name1 name2 with
| 0 ->
EnumConstructorMap.compare (Pos.compare_marked compare) emap1 emap2
| n -> n)
| n -> n)
| ELit l1, ELit l2 -> Stdlib.compare l1 l2
| EAbs ((binder1, _), typs1), EAbs ((binder2, _), typs2) -> (
match list_compare (Pos.compare_marked Typ.compare) typs1 typs2 with
| 0 ->
let _, (e1, _), (e2, _) = Bindlib.unmbind2 binder1 binder2 in
compare e1 e2
| n -> n)
| EApp ((f1, _), args1), EApp ((f2, _), args2) -> (
match compare f1 f2 with
| 0 -> list_compare (fun (x1, _) (x2, _) -> compare x1 x2) args1 args2
| n -> n)
| EOp op1, EOp op2 -> Stdlib.compare op1 op2
| ( EDefault (exs1, (just1, _), (cons1, _)),
EDefault (exs2, (just2, _), (cons2, _)) ) -> (
match compare just1 just2 with
| 0 -> (
match compare cons1 cons2 with
| 0 -> list_compare (Pos.compare_marked compare) exs1 exs2
| n -> n)
| n -> n)
| ( EIfThenElse ((i1, _), (t1, _), (e1, _)),
EIfThenElse ((i2, _), (t2, _), (e2, _)) ) -> (
match compare i1 i2 with
| 0 -> ( match compare t1 t2 with 0 -> compare e1 e2 | n -> n)
| n -> n)
| EArray a1, EArray a2 ->
list_compare (fun (e1, _) (e2, _) -> compare e1 e2) a1 a2
| ErrorOnEmpty (e1, _), ErrorOnEmpty (e2, _) -> compare e1 e2
| ELocation _, _ -> -1
| _, ELocation _ -> 1
| EVar _, _ -> -1
| _, EVar _ -> 1
| EStruct _, _ -> -1
| _, EStruct _ -> 1
| EStructAccess _, _ -> -1
| _, EStructAccess _ -> 1
| EEnumInj _, _ -> -1
| _, EEnumInj _ -> 1
| EMatch _, _ -> -1
| _, EMatch _ -> 1
| ELit _, _ -> -1
| _, ELit _ -> 1
| EAbs _, _ -> -1
| _, EAbs _ -> 1
| EApp _, _ -> -1
| _, EApp _ -> 1
| EOp _, _ -> -1
| _, EOp _ -> 1
| EDefault _, _ -> -1
| _, EDefault _ -> 1
| EIfThenElse _, _ -> -1
| _, EIfThenElse _ -> 1
| EArray _, _ -> -1
| _, EArray _ -> 1
end
module ExprMap = Map.Make (Expr)
let rec locations_used (e : expr Pos.marked) : LocationSet.t =
match Pos.unmark e with
| ELocation l -> LocationSet.singleton (l, Pos.get_position e)

View File

@ -66,6 +66,8 @@ type typ =
| TArray of typ
| TAny
module Typ : Set.OrderedType with type t = typ
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
library, based on higher-order abstract syntax*)
type expr =
@ -86,6 +88,9 @@ type expr =
| EArray of expr Pos.marked list
| ErrorOnEmpty of expr Pos.marked
module Expr : Set.OrderedType with type t = expr
module ExprMap : Map.S with type key = expr
val locations_used : expr Pos.marked -> LocationSet.t
(** This type characterizes the three levels of visibility for a given scope

View File

@ -222,6 +222,12 @@ let get_position ((_, x) : 'a marked) : t = x
let map_under_mark (f : 'a -> 'b) ((x, y) : 'a marked) : 'b marked = f x, y
let same_pos_as (x : 'a) ((_, y) : 'b marked) : 'a marked = x, y
let compare_marked
(cmp : 'a -> 'a -> int)
((x, _) : 'a marked)
((y, _) : 'a marked) : int =
cmp x y
let unmark_option (x : 'a marked option) : 'a option =
match x with Some x -> Some (unmark x) | None -> None

View File

@ -68,6 +68,9 @@ val map_under_mark : ('a -> 'b) -> 'a marked -> 'b marked
val same_pos_as : 'a -> 'b marked -> 'a marked
val unmark_option : 'a marked option -> 'a option
val compare_marked : ('a -> 'a -> int) -> 'a marked -> 'a marked -> int
(** Compares two marked values {b ignoring positions} *)
(** Visitors *)
class ['self] marked_map :