2022-06-24 19:50:51 -07:00

669 lines
22 KiB

using System.Collections;
using System.Collections.Immutable;
using System.Globalization;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Xml;
using System.Xml.Serialization;
namespace NAPS2.Serialization;
public abstract class XmlSerializer
protected static readonly XNamespace Xsi = "";
protected static readonly Dictionary<Type, CustomXmlSerializer> CustomSerializerCache = new();
protected static readonly Dictionary<Type, List<CustomXmlTypes>> CustomTypesCache = new();
protected static readonly List<Type> ArrayLikeTypes = new()
protected static readonly Dictionary<Type, XmlTypeInfo> TypeInfoCache = new()
{ typeof(char), new XmlTypeInfo { CustomSerializer = new CharSerializer() } },
{ typeof(string), new XmlTypeInfo { CustomSerializer = new StringSerializer() } },
{ typeof(bool), new XmlTypeInfo { CustomSerializer = new BooleanSerializer() } },
{ typeof(byte), new XmlTypeInfo { CustomSerializer = new ByteSerializer() } },
{ typeof(sbyte), new XmlTypeInfo { CustomSerializer = new SByteSerializer() } },
{ typeof(short), new XmlTypeInfo { CustomSerializer = new Int16Serializer() } },
{ typeof(ushort), new XmlTypeInfo { CustomSerializer = new UInt16Serializer() } },
{ typeof(int), new XmlTypeInfo { CustomSerializer = new Int32Serializer() } },
{ typeof(uint), new XmlTypeInfo { CustomSerializer = new UInt32Serializer() } },
{ typeof(long), new XmlTypeInfo { CustomSerializer = new Int64Serializer() } },
{ typeof(ulong), new XmlTypeInfo { CustomSerializer = new UInt64Serializer() } },
{ typeof(float), new XmlTypeInfo { CustomSerializer = new SingleSerializer() } },
{ typeof(double), new XmlTypeInfo { CustomSerializer = new DoubleSerializer() } },
{ typeof(decimal), new XmlTypeInfo { CustomSerializer = new DecimalSerializer() } },
{ typeof(IntPtr), new XmlTypeInfo { CustomSerializer = new IntPtrSerializer() } },
{ typeof(UIntPtr), new XmlTypeInfo { CustomSerializer = new UIntPtrSerializer() } },
{ typeof(List<>), new XmlTypeInfo { CustomSerializer = new CollectionSerializer() } },
{ typeof(HashSet<>), new XmlTypeInfo { CustomSerializer = new CollectionSerializer() } },
{ typeof(ImmutableList<>), new XmlTypeInfo { CustomSerializer = new ImmutableListSerializer() } },
{ typeof(ImmutableHashSet<>), new XmlTypeInfo { CustomSerializer = new ImmutableHashSetSerializer() } },
{ typeof(DateTime), new XmlTypeInfo { CustomSerializer = new DateTimeSerializer() } },
{ typeof(Nullable<>), new XmlTypeInfo { CustomSerializer = new NullableSerializer() } },
typeof(Transform), new XmlTypeInfo
KnownTypes = new HashSet<Type>(Assembly
.Where(t => typeof(Transform).IsAssignableFrom(t)))
public static void RegisterCustomSerializer<T>(CustomXmlSerializer<T> customSerializer)
RegisterCustomSerializer(typeof(T), customSerializer);
public static void RegisterCustomSerializer(Type type, CustomXmlSerializer customSerializer)
lock (TypeInfoCache)
CustomSerializerCache[type] = customSerializer;
public static void RegisterCustomTypes<T>(CustomXmlTypes<T> customTypes)
RegisterCustomTypes(typeof(T), customTypes);
public static void RegisterCustomTypes(Type type, CustomXmlTypes customTypes)
lock (TypeInfoCache)
CustomTypesCache.GetOrSet(type, new List<CustomXmlTypes>()).Add(customTypes);
protected static XElement SerializeInternal(object? obj, XElement element, Type type)
if (obj == null)
element.SetAttributeValue(Xsi + "nil", "true");
return element;
var actualType = Nullable.GetUnderlyingType(type) != null ? type : obj.GetType();
if (actualType != type)
element.SetAttributeValue(Xsi + "type", actualType.Name);
var typeInfo = GetTypeInfo(actualType);
if (typeInfo.CustomSerializer != null)
typeInfo.CustomSerializer.SerializeObject(obj, element, type);
foreach (var propInfo in typeInfo.Properties!)
var child = SerializeInternal(propInfo.Property.GetValue(obj), new XElement(propInfo.Property.Name), propInfo.Property.PropertyType);
return element;
protected static object? DeserializeInternal(XElement element, Type type)
if (element.Attribute(Xsi + "nil")?.Value == "true")
return null;
var actualTypeName = element.Attribute(Xsi + "type")?.Value;
var actualType = type;
if (actualTypeName != null)
actualType = FindType(type, actualTypeName);
if (actualType == null)
throw new InvalidOperationException($"Could not find type {actualTypeName} with base type {type.FullName}");
var typeInfo = GetTypeInfo(actualType);
if (typeInfo.CustomSerializer != null)
return typeInfo.CustomSerializer.DeserializeObject(element, actualType);
var obj = Activator.CreateInstance(actualType, true);
foreach (var propInfo in typeInfo.Properties)
// TODO: Detect unmapped elements
var child = element.Element(propInfo.Property.Name);
if (child != null)
var childObj = DeserializeInternal(child, propInfo.Property.PropertyType);
propInfo.Property.SetValue(obj, childObj);
return obj;
protected static Type? FindType(Type baseType, string actualTypeName)
lock (TypeInfoCache)
return GetTypeInfo(baseType).KnownTypes?.SingleOrDefault(x => GetElementNameForType(x) == actualTypeName);
protected static string GetElementNameForType(Type type)
// TODO: Cache this?
var xmlNameAttribute = type.GetCustomAttributes(typeof(XmlTypeAttribute)).Cast<XmlTypeAttribute>().FirstOrDefault();
if (xmlNameAttribute != null)
return xmlNameAttribute.TypeName;
if (type.IsArray)
return "ArrayOf" + GetElementNameForType(type.GetElementType()!);
if (type.IsGenericType)
string baseName;
if (ArrayLikeTypes.Contains(type.GetGenericTypeDefinition()))
baseName = "Array";
var backtickIndex = type.Name.IndexOf('`');
baseName = type.Name.Substring(0, backtickIndex);
return baseName + "Of" + string.Join("", type.GetGenericArguments().Select(GetElementNameForType));
return type.Name;
protected static XmlTypeInfo GetTypeInfo(Type type)
lock (TypeInfoCache)
return TypeInfoCache.GetOrSet(type, () =>
var props = type.GetProperties(BindingFlags.Instance | BindingFlags.Public)
.Where(x => x.GetMethod != null && x.SetMethod != null && !x.GetCustomAttributes(typeof(XmlIgnoreAttribute)).Any())
var typeInfo = new XmlTypeInfo
Properties = props.Select(x => new XmlPropertyInfo(x)).ToArray(),
CustomSerializer = GetCustomSerializer(type)
if (typeInfo.CustomSerializer == null)
if (type.IsGenericType)
var baseType = type.GetGenericTypeDefinition();
typeInfo.CustomSerializer = TypeInfoCache.Get(baseType)?.CustomSerializer;
if (type.IsArray)
typeInfo.CustomSerializer = new ArraySerializer();
else if (type.IsEnum)
typeInfo.CustomSerializer = new EnumSerializer();
if (typeInfo.CustomSerializer == null)
typeInfo.KnownTypes = props
.Select(x => GetTypeInfo(x.PropertyType).KnownTypes)
.Where(x => x != null)
.Append(type.GetCustomAttributes(typeof(XmlIncludeAttribute)).Cast<XmlIncludeAttribute>().Select(x => x.Type))
.Aggregate(new HashSet<Type>(), (x, y) => new HashSet<Type>(x.Union(y)));
return typeInfo;
private static int GetPropertyOrder(PropertyInfo prop)
return prop.GetCustomAttributes<XmlElementAttribute>().FirstOrDefault()?.Order ?? int.MaxValue;
protected static IEnumerable<Type> GetKnownTypes(Type type)
lock (TypeInfoCache)
return CustomTypesCache.Get(type)?.SelectMany(x => x.GetKnownTypes(type)) ?? Enumerable.Empty<Type>();
protected static CustomXmlSerializer? GetCustomSerializer(Type type)
lock (TypeInfoCache)
return CustomSerializerCache.Get(type);
protected class XmlTypeInfo
public XmlPropertyInfo[]? Properties { get; set; }
public CustomXmlSerializer? CustomSerializer { get; set; }
public HashSet<Type>? KnownTypes { get; set; }
protected record XmlPropertyInfo(PropertyInfo Property);
protected class CharSerializer : CustomXmlSerializer<char>
protected override void Serialize(char obj, XElement element)
element.Value = obj.ToString();
protected override char Deserialize(XElement element)
return char.Parse(element.Value);
protected class StringSerializer : CustomXmlSerializer<string>
protected override void Serialize(string obj, XElement element)
element.Value = obj;
protected override string Deserialize(XElement element)
return element.Value;
protected class BooleanSerializer : CustomXmlSerializer<bool>
protected override void Serialize(bool obj, XElement element)
element.Value = obj ? "true" : "false";
protected override bool Deserialize(XElement element)
return bool.Parse(element.Value);
protected class ByteSerializer : CustomXmlSerializer<byte>
protected override void Serialize(byte obj, XElement element)
element.Value = obj.ToString();
protected override byte Deserialize(XElement element)
return byte.Parse(element.Value);
protected class SByteSerializer : CustomXmlSerializer<sbyte>
protected override void Serialize(sbyte obj, XElement element)
element.Value = obj.ToString();
protected override sbyte Deserialize(XElement element)
return sbyte.Parse(element.Value);
protected class Int16Serializer : CustomXmlSerializer<short>
protected override void Serialize(short obj, XElement element)
element.Value = obj.ToString();
protected override short Deserialize(XElement element)
return short.Parse(element.Value);
protected class UInt16Serializer : CustomXmlSerializer<ushort>
protected override void Serialize(ushort obj, XElement element)
element.Value = obj.ToString();
protected override ushort Deserialize(XElement element)
return ushort.Parse(element.Value);
protected class Int32Serializer : CustomXmlSerializer<int>
protected override void Serialize(int obj, XElement element)
element.Value = obj.ToString();
protected override int Deserialize(XElement element)
return int.Parse(element.Value);
protected class UInt32Serializer : CustomXmlSerializer<uint>
protected override void Serialize(uint obj, XElement element)
element.Value = obj.ToString();
protected override uint Deserialize(XElement element)
return uint.Parse(element.Value);
protected class Int64Serializer : CustomXmlSerializer<long>
protected override void Serialize(long obj, XElement element)
element.Value = obj.ToString();
protected override long Deserialize(XElement element)
return long.Parse(element.Value);
protected class UInt64Serializer : CustomXmlSerializer<ulong>
protected override void Serialize(ulong obj, XElement element)
element.Value = obj.ToString();
protected override ulong Deserialize(XElement element)
return ulong.Parse(element.Value);
protected class SingleSerializer : CustomXmlSerializer<float>
protected override void Serialize(float obj, XElement element)
element.Value = obj.ToString(CultureInfo.InvariantCulture);
protected override float Deserialize(XElement element)
return float.Parse(element.Value);
protected class DoubleSerializer : CustomXmlSerializer<double>
protected override void Serialize(double obj, XElement element)
element.Value = obj.ToString(CultureInfo.InvariantCulture);
protected override double Deserialize(XElement element)
return double.Parse(element.Value);
protected class DecimalSerializer : CustomXmlSerializer<decimal>
protected override void Serialize(decimal obj, XElement element)
element.Value = obj.ToString(CultureInfo.InvariantCulture);
protected override decimal Deserialize(XElement element)
return decimal.Parse(element.Value);
protected class IntPtrSerializer : CustomXmlSerializer<IntPtr>
protected override void Serialize(IntPtr obj, XElement element)
element.Value = obj.ToString();
protected override IntPtr Deserialize(XElement element)
return (IntPtr)long.Parse(element.Value);
protected class UIntPtrSerializer : CustomXmlSerializer<UIntPtr>
protected override void Serialize(UIntPtr obj, XElement element)
element.Value = obj.ToString();
protected override UIntPtr Deserialize(XElement element)
return (UIntPtr)ulong.Parse(element.Value);
protected class ArraySerializer : CustomXmlSerializer
public override void SerializeObject(object obj, XElement element, Type type)
var itemType = type.GetElementType() ?? throw new ArgumentException("Not an array type");
var list = (IList)obj;
foreach (var item in list)
element.Add(SerializeInternal(item, new XElement(itemType.Name), itemType));
public override object DeserializeObject(XElement element, Type type)
var itemType = type.GetElementType() ?? throw new ArgumentException("Not an array type");
var elements = element.Elements().ToArray();
var array = Array.CreateInstance(itemType, elements.Length);
for (int i = 0; i < elements.Length; i++)
array.SetValue(DeserializeInternal(elements[i], itemType), i);
return array;
protected class CollectionSerializer : CustomXmlSerializer
public override void SerializeObject(object obj, XElement element, Type type)
var itemType = GetItemType(type);
var list = (IEnumerable)obj;
foreach (var item in list)
element.Add(SerializeInternal(item, new XElement(itemType.Name), itemType));
protected Type GetItemType(Type type)
if (type.IsGenericType)
var typeArgs = type.GetGenericArguments();
if (typeArgs.Length != 1)
throw new ArgumentException("Unexpected number of generic arguments for list type");
return typeArgs[0];
return typeof(object);
public override object DeserializeObject(XElement element, Type type)
var itemType = GetItemType(type);
var list = CreateInstance(type, itemType);
var add = list.GetType().GetMethod("Add", BindingFlags.Public | BindingFlags.Instance) ??
throw new ArgumentException("Collection type has no Add method");
if (add.ReturnType == list.GetType())
// Handle immutable collections
foreach (var itemElement in element.Elements())
list = add.Invoke(list, new[] { DeserializeInternal(itemElement, itemType) });
foreach (var itemElement in element.Elements())
add.Invoke(list, new[] { DeserializeInternal(itemElement, itemType) });
return list;
protected virtual object CreateInstance(Type type, Type itemType) => Activator.CreateInstance(type, true);
protected class ImmutableListSerializer : CollectionSerializer
protected override object CreateInstance(Type type, Type itemType)
var emptyField = typeof(ImmutableList<>).MakeGenericType(itemType).GetField("Empty", BindingFlags.Public | BindingFlags.Static) ??
throw new Exception("No Empty field on ImmutableList");
return emptyField.GetValue(null);
protected class ImmutableHashSetSerializer : CollectionSerializer
protected override object CreateInstance(Type type, Type itemType)
var emptyField = typeof(ImmutableHashSet<>).MakeGenericType(itemType).GetField("Empty", BindingFlags.Public | BindingFlags.Static) ??
throw new Exception("No Empty field on ImmutableHashSet");
return emptyField.GetValue(null);
protected class DateTimeSerializer : CustomXmlSerializer<DateTime>
protected override void Serialize(DateTime obj, XElement element)
element.Value = obj.ToString("O", CultureInfo.InvariantCulture);
protected override DateTime Deserialize(XElement element)
return DateTime.Parse(element.Value, CultureInfo.InvariantCulture);
protected class NullableSerializer : CustomXmlSerializer
public override void SerializeObject(object? obj, XElement element, Type type)
SerializeInternal(obj, element, Nullable.GetUnderlyingType(type)!);
public override object? DeserializeObject(XElement element, Type type)
return DeserializeInternal(element, Nullable.GetUnderlyingType(type)!);
protected class EnumSerializer : CustomXmlSerializer
public override void SerializeObject(object? obj, XElement element, Type type)
element.Value = obj!.ToString();
public override object DeserializeObject(XElement element, Type type)
return Enum.Parse(type, element.Value);
public class XmlSerializer<T> : XmlSerializer, ISerializer<T>
public void Serialize(Stream stream, T obj)
using var writer = new XmlTextWriter(new StreamWriter(stream, new UTF8Encoding(false), 1024, true)) { Formatting = Formatting.Indented };
public XDocument SerializeToXDocument(T obj)
var doc = new XDocument();
var root = SerializeToXElement(obj, GetElementNameForType(typeof(T)));
root.SetAttributeValue(XNamespace.Xmlns + "xsi", Xsi.NamespaceName);
return doc;
public XElement SerializeToXElement(T obj, string elementName)
return SerializeInternal(obj, new XElement(elementName), typeof(T));
public XElement SerializeToXElement(T obj, XElement element)
return SerializeInternal(obj, element, typeof(T));
public T? Deserialize(Stream stream)
return DeserializeFromXDocument(XDocument.Load(stream));
public T? DeserializeFromXDocument(XDocument doc)
if (doc.Root?.Name != GetElementNameForType(typeof(T)))
throw new InvalidOperationException(
$"Could not map XML element <{doc.Root?.Name}> to {typeof(T).FullName}. Expected <{GetElementNameForType(typeof(T))}>.");
return DeserializeFromXElement(doc.Root);
public T? DeserializeFromXElement(XElement element)
return (T?)DeserializeInternal(element, typeof(T));