Merge pull request #5742 from roc-lang/improve-rust-glue

Situationally gen Eq/Ord/Hash glue for tag unions
This commit is contained in:
Folkert de Vries 2023-08-10 20:10:13 +02:00 committed by GitHub
commit 2bd998a215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -426,107 +426,122 @@ deriveDebugTagUnion = \buf, types, tagUnionType, tags ->
}
"""
deriveEqTagUnion : Str, Str -> Str
deriveEqTagUnion = \buf, tagUnionType ->
"""
\(buf)
deriveEqTagUnion : Str, Types, Shape, Str -> Str
deriveEqTagUnion = \buf, types, shape, tagUnionType ->
if canSupportEqHashOrd types shape then
"""
\(buf)
impl Eq for \(tagUnionType) {}
"""
impl Eq for \(tagUnionType) {}
"""
else
buf
derivePartialEqTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialEqTagUnion = \buf, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName) == other.payload.\(tagName),
"""
derivePartialEqTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialEqTagUnion = \buf, types, shape, tagUnionType, tags ->
if canSupportPartialEqOrd types shape then
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName) == other.payload.\(tagName),
"""
"""
\(buf)
"""
\(buf)
impl PartialEq for \(tagUnionType) {
fn eq(&self, other: &Self) -> bool {
use discriminant_\(tagUnionType)::*;
impl PartialEq for \(tagUnionType) {
fn eq(&self, other: &Self) -> bool {
use discriminant_\(tagUnionType)::*;
if self.discriminant != other.discriminant {
return false;
}
unsafe {
match self.discriminant {\(checks)
if self.discriminant != other.discriminant {
return false;
}
}
}
}
"""
deriveOrdTagUnion : Str, Str -> Str
deriveOrdTagUnion = \buf, tagUnionType ->
"""
\(buf)
impl Ord for \(tagUnionType) {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap()
}
}
"""
derivePartialOrdTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialOrdTagUnion = \buf, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)),
"""
"""
\(buf)
impl PartialOrd for \(tagUnionType) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
use discriminant_\(tagUnionType)::*;
use std::cmp::Ordering::*;
match self.discriminant.cmp(&other.discriminant) {
Less => Option::Some(Less),
Greater => Option::Some(Greater),
Equal => unsafe {
unsafe {
match self.discriminant {\(checks)
}
},
}
}
}
"""
deriveHashTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
deriveHashTagUnion = \buf, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName).hash(state),
"""
"""
\(buf)
impl core::hash::Hash for \(tagUnionType) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
use discriminant_\(tagUnionType)::*;
unsafe {
match self.discriminant {\(checks)
}
}
}
}
"""
"""
else
buf
deriveOrdTagUnion : Str, Types, Shape, Str -> Str
deriveOrdTagUnion = \buf, types, shape, tagUnionType ->
if canSupportEqHashOrd types shape then
"""
\(buf)
impl Ord for \(tagUnionType) {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap()
}
}
"""
else
buf
derivePartialOrdTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialOrdTagUnion = \buf, types, shape, tagUnionType, tags ->
if canSupportPartialEqOrd types shape then
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)),
"""
"""
\(buf)
impl PartialOrd for \(tagUnionType) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
use discriminant_\(tagUnionType)::*;
use std::cmp::Ordering::*;
match self.discriminant.cmp(&other.discriminant) {
Less => Option::Some(Less),
Greater => Option::Some(Greater),
Equal => unsafe {
match self.discriminant {\(checks)
}
},
}
}
}
"""
else
buf
deriveHashTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
deriveHashTagUnion = \buf, types, shape, tagUnionType, tags ->
if canSupportEqHashOrd types shape then
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName).hash(state),
"""
"""
\(buf)
impl core::hash::Hash for \(tagUnionType) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
use discriminant_\(tagUnionType)::*;
unsafe {
match self.discriminant {\(checks)
}
}
}
}
"""
else
buf
generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
generateConstructorFunctions = \buf, types, tagUnionType, tags ->
@ -646,6 +661,7 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
sizeOfSelf = Num.toStr (Types.size types id)
alignOfSelf = Num.toStr (Types.alignment types id)
shape = Types.shape types id
# TODO: this value can be different than the alignment of `id`
align =
@ -701,16 +717,16 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
"""
|> deriveCloneTagUnion escapedName tags
|> deriveDebugTagUnion types escapedName tags
|> deriveEqTagUnion escapedName
|> derivePartialEqTagUnion escapedName tags
|> deriveOrdTagUnion escapedName
|> derivePartialOrdTagUnion escapedName tags
|> deriveHashTagUnion escapedName tags
|> deriveEqTagUnion types shape escapedName
|> derivePartialEqTagUnion types shape escapedName tags
|> deriveOrdTagUnion types shape escapedName
|> derivePartialOrdTagUnion types shape escapedName tags
|> deriveHashTagUnion types shape escapedName tags
|> generateDestructorFunctions types escapedName tags
|> generateConstructorFunctions types escapedName tags
|> \b ->
type = Types.shape types id
if cannotDeriveCopy types type then
if cannotSupportCopy types type then
# A custom drop impl is only needed when we can't derive copy.
b
|> Str.concat
@ -942,7 +958,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|> Str.joinWith "\n"
partialEqImpl =
if canDerivePartialEq types (Types.shape types id) then
if canSupportPartialEqOrd types (Types.shape types id) then
"""
impl PartialEq for \(escapedName) {
fn eq(&self, other: &Self) -> bool {
@ -1027,7 +1043,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
hashImpl =
if canDerivePartialEq types (Types.shape types id) then
if canSupportPartialEqOrd types (Types.shape types id) then
"""
impl core::hash::Hash for \(escapedName) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
@ -1067,7 +1083,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|> Str.joinWith "\n"
partialOrdImpl =
if canDerivePartialEq types (Types.shape types id) then
if canSupportPartialEqOrd types (Types.shape types id) then
"""
impl PartialOrd for \(escapedName) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
@ -1198,7 +1214,7 @@ generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, disc
buf
|> writeTagImpls tags discriminantName indents \name, payload ->
when payload is
Some id if cannotDeriveCopy types (Types.shape types id) ->
Some id if cannotSupportCopy types (Types.shape types id) ->
"unsafe { core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) },"
_ ->
@ -1272,7 +1288,7 @@ generateUnionField = \types ->
type = Types.shape types id
fullTypeStr =
if cannotDeriveCopy types type then
if cannotSupportCopy types type then
# types with pointers need ManuallyDrop
# because rust unions don't (and can't)
# know how to drop them automatically!
@ -1673,54 +1689,58 @@ generateDeriveStr = \buf, types, type, includeDebug ->
buf
|> Str.concat "#[derive(Clone, "
|> condWrite (!(cannotDeriveCopy types type)) "Copy, "
|> condWrite (!(cannotDeriveDefault types type)) "Default, "
|> condWrite (!(cannotSupportCopy types type)) "Copy, "
|> condWrite (!(cannotSupportDefault types type)) "Default, "
|> condWrite deriveDebug "Debug, "
|> condWrite (canDerivePartialEq types type) "PartialEq, PartialOrd, "
|> condWrite (!(hasFloat types type) && (canDerivePartialEq types type)) "Eq, Ord, Hash, "
|> condWrite (canSupportPartialEqOrd types type) "PartialEq, PartialOrd, "
|> condWrite (canSupportEqHashOrd types type) "Eq, Ord, Hash, "
|> Str.concat ")]\n"
canDerivePartialEq : Types, Shape -> Bool
canDerivePartialEq = \types, type ->
canSupportEqHashOrd : Types, Shape -> Bool
canSupportEqHashOrd = \types, type ->
!(hasFloat types type) && (canSupportPartialEqOrd types type)
canSupportPartialEqOrd : Types, Shape -> Bool
canSupportPartialEqOrd = \types, type ->
when type is
Function rocFn ->
runtimeRepresentation = Types.shape types rocFn.lambdaSet
canDerivePartialEq types runtimeRepresentation
canSupportPartialEqOrd types runtimeRepresentation
Unsized -> Bool.false
Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true
RocStr -> Bool.true
RocList inner | RocSet inner | RocBox inner ->
innerType = Types.shape types inner
canDerivePartialEq types innerType
canSupportPartialEqOrd types innerType
RocDict k v ->
kType = Types.shape types k
vType = Types.shape types v
canDerivePartialEq types kType && canDerivePartialEq types vType
canSupportPartialEqOrd types kType && canSupportPartialEqOrd types vType
TagUnion (Recursive { tags }) ->
List.all tags \{ payload } ->
when payload is
None -> Bool.true
Some id -> canDerivePartialEq types (Types.shape types id)
Some id -> canSupportPartialEqOrd types (Types.shape types id)
TagUnion (NullableWrapped { tags }) ->
List.all tags \{ payload } ->
when payload is
None -> Bool.true
Some id -> canDerivePartialEq types (Types.shape types id)
Some id -> canSupportPartialEqOrd types (Types.shape types id)
TagUnion (NonNullableUnwrapped { payload }) ->
canDerivePartialEq types (Types.shape types payload)
canSupportPartialEqOrd types (Types.shape types payload)
TagUnion (NullableUnwrapped { nonNullPayload }) ->
canDerivePartialEq types (Types.shape types nonNullPayload)
canSupportPartialEqOrd types (Types.shape types nonNullPayload)
RecursivePointer _ -> Bool.true
TagUnion (SingleTagStruct { payload: HasNoClosure fields }) ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
TagUnion (SingleTagStruct { payload: HasClosure _ }) ->
Bool.false
@ -1728,23 +1748,23 @@ canDerivePartialEq = \types, type ->
TagUnion (NonRecursive { tags }) ->
List.all tags \{ payload } ->
when payload is
Some id -> canDerivePartialEq types (Types.shape types id)
Some id -> canSupportPartialEqOrd types (Types.shape types id)
None -> Bool.true
RocResult okId errId ->
okShape = Types.shape types okId
errShape = Types.shape types errId
canDerivePartialEq types okShape && canDerivePartialEq types errShape
canSupportPartialEqOrd types okShape && canSupportPartialEqOrd types errShape
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
cannotDeriveCopy : Types, Shape -> Bool
cannotDeriveCopy = \types, type ->
cannotSupportCopy : Types, Shape -> Bool
cannotSupportCopy = \types, type ->
!(canDeriveCopy types type)
canDeriveCopy : Types, Shape -> Bool
@ -1780,22 +1800,22 @@ canDeriveCopy = \types, type ->
Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } ->
List.all fields \{ id } -> canDeriveCopy types (Types.shape types id)
cannotDeriveDefault = \types, type ->
cannotSupportDefault = \types, type ->
when type is
Unit | Unsized | EmptyTagUnion | TagUnion _ | RocResult _ _ | RecursivePointer _ | Function _ -> Bool.true
RocStr | Bool | Num _ -> Bool.false
RocList id | RocSet id | RocBox id ->
cannotDeriveDefault types (Types.shape types id)
cannotSupportDefault types (Types.shape types id)
TagUnionPayload { fields: HasClosure _ } -> Bool.true
RocDict keyId valId ->
cannotDeriveCopy types (Types.shape types keyId)
|| cannotDeriveCopy types (Types.shape types valId)
cannotSupportCopy types (Types.shape types keyId)
|| cannotSupportCopy types (Types.shape types valId)
Struct { fields: HasClosure _ } -> Bool.true
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
List.any fields \{ id } -> cannotDeriveDefault types (Types.shape types id)
List.any fields \{ id } -> cannotSupportDefault types (Types.shape types id)
hasFloat = \types, type ->
hasFloatHelp types type (Set.empty {})