Merge pull request #3148 from rtfeldman/bindgen-recursive-union

Bindgen normal recursive unions
This commit is contained in:
Richard Feldman 2022-05-29 12:19:11 -04:00 committed by GitHub
commit c8454d0f9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1304 additions and 798 deletions

View File

@ -1,9 +1,13 @@
use crate::structs::Structs;
use crate::types::{Field, RocTagUnion, TypeId, Types};
use crate::{enums::Enums, types::RocType};
use crate::types::{RocTagUnion, TypeId, Types};
use crate::{
enums::Enums,
types::{RocNum, RocType},
};
use bumpalo::Bump;
use roc_builtins::bitcode::{FloatWidth::*, IntWidth::*};
use roc_module::ident::{Lowercase, TagName};
use roc_collections::VecMap;
use roc_module::ident::TagName;
use roc_module::symbol::{Interns, Symbol};
use roc_mono::layout::{cmp_fields, ext_var_is_empty_tag_union, Builtin, Layout, LayoutCache};
use roc_types::subs::UnionTags;
@ -11,6 +15,7 @@ use roc_types::{
subs::{Content, FlatType, Subs, Variable},
types::RecordField,
};
use std::fmt::Display;
pub struct Env<'a> {
pub arena: &'a Bump,
@ -19,16 +24,55 @@ pub struct Env<'a> {
pub interns: &'a Interns,
pub struct_names: Structs,
pub enum_names: Enums,
pub pending_recursive_types: VecMap<TypeId, Variable>,
pub known_recursive_types: VecMap<Variable, TypeId>,
}
impl<'a> Env<'a> {
pub fn add_type(&mut self, var: Variable, types: &mut Types) -> TypeId {
pub fn vars_to_types<I>(&mut self, variables: I) -> Types
where
I: IntoIterator<Item = Variable>,
{
let mut types = Types::default();
for var in variables {
self.add_type(var, &mut types);
}
self.resolve_pending_recursive_types(&mut types);
types
}
fn add_type(&mut self, var: Variable, types: &mut Types) -> TypeId {
let layout = self
.layout_cache
.from_var(self.arena, var, self.subs)
.expect("Something weird ended up in the content");
add_type_help(self, layout, var, None, types, None)
add_type_help(self, layout, var, None, types)
}
fn resolve_pending_recursive_types(&mut self, types: &mut Types) {
// TODO if VecMap gets a drain() method, use that instead of doing take() and into_iter
let pending = core::mem::take(&mut self.pending_recursive_types);
for (type_id, var) in pending.into_iter() {
let actual_type_id = self.known_recursive_types.get(&var).unwrap_or_else(|| {
unreachable!(
"There was no known recursive TypeId for the pending recursive variable {:?}",
var
);
});
debug_assert!(
matches!(types.get(type_id), RocType::RecursivePointer(TypeId::PENDING)),
"The TypeId {:?} was registered as a pending recursive pointer, but was not stored in Types as one.",
type_id
);
types.replace(type_id, RocType::RecursivePointer(*actual_type_id));
}
}
}
@ -38,7 +82,6 @@ fn add_type_help<'a>(
var: Variable,
opt_name: Option<Symbol>,
types: &mut Types,
opt_recursion_id: Option<TypeId>,
) -> TypeId {
let subs = env.subs;
@ -56,7 +99,7 @@ fn add_type_help<'a>(
.flat_map(|(label, field)| {
match field {
RecordField::Required(field_var) | RecordField::Demanded(field_var) => {
Some((label.clone(), field_var))
Some((label.to_string(), field_var))
}
RecordField::Optional(_) => {
// drop optional fields
@ -70,7 +113,10 @@ fn add_type_help<'a>(
None => env.struct_names.get_name(var),
};
add_struct(env, name, it, types, opt_recursion_id)
add_struct(env, name, it, types, |name, fields| RocType::Struct {
name,
fields,
})
}
Content::Structure(FlatType::TagUnion(tags, ext_var)) => {
debug_assert!(ext_var_is_empty_tag_union(subs, *ext_var));
@ -125,14 +171,17 @@ fn add_type_help<'a>(
} else {
// If this was a non-builtin type alias, we can use that alias name
// in the generated bindings.
add_type_help(env, layout, *real_var, Some(*name), types, opt_recursion_id)
add_type_help(env, layout, *real_var, Some(*name), types)
}
}
Content::RangedNumber(_, _) => todo!(),
Content::Error => todo!(),
Content::RecursionVar { .. } => {
// We should always skip over RecursionVars before we get here.
unreachable!()
Content::RecursionVar { structure, .. } => {
let type_id = types.add(RocType::RecursivePointer(TypeId::PENDING));
env.pending_recursive_types.insert(type_id, *structure);
type_id
}
}
}
@ -146,29 +195,29 @@ fn add_builtin_type<'a>(
) -> TypeId {
match builtin {
Builtin::Int(width) => match width {
U8 => types.add(RocType::U8),
U16 => types.add(RocType::U16),
U32 => types.add(RocType::U32),
U64 => types.add(RocType::U64),
U128 => types.add(RocType::U128),
I8 => types.add(RocType::I8),
I16 => types.add(RocType::I16),
I32 => types.add(RocType::I32),
I64 => types.add(RocType::I64),
I128 => types.add(RocType::I128),
U8 => types.add(RocType::Num(RocNum::U8)),
U16 => types.add(RocType::Num(RocNum::U16)),
U32 => types.add(RocType::Num(RocNum::U32)),
U64 => types.add(RocType::Num(RocNum::U64)),
U128 => types.add(RocType::Num(RocNum::U128)),
I8 => types.add(RocType::Num(RocNum::I8)),
I16 => types.add(RocType::Num(RocNum::I16)),
I32 => types.add(RocType::Num(RocNum::I32)),
I64 => types.add(RocType::Num(RocNum::I64)),
I128 => types.add(RocType::Num(RocNum::I128)),
},
Builtin::Float(width) => match width {
F32 => types.add(RocType::F32),
F64 => types.add(RocType::F64),
F128 => types.add(RocType::F128),
F32 => types.add(RocType::Num(RocNum::F32)),
F64 => types.add(RocType::Num(RocNum::F64)),
F128 => types.add(RocType::Num(RocNum::F128)),
},
Builtin::Decimal => types.add(RocType::Num(RocNum::Dec)),
Builtin::Bool => types.add(RocType::Bool),
Builtin::Decimal => types.add(RocType::RocDec),
Builtin::Str => types.add(RocType::RocStr),
Builtin::Dict(key_layout, val_layout) => {
// TODO FIXME this `var` is wrong - should have a different `var` for key and for val
let key_id = add_type_help(env, *key_layout, var, opt_name, types, None);
let val_id = add_type_help(env, *val_layout, var, opt_name, types, None);
let key_id = add_type_help(env, *key_layout, var, opt_name, types);
let val_id = add_type_help(env, *val_layout, var, opt_name, types);
let dict_id = types.add(RocType::RocDict(key_id, val_id));
types.depends(dict_id, key_id);
@ -177,7 +226,7 @@ fn add_builtin_type<'a>(
dict_id
}
Builtin::Set(elem_layout) => {
let elem_id = add_type_help(env, *elem_layout, var, opt_name, types, None);
let elem_id = add_type_help(env, *elem_layout, var, opt_name, types);
let set_id = types.add(RocType::RocSet(elem_id));
types.depends(set_id, elem_id);
@ -185,7 +234,7 @@ fn add_builtin_type<'a>(
set_id
}
Builtin::List(elem_layout) => {
let elem_id = add_type_help(env, *elem_layout, var, opt_name, types, None);
let elem_id = add_type_help(env, *elem_layout, var, opt_name, types);
let list_id = types.add(RocType::RocList(elem_id));
types.depends(list_id, elem_id);
@ -195,41 +244,24 @@ fn add_builtin_type<'a>(
}
}
fn add_struct<I: IntoIterator<Item = (Lowercase, Variable)>>(
fn add_struct<I, L, F>(
env: &mut Env<'_>,
name: String,
fields: I,
types: &mut Types,
opt_recursion_id: Option<TypeId>,
) -> TypeId {
to_type: F,
) -> TypeId
where
I: IntoIterator<Item = (L, Variable)>,
L: Display + Ord,
F: FnOnce(String, Vec<(L, TypeId)>) -> RocType,
{
let subs = env.subs;
let fields_iter = &mut fields.into_iter();
let first_field = match fields_iter.next() {
Some(field) => field,
None => {
// This is an empty record; there's no more work to do!
return types.add(RocType::Struct {
name,
fields: Vec::new(),
});
}
};
let second_field = match fields_iter.next() {
Some(field) => field,
None => {
// This is a single-field record; put it in a transparent wrapper.
let content = env.add_type(first_field.1, types);
return types.add(RocType::TransparentWrapper { name, content });
}
};
let mut sortables =
bumpalo::collections::Vec::with_capacity_in(2 + fields_iter.size_hint().0, env.arena);
bumpalo::collections::Vec::with_capacity_in(fields_iter.size_hint().0, env.arena);
for (label, field_var) in std::iter::once(first_field)
.chain(std::iter::once(second_field))
.chain(fields_iter)
{
for (label, field_var) in fields_iter {
sortables.push((
label,
field_var,
@ -252,19 +284,13 @@ fn add_struct<I: IntoIterator<Item = (Lowercase, Variable)>>(
let fields = sortables
.into_iter()
.map(|(label, field_var, field_layout)| {
let content = subs.get_content_without_compacting(field_var);
let type_id = add_type_help(env, field_layout, field_var, None, types);
if matches!(content, Content::RecursionVar { .. }) {
Field::Recursive(label.to_string(), opt_recursion_id.unwrap())
} else {
let type_id = add_type_help(env, field_layout, field_var, None, types, None);
Field::NonRecursive(label.to_string(), type_id)
}
(label, type_id)
})
.collect();
.collect::<Vec<(L, TypeId)>>();
types.add(RocType::Struct { name, fields })
types.add(to_type(name, fields))
}
fn add_tag_union(
@ -287,200 +313,138 @@ fn add_tag_union(
})
.collect();
if tags.len() == 1 {
// This is a single-tag union.
let (tag_name, payload_vars) = tags.pop().unwrap();
let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap();
let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(),
None => env.enum_names.get_name(var),
};
// If there was a type alias name, use that. Otherwise use the tag name.
let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(),
None => tag_name,
};
// Sort tags alphabetically by tag name
tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2));
match payload_vars.len() {
0 => {
// This is a single-tag union with no payload, e.g. `[Foo]`
// so just generate an empty record
types.add(RocType::Struct {
name,
fields: Vec::new(),
})
let is_recursive = is_recursive_tag_union(&layout);
let mut tags: Vec<_> = tags
.into_iter()
.map(|(tag_name, payload_vars)| {
match struct_fields_needed(env, payload_vars.iter().copied()) {
0 => {
// no payload
(tag_name, None)
}
1 if !is_recursive => {
// this isn't recursive and there's 1 payload item, so it doesn't
// need its own struct - e.g. for `[Foo Str, Bar Str]` both of them
// can have payloads of plain old Str, no struct wrapper needed.
let payload_var = payload_vars.get(0).unwrap();
let layout = env
.layout_cache
.from_var(env.arena, *payload_var, env.subs)
.expect("Something weird ended up in the content");
let payload_id = add_type_help(env, layout, *payload_var, None, types);
(tag_name, Some(payload_id))
}
_ => {
// create a RocType for the payload and save it
let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant"
let fields = payload_vars.iter().copied().enumerate();
let struct_id = add_struct(env, struct_name, fields, types, |name, fields| {
RocType::TagUnionPayload { name, fields }
});
(tag_name, Some(struct_id))
}
}
1 => {
// This is a single-tag union with 1 payload field, e.g.`[Foo Str]`.
// We'll just wrap that.
let var = *payload_vars.get(0).unwrap();
let content = env.add_type(var, types);
})
.collect();
types.add(RocType::TransparentWrapper { name, content })
}
_ => {
// This is a single-tag union with multiple payload field, e.g.`[Foo Str U32]`.
// Generate a record.
let fields = payload_vars.iter().enumerate().map(|(index, payload_var)| {
let field_name = format!("f{}", index).into();
let typ = match layout {
Layout::Union(union_layout) => {
use roc_mono::layout::UnionLayout::*;
(field_name, *payload_var)
});
match union_layout {
// A non-recursive tag union
// e.g. `Result ok err : [Ok ok, Err err]`
NonRecursive(_) => RocType::TagUnion(RocTagUnion::NonRecursive { name, tags }),
// A recursive tag union (general case)
// e.g. `Expr : [Sym Str, Add Expr Expr]`
Recursive(_) => RocType::TagUnion(RocTagUnion::Recursive { name, tags }),
// A recursive tag union with just one constructor
// Optimization: No need to store a tag ID (the payload is "unwrapped")
// e.g. `RoseTree a : [Tree a (List (RoseTree a))]`
NonNullableUnwrapped(_) => {
todo!()
}
// A recursive tag union that has an empty variant
// Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison
// It has more than one other variant, so they need tag IDs (payloads are "wrapped")
// e.g. `FingerTree a : [Empty, Single a, More (Some a) (FingerTree (Tuple a)) (Some a)]`
// see also: https://youtu.be/ip92VMpf_-A?t=164
NullableWrapped { .. } => {
todo!()
}
// A recursive tag union with only two variants, where one is empty.
// Optimizations: Use null for the empty variant AND don't store a tag ID for the other variant.
// e.g. `ConsList a : [Nil, Cons a (ConsList a)]`
NullableUnwrapped {
nullable_id: null_represents_first_tag,
other_fields: _, // TODO use this!
} => {
// NullableUnwrapped tag unions should always have exactly 2 tags.
debug_assert_eq!(tags.len(), 2);
// Note that we assume no recursion variable here. If you write something like:
//
// Rec : [Blah Rec]
//
// ...then it's not even theoretically possible to instantiate one, so
// bindgen won't be able to help you do that!
add_struct(env, name, fields, types, None)
let null_tag;
let non_null;
if null_represents_first_tag {
// If nullable_id is true, then the null tag is second, which means
// pop() will return it because it's at the end of the vec.
null_tag = tags.pop().unwrap().0;
non_null = tags.pop().unwrap();
} else {
// The null tag is first, which means the tag with the payload is second.
non_null = tags.pop().unwrap();
null_tag = tags.pop().unwrap().0;
}
let (non_null_tag, non_null_payload) = non_null;
RocType::TagUnion(RocTagUnion::NullableUnwrapped {
name,
null_tag,
non_null_tag,
non_null_payload: non_null_payload.unwrap(),
null_represents_first_tag,
})
}
}
}
} else {
// This is a multi-tag union.
Layout::Builtin(Builtin::Int(_)) => RocType::TagUnion(RocTagUnion::Enumeration {
name,
tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(),
}),
Layout::Builtin(_)
| Layout::Struct { .. }
| Layout::Boxed(_)
| Layout::LambdaSet(_)
| Layout::RecursivePointer => {
// These must be single-tag unions. Bindgen ordinary nonrecursive
// tag unions for them, and let Rust do the unwrapping.
//
// This should be a very rare use case, and it's not worth overcomplicating
// the rest of bindgen to make it do something different.
RocType::TagUnion(RocTagUnion::NonRecursive { name, tags })
}
};
// This is a placeholder so that we can get a TypeId for future recursion IDs.
// At the end, we will replace this with the real tag union type.
let type_id = types.add(RocType::Struct {
name: "[THIS SHOULD BE REMOVED]".to_string(),
fields: Vec::new(),
});
let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap();
let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(),
None => env.enum_names.get_name(var),
};
let type_id = types.add(typ);
// Sort tags alphabetically by tag name
tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2));
let opt_recursion_id = if is_recursive_tag_union(&layout) {
Some(type_id)
} else {
None
};
let mut tags: Vec<_> = tags
.into_iter()
.map(|(tag_name, payload_vars)| {
match struct_fields_needed(env, payload_vars.iter().copied()) {
0 => {
// no payload
(tag_name, None)
}
1 if opt_recursion_id.is_none() => {
// this isn't recursive and there's 1 payload item, so it doesn't
// need its own struct - e.g. for `[Foo Str, Bar Str]` both of them
// can have payloads of plain old Str, no struct wrapper needed.
let payload_var = payload_vars.get(0).unwrap();
let layout = env
.layout_cache
.from_var(env.arena, *payload_var, env.subs)
.expect("Something weird ended up in the content");
let payload_id =
add_type_help(env, layout, *payload_var, None, types, opt_recursion_id);
(tag_name, Some(payload_id))
}
_ => {
// create a struct type for the payload and save it
let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant"
let fields = payload_vars.iter().enumerate().map(|(index, payload_var)| {
(format!("f{}", index).into(), *payload_var)
});
let struct_id =
add_struct(env, struct_name, fields, types, opt_recursion_id);
(tag_name, Some(struct_id))
}
}
})
.collect();
let typ = match layout {
Layout::Union(union_layout) => {
use roc_mono::layout::UnionLayout::*;
match union_layout {
// A non-recursive tag union
// e.g. `Result ok err : [Ok ok, Err err]`
NonRecursive(_) => RocType::TagUnion(RocTagUnion::NonRecursive { name, tags }),
// A recursive tag union (general case)
// e.g. `Expr : [Sym Str, Add Expr Expr]`
Recursive(_) => {
todo!()
}
// A recursive tag union with just one constructor
// Optimization: No need to store a tag ID (the payload is "unwrapped")
// e.g. `RoseTree a : [Tree a (List (RoseTree a))]`
NonNullableUnwrapped(_) => {
todo!()
}
// A recursive tag union that has an empty variant
// Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison
// It has more than one other variant, so they need tag IDs (payloads are "wrapped")
// e.g. `FingerTree a : [Empty, Single a, More (Some a) (FingerTree (Tuple a)) (Some a)]`
// see also: https://youtu.be/ip92VMpf_-A?t=164
NullableWrapped { .. } => {
todo!()
}
// A recursive tag union with only two variants, where one is empty.
// Optimizations: Use null for the empty variant AND don't store a tag ID for the other variant.
// e.g. `ConsList a : [Nil, Cons a (ConsList a)]`
NullableUnwrapped {
nullable_id: null_represents_first_tag,
other_fields: _, // TODO use this!
} => {
// NullableUnwrapped tag unions should always have exactly 2 tags.
debug_assert_eq!(tags.len(), 2);
let null_tag;
let non_null;
if null_represents_first_tag {
// If nullable_id is true, then the null tag is second, which means
// pop() will return it because it's at the end of the vec.
null_tag = tags.pop().unwrap().0;
non_null = tags.pop().unwrap();
} else {
// The null tag is first, which means the tag with the payload is second.
non_null = tags.pop().unwrap();
null_tag = tags.pop().unwrap().0;
}
let (non_null_tag, non_null_payload) = non_null;
RocType::TagUnion(RocTagUnion::NullableUnwrapped {
name,
null_tag,
non_null_tag,
non_null_payload: non_null_payload.unwrap(),
null_represents_first_tag,
})
}
}
}
Layout::Builtin(builtin) => match builtin {
Builtin::Int(_) => RocType::TagUnion(RocTagUnion::Enumeration {
name,
tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(),
}),
Builtin::Bool => RocType::Bool,
Builtin::Float(_)
| Builtin::Decimal
| Builtin::Str
| Builtin::Dict(_, _)
| Builtin::Set(_)
| Builtin::List(_) => unreachable!(),
},
Layout::Struct { .. }
| Layout::Boxed(_)
| Layout::LambdaSet(_)
| Layout::RecursivePointer => {
unreachable!()
}
};
types.replace(type_id, typ);
type_id
if is_recursive {
env.known_recursive_types.insert(var, type_id);
}
type_id
}
fn is_recursive_tag_union(layout: &Layout) -> bool {

File diff suppressed because it is too large Load Diff

View File

@ -19,8 +19,6 @@ pub fn load_types(
dir: &Path,
threading: Threading,
) -> Result<Vec<(Architecture, Types)>, io::Error> {
// TODO: generate both 32-bit and 64-bit #[cfg] macros if structs are different
// depending on 32-bit vs 64-bit targets.
let target_info = (&Triple::host()).into();
let arena = &Bump::new();
@ -61,53 +59,55 @@ pub fn load_types(
let mut answer = Vec::with_capacity(Architecture::iter().size_hint().0);
for architecture in Architecture::iter() {
let mut layout_cache = LayoutCache::new(architecture.into());
let mut env = Env {
arena,
layout_cache: &mut layout_cache,
interns: &interns,
struct_names: Default::default(),
enum_names: Default::default(),
subs,
};
let mut types = Types::default();
let defs_iter = decls.iter().flat_map(|decl| match decl {
Declaration::Declare(def) => {
vec![def.clone()]
}
Declaration::DeclareRec(defs, cycle_mark) => {
if cycle_mark.is_illegal(subs) {
Vec::new()
} else {
defs.clone()
}
}
Declaration::Builtin(..) => {
unreachable!("Builtin decl in userspace module?")
}
Declaration::InvalidCycle(..) => Vec::new(),
});
for decl in decls.iter() {
let defs = match decl {
Declaration::Declare(def) => {
vec![def.clone()]
}
Declaration::DeclareRec(defs, cycle_mark) => {
if cycle_mark.is_illegal(subs) {
Vec::new()
} else {
defs.clone()
}
}
Declaration::Builtin(..) => {
unreachable!("Builtin decl in userspace module?")
}
Declaration::InvalidCycle(..) => Vec::new(),
};
for Def {
loc_pattern,
pattern_vars,
..
} in defs.into_iter()
{
let vars_iter = defs_iter.filter_map(
|Def {
loc_pattern,
pattern_vars,
..
}| {
if let Pattern::Identifier(sym) = loc_pattern.value {
let var = pattern_vars
.get(&sym)
.expect("Indetifier known but it has no var?");
env.add_type(*var, &mut types);
Some(*var)
} else {
// figure out if we need to export non-identifier defs - when would that
// happen?
// figure out if we need to export non-identifier defs - when
// would that happen?
None
}
}
}
},
);
let mut layout_cache = LayoutCache::new(architecture.into());
let mut env = Env {
arena,
layout_cache: &mut layout_cache,
interns: &interns,
subs,
struct_names: Default::default(),
enum_names: Default::default(),
pending_recursive_types: Default::default(),
known_recursive_types: Default::default(),
};
let types = env.vars_to_types(vars_iter);
answer.push((architecture, types));
}

View File

@ -9,6 +9,13 @@ use std::convert::TryInto;
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct TypeId(usize);
impl TypeId {
/// Used when making recursive pointers, which need to temporarily
/// have *some* TypeId value until we later in the process determine
/// their real TypeId and can go back and fix them up.
pub(crate) const PENDING: Self = Self(usize::MAX);
}
#[derive(Default, Debug, Clone)]
pub struct Types {
by_id: Vec<RocType>,
@ -117,6 +124,28 @@ impl<'a> Iterator for TypesIter<'a> {
pub enum RocType {
RocStr,
Bool,
Num(RocNum),
RocList(TypeId),
RocDict(TypeId, TypeId),
RocSet(TypeId),
RocBox(TypeId),
TagUnion(RocTagUnion),
Struct {
name: String,
fields: Vec<(String, TypeId)>,
},
TagUnionPayload {
name: String,
fields: Vec<(usize, TypeId)>,
},
/// A recursive pointer, e.g. in StrConsList : [Nil, Cons Str StrConsList],
/// this would be the field of Cons containing the (recursive) StrConsList type,
/// and the TypeId is the TypeId of StrConsList itself.
RecursivePointer(TypeId),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum RocNum {
I8,
U8,
I16,
@ -130,44 +159,50 @@ pub enum RocType {
F32,
F64,
F128,
RocDec,
RocList(TypeId),
RocDict(TypeId, TypeId),
RocSet(TypeId),
RocBox(TypeId),
TagUnion(RocTagUnion),
Struct {
name: String,
fields: Vec<Field>,
},
/// Either a single-tag union or a single-field record
TransparentWrapper {
name: String,
content: TypeId,
},
Dec,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Field {
NonRecursive(String, TypeId),
/// A recursive field, e.g. in StrConsList : [Nil, Cons Str StrConsList],
/// this would be the field of Cons containing the (recursive) StrConsList type,
/// and the TypeId is the TypeId of StrConsList itself.
Recursive(String, TypeId),
}
impl RocNum {
fn size(&self) -> usize {
use core::mem::size_of;
use RocNum::*;
impl Field {
pub fn type_id(&self) -> TypeId {
match self {
Field::NonRecursive(_, type_id) => *type_id,
Field::Recursive(_, type_id) => *type_id,
I8 => size_of::<i8>(),
U8 => size_of::<u8>(),
I16 => size_of::<i16>(),
U16 => size_of::<u16>(),
I32 => size_of::<i32>(),
U32 => size_of::<u32>(),
I64 => size_of::<i64>(),
U64 => size_of::<u64>(),
I128 => size_of::<roc_std::I128>(),
U128 => size_of::<roc_std::U128>(),
F32 => size_of::<f32>(),
F64 => size_of::<f64>(),
F128 => todo!(),
Dec => size_of::<roc_std::RocDec>(),
}
}
pub fn label(&self) -> &str {
fn alignment(&self, target_info: TargetInfo) -> usize {
use RocNum::*;
match self {
Field::NonRecursive(label, _) => label,
Field::Recursive(label, _) => label,
I8 => IntWidth::I8.alignment_bytes(target_info) as usize,
U8 => IntWidth::U8.alignment_bytes(target_info) as usize,
I16 => IntWidth::I16.alignment_bytes(target_info) as usize,
U16 => IntWidth::U16.alignment_bytes(target_info) as usize,
I32 => IntWidth::I32.alignment_bytes(target_info) as usize,
U32 => IntWidth::U32.alignment_bytes(target_info) as usize,
I64 => IntWidth::I64.alignment_bytes(target_info) as usize,
U64 => IntWidth::U64.alignment_bytes(target_info) as usize,
I128 => IntWidth::I128.alignment_bytes(target_info) as usize,
U128 => IntWidth::U128.alignment_bytes(target_info) as usize,
F32 => FloatWidth::F32.alignment_bytes(target_info) as usize,
F64 => FloatWidth::F64.alignment_bytes(target_info) as usize,
F128 => FloatWidth::F128.alignment_bytes(target_info) as usize,
Dec => align_of::<RocDec>(),
}
}
}
@ -177,21 +212,8 @@ impl RocType {
pub fn has_pointer(&self, types: &Types) -> bool {
match self {
RocType::Bool
| RocType::I8
| RocType::U8
| RocType::I16
| RocType::U16
| RocType::I32
| RocType::U32
| RocType::I64
| RocType::U64
| RocType::I128
| RocType::U128
| RocType::F32
| RocType::F64
| RocType::F128
| RocType::TagUnion(RocTagUnion::Enumeration { .. })
| RocType::RocDec => false,
| RocType::Num(_)
| RocType::TagUnion(RocTagUnion::Enumeration { .. }) => false,
RocType::RocStr
| RocType::RocList(_)
| RocType::RocDict(_, _)
@ -200,84 +222,90 @@ impl RocType {
| RocType::TagUnion(RocTagUnion::NonNullableUnwrapped { .. })
| RocType::TagUnion(RocTagUnion::NullableUnwrapped { .. })
| RocType::TagUnion(RocTagUnion::NullableWrapped { .. })
| RocType::TagUnion(RocTagUnion::Recursive { .. }) => true,
| RocType::TagUnion(RocTagUnion::Recursive { .. })
| RocType::RecursivePointer { .. } => true,
RocType::TagUnion(RocTagUnion::NonRecursive { tags, .. }) => tags
.iter()
.any(|(_, payloads)| payloads.iter().any(|id| types.get(*id).has_pointer(types))),
RocType::Struct { fields, .. } => fields.iter().any(|field| match field {
Field::NonRecursive(_, type_id) => types.get(*type_id).has_pointer(types),
Field::Recursive(_, _) => true,
}),
RocType::TransparentWrapper { content, .. } => types.get(*content).has_pointer(types),
RocType::Struct { fields, .. } => fields
.iter()
.any(|(_, type_id)| types.get(*type_id).has_pointer(types)),
RocType::TagUnionPayload { fields, .. } => fields
.iter()
.any(|(_, type_id)| types.get(*type_id).has_pointer(types)),
}
}
/// Useful when determining whether to derive Eq, Ord, and Hash in a Rust type.
pub fn has_float(&self, types: &Types) -> bool {
self.has_float_help(types, &[])
}
fn has_float_help(&self, types: &Types, do_not_recurse: &[TypeId]) -> bool {
match self {
RocType::F32 | RocType::F64 | RocType::F128 => true,
RocType::Num(num) => {
use RocNum::*;
match num {
F32 | F64 | F128 => true,
I8 | U8 | I16 | U16 | I32 | U32 | I64 | U64 | I128 | U128 | Dec => false,
}
}
RocType::RocStr
| RocType::Bool
| RocType::I8
| RocType::U8
| RocType::I16
| RocType::U16
| RocType::I32
| RocType::U32
| RocType::I64
| RocType::U64
| RocType::I128
| RocType::U128
| RocType::RocDec
| RocType::TagUnion(RocTagUnion::Enumeration { .. }) => false,
RocType::RocList(id) | RocType::RocSet(id) | RocType::RocBox(id) => {
types.get(*id).has_float(types)
types.get(*id).has_float_help(types, do_not_recurse)
}
RocType::RocDict(key_id, val_id) => {
types.get(*key_id).has_float(types) || types.get(*val_id).has_float(types)
types.get(*key_id).has_float_help(types, do_not_recurse)
|| types.get(*val_id).has_float_help(types, do_not_recurse)
}
RocType::Struct { fields, .. } => fields.iter().any(|field| match field {
Field::NonRecursive(_, type_id) => types.get(*type_id).has_float(types),
// This has a float iff there's a float somewhere else.
// We don't want to recurse here, because that would recurse forever!
Field::Recursive(_, _) => false,
}),
RocType::Struct { fields, .. } => fields
.iter()
.any(|(_, type_id)| types.get(*type_id).has_float_help(types, do_not_recurse)),
RocType::TagUnionPayload { fields, .. } => fields
.iter()
.any(|(_, type_id)| types.get(*type_id).has_float_help(types, do_not_recurse)),
RocType::TagUnion(RocTagUnion::Recursive { tags, .. })
| RocType::TagUnion(RocTagUnion::NonRecursive { tags, .. }) => tags
.iter()
.any(|(_, payloads)| payloads.iter().any(|id| types.get(*id).has_float(types))),
RocType::TagUnion(RocTagUnion::NullableWrapped { non_null_tags, .. }) => non_null_tags
.iter()
.any(|(_, _, payloads)| payloads.iter().any(|id| types.get(*id).has_float(types))),
| RocType::TagUnion(RocTagUnion::NonRecursive { tags, .. }) => {
tags.iter().any(|(_, payloads)| {
payloads
.iter()
.any(|id| types.get(*id).has_float_help(types, do_not_recurse))
})
}
RocType::TagUnion(RocTagUnion::NullableWrapped { non_null_tags, .. }) => {
non_null_tags.iter().any(|(_, _, payloads)| {
payloads
.iter()
.any(|id| types.get(*id).has_float_help(types, do_not_recurse))
})
}
RocType::TagUnion(RocTagUnion::NullableUnwrapped {
non_null_payload: content,
..
})
| RocType::TagUnion(RocTagUnion::NonNullableUnwrapped { content, .. })
| RocType::TransparentWrapper { content, .. } => types.get(*content).has_float(types),
| RocType::RecursivePointer(content) => {
if do_not_recurse.contains(content) {
false
} else {
let mut do_not_recurse: Vec<TypeId> = do_not_recurse.into();
do_not_recurse.push(*content);
types.get(*content).has_float_help(types, &do_not_recurse)
}
}
}
}
/// Useful when determining whether to derive Default in a Rust type.
pub fn has_enumeration(&self, types: &Types) -> bool {
match self {
RocType::TagUnion { .. } => true,
RocType::RocStr
| RocType::Bool
| RocType::I8
| RocType::U8
| RocType::I16
| RocType::U16
| RocType::I32
| RocType::U32
| RocType::I64
| RocType::U64
| RocType::I128
| RocType::U128
| RocType::F32
| RocType::F64
| RocType::F128
| RocType::RocDec => false,
RocType::TagUnion { .. } | RocType::RecursivePointer { .. } => true,
RocType::RocStr | RocType::Bool | RocType::Num(_) => false,
RocType::RocList(id) | RocType::RocSet(id) | RocType::RocBox(id) => {
types.get(*id).has_enumeration(types)
}
@ -285,14 +313,12 @@ impl RocType {
types.get(*key_id).has_enumeration(types)
|| types.get(*val_id).has_enumeration(types)
}
RocType::Struct { fields, .. } => fields.iter().any(|field| match field {
Field::NonRecursive(_, type_id) => types.get(*type_id).has_enumeration(types),
// If this struct has a recursive field, that means we're inside an enumeration!
Field::Recursive(_, _) => true,
}),
RocType::TransparentWrapper { content, .. } => {
types.get(*content).has_enumeration(types)
}
RocType::Struct { fields, .. } => fields
.iter()
.any(|(_, type_id)| types.get(*type_id).has_enumeration(types)),
RocType::TagUnionPayload { fields, .. } => fields
.iter()
.any(|(_, type_id)| types.get(*type_id).has_enumeration(types)),
}
}
@ -301,27 +327,14 @@ impl RocType {
match self {
RocType::Bool => size_of::<bool>(),
RocType::I8 => size_of::<i8>(),
RocType::U8 => size_of::<u8>(),
RocType::I16 => size_of::<i16>(),
RocType::U16 => size_of::<u16>(),
RocType::I32 => size_of::<i32>(),
RocType::U32 => size_of::<u32>(),
RocType::I64 => size_of::<i64>(),
RocType::U64 => size_of::<u64>(),
RocType::I128 => size_of::<roc_std::I128>(),
RocType::U128 => size_of::<roc_std::U128>(),
RocType::F32 => size_of::<f32>(),
RocType::F64 => size_of::<f64>(),
RocType::F128 => todo!(),
RocType::RocDec => size_of::<roc_std::RocDec>(),
RocType::Num(num) => num.size(),
RocType::RocStr | RocType::RocList(_) | RocType::RocDict(_, _) | RocType::RocSet(_) => {
3 * target_info.ptr_size()
}
RocType::RocBox(_) => target_info.ptr_size(),
RocType::TagUnion(tag_union) => match tag_union {
RocTagUnion::Enumeration { tags, .. } => size_for_tag_count(tags.len()),
RocTagUnion::NonRecursive { tags, .. } => {
RocTagUnion::NonRecursive { tags, .. } | RocTagUnion::Recursive { tags, .. } => {
// The "unpadded" size (without taking alignment into account)
// is the sum of all the sizes of the fields.
let size_unpadded = tags.iter().fold(0, |total, (_, opt_payload_id)| {
@ -355,37 +368,23 @@ impl RocType {
size_padded
}
}
RocTagUnion::Recursive { .. } => todo!(),
RocTagUnion::NonNullableUnwrapped { .. } => todo!(),
RocTagUnion::NullableWrapped { .. } => todo!(),
RocTagUnion::NullableUnwrapped { .. } => todo!(),
},
RocType::Struct { fields, .. } => {
// The "unpadded" size (without taking alignment into account)
// is the sum of all the sizes of the fields.
let size_unpadded = fields.iter().fold(0, |total, field| match field {
Field::NonRecursive(_, field_id) => {
total + types.get(*field_id).size(types, target_info)
}
Field::Recursive(_, _) => {
// The recursion var is a pointer.
total + target_info.ptr_size()
}
});
// Round up to the next multiple of alignment, to incorporate
// any necessary alignment padding.
//
// e.g. if we have a record with a Str and a U8, that would be a
// size_unpadded of 25, because Str is three 8-byte pointers and U8 is 1 byte,
// but the 8-byte alignment of the pointers means we'll round 25 up to 32.
let align = self.alignment(types, target_info);
(size_unpadded / align) * align
}
RocType::TransparentWrapper { content, .. } => {
types.get(*content).size(types, target_info)
}
RocType::Struct { fields, .. } => struct_size(
fields.iter().map(|(_, type_id)| *type_id),
types,
target_info,
self.alignment(types, target_info),
),
RocType::TagUnionPayload { fields, .. } => struct_size(
fields.iter().map(|(_, type_id)| *type_id),
types,
target_info,
self.alignment(types, target_info),
),
RocType::RecursivePointer { .. } => target_info.ptr_size(),
}
}
@ -396,7 +395,7 @@ impl RocType {
| RocType::RocDict(_, _)
| RocType::RocSet(_)
| RocType::RocBox(_) => target_info.ptr_alignment_bytes(),
RocType::RocDec => align_of::<RocDec>(),
RocType::Num(num) => num.alignment(target_info),
RocType::Bool => align_of::<bool>(),
RocType::TagUnion(RocTagUnion::NonRecursive { tags, .. }) => {
// The smallest alignment this could possibly have is based on the number of tags - e.g.
@ -444,30 +443,15 @@ impl RocType {
align
}
RocType::Struct { fields, .. } => fields.iter().fold(0, |align, field| match field {
Field::NonRecursive(_, field_id) => {
align.max(types.get(*field_id).alignment(types, target_info))
}
Field::Recursive(_, _) => {
// The recursion var is a pointer.
align.max(target_info.ptr_alignment_bytes())
}
RocType::Struct { fields, .. } => fields.iter().fold(0, |align, (_, field_id)| {
align.max(types.get(*field_id).alignment(types, target_info))
}),
RocType::I8 => IntWidth::I8.alignment_bytes(target_info) as usize,
RocType::U8 => IntWidth::U8.alignment_bytes(target_info) as usize,
RocType::I16 => IntWidth::I16.alignment_bytes(target_info) as usize,
RocType::U16 => IntWidth::U16.alignment_bytes(target_info) as usize,
RocType::I32 => IntWidth::I32.alignment_bytes(target_info) as usize,
RocType::U32 => IntWidth::U32.alignment_bytes(target_info) as usize,
RocType::I64 => IntWidth::I64.alignment_bytes(target_info) as usize,
RocType::U64 => IntWidth::U64.alignment_bytes(target_info) as usize,
RocType::I128 => IntWidth::I128.alignment_bytes(target_info) as usize,
RocType::U128 => IntWidth::U128.alignment_bytes(target_info) as usize,
RocType::F32 => FloatWidth::F32.alignment_bytes(target_info) as usize,
RocType::F64 => FloatWidth::F64.alignment_bytes(target_info) as usize,
RocType::F128 => FloatWidth::F128.alignment_bytes(target_info) as usize,
RocType::TransparentWrapper { content, .. }
| RocType::TagUnion(RocTagUnion::NullableUnwrapped {
RocType::TagUnionPayload { fields, .. } => {
fields.iter().fold(0, |align, (_, field_id)| {
align.max(types.get(*field_id).alignment(types, target_info))
})
}
RocType::TagUnion(RocTagUnion::NullableUnwrapped {
non_null_payload: content,
..
})
@ -480,10 +464,32 @@ impl RocType {
.try_into()
.unwrap()
}
RocType::RecursivePointer { .. } => target_info.ptr_alignment_bytes(),
}
}
}
fn struct_size(
fields: impl Iterator<Item = TypeId>,
types: &Types,
target_info: TargetInfo,
align: usize,
) -> usize {
// The "unpadded" size (without taking alignment into account)
// is the sum of all the sizes of the fields.
let size_unpadded = fields.fold(0, |total, field_id| {
total + types.get(field_id).size(types, target_info)
});
// Round up to the next multiple of alignment, to incorporate
// any necessary alignment padding.
//
// e.g. if we have a record with a Str and a U8, that would be a
// size_unpadded of 25, because Str is three 8-byte pointers and U8 is 1 byte,
// but the 8-byte alignment of the pointers means we'll round 25 up to 32.
(size_unpadded / align) * align
}
fn size_for_tag_count(num_tags: usize) -> usize {
if num_tags == 0 {
// empty tag union
@ -603,14 +609,8 @@ impl RocTagUnion {
// into account, since the discriminant is
// stored after those bytes.
RocType::Struct { fields, .. } => {
fields.iter().fold(0, |total, field| match field {
Field::NonRecursive(_, field_id) => {
total + types.get(*field_id).size(types, target_info)
}
Field::Recursive(_, _) => {
// The recursion var is a pointer.
total + target_info.ptr_size()
}
fields.iter().fold(0, |total, (_, field_id)| {
total + types.get(*field_id).size(types, target_info)
})
}
typ => max_size.max(typ.size(types, target_info)),

View File

@ -1,6 +1,6 @@
mod bindings;
use bindings::{StrConsList, StrConsList_Cons};
use bindings::StrConsList;
use indoc::indoc;
extern "C" {
@ -38,10 +38,7 @@ pub extern "C" fn rust_main() -> i32 {
"#
),
tag_union,
StrConsList::Cons(StrConsList_Cons {
f0: "small str".into(),
f1: StrConsList::Nil
}),
StrConsList::Cons("small str".into(), StrConsList::Nil),
StrConsList::Nil,
); // Debug

View File

@ -0,0 +1,11 @@
platform "test-platform"
requires {} { main : _ }
exposes []
packages {}
imports []
provides [mainForHost]
Expr : [String Str, Concat Expr Expr]
mainForHost : Expr
mainForHost = main

View File

@ -0,0 +1,6 @@
app "app"
packages { pf: "." }
imports []
provides [main] to pf
main = Concat (String "Hello, ") (String "World!")

View File

@ -0,0 +1,106 @@
mod bindings;
use bindings::Expr;
use indoc::indoc;
extern "C" {
#[link_name = "roc__mainForHost_1_exposed_generic"]
fn roc_main(_: *mut Expr);
}
#[no_mangle]
pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<Expr> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
// Verify that it has all the expected traits.
assert!(tag_union == tag_union); // PartialEq
assert!(tag_union.clone() == tag_union.clone()); // Clone
assert!(tag_union.partial_cmp(&tag_union) == Some(Ordering::Equal)); // PartialOrd
assert!(tag_union.cmp(&tag_union) == Ordering::Equal); // Ord
print!(
indoc!(
r#"
tag_union was: {:?}
`Concat (String "Hello, ") (String "World!")` is: {:?}
`String "this is a test"` is: {:?}
"#
),
tag_union,
Expr::Concat(
Expr::String("Hello, ".into()),
Expr::String("World!".into()),
),
Expr::String("this is a test".into()),
); // Debug
let mut set = HashSet::new();
set.insert(tag_union.clone()); // Eq, Hash
set.insert(tag_union);
assert_eq!(set.len(), 1);
// Exit code
0
}
// Externs required by roc_std and by the Roc app
use core::ffi::c_void;
use std::ffi::CStr;
use std::os::raw::c_char;
#[no_mangle]
pub unsafe extern "C" fn roc_alloc(size: usize, _alignment: u32) -> *mut c_void {
return libc::malloc(size);
}
#[no_mangle]
pub unsafe extern "C" fn roc_realloc(
c_ptr: *mut c_void,
new_size: usize,
_old_size: usize,
_alignment: u32,
) -> *mut c_void {
return libc::realloc(c_ptr, new_size);
}
#[no_mangle]
pub unsafe extern "C" fn roc_dealloc(c_ptr: *mut c_void, _alignment: u32) {
return libc::free(c_ptr);
}
#[no_mangle]
pub unsafe extern "C" fn roc_panic(c_ptr: *mut c_void, tag_id: u32) {
match tag_id {
0 => {
let slice = CStr::from_ptr(c_ptr as *const c_char);
let string = slice.to_str().unwrap();
eprintln!("Roc hit a panic: {}", string);
std::process::exit(1);
}
_ => todo!(),
}
}
#[no_mangle]
pub unsafe extern "C" fn roc_memcpy(dst: *mut c_void, src: *mut c_void, n: usize) -> *mut c_void {
libc::memcpy(dst, src, n)
}
#[no_mangle]
pub unsafe extern "C" fn roc_memset(dst: *mut c_void, c: i32, n: usize) -> *mut c_void {
libc::memset(dst, c, n)
}

View File

@ -233,80 +233,4 @@ mod test_gen_rs {
)
);
}
#[test]
fn single_tag_union_with_payloads() {
let module = indoc!(
r#"
UserId : [Id U32 Str]
main : UserId
main = Id 42 "blah"
"#
);
assert_eq!(
generate_bindings(module)
.strip_prefix('\n')
.unwrap_or_default(),
indoc!(
r#"
#[cfg(any(
target_arch = "x86_64",
target_arch = "aarch64"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
pub struct UserId {
pub f1: roc_std::RocStr,
pub f0: u32,
}
#[cfg(any(
target_arch = "x86",
target_arch = "arm",
target_arch = "wasm32"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
pub struct UserId {
pub f0: u32,
pub f1: roc_std::RocStr,
}
"#
)
);
}
#[test]
fn single_tag_union_with_one_payload_field() {
let module = indoc!(
r#"
UserId : [Id Str]
main : UserId
main = Id "blah"
"#
);
assert_eq!(
generate_bindings(module)
.strip_prefix('\n')
.unwrap_or_default(),
indoc!(
r#"
#[cfg(any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "arm",
target_arch = "wasm32"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(transparent)]
pub struct UserId(roc_std::RocStr);
"#
)
);
}
}

View File

@ -117,10 +117,15 @@ mod bindgen_cli_run {
`Blah 456` is: NonRecursive::Blah(456)
"#),
nullable_unwrapped:"nullable-unwrapped" => indoc!(r#"
tag_union was: StrConsList::Cons(StrConsList_Cons { f0: "World!", f1: StrConsList::Cons(StrConsList_Cons { f0: "Hello ", f1: StrConsList::Nil }) })
`Cons "small str" Nil` is: StrConsList::Cons(StrConsList_Cons { f0: "small str", f1: StrConsList::Nil })
tag_union was: StrConsList::Cons("World!", StrConsList::Cons("Hello ", StrConsList::Nil))
`Cons "small str" Nil` is: StrConsList::Cons("small str", StrConsList::Nil)
`Nil` is: StrConsList::Nil
"#),
recursive_union:"recursive-union" => indoc!(r#"
tag_union was: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
`Concat (String "Hello, ") (String "World!")` is: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
`String "this is a test"` is: Expr::String("this is a test")
"#),
}
fn check_for_tests(all_fixtures: &mut roc_collections::VecSet<String>) {

View File

@ -2971,10 +2971,10 @@ mod test {
/// This is called by both code gen and bindgen, so that
/// their field orderings agree.
#[inline(always)]
pub fn cmp_fields(
label1: &Lowercase,
pub fn cmp_fields<L: Ord>(
label1: &L,
layout1: &Layout<'_>,
label2: &Lowercase,
label2: &L,
layout2: &Layout<'_>,
target_info: TargetInfo,
) -> Ordering {