Merge pull request #5574 from roc-lang/glue-limit-derives

glue: limit derives
This commit is contained in:
Ayaz 2023-06-19 21:29:46 -05:00 committed by GitHub
commit 5cafc50a50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 54 deletions

View File

@ -34,7 +34,7 @@ convertTypesToFile = \types ->
generateStruct buf types id name fields Public
TagUnionPayload { name, fields } ->
generateStruct buf types id name (nameTagUnionPayloadFields fields) Private
generateStruct buf types id name (nameTagUnionPayloadFields fields) Public
TagUnion (Enumeration { name, tags, size }) ->
generateEnumeration buf types type name tags size
@ -898,6 +898,28 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|> List.mapWithIndex partialEqCase
|> Str.joinWith "\n"
partialEqImpl =
if canDerivePartialEq types (Types.shape types id) then
"""
impl PartialEq for \(escapedName) {
fn eq(&self, other: &Self) -> bool {
use discriminant_\(escapedName)::*;
if self.discriminant() != other.discriminant() {
return false;
}
match self.discriminant() {
\(partialEqCases)
}
}
}
impl Eq for \(escapedName) {}
"""
else
""
debugCase = \{ name: tagName, payload: optPayload }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
@ -959,6 +981,26 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|> List.mapWithIndex hashCase
|> Str.joinWith "\n"
hashImpl =
if canDerivePartialEq types (Types.shape types id) then
"""
impl core::hash::Hash for \(escapedName) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
use discriminant_\(escapedName)::*;
self.discriminant().hash(state);
match self.discriminant() {
\(hashCases)
}
}
}
"""
else
""
partialOrdCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
@ -981,6 +1023,36 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|> List.mapWithIndex partialOrdCase
|> Str.joinWith "\n"
partialOrdImpl =
if canDerivePartialEq types (Types.shape types id) then
"""
impl PartialOrd for \(escapedName) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(<Self as Ord>::cmp(self, other))
}
}
impl Ord for \(escapedName) {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use discriminant_\(escapedName)::*;
use std::cmp::Ordering::*;
match self.discriminant().cmp(&other.discriminant()) {
Less => Less,
Greater => Greater,
Equal => unsafe {
match self.discriminant() {
\(partialOrdCases)
}
},
}
}
}
"""
else
""
sizeOfSelf = Num.toStr (Types.size types id)
alignOfSelf = Num.toStr (Types.alignment types id)
@ -1042,21 +1114,12 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
}
}
impl PartialEq for \(escapedName) {
fn eq(&self, other: &Self) -> bool {
use discriminant_\(escapedName)::*;
\(partialEqImpl)
if self.discriminant() != other.discriminant() {
return false;
}
\(hashImpl)
match self.discriminant() {
\(partialEqCases)
}
}
}
\(partialOrdImpl)
impl Eq for \(escapedName) {}
impl core::fmt::Debug for \(escapedName) {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
@ -1068,41 +1131,6 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
}
}
impl core::hash::Hash for \(escapedName) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
use discriminant_\(escapedName)::*;
self.discriminant().hash(state);
match self.discriminant() {
\(hashCases)
}
}
}
impl PartialOrd for \(escapedName) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(<Self as Ord>::cmp(self, other))
}
}
impl Ord for \(escapedName) {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use discriminant_\(escapedName)::*;
use std::cmp::Ordering::*;
match self.discriminant().cmp(&other.discriminant()) {
Less => Less,
Greater => Greater,
Equal => unsafe {
match self.discriminant() {
\(partialOrdCases)
}
},
}
}
}
#[repr(C)]
union \(unionName) {
@ -1651,8 +1679,8 @@ canDerivePartialEq = \types, type ->
TagUnion (SingleTagStruct { payload: HasNoClosure fields }) ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
TagUnion (SingleTagStruct { payload: HasClosure fields }) ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
TagUnion (SingleTagStruct { payload: HasClosure _ }) ->
Bool.false
TagUnion (NonRecursive { tags }) ->
List.all tags \{ payload } ->
@ -1712,10 +1740,12 @@ canDeriveCopy = \types, type ->
cannotDeriveDefault = \types, type ->
when type is
Unit | Unsized | EmptyTagUnion | TagUnion _ | RocResult _ _ | RecursivePointer _ | Function _ -> Bool.true
RocStr | Bool | Num _ | TagUnionPayload { fields: HasClosure _ } -> Bool.false
RocStr | Bool | Num _ -> Bool.false
RocList id | RocSet id | RocBox id ->
cannotDeriveDefault types (Types.shape types id)
TagUnionPayload { fields: HasClosure _ } -> Bool.true
RocDict keyId valId ->
cannotDeriveCopy types (Types.shape types keyId)
|| cannotDeriveCopy types (Types.shape types valId)

View File

@ -4,6 +4,7 @@ app "rocLovesRust"
provides [main] to pf
main =
msg = "Roc <3 Rust, also on stderr!\n"
StdoutWrite "Roc <3 Rust!\n" \{} ->
StdoutWrite "Roc <3 Rust!\n" \{} ->
StderrWrite msg \{} ->
Done

View File

@ -89,7 +89,7 @@ pub extern "C" fn rust_main() -> i32 {
loop {
match dbg!(op.discriminant()) {
StdoutWrite => {
let stdout_write = unsafe { op.get_StdoutWrite() };
let stdout_write = op.get_StdoutWrite();
let output: RocStr = stdout_write.f0;
op = unsafe { stdout_write.f1.force_thunk(()) };
@ -98,7 +98,7 @@ pub extern "C" fn rust_main() -> i32 {
}
}
StderrWrite => {
let stderr_write = unsafe { op.get_StderrWrite() };
let stderr_write = op.get_StderrWrite();
let output: RocStr = stderr_write.f0;
op = unsafe { stderr_write.f1.force_thunk(()) };