diff --git a/AK/Variant.h b/AK/Variant.h index 8962b7ee0a0..eb6500bc442 100644 --- a/AK/Variant.h +++ b/AK/Variant.h @@ -406,30 +406,50 @@ public: } template - Variant downcast() && + decltype(auto) downcast() && { - Variant instance { Variant::invalid_index, Detail::VariantConstructTag {} }; - visit([&](auto& value) { - if constexpr (Variant::template can_contain>()) - instance.set(move(value), Detail::VariantNoClearTag {}); - }); - VERIFY(instance.m_index != instance.invalid_index); - return instance; + if constexpr (sizeof...(NewTs) == 1 && (IsSpecializationOf && ...)) { + return move(*this).template downcast_variant(); + } else { + Variant instance { Variant::invalid_index, Detail::VariantConstructTag {} }; + visit([&](auto& value) { + if constexpr (Variant::template can_contain>()) + instance.set(move(value), Detail::VariantNoClearTag {}); + }); + VERIFY(instance.m_index != instance.invalid_index); + return instance; + } } template - Variant downcast() const& + decltype(auto) downcast() const& { - Variant instance { Variant::invalid_index, Detail::VariantConstructTag {} }; - visit([&](auto const& value) { - if constexpr (Variant::template can_contain>()) - instance.set(value, Detail::VariantNoClearTag {}); - }); - VERIFY(instance.m_index != instance.invalid_index); - return instance; + if constexpr (sizeof...(NewTs) == 1 && (IsSpecializationOf && ...)) { + return (*this).template downcast_variant(TypeWrapper {}); + } else { + Variant instance { Variant::invalid_index, Detail::VariantConstructTag {} }; + visit([&](auto const& value) { + if constexpr (Variant::template can_contain>()) + instance.set(value, Detail::VariantNoClearTag {}); + }); + VERIFY(instance.m_index != instance.invalid_index); + return instance; + } } private: + template + Variant downcast_variant(TypeWrapper>) && + { + return move(*this).template downcast(); + } + + template + Variant downcast_variant(TypeWrapper>) const& + { + return (*this).template downcast(); + } + static constexpr auto data_size = Detail::integer_sequence_generate_array(0, IntegerSequence()).max(); static constexpr auto data_alignment = Detail::integer_sequence_generate_array(0, IntegerSequence()).max(); using Helper = Detail::Variant; diff --git a/Tests/AK/TestVariant.cpp b/Tests/AK/TestVariant.cpp index f55d176537e..ea7998a0992 100644 --- a/Tests/AK/TestVariant.cpp +++ b/Tests/AK/TestVariant.cpp @@ -125,6 +125,13 @@ TEST_CASE(verify_cast) EXPECT(one_integer_to_rule_them_all.has()); EXPECT_EQ(fake_integer.get(), 60); EXPECT_EQ(one_integer_to_rule_them_all.get(), 60); + + using SomeFancyType = Variant; + one_integer_to_rule_them_all = fake_integer.downcast(); + EXPECT(fake_integer.has()); + EXPECT(one_integer_to_rule_them_all.has()); + EXPECT_EQ(fake_integer.get(), 60); + EXPECT_EQ(one_integer_to_rule_them_all.get(), 60); } TEST_CASE(moved_from_state)