diff --git a/Cargo.lock b/Cargo.lock index 8c99dbd32..ec6d475fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2021,10 +2021,7 @@ dependencies = [ "bstr 0.2.17", "log", "mlua", - "serde", - "serde_json", "strsim 0.10.0", - "thiserror", "wezterm-dynamic", ] @@ -3057,7 +3054,7 @@ dependencies = [ "log", "luahelper", "ntapi", - "serde", + "wezterm-dynamic", "winapi 0.3.9", ] diff --git a/config/src/lua.rs b/config/src/lua.rs index e19613c53..3efb89dfd 100644 --- a/config/src/lua.rs +++ b/config/src/lua.rs @@ -133,7 +133,7 @@ pub fn make_lua_context(config_file: &Path) -> anyhow::Result { } }, item @ _ => { - let item = format!("{:?}", ValueWrapper(item)); + let item = format!("{:?}", item); output.push_str(&item); } } diff --git a/luahelper/Cargo.toml b/luahelper/Cargo.toml index 86bc4664e..46e966131 100644 --- a/luahelper/Cargo.toml +++ b/luahelper/Cargo.toml @@ -10,8 +10,5 @@ edition = "2018" bstr = "0.2" log = "0.4" mlua = "0.7" -serde = {version="1.0", features = ["rc", "derive"]} -serde_json = "1.0" strsim = "0.10" -thiserror = "1.0" wezterm-dynamic = { path = "../wezterm-dynamic" } diff --git a/luahelper/src/lib.rs b/luahelper/src/lib.rs index 90299d0c8..ba3faf621 100644 --- a/luahelper/src/lib.rs +++ b/luahelper/src/lib.rs @@ -1,43 +1,16 @@ #![macro_use] +pub use mlua; +use mlua::{ToLua, Value as LuaValue}; use std::collections::BTreeMap; use wezterm_dynamic::{FromDynamic, ToDynamic, Value as DynValue}; -mod serde_lua; -pub use mlua; -use mlua::{ToLua, Value as LuaValue}; -pub use serde_lua::from_lua_value; -pub use serde_lua::ser::to_lua_value; - /// Implement lua conversion traits for a type. /// This implementation requires that the type implement -/// serde Serialize and Deserialize. +/// FromDynamic and ToDynamic. /// Why do we need these traits? They allow `create_function` to /// operate in terms of our internal types rather than forcing /// the implementer to use generic Value parameter or return values. -#[macro_export] -macro_rules! impl_lua_conversion { - ($struct:ident) => { - impl<'lua> $crate::mlua::ToLua<'lua> for $struct { - fn to_lua( - self, - lua: &'lua $crate::mlua::Lua, - ) -> Result<$crate::mlua::Value<'lua>, $crate::mlua::Error> { - Ok($crate::to_lua_value(lua, self)?) - } - } - - impl<'lua> $crate::mlua::FromLua<'lua> for $struct { - fn from_lua( - value: $crate::mlua::Value<'lua>, - _lua: &'lua $crate::mlua::Lua, - ) -> Result { - Ok($crate::from_lua_value(value)?) - } - } - }; -} - #[macro_export] macro_rules! impl_lua_conversion_dynamic { ($struct:ident) => { @@ -103,6 +76,7 @@ pub fn dynamic_to_lua_value<'lua>( }) } +/// FIXME: lua_value_to_dynamic should detect and avoid cycles in the underlying lua object pub fn lua_value_to_dynamic(value: LuaValue) -> mlua::Result { Ok(match value { LuaValue::Nil => DynValue::Null, @@ -166,5 +140,3 @@ pub struct ValueLua { pub value: wezterm_dynamic::Value, } impl_lua_conversion_dynamic!(ValueLua); - -pub use serde_lua::ValueWrapper; diff --git a/luahelper/src/serde_lua/mod.rs b/luahelper/src/serde_lua/mod.rs deleted file mode 100644 index 6b5348ee5..000000000 --- a/luahelper/src/serde_lua/mod.rs +++ /dev/null @@ -1,1045 +0,0 @@ -use mlua::{Table, Value}; -use serde::de::value::{MapDeserializer, SeqDeserializer}; -use serde::de::{ - DeserializeOwned, DeserializeSeed, Deserializer, EnumAccess, Error as SerdeDeError, - IntoDeserializer, Unexpected, VariantAccess, Visitor, -}; -use serde::{serde_if_integer128, Deserialize}; -use std::cmp::Ordering; -use std::collections::{BTreeMap, HashSet}; -use std::convert::TryInto; -use thiserror::*; - -pub mod ser; - -/// This is the key function from this module; it uses serde to -/// "parse" a lua value into a Rust type that implements Deserialize. -pub fn from_lua_value(value: Value) -> Result -where - T: DeserializeOwned, -{ - T::deserialize(ValueWrapper(value)) -} - -fn unexpected<'lua>(v: &'lua Value<'lua>) -> Unexpected<'lua> { - match v { - Value::Nil => Unexpected::Other("lua nil"), - Value::Boolean(b) => Unexpected::Bool(*b), - Value::LightUserData(_) => Unexpected::Other("lua lightuserdata"), - Value::Integer(i) => Unexpected::Signed(*i), - Value::Number(n) => Unexpected::Float(*n), - Value::String(s) => match s.to_str() { - Ok(s) => Unexpected::Str(s), - Err(_) => Unexpected::Bytes(s.as_bytes()), - }, - Value::Table(t) => match t.contains_key(1) { - Ok(true) => Unexpected::Other("lua array-like table"), - Ok(false) => Unexpected::Other("lua map-like table"), - Err(_) => Unexpected::Other( - "lua table (but encountered an error while testing if it is array- or map-like)", - ), - }, - Value::Function(_) => Unexpected::Other("lua function"), - Value::Thread(_) => Unexpected::Other("lua thread"), - Value::UserData(_) => Unexpected::Other("lua userdata"), - Value::Error(_) => Unexpected::Other("lua error"), - } -} - -#[derive(Debug, Error)] -pub enum Error { - #[error("{}", msg)] - Custom { msg: String }, -} - -impl SerdeDeError for Error { - fn custom(msg: T) -> Self { - Error::Custom { - msg: msg.to_string(), - } - } -} - -impl From for mlua::Error { - fn from(e: Error) -> mlua::Error { - mlua::Error::external(e) - } -} - -pub struct ValueWrapper<'lua>(pub Value<'lua>); - -impl<'lua> PartialEq for ValueWrapper<'lua> { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl<'lua> Eq for ValueWrapper<'lua> {} - -impl<'lua> PartialOrd for ValueWrapper<'lua> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl<'lua> Ord for ValueWrapper<'lua> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (&self.0, &other.0) { - (Value::Nil, Value::Nil) => Ordering::Equal, - (Value::String(a), Value::String(b)) => a.as_bytes().cmp(b.as_bytes()), - (Value::Boolean(a), Value::Boolean(b)) => a.cmp(&b), - (Value::Integer(a), Value::Integer(b)) => a.cmp(&b), - (Value::Number(a), Value::Number(b)) => { - a.partial_cmp(&b).expect("use of keys that can be ordered") - } - (Value::LightUserData(a), Value::LightUserData(b)) => { - if a == b { - Ordering::Equal - } else { - panic!("use only keys that can be ordered") - } - } - (Value::Table(a), Value::Table(b)) => { - if a == b { - Ordering::Equal - } else { - panic!("use only keys that can be ordered") - } - } - (Value::Function(_), Value::Function(_)) => panic!("cannot order functions"), - (Value::Thread(_), Value::Thread(_)) => panic!("cannot order threads"), - (Value::UserData(_), Value::UserData(_)) => panic!("cannot order userdata"), - (Value::Error(_), Value::Error(_)) => panic!("cannot order errors"), - (Value::Nil, _) => panic!("cannot order differing types"), - (Value::Boolean(_), _) => panic!("cannot order differing types"), - (Value::Integer(_), _) => panic!("cannot order differing types"), - (Value::Number(_), _) => panic!("cannot order differing types"), - (Value::LightUserData(_), _) => panic!("cannot order differing types"), - (Value::UserData(_), _) => panic!("cannot order differing types"), - (Value::String(_), _) => panic!("cannot order differing types"), - (Value::Table(_), _) => panic!("cannot order differing types"), - (Value::Function(_), _) => panic!("cannot order differing types"), - (Value::Thread(_), _) => panic!("cannot order differing types"), - (Value::Error(_), _) => panic!("cannot order differing types"), - } - } -} - -fn table_has_cycle(top: &Table, value: &Value) -> bool { - if let Value::Table(table) = value { - let mut seen = HashSet::new(); - - /// Capture the lua reference number. There's no direct accessor, - /// but the debug struct includes that number. - fn lref(t: &Table) -> String { - format!("{:?}", t) - } - - seen.insert(lref(top)); - - fn check_cycle(seen: &mut HashSet, table: &Table, depth: usize) -> bool { - let cref = lref(table); - if seen.contains(&cref) { - return true; - } - if depth > 128 { - // Seems suspicious - return true; - } - seen.insert(cref); - for pair in table.clone().pairs::() { - if let Ok(pair) = pair { - if let Value::Table(child) = pair.1 { - if check_cycle(seen, &child, depth + 1) { - return true; - } - } - } - } - false - } - - return check_cycle(&mut seen, table, 0); - } - false -} - -impl<'lua> std::fmt::Debug for ValueWrapper<'lua> { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { - match &self.0 { - Value::Nil => fmt.write_str("nil"), - Value::Boolean(b) => fmt.write_str(if *b { "true" } else { "false" }), - Value::Integer(i) => fmt.write_fmt(format_args!("{}", i)), - Value::Number(i) => fmt.write_fmt(format_args!("{}", i)), - Value::String(s) => match s.to_str() { - Ok(s) => fmt.write_fmt(format_args!("{:?}", s)), - Err(_) => fmt.write_fmt(format_args!("{:?}", s.as_bytes())), - }, - Value::Table(t) => { - if let Ok(true) = t.contains_key(1) { - // Treat as list - let mut list = fmt.debug_list(); - for (idx, value) in t.clone().sequence_values().enumerate() { - match value { - Ok(value) => { - if !table_has_cycle(t, &value) { - list.entry(&ValueWrapper(value)); - } else { - log::warn!("Ignoring value at ordinal position {} which has cyclical reference", idx); - } - } - Err(err) => { - list.entry(&err); - } - } - } - list.finish() - } else { - // Treat as map; put it into a BTreeMap so that we have a stable - // order for our tests. - let mut map = BTreeMap::new(); - for pair in t.clone().pairs::() { - match pair { - Ok(pair) => { - if !table_has_cycle(t, &pair.1) { - map.insert(ValueWrapper(pair.0), ValueWrapper(pair.1)); - } else { - log::warn!( - "Ignoring field {:?} which has cyclical reference", - ValueWrapper(pair.0) - ); - } - } - Err(err) => { - log::error!("error while retrieving map entry: {}", err); - break; - } - } - } - fmt.debug_map().entries(&map).finish() - } - } - Value::UserData(_) | Value::LightUserData(_) => fmt.write_str("userdata"), - Value::Thread(_) => fmt.write_str("thread"), - Value::Function(_) => fmt.write_str("function"), - Value::Error(e) => fmt.write_fmt(format_args!("error {}", e)), - } - } -} - -impl<'de, 'lua> IntoDeserializer<'de, Error> for ValueWrapper<'lua> { - type Deserializer = Self; - - fn into_deserializer(self) -> Self { - self - } -} - -fn visit_table<'de, 'lua, V>( - table: Table<'lua>, - visitor: V, - struct_name: Option<&'static str>, - allowed_fields: Option<&'static [&'static str]>, -) -> Result -where - V: Visitor<'de>, -{ - // First we need to determine whether this table looks like an array - // or whether it looks like a map. Lua allows for either or both in - // the same table. - // Since array like tables start with index 1 we look for that key - // and assume that if it has that that it is an array. - if let Ok(true) = table.contains_key(1) { - // Treat it as an array - let mut values = vec![]; - for value in table.sequence_values() { - match value { - Ok(value) => values.push(ValueWrapper(value)), - Err(err) => { - return Err(Error::custom(format!( - "while retrieving an array element: {}", - err - ))) - } - } - } - - let mut deser = SeqDeserializer::new(values.into_iter()); - let seq = match visitor.visit_seq(&mut deser) { - Ok(seq) => seq, - Err(err) => return Err(err), - }; - - deser.end()?; - Ok(seq) - } else { - // Treat it as a map - let mut pairs = vec![]; - for pair in table.pairs::() { - match pair { - Ok(pair) => { - // When deserializing into a struct with known field names, - // we don't want to hard error if the user gave a bogus field - // name; we'd rather generate a warning somewhere and attempt - // to proceed. This makes the config a bit more forgiving of - // typos and also makes it easier to use a given config in - // a future version of wezterm where the configuration may - // evolve over time. - if let Some(allowed_fields) = allowed_fields { - if !allowed_fields.iter().any(|&name| name == pair.0) { - // The field wasn't one of the allowed fields in this - // context. Generate an error message that is hopefully - // helpful; we'll suggest the set of most similar field - // names (ordered by similarity) and list out the remaining - // possible field names in alpha order - - // Produce similar field name list - let mut candidates: Vec<(f64, &str)> = allowed_fields - .iter() - .map(|&name| (strsim::jaro_winkler(&pair.0, name), name)) - .filter(|(confidence, _)| *confidence > 0.8) - .collect(); - candidates.sort_by(|a, b| { - b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal) - }); - let suggestions: Vec<&str> = - candidates.into_iter().map(|(_, name)| name).collect(); - - // Filter the suggestions out of the allowed field names - // and sort what remains. - let mut fields: Vec<&str> = allowed_fields - .iter() - .filter(|&name| { - !suggestions.iter().any(|candidate| candidate == name) - }) - .copied() - .collect(); - fields.sort_unstable(); - - let mut message = String::new(); - - match suggestions.len() { - 0 => {} - 1 => { - message.push_str(&format!("Did you mean `{}`?", suggestions[0])) - } - _ => { - message.push_str("Did you mean one of "); - for (idx, candidate) in suggestions.iter().enumerate() { - if idx > 0 { - message.push_str(", "); - } - message.push('`'); - message.push_str(candidate); - message.push('`'); - } - message.push('?'); - } - } - if !fields.is_empty() { - if suggestions.is_empty() { - message.push_str("Possible fields are "); - } else { - message.push_str(" Other possible fields are "); - } - for (idx, candidate) in fields.iter().enumerate() { - if idx > 0 { - message.push_str(", "); - } - message.push('`'); - message.push_str(candidate); - message.push('`'); - } - message.push('.'); - } - log::error!( - "Ignoring unknown field `{}` in struct of type `{}`. {}", - pair.0, - struct_name.unwrap_or(""), - message - ); - - continue; - } - } - pairs.push((pair.0, ValueWrapper(pair.1))) - } - Err(err) => { - return Err(Error::custom(format!( - "while retrieving map element: {}", - err - ))) - } - } - } - let mut deser = MapDeserializer::new(pairs.into_iter()); - let seq = match visitor.visit_map(&mut deser) { - Ok(seq) => seq, - Err(err) => return Err(err), - }; - - deser.end()?; - Ok(seq) - } -} - -macro_rules! int { - ($name:ident, $ty:ty, $visit:ident) => { - fn $name(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Integer(i) => v.$visit(i.try_into().map_err(|e| { - Error::custom(format!( - "lua Integer value {} doesn't fit \ - in specified type: {}", - i, e - )) - })?), - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"integer", - )), - } - } - }; -} - -impl<'de, 'lua> Deserializer<'de> for ValueWrapper<'lua> { - type Error = Error; - - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Nil => visitor.visit_unit(), - Value::Boolean(v) => visitor.visit_bool(v), - Value::Integer(i) => visitor.visit_i64(i), - Value::Number(n) => visitor.visit_f64(n), - Value::String(s) => match s.to_str() { - Ok(s) => visitor.visit_str(s), - Err(_) => visitor.visit_bytes(s.as_bytes()), - }, - Value::Table(t) => visit_table(t, visitor, None, None), - Value::UserData(_) | Value::LightUserData(_) => Err(Error::custom( - "cannot represent userdata in the serde data model", - )), - Value::Thread(_) => Err(Error::custom( - "cannot represent thread in the serde data model", - )), - Value::Function(_) => Err(Error::custom( - "cannot represent lua function in the serde data model", - )), - Value::Error(e) => Err(Error::custom(format!( - "cannot represent lua error {} in the serde data model", - e - ))), - } - } - - fn deserialize_bool(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Boolean(b) => v.visit_bool(b), - _ => Err(serde::de::Error::invalid_type(unexpected(&self.0), &"bool")), - } - } - - fn deserialize_option(self, v: V) -> Result - where - V: Visitor<'de>, - { - match &self.0 { - Value::Nil => v.visit_none(), - Value::Table(t) => { - let mut iter = t.clone().pairs::(); - if iter.next().is_none() { - v.visit_none() - } else { - v.visit_some(self) - } - } - _ => v.visit_some(self), - } - } - - fn deserialize_unit(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Nil => v.visit_unit(), - _ => v.visit_some(self), - } - } - - fn deserialize_ignored_any(self, v: V) -> Result - where - V: Visitor<'de>, - { - v.visit_unit() - } - - fn deserialize_unit_struct(self, _name: &str, v: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_unit(v) - } - - fn deserialize_newtype_struct(self, _name: &str, v: V) -> Result - where - V: Visitor<'de>, - { - v.visit_newtype_struct(self) - } - - fn deserialize_char(self, v: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_string(v) - } - - fn deserialize_str(self, v: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_string(v) - } - - fn deserialize_identifier(self, v: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_string(v) - } - - fn deserialize_string(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::String(s) => match s.to_str() { - Ok(s) => v.visit_str(s), - Err(_) => Err(Error::custom( - "expected String but found a non-UTF8 lua string", - )), - }, - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"string", - )), - } - } - - fn deserialize_bytes(self, v: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_byte_buf(v) - } - - fn deserialize_byte_buf(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::String(s) => match s.to_str() { - Ok(s) => v.visit_str(s), - Err(_) => v.visit_bytes(s.as_bytes()), - }, - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"bytes", - )), - } - } - - fn deserialize_seq(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Table(t) => visit_table(t, v, None, None), - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"sequence/array", - )), - } - } - - fn deserialize_tuple(self, _len: usize, v: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_seq(v) - } - - fn deserialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - v: V, - ) -> Result - where - V: Visitor<'de>, - { - self.deserialize_seq(v) - } - - int!(deserialize_i8, i8, visit_i8); - int!(deserialize_u8, u8, visit_u8); - int!(deserialize_i16, i16, visit_i16); - int!(deserialize_u16, u16, visit_u16); - int!(deserialize_i32, i32, visit_i32); - int!(deserialize_u32, u32, visit_u32); - int!(deserialize_i64, i64, visit_i64); - int!(deserialize_u64, u64, visit_u64); - - serde_if_integer128! { - int!(deserialize_i128, i128, visit_i128); - int!(deserialize_u128, u128, visit_u128); - } - - fn deserialize_f64(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Number(i) => v.visit_f64(i), - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"floating point number", - )), - } - } - - fn deserialize_f32(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Number(i) => v.visit_f32(i as f32), - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"floating point number", - )), - } - } - - fn deserialize_enum( - self, - name: &str, - variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - let saved_value = self.0.clone(); - let (variant, value) = match self.0 { - Value::Table(t) => { - let mut iter = t.pairs::(); - let (variant, value) = match iter.next() { - Some(Ok(v)) => v, - Some(Err(e)) => { - return Err(Error::custom(format!( - "failed to retrieve enum pair from map: {}", - e - ))); - } - None => { - return Err(serde::de::Error::invalid_value( - Unexpected::Map, - &"map with a single key", - )); - } - }; - - // enums are encoded in serde_json as maps with a single - // key:value pair, so we mirror that here - if iter.next().is_some() { - return Err(serde::de::Error::invalid_value( - Unexpected::Map, - &"map with a single key", - )); - } - (variant, Some(value)) - } - Value::String(s) => match s.to_str() { - Ok(s) => (s.to_owned(), None), - Err(_) => { - return Err(serde::de::Error::invalid_value( - Unexpected::Bytes(s.as_bytes()), - &"UTF-8 string key", - )) - } - }, - _ => { - return Err(serde::de::Error::invalid_type( - Unexpected::Other("?"), - &"string or map", - )); - } - }; - - visitor - .visit_enum(EnumDeserializer { variant, value }) - .map_err(|err| { - let mut variants = variants.to_vec(); - variants.sort(); - let mut allowed = String::new(); - for varname in variants.into_iter() { - if !allowed.is_empty() { - allowed.push_str(", "); - } - allowed.push('`'); - allowed.push_str(varname); - allowed.push('`'); - } - Error::custom(format!( - "while processing an enum of type `{}` and \ - value {:?}\nwhich has allowed variants {}\n{}", - name, - ValueWrapper(saved_value), - allowed, - err - )) - }) - } - - fn deserialize_map(self, v: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Table(t) => visit_table(t, v, None, None), - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"a map", - )), - } - } - - fn deserialize_struct( - self, - struct_name: &'static str, - fields: &'static [&'static str], - v: V, - ) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Table(t) => { - let table = t.clone(); - match visit_table(t, v, Some(struct_name), Some(fields)) { - Ok(v) => Ok(v), - Err(err) => Err(Error::custom(format!( - "while processing a struct of type `{}` with value:\n{:#?}\n{}", - struct_name, - ValueWrapper(Value::Table(table)), - err, - ))), - } - } - _ => Err(serde::de::Error::invalid_type( - unexpected(&self.0), - &"a map", - )), - } - } -} - -struct EnumDeserializer<'lua> { - variant: String, - value: Option>, -} - -impl<'de, 'lua> EnumAccess<'de> for EnumDeserializer<'lua> { - type Error = Error; - type Variant = VariantDeserializer<'lua>; - - fn variant_seed(self, seed: V) -> Result<(V::Value, VariantDeserializer<'lua>), Error> - where - V: DeserializeSeed<'de>, - { - let variant = self.variant.into_deserializer(); - let visitor = VariantDeserializer { value: self.value }; - seed.deserialize(variant).map(|v| (v, visitor)) - } -} - -struct VariantDeserializer<'lua> { - value: Option>, -} - -impl<'de, 'lua> VariantAccess<'de> for VariantDeserializer<'lua> { - type Error = Error; - - fn unit_variant(self) -> Result<(), Error> { - match self.value { - Some(value) => Deserialize::deserialize(ValueWrapper(value)), - None => Ok(()), - } - } - - fn newtype_variant_seed(self, seed: T) -> Result - where - T: DeserializeSeed<'de>, - { - match self.value { - Some(value) => seed.deserialize(ValueWrapper(value)), - None => Err(Error::custom(format!( - "Expected a variant with parameters but got a unit variant instead" - ))), - } - } - - fn tuple_variant(self, _len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - match self.value { - Some(Value::Table(table)) => { - if let Ok(true) = table.contains_key(1) { - let mut values = vec![]; - for value in table.sequence_values() { - match value { - Ok(value) => values.push(ValueWrapper(value)), - Err(err) => { - return Err(Error::custom(format!( - "while retrieving an array element: {}", - err - ))) - } - } - } - - let deser = SeqDeserializer::new(values.into_iter()); - serde::Deserializer::deserialize_any(deser, visitor) - } else { - Err(serde::de::Error::invalid_type( - Unexpected::Map, - &"tuple variant", - )) - } - } - Some(v) => Err(serde::de::Error::invalid_type( - unexpected(&v), - &"tuple variant", - )), - None => Err(serde::de::Error::invalid_type( - Unexpected::UnitVariant, - &"tuple variant", - )), - } - } - - fn struct_variant( - self, - _fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - match self.value { - Some(Value::Table(table)) => { - if let Ok(false) = table.contains_key(1) { - let mut pairs = vec![]; - for pair in table.pairs::() { - match pair { - Ok(pair) => pairs.push((pair.0, ValueWrapper(pair.1))), - Err(err) => { - return Err(Error::custom(format!( - "while retrieving map element: {}", - err - ))) - } - } - } - let deser = MapDeserializer::new(pairs.into_iter()); - serde::Deserializer::deserialize_any(deser, visitor) - } else { - Err(serde::de::Error::invalid_type( - Unexpected::Seq, - &"struct variant", - )) - } - } - Some(v) => Err(serde::de::Error::invalid_type( - unexpected(&v), - &"struct variant", - )), - _ => Err(serde::de::Error::invalid_type( - Unexpected::UnitVariant, - &"struct variant", - )), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use mlua::Lua; - use serde::Serialize; - - fn round_trip< - T: Serialize + DeserializeOwned + ?Sized + PartialEq + std::fmt::Debug + Clone, - >( - value: T, - ) { - let lua = Lua::new(); - let lua_value: Value = ser::to_lua_value(&lua, value.clone()).unwrap(); - let round_tripped: T = from_lua_value(lua_value).unwrap(); - assert_eq!(value, round_tripped); - } - - #[test] - fn test_bool() { - let lua = Lua::new(); - let res: bool = from_lua_value(lua.load("true").eval().unwrap()).unwrap(); - assert_eq!(res, true); - round_trip(res); - - let res: bool = from_lua_value(lua.load("false").eval().unwrap()).unwrap(); - assert_eq!(res, false); - round_trip(res); - } - - #[test] - fn test_nil() { - let lua = Lua::new(); - let res: () = from_lua_value(lua.load("nil").eval().unwrap()).unwrap(); - round_trip(res); - } - - #[test] - fn test_int() { - let lua = Lua::new(); - let res: i64 = from_lua_value(lua.load("123").eval().unwrap()).unwrap(); - assert_eq!(res, 123); - round_trip(res); - - let res: i32 = from_lua_value(lua.load("123").eval().unwrap()).unwrap(); - assert_eq!(res, 123); - round_trip(res); - - let res: i16 = from_lua_value(lua.load("123").eval().unwrap()).unwrap(); - assert_eq!(res, 123); - round_trip(res); - - let res: i8 = from_lua_value(lua.load("123").eval().unwrap()).unwrap(); - assert_eq!(res, 123); - round_trip(res); - } - - #[test] - fn test_float() { - let lua = Lua::new(); - let res: f64 = from_lua_value(lua.load("123.5").eval().unwrap()).unwrap(); - assert_eq!(res, 123.5); - round_trip(res); - } - - #[test] - fn test_string() { - let lua = Lua::new(); - let res: String = from_lua_value(lua.load("\"hello\"").eval().unwrap()).unwrap(); - assert_eq!(res, "hello"); - round_trip(res); - } - - #[test] - fn test_array_table() { - let lua = Lua::new(); - let res: Vec = from_lua_value(lua.load("{1, 2, 3}").eval().unwrap()).unwrap(); - assert_eq!(res, vec![1, 2, 3]); - round_trip(res); - } - - #[test] - fn test_map_table() { - let lua = Lua::new(); - - #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] - struct MyMap { - hello: String, - age: usize, - } - - let res: MyMap = - from_lua_value(lua.load("{hello=\"hello\", age=42}").eval().unwrap()).unwrap(); - assert_eq!( - res, - MyMap { - hello: "hello".to_owned(), - age: 42 - } - ); - round_trip(res); - - let err = from_lua_value::(lua.load("{hello=\"hello\", age=true}").eval().unwrap()) - .unwrap_err(); - assert_eq!( - err.to_string(), - "while processing a struct of type `MyMap` with value:\n\ - {\n \"age\": true,\n \"hello\": \"hello\",\n}\n\ - invalid type: boolean `true`, expected integer", - ); - } - - #[test] - fn test_enum() { - #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] - enum MyEnum { - Foo, - Bar, - } - let lua = Lua::new(); - let res: MyEnum = from_lua_value(lua.load("\"Foo\"").eval().unwrap()).unwrap(); - assert_eq!(res, MyEnum::Foo); - round_trip(res); - - let res: MyEnum = from_lua_value(lua.load("\"Bar\"").eval().unwrap()).unwrap(); - assert_eq!(res, MyEnum::Bar); - round_trip(res); - - let err = from_lua_value::(lua.load("\"Invalid\"").eval().unwrap()).unwrap_err(); - assert_eq!( - err.to_string(), - "while processing an enum of type `MyEnum` and value \"Invalid\"\n\ - which has allowed variants `Bar`, `Foo`\n\ - unknown variant `Invalid`, expected `Foo` or `Bar`", - ); - } - - #[test] - fn test_option_mode() { - #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] - enum MyEnum { - Foo, - Bar, - } - #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] - enum AnotherEnum { - ThisOne(Option), - } - - let lua = Lua::new(); - let res: AnotherEnum = - from_lua_value(lua.load("{ThisOne=\"Foo\"}").eval().unwrap()).unwrap(); - assert_eq!(res, AnotherEnum::ThisOne(Some(MyEnum::Foo))); - round_trip(res); - - let res: AnotherEnum = from_lua_value(lua.load("{ThisOne={}}").eval().unwrap()).unwrap(); - assert_eq!(res, AnotherEnum::ThisOne(None)); - } -} diff --git a/luahelper/src/serde_lua/ser.rs b/luahelper/src/serde_lua/ser.rs deleted file mode 100644 index 12f535d5f..000000000 --- a/luahelper/src/serde_lua/ser.rs +++ /dev/null @@ -1,466 +0,0 @@ -use super::ValueWrapper; -use mlua::{Lua, Table, ToLua, Value}; -use serde::ser::Error as SerError; -use serde::{serde_if_integer128, Serialize, Serializer}; -use thiserror::*; - -pub fn to_lua_value<'lua, T>(lua: &'lua Lua, input: T) -> Result, Error> -where - T: Serialize, -{ - let serializer = LuaSerializer { lua }; - input.serialize(serializer) -} - -#[derive(Debug, Error)] -pub enum Error { - #[error("{}", msg)] - Custom { msg: String }, -} - -impl Error { - fn lua(e: mlua::Error) -> Error { - Error::custom(e) - } -} - -impl From for mlua::Error { - fn from(e: Error) -> mlua::Error { - mlua::Error::external(e) - } -} - -impl SerError for Error { - fn custom(msg: T) -> Self { - Error::Custom { - msg: msg.to_string(), - } - } -} - -struct LuaSerializer<'lua> { - lua: &'lua Lua, -} - -struct LuaSeqSerializer<'lua> { - lua: &'lua Lua, - table: Table<'lua>, - index: usize, -} - -impl<'lua> serde::ser::SerializeSeq for LuaSeqSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - - fn serialize_element(&mut self, value: &T) -> Result<(), Error> { - let value = value.serialize(LuaSerializer { lua: self.lua })?; - self.table.set(self.index, value).map_err(Error::lua)?; - self.index += 1; - Ok(()) - } - - fn end(self) -> Result { - Ok(Value::Table(self.table)) - } -} - -impl<'lua> serde::ser::SerializeTuple for LuaSeqSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - - fn serialize_element(&mut self, value: &T) -> Result<(), Error> { - serde::ser::SerializeSeq::serialize_element(self, value) - } - - fn end(self) -> Result { - serde::ser::SerializeSeq::end(self) - } -} - -impl<'lua> serde::ser::SerializeTupleStruct for LuaSeqSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - - fn serialize_field(&mut self, value: &T) -> Result<(), Error> { - serde::ser::SerializeSeq::serialize_element(self, value) - } - fn end(self) -> Result, Error> { - serde::ser::SerializeSeq::end(self) - } -} - -struct LuaTupleVariantSerializer<'lua> { - lua: &'lua Lua, - table: Table<'lua>, - index: usize, - name: String, -} - -impl<'lua> serde::ser::SerializeTupleVariant for LuaTupleVariantSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - - fn serialize_field(&mut self, value: &T) -> Result<(), Error> { - let value = value.serialize(LuaSerializer { lua: self.lua })?; - self.table.set(self.index, value).map_err(Error::lua)?; - self.index += 1; - Ok(()) - } - - fn end(self) -> Result, Error> { - let map = self.lua.create_table().map_err(Error::lua)?; - map.set(self.name, self.table).map_err(Error::lua)?; - Ok(Value::Table(map)) - } -} - -struct LuaMapSerializer<'lua> { - lua: &'lua Lua, - table: Table<'lua>, - key: Option>, -} - -impl<'lua> serde::ser::SerializeMap for LuaMapSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - - fn serialize_key(&mut self, key: &T) -> Result<(), Error> { - let key = key.serialize(LuaSerializer { lua: self.lua })?; - self.key.replace(key); - Ok(()) - } - - fn serialize_value(&mut self, value: &T) -> Result<(), Error> { - let value = value.serialize(LuaSerializer { lua: self.lua })?; - let key = self - .key - .take() - .expect("serialize_key must be called before serialize_value"); - self.table.set(key, value).map_err(Error::lua)?; - Ok(()) - } - - fn serialize_entry( - &mut self, - key: &K, - value: &V, - ) -> Result<(), Error> { - let key = key.serialize(LuaSerializer { lua: self.lua })?; - let value = value.serialize(LuaSerializer { lua: self.lua })?; - self.table.set(key, value).map_err(Error::lua)?; - Ok(()) - } - - fn end(self) -> Result, Error> { - Ok(Value::Table(self.table)) - } -} - -impl<'lua> serde::ser::SerializeStruct for LuaMapSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - - fn serialize_field( - &mut self, - key: &str, - value: &T, - ) -> Result<(), Error> { - serde::ser::SerializeMap::serialize_entry(self, key, value) - } - - fn end(self) -> Result, Error> { - serde::ser::SerializeMap::end(self) - } -} - -struct LuaStructVariantSerializer<'lua> { - lua: &'lua Lua, - name: String, - table: Table<'lua>, -} - -impl<'lua> serde::ser::SerializeStructVariant for LuaStructVariantSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - - fn serialize_field( - &mut self, - key: &str, - value: &T, - ) -> Result<(), Error> { - let key = key.serialize(LuaSerializer { lua: self.lua })?; - let value = value.serialize(LuaSerializer { lua: self.lua })?; - self.table.set(key, value).map_err(Error::lua)?; - Ok(()) - } - - fn end(self) -> Result, Error> { - let map = self.lua.create_table().map_err(Error::lua)?; - map.set(self.name, self.table).map_err(Error::lua)?; - Ok(Value::Table(map)) - } -} - -impl<'lua> serde::Serializer for LuaSerializer<'lua> { - type Ok = Value<'lua>; - type Error = Error; - type SerializeSeq = LuaSeqSerializer<'lua>; - type SerializeTuple = LuaSeqSerializer<'lua>; - type SerializeTupleStruct = LuaSeqSerializer<'lua>; - type SerializeTupleVariant = LuaTupleVariantSerializer<'lua>; - type SerializeMap = LuaMapSerializer<'lua>; - type SerializeStruct = LuaMapSerializer<'lua>; - type SerializeStructVariant = LuaStructVariantSerializer<'lua>; - - fn serialize_bool(self, b: bool) -> Result, Error> { - b.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_i8(self, i: i8) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_i16(self, i: i16) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_i32(self, i: i32) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_i64(self, i: i64) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_u8(self, i: u8) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_u16(self, i: u16) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_u32(self, i: u32) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_u64(self, i: u64) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - - serde_if_integer128! { - fn serialize_u128(self, i: u128) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - fn serialize_i128(self, i: i128) -> Result, Error> { - i.to_lua(self.lua).map_err(Error::lua) - } - } - - fn serialize_f32(self, f: f32) -> Result, Error> { - f.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_f64(self, f: f64) -> Result, Error> { - f.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_char(self, c: char) -> Result, Error> { - let mut s = String::new(); - s.push(c); - self.serialize_str(&s) - } - - fn serialize_str(self, s: &str) -> Result, Error> { - s.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_bytes(self, b: &[u8]) -> Result, Error> { - let b: &bstr::BStr = b.into(); - b.to_lua(self.lua).map_err(Error::lua) - } - - fn serialize_none(self) -> Result, Error> { - Ok(Value::Nil) - } - - fn serialize_some(self, v: &T) -> Result, Error> { - v.serialize(self) - } - - fn serialize_unit(self) -> Result, Error> { - Ok(Value::Nil) - } - - fn serialize_unit_struct(self, _name: &'static str) -> Result, Error> { - self.serialize_unit() - } - - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - ) -> Result, Error> { - self.serialize_str(variant) - } - - fn serialize_newtype_struct( - self, - _name: &'static str, - value: &T, - ) -> Result, Error> { - value.serialize(self) - } - - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - value: &T, - ) -> Result, Error> { - let value = value.serialize(LuaSerializer { lua: self.lua })?; - - let table = self.lua.create_table().map_err(Error::lua)?; - table.set(variant, value).map_err(Error::lua)?; - Ok(Value::Table(table)) - } - - fn serialize_seq(self, len: Option) -> Result, Error> { - self.serialize_tuple(len.unwrap_or(0)) - } - - fn serialize_tuple(self, _len: usize) -> Result, Error> { - let table = self.lua.create_table().map_err(Error::lua)?; - Ok(LuaSeqSerializer { - lua: self.lua, - table, - index: 1, - }) - } - - fn serialize_tuple_struct( - self, - _name: &'static str, - len: usize, - ) -> Result, Error> { - self.serialize_tuple(len) - } - - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - _len: usize, - ) -> Result, Error> { - let table = self.lua.create_table().map_err(Error::lua)?; - Ok(LuaTupleVariantSerializer { - lua: self.lua, - table, - index: 1, - name: variant.to_string(), - }) - } - - fn serialize_map( - self, - _len: std::option::Option, - ) -> Result, Error> { - let table = self.lua.create_table().map_err(Error::lua)?; - Ok(LuaMapSerializer { - lua: self.lua, - table, - key: None, - }) - } - - fn serialize_struct( - self, - _name: &'static str, - len: usize, - ) -> Result, Error> { - self.serialize_map(Some(len)) - } - - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - _len: usize, - ) -> Result, Error> { - let table = self.lua.create_table().map_err(Error::lua)?; - Ok(LuaStructVariantSerializer { - lua: self.lua, - table, - name: variant.to_owned(), - }) - } -} - -impl<'lua> Serialize for ValueWrapper<'lua> { - fn serialize(&self, serializer: S) -> Result { - match &self.0 { - Value::Nil => serializer.serialize_unit(), - Value::Boolean(b) => serializer.serialize_bool(*b), - Value::Integer(i) => serializer.serialize_i64(*i), - Value::Number(n) => serializer.serialize_f64(*n), - Value::String(s) => match s.to_str() { - Ok(s) => serializer.serialize_str(s), - Err(_) => serializer.serialize_bytes(s.as_bytes()), - }, - Value::Table(table) => { - if let Ok(true) = table.contains_key(1) { - let mut values = vec![]; - for value in table.clone().sequence_values() { - match value { - Ok(value) => values.push(ValueWrapper(value)), - Err(err) => { - return Err(S::Error::custom(format!( - "while retrieving an array element: {}", - err - ))) - } - } - } - values.serialize(serializer) - } else { - use serde::ser::SerializeMap; - let mut pairs = vec![]; - for pair in table.clone().pairs::() { - match pair { - Ok(pair) => pairs.push((ValueWrapper(pair.0), ValueWrapper(pair.1))), - Err(err) => { - return Err(S::Error::custom(format!( - "while retrieving map element: {}", - err - ))) - } - } - } - - let mut map = serializer.serialize_map(Some(pairs.len()))?; - for (k, v) in pairs.into_iter() { - map.serialize_entry(&k, &v)?; - } - map.end() - } - } - Value::UserData(_) | Value::LightUserData(_) => Err(S::Error::custom( - "cannot represent userdata in the serde data model", - )), - Value::Thread(_) => Err(S::Error::custom( - "cannot represent thread in the serde data model", - )), - Value::Function(_) => Err(S::Error::custom( - "cannot represent lua function in the serde data model", - )), - Value::Error(e) => Err(S::Error::custom(format!( - "cannot represent lua error {} in the serde data model", - e - ))), - } - } -} diff --git a/procinfo/Cargo.toml b/procinfo/Cargo.toml index d42e33f11..31c5ca70d 100644 --- a/procinfo/Cargo.toml +++ b/procinfo/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" libc = "0.2" log = "0.4" luahelper = { path = "../luahelper" } -serde = {version="1.0", features = ["derive"]} +wezterm-dynamic = { path = "../wezterm-dynamic" } [target."cfg(windows)".dependencies] ntapi = "0.3" diff --git a/procinfo/src/lib.rs b/procinfo/src/lib.rs index 21798fdd6..461138b01 100644 --- a/procinfo/src/lib.rs +++ b/procinfo/src/lib.rs @@ -1,12 +1,12 @@ -use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; +use wezterm_dynamic::{FromDynamic, ToDynamic}; mod linux; mod macos; mod windows; -#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +#[derive(Debug, Copy, Clone, FromDynamic, ToDynamic)] pub enum LocalProcessStatus { Idle, Run, @@ -22,7 +22,7 @@ pub enum LocalProcessStatus { Unknown, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Clone, FromDynamic, ToDynamic)] pub struct LocalProcessInfo { /// The process identifier pub pid: u32, @@ -55,7 +55,7 @@ pub struct LocalProcessInfo { /// Child processes, keyed by pid pub children: HashMap, } -luahelper::impl_lua_conversion!(LocalProcessInfo); +luahelper::impl_lua_conversion_dynamic!(LocalProcessInfo); impl LocalProcessInfo { /// Walk this sub-tree of processes and return a unique set diff --git a/wezterm-gui/src/overlay/debug.rs b/wezterm-gui/src/overlay/debug.rs index 9cdf16778..84190084e 100644 --- a/wezterm-gui/src/overlay/debug.rs +++ b/wezterm-gui/src/overlay/debug.rs @@ -1,7 +1,6 @@ use crate::scripting::guiwin::GuiWin; use chrono::prelude::*; use log::Level; -use luahelper::ValueWrapper; use mlua::Value; use mux::termwiztermtab::TermWizTerminal; use termwiz::cell::{AttributeChange, CellAttributes, Intensity}; @@ -134,7 +133,8 @@ pub fn show_debug_overlay(mut term: TermWizTerminal, gui_win: GuiWin) -> anyhow: let chunk = host.lua.load(&expr); match smol::block_on(chunk.eval_async::()) { Ok(result) => { - let text = format!("{:#?}", ValueWrapper(result)); + let value = luahelper::lua_value_to_dynamic(result); + let text = format!("{:#?}", value); term.render(&[Change::Text(format!("{}\r\n", text.replace("\n", "\r\n")))])?; } Err(err) => {