Overhaul nonrecursive tag union bindgen

This commit is contained in:
Richard Feldman 2022-05-20 20:01:49 -04:00
parent c685acd3cd
commit 3ba05bdef2
No known key found for this signature in database
GPG Key ID: 7E4127D1E4241798

View File

@ -1,4 +1,5 @@
use roc_mono::layout::UnionLayout;
use roc_target::TargetInfo;
use crate::types::{RocTagUnion, RocType, TypeId, Types};
use std::{
@ -140,15 +141,6 @@ fn write_tag_union(
let discriminant_name = write_discriminant(name, tag_names, types, buf)?;
let typ = types.get(type_id);
// The tag union's variant union, e.g.
//
// #[repr(C)]
// union union_MyTagUnion {
// Bar: u128,
// Foo: core::mem::ManuallyDrop<roc_std::RocStr>,
// }
let variant_name = format!("union_{name}");
{
// No deriving for unions; we have to add the impls ourselves!
@ -156,7 +148,7 @@ fn write_tag_union(
buf,
r#"
#[repr(C)]
pub union {variant_name} {{"#
pub union {name} {{"#
)?;
for (tag_name, opt_payload_id) in tags {
@ -184,36 +176,35 @@ pub union {variant_name} {{"#
buf.write_str("}\n")?;
}
// The tag union struct itself, e.g.
//
// #[repr(C)]
// pub struct MyTagUnion {
// variant: variant_MyTagUnion,
// tag: tag_MyTagUnion,
// }
{
// no deriving because it contains a union; we have to
// generate the impls explicitly!
writeln!(
buf,
r#"
#[repr(C)]
pub struct {name} {{
variant: {variant_name},
tag: {discriminant_name},
}}"#,
)?;
}
// The impl for the tag union
{
// TODO also do this for other targets. Remember, these can change based on more
// than just pointer width; e.g. on wasm, the alignments of U16 and U8 are both 4!
let discriminant_offset = RocTagUnion::discriminant_offset(
tags,
types,
TargetInfo::from(&target_lexicon::Triple::host()),
);
write!(
buf,
r#"
impl {name} {{
pub fn tag(&self) -> {discriminant_name} {{
self.tag
unsafe {{
let bytes = core::mem::transmute::<&Self, &[u8; core::mem::size_of::<Self>()]>(self);
core::mem::transmute::<u8, {discriminant_name}>(*bytes.as_ptr().add({discriminant_offset}))
}}
}}
/// Internal helper
fn set_discriminant(&mut self, tag: {discriminant_name}) {{
let discriminant_ptr: *mut {discriminant_name} = (self as *mut {name}).cast();
unsafe {{
*(discriminant_ptr.add({discriminant_offset})) = tag;
}}
}}
"#
)?;
@ -238,9 +229,9 @@ impl {name} {{
if payload_type.has_pointer(types) {
(
"core::mem::ManuallyDrop::new(payload)",
format!("core::mem::ManuallyDrop::take(&mut self.variant.{tag_name})",),
format!("core::mem::ManuallyDrop::take(&mut self.{tag_name})",),
// Since this is a ManuallyDrop, our `as_` method will need
// to dereference the variant (e.g. `&self.variant.Foo`)
// to dereference the variant (e.g. `&self.Foo`)
"&",
// we need `mut self` for the argument because of ManuallyDrop
"mut self",
@ -248,9 +239,9 @@ impl {name} {{
} else {
(
"payload",
format!("self.variant.{tag_name}"),
format!("self.{tag_name}"),
// Since this is not a ManuallyDrop, our `as_` method will not
// want to dereference the variant (e.g. `self.variant.Foo` with no '&')
// want to dereference the variant (e.g. `self.Foo` with no '&')
"",
// we don't need `mut self` unless we need ManuallyDrop
"self",
@ -263,12 +254,13 @@ impl {name} {{
r#"
/// Construct a tag named {tag_name}, with the appropriate payload
pub fn {tag_name}(payload: {payload_type_name}) -> Self {{
Self {{
tag: {discriminant_name}::{tag_name},
variant: {variant_name} {{
{tag_name}: {init_payload}
}},
}}
let mut answer = Self {{
{tag_name}: {init_payload}
}};
answer.set_discriminant({discriminant_name}::{tag_name});
answer
}}"#,
)?;
@ -294,7 +286,7 @@ impl {name} {{
/// Panics in debug builds if the .tag() doesn't return {tag_name}.
pub unsafe fn as_{tag_name}(&self) -> {ref_if_needed}{payload_type_name} {{
debug_assert_eq!(self.tag(), {discriminant_name}::{tag_name});
{ref_if_needed}self.variant.{tag_name}
{ref_if_needed}self.{tag_name}
}}"#,
)?;
} else {
@ -303,14 +295,12 @@ impl {name} {{
// Don't use indoc because this must be indented once!
r#"
/// A tag named {tag_name}, which has no payload.
pub const {tag_name}: Self = Self {{
tag: {discriminant_name}::{tag_name},
variant: unsafe {{
core::mem::transmute::<
core::mem::MaybeUninit<{variant_name}>,
{variant_name},
>(core::mem::MaybeUninit::uninit())
}},
pub const {tag_name}: Self = unsafe {{
let mut bytes = [0; core::mem::size_of::<{name}>()];
bytes[{discriminant_offset}] = {discriminant_name}::{tag_name} as u8;
core::mem::transmute::<[u8; core::mem::size_of::<{name}>()], {name}>(bytes)
}};"#,
)?;
@ -358,9 +348,7 @@ impl Drop for {name} {{
|tag_name, opt_payload_id| {
match opt_payload_id {
Some(payload_id) if types.get(payload_id).has_pointer(types) => {
format!(
"unsafe {{ core::mem::ManuallyDrop::drop(&mut self.variant.{tag_name}) }},",
)
format!("unsafe {{ core::mem::ManuallyDrop::drop(&mut self.{tag_name}) }},",)
}
_ => {
// If it had no payload, or if the payload had no pointers,
@ -385,7 +373,7 @@ impl Drop for {name} {{
r#"
impl PartialEq for {name} {{
fn eq(&self, other: &Self) -> bool {{
if self.tag != other.tag {{
if self.tag() != other.tag() {{
return false;
}}
@ -399,7 +387,7 @@ impl PartialEq for {name} {{
buf,
|tag_name, opt_payload_id| {
if opt_payload_id.is_some() {
format!("self.variant.{tag_name} == other.variant.{tag_name},")
format!("self.{tag_name} == other.{tag_name},")
} else {
// if the tags themselves had been unequal, we already would have
// early-returned with false, so this means the tags were equal
@ -428,7 +416,7 @@ impl PartialEq for {name} {{
r#"
impl PartialOrd for {name} {{
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {{
match self.tag.partial_cmp(&other.tag) {{
match self.tag().partial_cmp(&other.tag()) {{
Some(core::cmp::Ordering::Equal) => {{}}
not_eq => return not_eq,
}}
@ -443,7 +431,7 @@ impl PartialOrd for {name} {{
buf,
|tag_name, opt_payload_id| {
if opt_payload_id.is_some() {
format!("self.variant.{tag_name}.partial_cmp(&other.variant.{tag_name}),",)
format!("self.{tag_name}.partial_cmp(&other.{tag_name}),",)
} else {
// if the tags themselves had been unequal, we already would have
// early-returned, so this means the tags were equal and there's
@ -468,7 +456,7 @@ impl PartialOrd for {name} {{
r#"
impl Ord for {name} {{
fn cmp(&self, other: &Self) -> core::cmp::Ordering {{
match self.tag.cmp(&other.tag) {{
match self.tag().cmp(&other.tag()) {{
core::cmp::Ordering::Equal => {{}}
not_eq => return not_eq,
}}
@ -483,7 +471,7 @@ impl Ord for {name} {{
buf,
|tag_name, opt_payload_id| {
if opt_payload_id.is_some() {
format!("self.variant.{tag_name}.cmp(&other.variant.{tag_name}),",)
format!("self.{tag_name}.cmp(&other.{tag_name}),",)
} else {
// if the tags themselves had been unequal, we already would have
// early-returned, so this means the tags were equal and there's
@ -507,11 +495,12 @@ impl Ord for {name} {{
buf,
r#"
impl Clone for {name} {{
fn clone(&self) -> Self {{"#
fn clone(&self) -> Self {{
let mut answer = unsafe {{"#
)?;
write_impl_tags(
2,
3,
tags.iter(),
&discriminant_name,
buf,
@ -519,25 +508,16 @@ impl Clone for {name} {{
if opt_payload_id.is_some() {
format!(
r#"Self {{
variant: {variant_name} {{
{tag_name}: unsafe {{ self.variant.{tag_name}.clone() }},
}},
tag: {discriminant_name}::{tag_name},
}},"#,
{tag_name}: self.{tag_name}.clone(),
}},"#,
)
} else {
// when there's no payload, we set the clone's `variant` field to
// garbage memory
// when there's no payload, initialize to garbage memory.
format!(
r#"Self {{
variant: unsafe {{
core::mem::transmute::<
core::mem::MaybeUninit<{variant_name}>,
{variant_name},
>(core::mem::MaybeUninit::uninit())
}},
tag: {discriminant_name}::{tag_name},
}},"#,
r#"core::mem::transmute::<
core::mem::MaybeUninit<{name}>,
{name},
>(core::mem::MaybeUninit::uninit()),"#,
)
}
},
@ -545,7 +525,13 @@ impl Clone for {name} {{
writeln!(
buf,
r#" }}
r#"
}};
answer.set_discriminant(self.tag());
answer
}}
}}"#
)?;
}
@ -575,7 +561,7 @@ impl core::hash::Hash for {name} {{
format!(
r#"unsafe {{
{hash_tag};
&self.variant.{tag_name}.hash(state);
&self.{tag_name}.hash(state);
}},"#
)
} else {
@ -620,7 +606,7 @@ impl core::fmt::Debug for {name} {{
};
format!(
r#"f.debug_tuple("{tag_name}").field({deref_str}self.variant.{tag_name}).finish(),"#,
r#"f.debug_tuple("{tag_name}").field({deref_str}self.{tag_name}).finish(),"#,
)
}
None => format!(r#"f.write_str("{tag_name}"),"#),