mirror of
https://github.com/roc-lang/roc.git
synced 2024-11-12 23:50:20 +03:00
Overhaul nonrecursive tag union bindgen
This commit is contained in:
parent
c685acd3cd
commit
3ba05bdef2
@ -1,4 +1,5 @@
|
||||
use roc_mono::layout::UnionLayout;
|
||||
use roc_target::TargetInfo;
|
||||
|
||||
use crate::types::{RocTagUnion, RocType, TypeId, Types};
|
||||
use std::{
|
||||
@ -140,15 +141,6 @@ fn write_tag_union(
|
||||
let discriminant_name = write_discriminant(name, tag_names, types, buf)?;
|
||||
let typ = types.get(type_id);
|
||||
|
||||
// The tag union's variant union, e.g.
|
||||
//
|
||||
// #[repr(C)]
|
||||
// union union_MyTagUnion {
|
||||
// Bar: u128,
|
||||
// Foo: core::mem::ManuallyDrop<roc_std::RocStr>,
|
||||
// }
|
||||
let variant_name = format!("union_{name}");
|
||||
|
||||
{
|
||||
// No deriving for unions; we have to add the impls ourselves!
|
||||
|
||||
@ -156,7 +148,7 @@ fn write_tag_union(
|
||||
buf,
|
||||
r#"
|
||||
#[repr(C)]
|
||||
pub union {variant_name} {{"#
|
||||
pub union {name} {{"#
|
||||
)?;
|
||||
|
||||
for (tag_name, opt_payload_id) in tags {
|
||||
@ -184,36 +176,35 @@ pub union {variant_name} {{"#
|
||||
buf.write_str("}\n")?;
|
||||
}
|
||||
|
||||
// The tag union struct itself, e.g.
|
||||
//
|
||||
// #[repr(C)]
|
||||
// pub struct MyTagUnion {
|
||||
// variant: variant_MyTagUnion,
|
||||
// tag: tag_MyTagUnion,
|
||||
// }
|
||||
{
|
||||
// no deriving because it contains a union; we have to
|
||||
// generate the impls explicitly!
|
||||
|
||||
writeln!(
|
||||
buf,
|
||||
r#"
|
||||
#[repr(C)]
|
||||
pub struct {name} {{
|
||||
variant: {variant_name},
|
||||
tag: {discriminant_name},
|
||||
}}"#,
|
||||
)?;
|
||||
}
|
||||
|
||||
// The impl for the tag union
|
||||
{
|
||||
// TODO also do this for other targets. Remember, these can change based on more
|
||||
// than just pointer width; e.g. on wasm, the alignments of U16 and U8 are both 4!
|
||||
let discriminant_offset = RocTagUnion::discriminant_offset(
|
||||
tags,
|
||||
types,
|
||||
TargetInfo::from(&target_lexicon::Triple::host()),
|
||||
);
|
||||
|
||||
write!(
|
||||
buf,
|
||||
r#"
|
||||
impl {name} {{
|
||||
pub fn tag(&self) -> {discriminant_name} {{
|
||||
self.tag
|
||||
unsafe {{
|
||||
let bytes = core::mem::transmute::<&Self, &[u8; core::mem::size_of::<Self>()]>(self);
|
||||
|
||||
core::mem::transmute::<u8, {discriminant_name}>(*bytes.as_ptr().add({discriminant_offset}))
|
||||
}}
|
||||
}}
|
||||
|
||||
/// Internal helper
|
||||
fn set_discriminant(&mut self, tag: {discriminant_name}) {{
|
||||
let discriminant_ptr: *mut {discriminant_name} = (self as *mut {name}).cast();
|
||||
|
||||
unsafe {{
|
||||
*(discriminant_ptr.add({discriminant_offset})) = tag;
|
||||
}}
|
||||
}}
|
||||
"#
|
||||
)?;
|
||||
@ -238,9 +229,9 @@ impl {name} {{
|
||||
if payload_type.has_pointer(types) {
|
||||
(
|
||||
"core::mem::ManuallyDrop::new(payload)",
|
||||
format!("core::mem::ManuallyDrop::take(&mut self.variant.{tag_name})",),
|
||||
format!("core::mem::ManuallyDrop::take(&mut self.{tag_name})",),
|
||||
// Since this is a ManuallyDrop, our `as_` method will need
|
||||
// to dereference the variant (e.g. `&self.variant.Foo`)
|
||||
// to dereference the variant (e.g. `&self.Foo`)
|
||||
"&",
|
||||
// we need `mut self` for the argument because of ManuallyDrop
|
||||
"mut self",
|
||||
@ -248,9 +239,9 @@ impl {name} {{
|
||||
} else {
|
||||
(
|
||||
"payload",
|
||||
format!("self.variant.{tag_name}"),
|
||||
format!("self.{tag_name}"),
|
||||
// Since this is not a ManuallyDrop, our `as_` method will not
|
||||
// want to dereference the variant (e.g. `self.variant.Foo` with no '&')
|
||||
// want to dereference the variant (e.g. `self.Foo` with no '&')
|
||||
"",
|
||||
// we don't need `mut self` unless we need ManuallyDrop
|
||||
"self",
|
||||
@ -263,12 +254,13 @@ impl {name} {{
|
||||
r#"
|
||||
/// Construct a tag named {tag_name}, with the appropriate payload
|
||||
pub fn {tag_name}(payload: {payload_type_name}) -> Self {{
|
||||
Self {{
|
||||
tag: {discriminant_name}::{tag_name},
|
||||
variant: {variant_name} {{
|
||||
{tag_name}: {init_payload}
|
||||
}},
|
||||
}}
|
||||
let mut answer = Self {{
|
||||
{tag_name}: {init_payload}
|
||||
}};
|
||||
|
||||
answer.set_discriminant({discriminant_name}::{tag_name});
|
||||
|
||||
answer
|
||||
}}"#,
|
||||
)?;
|
||||
|
||||
@ -294,7 +286,7 @@ impl {name} {{
|
||||
/// Panics in debug builds if the .tag() doesn't return {tag_name}.
|
||||
pub unsafe fn as_{tag_name}(&self) -> {ref_if_needed}{payload_type_name} {{
|
||||
debug_assert_eq!(self.tag(), {discriminant_name}::{tag_name});
|
||||
{ref_if_needed}self.variant.{tag_name}
|
||||
{ref_if_needed}self.{tag_name}
|
||||
}}"#,
|
||||
)?;
|
||||
} else {
|
||||
@ -303,14 +295,12 @@ impl {name} {{
|
||||
// Don't use indoc because this must be indented once!
|
||||
r#"
|
||||
/// A tag named {tag_name}, which has no payload.
|
||||
pub const {tag_name}: Self = Self {{
|
||||
tag: {discriminant_name}::{tag_name},
|
||||
variant: unsafe {{
|
||||
core::mem::transmute::<
|
||||
core::mem::MaybeUninit<{variant_name}>,
|
||||
{variant_name},
|
||||
>(core::mem::MaybeUninit::uninit())
|
||||
}},
|
||||
pub const {tag_name}: Self = unsafe {{
|
||||
let mut bytes = [0; core::mem::size_of::<{name}>()];
|
||||
|
||||
bytes[{discriminant_offset}] = {discriminant_name}::{tag_name} as u8;
|
||||
|
||||
core::mem::transmute::<[u8; core::mem::size_of::<{name}>()], {name}>(bytes)
|
||||
}};"#,
|
||||
)?;
|
||||
|
||||
@ -358,9 +348,7 @@ impl Drop for {name} {{
|
||||
|tag_name, opt_payload_id| {
|
||||
match opt_payload_id {
|
||||
Some(payload_id) if types.get(payload_id).has_pointer(types) => {
|
||||
format!(
|
||||
"unsafe {{ core::mem::ManuallyDrop::drop(&mut self.variant.{tag_name}) }},",
|
||||
)
|
||||
format!("unsafe {{ core::mem::ManuallyDrop::drop(&mut self.{tag_name}) }},",)
|
||||
}
|
||||
_ => {
|
||||
// If it had no payload, or if the payload had no pointers,
|
||||
@ -385,7 +373,7 @@ impl Drop for {name} {{
|
||||
r#"
|
||||
impl PartialEq for {name} {{
|
||||
fn eq(&self, other: &Self) -> bool {{
|
||||
if self.tag != other.tag {{
|
||||
if self.tag() != other.tag() {{
|
||||
return false;
|
||||
}}
|
||||
|
||||
@ -399,7 +387,7 @@ impl PartialEq for {name} {{
|
||||
buf,
|
||||
|tag_name, opt_payload_id| {
|
||||
if opt_payload_id.is_some() {
|
||||
format!("self.variant.{tag_name} == other.variant.{tag_name},")
|
||||
format!("self.{tag_name} == other.{tag_name},")
|
||||
} else {
|
||||
// if the tags themselves had been unequal, we already would have
|
||||
// early-returned with false, so this means the tags were equal
|
||||
@ -428,7 +416,7 @@ impl PartialEq for {name} {{
|
||||
r#"
|
||||
impl PartialOrd for {name} {{
|
||||
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {{
|
||||
match self.tag.partial_cmp(&other.tag) {{
|
||||
match self.tag().partial_cmp(&other.tag()) {{
|
||||
Some(core::cmp::Ordering::Equal) => {{}}
|
||||
not_eq => return not_eq,
|
||||
}}
|
||||
@ -443,7 +431,7 @@ impl PartialOrd for {name} {{
|
||||
buf,
|
||||
|tag_name, opt_payload_id| {
|
||||
if opt_payload_id.is_some() {
|
||||
format!("self.variant.{tag_name}.partial_cmp(&other.variant.{tag_name}),",)
|
||||
format!("self.{tag_name}.partial_cmp(&other.{tag_name}),",)
|
||||
} else {
|
||||
// if the tags themselves had been unequal, we already would have
|
||||
// early-returned, so this means the tags were equal and there's
|
||||
@ -468,7 +456,7 @@ impl PartialOrd for {name} {{
|
||||
r#"
|
||||
impl Ord for {name} {{
|
||||
fn cmp(&self, other: &Self) -> core::cmp::Ordering {{
|
||||
match self.tag.cmp(&other.tag) {{
|
||||
match self.tag().cmp(&other.tag()) {{
|
||||
core::cmp::Ordering::Equal => {{}}
|
||||
not_eq => return not_eq,
|
||||
}}
|
||||
@ -483,7 +471,7 @@ impl Ord for {name} {{
|
||||
buf,
|
||||
|tag_name, opt_payload_id| {
|
||||
if opt_payload_id.is_some() {
|
||||
format!("self.variant.{tag_name}.cmp(&other.variant.{tag_name}),",)
|
||||
format!("self.{tag_name}.cmp(&other.{tag_name}),",)
|
||||
} else {
|
||||
// if the tags themselves had been unequal, we already would have
|
||||
// early-returned, so this means the tags were equal and there's
|
||||
@ -507,11 +495,12 @@ impl Ord for {name} {{
|
||||
buf,
|
||||
r#"
|
||||
impl Clone for {name} {{
|
||||
fn clone(&self) -> Self {{"#
|
||||
fn clone(&self) -> Self {{
|
||||
let mut answer = unsafe {{"#
|
||||
)?;
|
||||
|
||||
write_impl_tags(
|
||||
2,
|
||||
3,
|
||||
tags.iter(),
|
||||
&discriminant_name,
|
||||
buf,
|
||||
@ -519,25 +508,16 @@ impl Clone for {name} {{
|
||||
if opt_payload_id.is_some() {
|
||||
format!(
|
||||
r#"Self {{
|
||||
variant: {variant_name} {{
|
||||
{tag_name}: unsafe {{ self.variant.{tag_name}.clone() }},
|
||||
}},
|
||||
tag: {discriminant_name}::{tag_name},
|
||||
}},"#,
|
||||
{tag_name}: self.{tag_name}.clone(),
|
||||
}},"#,
|
||||
)
|
||||
} else {
|
||||
// when there's no payload, we set the clone's `variant` field to
|
||||
// garbage memory
|
||||
// when there's no payload, initialize to garbage memory.
|
||||
format!(
|
||||
r#"Self {{
|
||||
variant: unsafe {{
|
||||
core::mem::transmute::<
|
||||
core::mem::MaybeUninit<{variant_name}>,
|
||||
{variant_name},
|
||||
>(core::mem::MaybeUninit::uninit())
|
||||
}},
|
||||
tag: {discriminant_name}::{tag_name},
|
||||
}},"#,
|
||||
r#"core::mem::transmute::<
|
||||
core::mem::MaybeUninit<{name}>,
|
||||
{name},
|
||||
>(core::mem::MaybeUninit::uninit()),"#,
|
||||
)
|
||||
}
|
||||
},
|
||||
@ -545,7 +525,13 @@ impl Clone for {name} {{
|
||||
|
||||
writeln!(
|
||||
buf,
|
||||
r#" }}
|
||||
r#"
|
||||
}};
|
||||
|
||||
answer.set_discriminant(self.tag());
|
||||
|
||||
answer
|
||||
}}
|
||||
}}"#
|
||||
)?;
|
||||
}
|
||||
@ -575,7 +561,7 @@ impl core::hash::Hash for {name} {{
|
||||
format!(
|
||||
r#"unsafe {{
|
||||
{hash_tag};
|
||||
&self.variant.{tag_name}.hash(state);
|
||||
&self.{tag_name}.hash(state);
|
||||
}},"#
|
||||
)
|
||||
} else {
|
||||
@ -620,7 +606,7 @@ impl core::fmt::Debug for {name} {{
|
||||
};
|
||||
|
||||
format!(
|
||||
r#"f.debug_tuple("{tag_name}").field({deref_str}self.variant.{tag_name}).finish(),"#,
|
||||
r#"f.debug_tuple("{tag_name}").field({deref_str}self.{tag_name}).finish(),"#,
|
||||
)
|
||||
}
|
||||
None => format!(r#"f.write_str("{tag_name}"),"#),
|
||||
|
Loading…
Reference in New Issue
Block a user