/* * Copyright (C) 2016 Apple Inc. All rights reserved. * Copyright (c) 2021, Gunnar Beutner * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS'' * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF * THE POSSIBILITY OF SUCH DAMAGE. */ #pragma once #include #include #include #include #include #include #include namespace AK { template class Function; template inline constexpr bool IsFunctionPointer = (IsPointer && IsFunction>); // Not a function pointer, and not an lvalue reference. template inline constexpr bool IsFunctionObject = (!IsFunctionPointer && IsRvalueReference); template class Function { AK_MAKE_NONCOPYABLE(Function); public: using ReturnType = Out; Function() = default; Function(std::nullptr_t) { } ~Function() { clear(false); } template Function(CallableType&& callable) requires((IsFunctionObject && IsCallableWithArguments && !IsSame, Function>)) { init_with_callable(forward(callable)); } template Function(FunctionType f) requires((IsFunctionPointer && IsCallableWithArguments, In...> && !IsSame, Function>)) { init_with_callable(move(f)); } Function(Function&& other) { move_from(move(other)); } // Note: Despite this method being const, a mutable lambda _may_ modify its own captures. Out operator()(In... in) const { auto* wrapper = callable_wrapper(); VERIFY(wrapper); ++m_call_nesting_level; ScopeGuard guard([this] { if (--m_call_nesting_level == 0 && m_deferred_clear) const_cast(this)->clear(false); }); return wrapper->call(forward(in)...); } explicit operator bool() const { return !!callable_wrapper(); } template Function& operator=(CallableType&& callable) requires((IsFunctionObject && IsCallableWithArguments)) { clear(); init_with_callable(forward(callable)); return *this; } template Function& operator=(FunctionType f) requires((IsFunctionPointer && IsCallableWithArguments, In...>)) { clear(); if (f) init_with_callable(move(f)); return *this; } Function& operator=(std::nullptr_t) { clear(); return *this; } Function& operator=(Function&& other) { if (this != &other) { clear(); move_from(move(other)); } return *this; } private: class CallableWrapperBase { public: virtual ~CallableWrapperBase() = default; // Note: This is not const to allow storing mutable lambdas. virtual Out call(In...) = 0; virtual void destroy() = 0; virtual void init_and_swap(u8*, size_t) = 0; }; template class CallableWrapper final : public CallableWrapperBase { AK_MAKE_NONMOVABLE(CallableWrapper); AK_MAKE_NONCOPYABLE(CallableWrapper); public: explicit CallableWrapper(CallableType&& callable) : m_callable(move(callable)) { } Out call(In... in) final override { return m_callable(forward(in)...); } void destroy() final override { delete this; } // NOLINTNEXTLINE(readability-non-const-parameter) False positive; destination is used in a placement new expression void init_and_swap(u8* destination, size_t size) final override { VERIFY(size >= sizeof(CallableWrapper)); new (destination) CallableWrapper { move(m_callable) }; } private: CallableType m_callable; }; enum class FunctionKind { NullPointer, Inline, Outline, }; CallableWrapperBase* callable_wrapper() const { switch (m_kind) { case FunctionKind::NullPointer: return nullptr; case FunctionKind::Inline: return bit_cast(&m_storage); case FunctionKind::Outline: return *bit_cast(&m_storage); default: VERIFY_NOT_REACHED(); } } void clear(bool may_defer = true) { bool called_from_inside_function = m_call_nesting_level > 0; // NOTE: This VERIFY could fail because a Function is destroyed from within itself. VERIFY(may_defer || !called_from_inside_function); if (called_from_inside_function && may_defer) { m_deferred_clear = true; return; } m_deferred_clear = false; auto* wrapper = callable_wrapper(); if (m_kind == FunctionKind::Inline) { VERIFY(wrapper); wrapper->~CallableWrapperBase(); } else if (m_kind == FunctionKind::Outline) { VERIFY(wrapper); wrapper->destroy(); } m_kind = FunctionKind::NullPointer; } template void init_with_callable(Callable&& callable) { VERIFY(m_call_nesting_level == 0); using WrapperType = CallableWrapper; #ifndef KERNEL if constexpr (sizeof(WrapperType) > inline_capacity) { *bit_cast(&m_storage) = new WrapperType(forward(callable)); m_kind = FunctionKind::Outline; } else { #endif static_assert(sizeof(WrapperType) <= inline_capacity); new (m_storage) WrapperType(forward(callable)); m_kind = FunctionKind::Inline; #ifndef KERNEL } #endif } void move_from(Function&& other) { VERIFY(m_call_nesting_level == 0 && other.m_call_nesting_level == 0); auto* other_wrapper = other.callable_wrapper(); switch (other.m_kind) { case FunctionKind::NullPointer: break; case FunctionKind::Inline: other_wrapper->init_and_swap(m_storage, inline_capacity); m_kind = FunctionKind::Inline; break; case FunctionKind::Outline: *bit_cast(&m_storage) = other_wrapper; m_kind = FunctionKind::Outline; break; default: VERIFY_NOT_REACHED(); } other.m_kind = FunctionKind::NullPointer; } FunctionKind m_kind { FunctionKind::NullPointer }; bool m_deferred_clear { false }; mutable Atomic m_call_nesting_level { 0 }; #ifndef KERNEL // Empirically determined to fit most lambdas and functions. static constexpr size_t inline_capacity = 4 * sizeof(void*); #else // FIXME: Try to decrease this. static constexpr size_t inline_capacity = 6 * sizeof(void*); #endif alignas(max(alignof(CallableWrapperBase), alignof(CallableWrapperBase*))) u8 m_storage[inline_capacity]; }; } using AK::Function;