LibCore: Add a class for thread-safe promises

Since the existing Promise class is designed with deferred tasks on the
main thread only, we need a new class that will ensure we can handle
promises that are resolved/rejected off the main thread.

This new class ensures that the callbacks are only called on the same
thread that the promise is fulfilled from. If the callbacks are not set
before the thread tries to fulfill the promise, it will spin until they
are so that they will run on that thread.
This commit is contained in:
Zaggy1024 2023-07-11 20:48:56 -05:00 committed by Andrew Kaster
parent 8626404ddb
commit fe672989a9
Notes: sideshowbarker 2024-07-16 21:30:46 +09:00
7 changed files with 321 additions and 4 deletions

View File

@ -662,6 +662,7 @@ if (BUILD_LAGOM)
# LibCore
if ((LINUX OR APPLE) AND NOT EMSCRIPTEN)
lagom_test(../../Tests/LibCore/TestLibCoreFileWatcher.cpp)
lagom_test(../../Tests/LibCore/TestLibCorePromise.cpp LIBS LibThreading)
endif()
# RegexLibC test POSIX <regex.h> and contains many Serenity extensions

View File

@ -12,6 +12,7 @@ foreach(source IN LISTS TEST_SOURCES)
serenity_test("${source}" LibCore)
endforeach()
target_link_libraries(TestLibCorePromise PRIVATE LibThreading)
# NOTE: Required because of the LocalServer tests
target_link_libraries(TestLibCoreStream PRIVATE LibThreading)
target_link_libraries(TestLibCoreSharedSingleProducerCircularQueue PRIVATE LibThreading)

View File

@ -6,7 +6,10 @@
#include <LibCore/EventLoop.h>
#include <LibCore/Promise.h>
#include <LibCore/ThreadedPromise.h>
#include <LibTest/TestSuite.h>
#include <LibThreading/Thread.h>
#include <unistd.h>
TEST_CASE(promise_await_async_event)
{
@ -57,3 +60,108 @@ TEST_CASE(promise_chain_handlers)
EXPECT(resolved);
EXPECT(!rejected);
}
TEST_CASE(threaded_promise_instantly_resolved)
{
Core::EventLoop loop;
bool resolved = false;
bool rejected = true;
Optional<pthread_t> thread_id;
auto promise = Core::ThreadedPromise<int>::create();
auto thread = Threading::Thread::construct([&, promise] {
thread_id = pthread_self();
promise->resolve(42);
return 0;
});
thread->start();
promise
->when_resolved([&](int result) {
EXPECT(thread_id.has_value());
EXPECT(pthread_equal(thread_id.value(), pthread_self()));
resolved = true;
rejected = false;
EXPECT_EQ(result, 42);
})
.when_rejected([](Error&&) {
VERIFY_NOT_REACHED();
});
promise->await();
EXPECT(promise->has_completed());
EXPECT(resolved);
EXPECT(!rejected);
MUST(thread->join());
}
TEST_CASE(threaded_promise_resolved_later)
{
Core::EventLoop loop;
bool unblock_thread = false;
bool resolved = false;
bool rejected = true;
Optional<pthread_t> thread_id;
auto promise = Core::ThreadedPromise<int>::create();
auto thread = Threading::Thread::construct([&, promise] {
thread_id = pthread_self();
while (!unblock_thread)
usleep(500);
promise->resolve(42);
return 0;
});
thread->start();
promise
->when_resolved([&]() {
EXPECT(thread_id.has_value());
EXPECT(pthread_equal(thread_id.value(), pthread_self()));
EXPECT(unblock_thread);
resolved = true;
rejected = false;
})
.when_rejected([](Error&&) {
VERIFY_NOT_REACHED();
});
Core::EventLoop::current().deferred_invoke([&]() { unblock_thread = true; });
promise->await();
EXPECT(promise->has_completed());
EXPECT(unblock_thread);
EXPECT(resolved);
EXPECT(!rejected);
MUST(thread->join());
}
TEST_CASE(threaded_promise_synchronously_resolved)
{
Core::EventLoop loop;
bool resolved = false;
bool rejected = true;
auto thread_id = pthread_self();
auto promise = Core::ThreadedPromise<int>::create();
promise->resolve(1337);
promise
->when_resolved([&]() {
EXPECT(pthread_equal(thread_id, pthread_self()));
resolved = true;
rejected = false;
})
.when_rejected([](Error&&) {
VERIFY_NOT_REACHED();
});
promise->await();
EXPECT(promise->has_completed());
EXPECT(resolved);
EXPECT(!rejected);
}

View File

@ -17,12 +17,17 @@
namespace Core {
namespace {
Vector<EventLoop&>& event_loop_stack()
OwnPtr<Vector<EventLoop&>>& event_loop_stack_uninitialized()
{
thread_local OwnPtr<Vector<EventLoop&>> s_event_loop_stack = nullptr;
if (s_event_loop_stack == nullptr)
s_event_loop_stack = make<Vector<EventLoop&>>();
return *s_event_loop_stack;
return s_event_loop_stack;
}
Vector<EventLoop&>& event_loop_stack()
{
auto& the_stack = event_loop_stack_uninitialized();
if (the_stack == nullptr)
the_stack = make<Vector<EventLoop&>>();
return *the_stack;
}
}
@ -41,6 +46,12 @@ EventLoop::~EventLoop()
}
}
bool EventLoop::is_running()
{
auto& stack = event_loop_stack_uninitialized();
return stack != nullptr && !stack->is_empty();
}
EventLoop& EventLoop::current()
{
return event_loop_stack().last();

View File

@ -92,6 +92,7 @@ public:
};
static void notify_forked(ForkEvent);
static bool is_running();
static EventLoop& current();
EventLoopImplementation& impl() { return *m_impl; }

View File

@ -36,6 +36,8 @@ class ProcessStatisticsReader;
class Socket;
template<typename Result, typename TError = AK::Error>
class Promise;
template<typename Result, typename TError = AK::Error>
class ThreadedPromise;
class SocketAddress;
class TCPServer;
class TCPSocket;

View File

@ -0,0 +1,193 @@
/*
* Copyright (c) 2021, Kyle Pereira <hey@xylepereira.me>
* Copyright (c) 2022, kleines Filmröllchen <filmroellchen@serenityos.org>
* Copyright (c) 2021-2023, Ali Mohammad Pur <mpfard@serenityos.org>
* Copyright (c) 2023, Gregory Bertilson <zaggy1024@gmail.com>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#pragma once
#include <AK/AtomicRefCounted.h>
#include <AK/Concepts.h>
#include <LibCore/EventLoop.h>
#include <LibCore/Object.h>
#include <LibThreading/Mutex.h>
namespace Core {
template<typename TResult, typename TError>
class ThreadedPromise
: public AtomicRefCounted<ThreadedPromise<TResult, TError>> {
public:
static NonnullRefPtr<ThreadedPromise<TResult, TError>> create()
{
return adopt_ref(*new ThreadedPromise<TResult, TError>());
}
using ResultType = Conditional<IsSame<TResult, void>, Empty, TResult>;
using ErrorType = TError;
void resolve(ResultType&& result)
{
when_error_handler_is_ready([self = NonnullRefPtr(*this), result = move(result)]() mutable {
if (self->m_resolution_handler) {
auto handler_result = self->m_resolution_handler(forward<ResultType>(result));
if (handler_result.is_error())
self->m_rejection_handler(handler_result.release_error());
self->m_has_completed = true;
}
});
}
void resolve()
requires IsSame<ResultType, Empty>
{
resolve(Empty());
}
void reject(ErrorType&& error)
{
when_error_handler_is_ready([this, error = move(error)]() mutable {
m_rejection_handler(forward<ErrorType>(error));
m_has_completed = true;
});
}
void reject(ErrorType const& error)
requires IsTriviallyCopyable<ErrorType>
{
reject(ErrorType(error));
}
bool has_completed()
{
Threading::MutexLocker locker { m_mutex };
return m_has_completed;
}
void await()
{
while (!has_completed())
Core::EventLoop::current().pump(EventLoop::WaitMode::PollForEvents);
}
// Set the callback to be called when the promise is resolved. A rejection callback
// must also be provided before any callback will be called.
template<CallableAs<ErrorOr<void>, ResultType&&> ResolvedHandler>
ThreadedPromise& when_resolved(ResolvedHandler handler)
{
Threading::MutexLocker locker { m_mutex };
VERIFY(!m_resolution_handler);
m_resolution_handler = move(handler);
return *this;
}
template<CallableAs<void, ResultType&&> ResolvedHandler>
ThreadedPromise& when_resolved(ResolvedHandler handler)
{
return when_resolved([handler = move(handler)](ResultType&& result) -> ErrorOr<void> {
handler(forward<ResultType>(result));
return {};
});
}
template<CallableAs<ErrorOr<void>> ResolvedHandler>
ThreadedPromise& when_resolved(ResolvedHandler handler)
{
return when_resolved([handler = move(handler)](ResultType&&) -> ErrorOr<void> {
return handler();
});
}
template<CallableAs<void> ResolvedHandler>
ThreadedPromise& when_resolved(ResolvedHandler handler)
{
return when_resolved([handler = move(handler)](ResultType&&) -> ErrorOr<void> {
handler();
return {};
});
}
// Set the callback to be called when the promise is rejected. Setting this callback
// will cause the promise fulfillment to be ready to be handled.
template<CallableAs<void, ErrorType&&> RejectedHandler>
ThreadedPromise& when_rejected(RejectedHandler when_rejected = [](ErrorType&) {})
{
Threading::MutexLocker locker { m_mutex };
VERIFY(!m_rejection_handler);
m_rejection_handler = move(when_rejected);
return *this;
}
template<typename T, CallableAs<NonnullRefPtr<ThreadedPromise<T, ErrorType>>, ResultType&&> ChainedResolution>
NonnullRefPtr<ThreadedPromise<T, ErrorType>> chain_promise(ChainedResolution chained_resolution)
{
auto new_promise = ThreadedPromise<T, ErrorType>::create();
when_resolved([=, chained_resolution = move(chained_resolution)](ResultType&& result) mutable -> ErrorOr<void> {
chained_resolution(forward<ResultType>(result))
->when_resolved([=](auto&& new_result) { new_promise->resolve(move(new_result)); })
.when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); });
return {};
});
when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); });
return new_promise;
}
template<typename T, CallableAs<ErrorOr<T, ErrorType>, ResultType&&> MappingFunction>
NonnullRefPtr<ThreadedPromise<T, ErrorType>> map(MappingFunction mapping_function)
{
auto new_promise = ThreadedPromise<T, ErrorType>::create();
when_resolved([=, mapping_function = move(mapping_function)](ResultType&& result) -> ErrorOr<void> {
new_promise->resolve(TRY(mapping_function(forward<ResultType>(result))));
return {};
});
when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); });
return new_promise;
}
private:
template<typename F>
static void deferred_handler_check(NonnullRefPtr<ThreadedPromise> self, F&& function)
{
Threading::MutexLocker locker { self->m_mutex };
if (self->m_rejection_handler) {
function();
return;
}
EventLoop::current().deferred_invoke([self, function = forward<F>(function)]() mutable {
deferred_handler_check(self, move(function));
});
}
template<typename F>
void when_error_handler_is_ready(F function)
{
if (EventLoop::is_running()) {
deferred_handler_check(NonnullRefPtr(*this), move(function));
} else {
// NOTE: Handlers should always be set almost immediately, so we can expect this
// to spin extremely briefly. Therefore, sleeping the thread should not be
// necessary.
while (true) {
Threading::MutexLocker locker { m_mutex };
if (m_rejection_handler)
break;
}
VERIFY(m_rejection_handler);
function();
}
}
ThreadedPromise() = default;
ThreadedPromise(Object* parent)
: Object(parent)
{
}
Function<ErrorOr<void>(ResultType&&)> m_resolution_handler;
Function<void(ErrorType&&)> m_rejection_handler;
Threading::Mutex m_mutex;
bool m_has_completed;
};
}