1
1
mirror of https://github.com/wez/wezterm.git synced 2024-09-19 02:37:51 +03:00

lua: add serde-powered to_lua_value function

This will help in implementing lua helper functions for various
config and other structures.
This commit is contained in:
Wez Furlong 2020-02-29 21:46:59 -08:00
parent a6fb3971c6
commit fa01ca59ca
5 changed files with 504 additions and 3 deletions

12
Cargo.lock generated
View File

@ -283,7 +283,9 @@ version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "502ae1441a0a5adb8fbd38a5955a6416b9493e92b465de5e4a9bde6a539c2c48"
dependencies = [
"lazy_static",
"memchr",
"regex-automata",
]
[[package]]
@ -2437,6 +2439,15 @@ dependencies = [
"thread_local",
]
[[package]]
name = "regex-automata"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92b73c2a1770c255c240eaa4ee600df1704a38dc3feaa6e949e7fcd4f8dc09f9"
dependencies = [
"byteorder",
]
[[package]]
name = "regex-syntax"
version = "0.6.14"
@ -3462,6 +3473,7 @@ dependencies = [
"base64 0.10.1",
"base91",
"bitflags 1.2.1",
"bstr",
"core-foundation 0.7.0",
"core-graphics 0.19.0",
"core-text 15.0.0",

View File

@ -21,6 +21,7 @@ base64 = "0.10"
base91 = { path = "base91" }
rangeset = { path = "rangeset" }
bitflags = "1.0"
bstr = "0.2"
crossbeam = "0.7"
dirs = "1.0"
downcast-rs = "1.0"

View File

@ -5,6 +5,7 @@ use std::path::Path;
mod serde_lua;
pub use serde_lua::from_lua_value;
pub use serde_lua::ser::to_lua_value;
/// Set up a lua context for executing some code.
/// The path to the directory containing the configuration is

View File

@ -8,6 +8,8 @@ use serde::{serde_if_integer128, Deserialize};
use std::convert::TryInto;
use thiserror::*;
pub mod ser;
/// This is the key function from this module; it uses serde to
/// "parse" a lua value into a Rust type that implements Deserialize.
pub fn from_lua_value<T>(value: Value) -> Result<T, Error>
@ -592,21 +594,36 @@ impl<'de, 'lua> VariantAccess<'de> for VariantDeserializer<'lua> {
mod test {
use super::*;
use mlua::Lua;
use serde::Serialize;
fn round_trip<
T: Serialize + DeserializeOwned + ?Sized + PartialEq + std::fmt::Debug + Clone,
>(
value: T,
) {
let lua = Lua::new();
let lua_value: Value = ser::to_lua_value(&lua, value.clone()).unwrap();
let round_tripped: T = from_lua_value(lua_value).unwrap();
assert_eq!(value, round_tripped);
}
#[test]
fn test_bool() {
let lua = Lua::new();
let res: bool = from_lua_value(lua.load("true").eval().unwrap()).unwrap();
assert_eq!(res, true);
round_trip(res);
let res: bool = from_lua_value(lua.load("false").eval().unwrap()).unwrap();
assert_eq!(res, false);
round_trip(res);
}
#[test]
fn test_nil() {
let lua = Lua::new();
let _res: () = from_lua_value(lua.load("nil").eval().unwrap()).unwrap();
let res: () = from_lua_value(lua.load("nil").eval().unwrap()).unwrap();
round_trip(res);
}
#[test]
@ -614,15 +631,19 @@ mod test {
let lua = Lua::new();
let res: i64 = from_lua_value(lua.load("123").eval().unwrap()).unwrap();
assert_eq!(res, 123);
round_trip(res);
let res: i32 = from_lua_value(lua.load("123").eval().unwrap()).unwrap();
assert_eq!(res, 123);
round_trip(res);
let res: i16 = from_lua_value(lua.load("123").eval().unwrap()).unwrap();
assert_eq!(res, 123);
round_trip(res);
let res: i8 = from_lua_value(lua.load("123").eval().unwrap()).unwrap();
assert_eq!(res, 123);
round_trip(res);
}
#[test]
@ -630,6 +651,7 @@ mod test {
let lua = Lua::new();
let res: f64 = from_lua_value(lua.load("123.5").eval().unwrap()).unwrap();
assert_eq!(res, 123.5);
round_trip(res);
}
#[test]
@ -637,6 +659,7 @@ mod test {
let lua = Lua::new();
let res: String = from_lua_value(lua.load("\"hello\"").eval().unwrap()).unwrap();
assert_eq!(res, "hello");
round_trip(res);
}
#[test]
@ -644,13 +667,14 @@ mod test {
let lua = Lua::new();
let res: Vec<i8> = from_lua_value(lua.load("{1, 2, 3}").eval().unwrap()).unwrap();
assert_eq!(res, vec![1, 2, 3]);
round_trip(res);
}
#[test]
fn test_map_table() {
let lua = Lua::new();
#[derive(Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
struct MyMap {
hello: String,
age: usize,
@ -665,6 +689,7 @@ mod test {
age: 42
}
);
round_trip(res);
let err = from_lua_value::<MyMap>(lua.load("{hello=\"hello\", age=true}").eval().unwrap())
.unwrap_err();
@ -678,7 +703,7 @@ mod test {
#[test]
fn test_enum() {
#[derive(Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
enum MyEnum {
Foo,
Bar,
@ -686,9 +711,11 @@ mod test {
let lua = Lua::new();
let res: MyEnum = from_lua_value(lua.load("\"Foo\"").eval().unwrap()).unwrap();
assert_eq!(res, MyEnum::Foo);
round_trip(res);
let res: MyEnum = from_lua_value(lua.load("\"Bar\"").eval().unwrap()).unwrap();
assert_eq!(res, MyEnum::Bar);
round_trip(res);
let err = from_lua_value::<MyEnum>(lua.load("\"Invalid\"").eval().unwrap()).unwrap_err();
assert_eq!(

View File

@ -0,0 +1,460 @@
use super::ValueWrapper;
use mlua::{Lua, Table, ToLua, Value};
use serde::ser::Error as SerError;
use serde::{serde_if_integer128, Serialize, Serializer};
use thiserror::*;
pub fn to_lua_value<'lua, T>(lua: &'lua Lua, input: T) -> Result<Value<'lua>, Error>
where
T: Serialize,
{
let serializer = LuaSerializer { lua };
input.serialize(serializer)
}
#[derive(Debug, Error)]
pub enum Error {
#[error("{:?}", msg)]
Custom { msg: String },
}
impl Error {
fn lua(e: mlua::Error) -> Error {
Error::custom(e)
}
}
impl SerError for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Error::Custom {
msg: msg.to_string(),
}
}
}
struct LuaSerializer<'lua> {
lua: &'lua Lua,
}
struct LuaSeqSerializer<'lua> {
lua: &'lua Lua,
table: Table<'lua>,
index: usize,
}
impl<'lua> serde::ser::SerializeSeq for LuaSeqSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
fn serialize_element<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<(), Error> {
let value = value.serialize(LuaSerializer { lua: self.lua })?;
self.table.set(self.index, value).map_err(Error::lua)?;
self.index += 1;
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(Value::Table(self.table))
}
}
impl<'lua> serde::ser::SerializeTuple for LuaSeqSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
fn serialize_element<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<(), Error> {
serde::ser::SerializeSeq::serialize_element(self, value)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
serde::ser::SerializeSeq::end(self)
}
}
impl<'lua> serde::ser::SerializeTupleStruct for LuaSeqSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
fn serialize_field<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<(), Error> {
serde::ser::SerializeSeq::serialize_element(self, value)
}
fn end(self) -> Result<Value<'lua>, Error> {
serde::ser::SerializeSeq::end(self)
}
}
struct LuaTupleVariantSerializer<'lua> {
lua: &'lua Lua,
table: Table<'lua>,
index: usize,
name: String,
}
impl<'lua> serde::ser::SerializeTupleVariant for LuaTupleVariantSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
fn serialize_field<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<(), Error> {
let value = value.serialize(LuaSerializer { lua: self.lua })?;
self.table.set(self.index, value).map_err(Error::lua)?;
self.index += 1;
Ok(())
}
fn end(self) -> Result<Value<'lua>, Error> {
let map = self.lua.create_table().map_err(Error::lua)?;
map.set(self.name, self.table).map_err(Error::lua)?;
Ok(Value::Table(map))
}
}
struct LuaMapSerializer<'lua> {
lua: &'lua Lua,
table: Table<'lua>,
key: Option<Value<'lua>>,
}
impl<'lua> serde::ser::SerializeMap for LuaMapSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
fn serialize_key<T: Serialize + ?Sized>(&mut self, key: &T) -> Result<(), Error> {
let key = key.serialize(LuaSerializer { lua: self.lua })?;
self.key.replace(key);
Ok(())
}
fn serialize_value<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<(), Error> {
let value = value.serialize(LuaSerializer { lua: self.lua })?;
let key = self
.key
.take()
.expect("serialize_key must be called before serialize_value");
self.table.set(key, value).map_err(Error::lua)?;
Ok(())
}
fn serialize_entry<K: Serialize + ?Sized, V: Serialize + ?Sized>(
&mut self,
key: &K,
value: &V,
) -> Result<(), Error> {
let key = key.serialize(LuaSerializer { lua: self.lua })?;
let value = value.serialize(LuaSerializer { lua: self.lua })?;
self.table.set(key, value).map_err(Error::lua)?;
Ok(())
}
fn end(self) -> Result<Value<'lua>, Error> {
Ok(Value::Table(self.table))
}
}
impl<'lua> serde::ser::SerializeStruct for LuaMapSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
fn serialize_field<T: Serialize + ?Sized>(
&mut self,
key: &str,
value: &T,
) -> Result<(), Error> {
serde::ser::SerializeMap::serialize_entry(self, key, value)
}
fn end(self) -> Result<Value<'lua>, Error> {
serde::ser::SerializeMap::end(self)
}
}
struct LuaStructVariantSerializer<'lua> {
lua: &'lua Lua,
name: String,
table: Table<'lua>,
}
impl<'lua> serde::ser::SerializeStructVariant for LuaStructVariantSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
fn serialize_field<T: Serialize + ?Sized>(
&mut self,
key: &str,
value: &T,
) -> Result<(), Error> {
let key = key.serialize(LuaSerializer { lua: self.lua })?;
let value = value.serialize(LuaSerializer { lua: self.lua })?;
self.table.set(key, value).map_err(Error::lua)?;
Ok(())
}
fn end(self) -> Result<Value<'lua>, Error> {
let map = self.lua.create_table().map_err(Error::lua)?;
map.set(self.name, self.table).map_err(Error::lua)?;
Ok(Value::Table(map))
}
}
impl<'lua> serde::Serializer for LuaSerializer<'lua> {
type Ok = Value<'lua>;
type Error = Error;
type SerializeSeq = LuaSeqSerializer<'lua>;
type SerializeTuple = LuaSeqSerializer<'lua>;
type SerializeTupleStruct = LuaSeqSerializer<'lua>;
type SerializeTupleVariant = LuaTupleVariantSerializer<'lua>;
type SerializeMap = LuaMapSerializer<'lua>;
type SerializeStruct = LuaMapSerializer<'lua>;
type SerializeStructVariant = LuaStructVariantSerializer<'lua>;
fn serialize_bool(self, b: bool) -> Result<Value<'lua>, Error> {
b.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_i8(self, i: i8) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_i16(self, i: i16) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_i32(self, i: i32) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_i64(self, i: i64) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_u8(self, i: u8) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_u16(self, i: u16) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_u32(self, i: u32) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_u64(self, i: u64) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
serde_if_integer128! {
fn serialize_u128(self, i: u128) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_i128(self, i: i128) -> Result<Value<'lua>, Error> {
i.to_lua(self.lua).map_err(Error::lua)
}
}
fn serialize_f32(self, f: f32) -> Result<Value<'lua>, Error> {
f.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_f64(self, f: f64) -> Result<Value<'lua>, Error> {
f.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_char(self, c: char) -> Result<Value<'lua>, Error> {
let mut s = String::new();
s.push(c);
self.serialize_str(&s)
}
fn serialize_str(self, s: &str) -> Result<Value<'lua>, Error> {
s.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_bytes(self, b: &[u8]) -> Result<Value<'lua>, Error> {
let b: &bstr::BStr = b.into();
b.to_lua(self.lua).map_err(Error::lua)
}
fn serialize_none(self) -> Result<Value<'lua>, Error> {
Ok(Value::Nil)
}
fn serialize_some<T: Serialize + ?Sized>(self, v: &T) -> Result<Value<'lua>, Error> {
v.serialize(self)
}
fn serialize_unit(self) -> Result<Value<'lua>, Error> {
Ok(Value::Nil)
}
fn serialize_unit_struct(self, _name: &'static str) -> Result<Value<'lua>, Error> {
self.serialize_unit()
}
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
) -> Result<Value<'lua>, Error> {
self.serialize_str(variant)
}
fn serialize_newtype_struct<T: Serialize + ?Sized>(
self,
_name: &'static str,
value: &T,
) -> Result<Value<'lua>, Error> {
value.serialize(self)
}
fn serialize_newtype_variant<T: Serialize + ?Sized>(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
value: &T,
) -> Result<Value<'lua>, Error> {
let value = value.serialize(LuaSerializer { lua: self.lua })?;
let table = self.lua.create_table().map_err(Error::lua)?;
table.set(variant, value).map_err(Error::lua)?;
Ok(Value::Table(table))
}
fn serialize_seq(self, len: Option<usize>) -> Result<LuaSeqSerializer<'lua>, Error> {
self.serialize_tuple(len.unwrap_or(0))
}
fn serialize_tuple(self, _len: usize) -> Result<LuaSeqSerializer<'lua>, Error> {
let table = self.lua.create_table().map_err(Error::lua)?;
Ok(LuaSeqSerializer {
lua: self.lua,
table,
index: 1,
})
}
fn serialize_tuple_struct(
self,
_name: &'static str,
len: usize,
) -> Result<LuaSeqSerializer<'lua>, Error> {
self.serialize_tuple(len)
}
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<LuaTupleVariantSerializer<'lua>, Error> {
let table = self.lua.create_table().map_err(Error::lua)?;
Ok(LuaTupleVariantSerializer {
lua: self.lua,
table,
index: 1,
name: variant.to_string(),
})
}
fn serialize_map(
self,
_len: std::option::Option<usize>,
) -> Result<LuaMapSerializer<'lua>, Error> {
let table = self.lua.create_table().map_err(Error::lua)?;
Ok(LuaMapSerializer {
lua: self.lua,
table,
key: None,
})
}
fn serialize_struct(
self,
_name: &'static str,
len: usize,
) -> Result<LuaMapSerializer<'lua>, Error> {
self.serialize_map(Some(len))
}
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<LuaStructVariantSerializer<'lua>, Error> {
let table = self.lua.create_table().map_err(Error::lua)?;
Ok(LuaStructVariantSerializer {
lua: self.lua,
table,
name: variant.to_owned(),
})
}
}
impl<'lua> Serialize for ValueWrapper<'lua> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match &self.0 {
Value::Nil => serializer.serialize_unit(),
Value::Boolean(b) => serializer.serialize_bool(*b),
Value::Integer(i) => serializer.serialize_i64(*i),
Value::Number(n) => serializer.serialize_f64(*n),
Value::String(s) => match s.to_str() {
Ok(s) => serializer.serialize_str(s),
Err(_) => serializer.serialize_bytes(s.as_bytes()),
},
Value::Table(table) => {
if let Ok(true) = table.contains_key(1) {
let mut values = vec![];
for value in table.clone().sequence_values() {
match value {
Ok(value) => values.push(ValueWrapper(value)),
Err(err) => {
return Err(S::Error::custom(format!(
"while retrieving an array element: {}",
err
)))
}
}
}
values.serialize(serializer)
} else {
use serde::ser::SerializeMap;
let mut pairs = vec![];
for pair in table.clone().pairs::<Value, Value>() {
match pair {
Ok(pair) => pairs.push((ValueWrapper(pair.0), ValueWrapper(pair.1))),
Err(err) => {
return Err(S::Error::custom(format!(
"while retrieving map element: {}",
err
)))
}
}
}
let mut map = serializer.serialize_map(Some(pairs.len()))?;
for (k, v) in pairs.into_iter() {
map.serialize_entry(&k, &v)?;
}
map.end()
}
}
Value::UserData(_) | Value::LightUserData(_) => Err(S::Error::custom(
"cannot represent userdata in the serde data model",
)),
Value::Thread(_) => Err(S::Error::custom(
"cannot represent thread in the serde data model",
)),
Value::Function(_) => Err(S::Error::custom(
"cannot represent lua function in the serde data model",
)),
Value::Error(e) => Err(S::Error::custom(format!(
"cannot represent lua error {} in the serde data model",
e
))),
}
}
}