mirror of
https://github.com/roc-lang/roc.git
synced 2024-09-20 07:17:50 +03:00
Merge pull request #5742 from roc-lang/improve-rust-glue
Situationally gen Eq/Ord/Hash glue for tag unions
This commit is contained in:
commit
2bd998a215
@ -426,107 +426,122 @@ deriveDebugTagUnion = \buf, types, tagUnionType, tags ->
|
||||
}
|
||||
"""
|
||||
|
||||
deriveEqTagUnion : Str, Str -> Str
|
||||
deriveEqTagUnion = \buf, tagUnionType ->
|
||||
"""
|
||||
\(buf)
|
||||
deriveEqTagUnion : Str, Types, Shape, Str -> Str
|
||||
deriveEqTagUnion = \buf, types, shape, tagUnionType ->
|
||||
if canSupportEqHashOrd types shape then
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl Eq for \(tagUnionType) {}
|
||||
"""
|
||||
impl Eq for \(tagUnionType) {}
|
||||
"""
|
||||
else
|
||||
buf
|
||||
|
||||
derivePartialEqTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
derivePartialEqTagUnion = \buf, tagUnionType, tags ->
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName) == other.payload.\(tagName),
|
||||
"""
|
||||
derivePartialEqTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
derivePartialEqTagUnion = \buf, types, shape, tagUnionType, tags ->
|
||||
if canSupportPartialEqOrd types shape then
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName) == other.payload.\(tagName),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl PartialEq for \(tagUnionType) {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
impl PartialEq for \(tagUnionType) {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
if self.discriminant != other.discriminant {
|
||||
return false;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
if self.discriminant != other.discriminant {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
deriveOrdTagUnion : Str, Str -> Str
|
||||
deriveOrdTagUnion = \buf, tagUnionType ->
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl Ord for \(tagUnionType) {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(other).unwrap()
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
derivePartialOrdTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
derivePartialOrdTagUnion = \buf, tagUnionType, tags ->
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl PartialOrd for \(tagUnionType) {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
use std::cmp::Ordering::*;
|
||||
|
||||
match self.discriminant.cmp(&other.discriminant) {
|
||||
Less => Option::Some(Less),
|
||||
Greater => Option::Some(Greater),
|
||||
Equal => unsafe {
|
||||
unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
deriveHashTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
deriveHashTagUnion = \buf, tagUnionType, tags ->
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName).hash(state),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl core::hash::Hash for \(tagUnionType) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
"""
|
||||
else
|
||||
buf
|
||||
|
||||
deriveOrdTagUnion : Str, Types, Shape, Str -> Str
|
||||
deriveOrdTagUnion = \buf, types, shape, tagUnionType ->
|
||||
if canSupportEqHashOrd types shape then
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl Ord for \(tagUnionType) {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(other).unwrap()
|
||||
}
|
||||
}
|
||||
"""
|
||||
else
|
||||
buf
|
||||
|
||||
derivePartialOrdTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
derivePartialOrdTagUnion = \buf, types, shape, tagUnionType, tags ->
|
||||
if canSupportPartialEqOrd types shape then
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl PartialOrd for \(tagUnionType) {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
use std::cmp::Ordering::*;
|
||||
|
||||
match self.discriminant.cmp(&other.discriminant) {
|
||||
Less => Option::Some(Less),
|
||||
Greater => Option::Some(Greater),
|
||||
Equal => unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
else
|
||||
buf
|
||||
|
||||
deriveHashTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
deriveHashTagUnion = \buf, types, shape, tagUnionType, tags ->
|
||||
if canSupportEqHashOrd types shape then
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName).hash(state),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl core::hash::Hash for \(tagUnionType) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
else
|
||||
buf
|
||||
|
||||
generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
generateConstructorFunctions = \buf, types, tagUnionType, tags ->
|
||||
@ -646,6 +661,7 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
|
||||
|
||||
sizeOfSelf = Num.toStr (Types.size types id)
|
||||
alignOfSelf = Num.toStr (Types.alignment types id)
|
||||
shape = Types.shape types id
|
||||
|
||||
# TODO: this value can be different than the alignment of `id`
|
||||
align =
|
||||
@ -701,16 +717,16 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
|
||||
"""
|
||||
|> deriveCloneTagUnion escapedName tags
|
||||
|> deriveDebugTagUnion types escapedName tags
|
||||
|> deriveEqTagUnion escapedName
|
||||
|> derivePartialEqTagUnion escapedName tags
|
||||
|> deriveOrdTagUnion escapedName
|
||||
|> derivePartialOrdTagUnion escapedName tags
|
||||
|> deriveHashTagUnion escapedName tags
|
||||
|> deriveEqTagUnion types shape escapedName
|
||||
|> derivePartialEqTagUnion types shape escapedName tags
|
||||
|> deriveOrdTagUnion types shape escapedName
|
||||
|> derivePartialOrdTagUnion types shape escapedName tags
|
||||
|> deriveHashTagUnion types shape escapedName tags
|
||||
|> generateDestructorFunctions types escapedName tags
|
||||
|> generateConstructorFunctions types escapedName tags
|
||||
|> \b ->
|
||||
type = Types.shape types id
|
||||
if cannotDeriveCopy types type then
|
||||
if cannotSupportCopy types type then
|
||||
# A custom drop impl is only needed when we can't derive copy.
|
||||
b
|
||||
|> Str.concat
|
||||
@ -942,7 +958,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|
||||
|> Str.joinWith "\n"
|
||||
|
||||
partialEqImpl =
|
||||
if canDerivePartialEq types (Types.shape types id) then
|
||||
if canSupportPartialEqOrd types (Types.shape types id) then
|
||||
"""
|
||||
impl PartialEq for \(escapedName) {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
@ -1027,7 +1043,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|
||||
|
||||
|
||||
hashImpl =
|
||||
if canDerivePartialEq types (Types.shape types id) then
|
||||
if canSupportPartialEqOrd types (Types.shape types id) then
|
||||
"""
|
||||
impl core::hash::Hash for \(escapedName) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
@ -1067,7 +1083,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|
||||
|> Str.joinWith "\n"
|
||||
|
||||
partialOrdImpl =
|
||||
if canDerivePartialEq types (Types.shape types id) then
|
||||
if canSupportPartialEqOrd types (Types.shape types id) then
|
||||
"""
|
||||
impl PartialOrd for \(escapedName) {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
@ -1198,7 +1214,7 @@ generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, disc
|
||||
buf
|
||||
|> writeTagImpls tags discriminantName indents \name, payload ->
|
||||
when payload is
|
||||
Some id if cannotDeriveCopy types (Types.shape types id) ->
|
||||
Some id if cannotSupportCopy types (Types.shape types id) ->
|
||||
"unsafe { core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) },"
|
||||
|
||||
_ ->
|
||||
@ -1272,7 +1288,7 @@ generateUnionField = \types ->
|
||||
|
||||
type = Types.shape types id
|
||||
fullTypeStr =
|
||||
if cannotDeriveCopy types type then
|
||||
if cannotSupportCopy types type then
|
||||
# types with pointers need ManuallyDrop
|
||||
# because rust unions don't (and can't)
|
||||
# know how to drop them automatically!
|
||||
@ -1673,54 +1689,58 @@ generateDeriveStr = \buf, types, type, includeDebug ->
|
||||
|
||||
buf
|
||||
|> Str.concat "#[derive(Clone, "
|
||||
|> condWrite (!(cannotDeriveCopy types type)) "Copy, "
|
||||
|> condWrite (!(cannotDeriveDefault types type)) "Default, "
|
||||
|> condWrite (!(cannotSupportCopy types type)) "Copy, "
|
||||
|> condWrite (!(cannotSupportDefault types type)) "Default, "
|
||||
|> condWrite deriveDebug "Debug, "
|
||||
|> condWrite (canDerivePartialEq types type) "PartialEq, PartialOrd, "
|
||||
|> condWrite (!(hasFloat types type) && (canDerivePartialEq types type)) "Eq, Ord, Hash, "
|
||||
|> condWrite (canSupportPartialEqOrd types type) "PartialEq, PartialOrd, "
|
||||
|> condWrite (canSupportEqHashOrd types type) "Eq, Ord, Hash, "
|
||||
|> Str.concat ")]\n"
|
||||
|
||||
canDerivePartialEq : Types, Shape -> Bool
|
||||
canDerivePartialEq = \types, type ->
|
||||
canSupportEqHashOrd : Types, Shape -> Bool
|
||||
canSupportEqHashOrd = \types, type ->
|
||||
!(hasFloat types type) && (canSupportPartialEqOrd types type)
|
||||
|
||||
canSupportPartialEqOrd : Types, Shape -> Bool
|
||||
canSupportPartialEqOrd = \types, type ->
|
||||
when type is
|
||||
Function rocFn ->
|
||||
runtimeRepresentation = Types.shape types rocFn.lambdaSet
|
||||
canDerivePartialEq types runtimeRepresentation
|
||||
canSupportPartialEqOrd types runtimeRepresentation
|
||||
|
||||
Unsized -> Bool.false
|
||||
Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true
|
||||
RocStr -> Bool.true
|
||||
RocList inner | RocSet inner | RocBox inner ->
|
||||
innerType = Types.shape types inner
|
||||
canDerivePartialEq types innerType
|
||||
canSupportPartialEqOrd types innerType
|
||||
|
||||
RocDict k v ->
|
||||
kType = Types.shape types k
|
||||
vType = Types.shape types v
|
||||
|
||||
canDerivePartialEq types kType && canDerivePartialEq types vType
|
||||
canSupportPartialEqOrd types kType && canSupportPartialEqOrd types vType
|
||||
|
||||
TagUnion (Recursive { tags }) ->
|
||||
List.all tags \{ payload } ->
|
||||
when payload is
|
||||
None -> Bool.true
|
||||
Some id -> canDerivePartialEq types (Types.shape types id)
|
||||
Some id -> canSupportPartialEqOrd types (Types.shape types id)
|
||||
|
||||
TagUnion (NullableWrapped { tags }) ->
|
||||
List.all tags \{ payload } ->
|
||||
when payload is
|
||||
None -> Bool.true
|
||||
Some id -> canDerivePartialEq types (Types.shape types id)
|
||||
Some id -> canSupportPartialEqOrd types (Types.shape types id)
|
||||
|
||||
TagUnion (NonNullableUnwrapped { payload }) ->
|
||||
canDerivePartialEq types (Types.shape types payload)
|
||||
canSupportPartialEqOrd types (Types.shape types payload)
|
||||
|
||||
TagUnion (NullableUnwrapped { nonNullPayload }) ->
|
||||
canDerivePartialEq types (Types.shape types nonNullPayload)
|
||||
canSupportPartialEqOrd types (Types.shape types nonNullPayload)
|
||||
|
||||
RecursivePointer _ -> Bool.true
|
||||
TagUnion (SingleTagStruct { payload: HasNoClosure fields }) ->
|
||||
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
|
||||
List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
|
||||
|
||||
TagUnion (SingleTagStruct { payload: HasClosure _ }) ->
|
||||
Bool.false
|
||||
@ -1728,23 +1748,23 @@ canDerivePartialEq = \types, type ->
|
||||
TagUnion (NonRecursive { tags }) ->
|
||||
List.all tags \{ payload } ->
|
||||
when payload is
|
||||
Some id -> canDerivePartialEq types (Types.shape types id)
|
||||
Some id -> canSupportPartialEqOrd types (Types.shape types id)
|
||||
None -> Bool.true
|
||||
|
||||
RocResult okId errId ->
|
||||
okShape = Types.shape types okId
|
||||
errShape = Types.shape types errId
|
||||
|
||||
canDerivePartialEq types okShape && canDerivePartialEq types errShape
|
||||
canSupportPartialEqOrd types okShape && canSupportPartialEqOrd types errShape
|
||||
|
||||
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
|
||||
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
|
||||
List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
|
||||
|
||||
Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } ->
|
||||
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
|
||||
List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
|
||||
|
||||
cannotDeriveCopy : Types, Shape -> Bool
|
||||
cannotDeriveCopy = \types, type ->
|
||||
cannotSupportCopy : Types, Shape -> Bool
|
||||
cannotSupportCopy = \types, type ->
|
||||
!(canDeriveCopy types type)
|
||||
|
||||
canDeriveCopy : Types, Shape -> Bool
|
||||
@ -1780,22 +1800,22 @@ canDeriveCopy = \types, type ->
|
||||
Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } ->
|
||||
List.all fields \{ id } -> canDeriveCopy types (Types.shape types id)
|
||||
|
||||
cannotDeriveDefault = \types, type ->
|
||||
cannotSupportDefault = \types, type ->
|
||||
when type is
|
||||
Unit | Unsized | EmptyTagUnion | TagUnion _ | RocResult _ _ | RecursivePointer _ | Function _ -> Bool.true
|
||||
RocStr | Bool | Num _ -> Bool.false
|
||||
RocList id | RocSet id | RocBox id ->
|
||||
cannotDeriveDefault types (Types.shape types id)
|
||||
cannotSupportDefault types (Types.shape types id)
|
||||
|
||||
TagUnionPayload { fields: HasClosure _ } -> Bool.true
|
||||
|
||||
RocDict keyId valId ->
|
||||
cannotDeriveCopy types (Types.shape types keyId)
|
||||
|| cannotDeriveCopy types (Types.shape types valId)
|
||||
cannotSupportCopy types (Types.shape types keyId)
|
||||
|| cannotSupportCopy types (Types.shape types valId)
|
||||
|
||||
Struct { fields: HasClosure _ } -> Bool.true
|
||||
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
|
||||
List.any fields \{ id } -> cannotDeriveDefault types (Types.shape types id)
|
||||
List.any fields \{ id } -> cannotSupportDefault types (Types.shape types id)
|
||||
|
||||
hasFloat = \types, type ->
|
||||
hasFloatHelp types type (Set.empty {})
|
||||
|
Loading…
Reference in New Issue
Block a user