mirror of
https://github.com/LadybirdBrowser/ladybird.git
synced 2024-12-29 14:14:45 +03:00
ProtocolServer+LibTLS: Pipe certificate requests from LibTLS to clients
This makes gemini.circumlunar.space (and some more gemini pages) work again :^)
This commit is contained in:
parent
9d3ffa096a
commit
97256ad977
Notes:
sideshowbarker
2024-07-19 04:22:42 +09:00
Author: https://github.com/alimpfard Commit: https://github.com/SerenityOS/serenity/commit/97256ad977d Pull-request: https://github.com/SerenityOS/serenity/pull/2956
@ -63,6 +63,10 @@ void GeminiJob::start()
|
||||
m_socket->on_tls_finished = [this] {
|
||||
finish_up();
|
||||
};
|
||||
m_socket->on_tls_certificate_request = [this](auto&) {
|
||||
if (on_certificate_requested)
|
||||
on_certificate_requested(*this);
|
||||
};
|
||||
bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port());
|
||||
if (!success) {
|
||||
deferred_invoke([this](auto&) {
|
||||
@ -89,6 +93,15 @@ void GeminiJob::read_while_data_available(Function<IterationDecision()> read)
|
||||
}
|
||||
}
|
||||
|
||||
void GeminiJob::set_certificate(String certificate, String private_key)
|
||||
{
|
||||
if (!m_socket->add_client_key(ByteBuffer::wrap(certificate.characters(), certificate.length()), ByteBuffer::wrap(private_key.characters(), private_key.length()))) {
|
||||
dbg() << "LibGemini: Failed to set a client certificate";
|
||||
// FIXME: Do something about this failure
|
||||
ASSERT_NOT_REACHED();
|
||||
}
|
||||
}
|
||||
|
||||
void GeminiJob::register_on_ready_to_read(Function<void()> callback)
|
||||
{
|
||||
m_socket->on_tls_ready_to_read = [callback = move(callback)](auto&) {
|
||||
|
@ -48,6 +48,9 @@ public:
|
||||
|
||||
virtual void start() override;
|
||||
virtual void shutdown() override;
|
||||
void set_certificate(String certificate, String key);
|
||||
|
||||
Function<void(GeminiJob&)> on_certificate_requested;
|
||||
|
||||
protected:
|
||||
virtual void register_on_ready_to_read(Function<void()>) override;
|
||||
|
@ -64,6 +64,10 @@ void HttpsJob::start()
|
||||
m_socket->on_tls_finished = [&] {
|
||||
finish_up();
|
||||
};
|
||||
m_socket->on_tls_certificate_request = [this](auto&) {
|
||||
if (on_certificate_requested)
|
||||
on_certificate_requested(*this);
|
||||
};
|
||||
bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port());
|
||||
if (!success) {
|
||||
deferred_invoke([this](auto&) {
|
||||
@ -82,6 +86,15 @@ void HttpsJob::shutdown()
|
||||
m_socket = nullptr;
|
||||
}
|
||||
|
||||
void HttpsJob::set_certificate(String certificate, String private_key)
|
||||
{
|
||||
if (!m_socket->add_client_key(ByteBuffer::wrap(certificate.characters(), certificate.length()), ByteBuffer::wrap(private_key.characters(), private_key.length()))) {
|
||||
dbg() << "LibHTTP: Failed to set a client certificate";
|
||||
// FIXME: Do something about this failure
|
||||
ASSERT_NOT_REACHED();
|
||||
}
|
||||
}
|
||||
|
||||
void HttpsJob::read_while_data_available(Function<IterationDecision()> read)
|
||||
{
|
||||
while (m_socket->can_read()) {
|
||||
|
@ -49,6 +49,9 @@ public:
|
||||
|
||||
virtual void start() override;
|
||||
virtual void shutdown() override;
|
||||
void set_certificate(String certificate, String key);
|
||||
|
||||
Function<void(HttpsJob&)> on_certificate_requested;
|
||||
|
||||
protected:
|
||||
virtual void register_on_ready_to_read(Function<void()>) override;
|
||||
|
@ -68,6 +68,13 @@ bool Client::stop_download(Badge<Download>, Download& download)
|
||||
return send_sync<Messages::ProtocolServer::StopDownload>(download.id())->success();
|
||||
}
|
||||
|
||||
bool Client::set_certificate(Badge<Download>, Download& download, String certificate, String key)
|
||||
{
|
||||
if (!m_downloads.contains(download.id()))
|
||||
return false;
|
||||
return send_sync<Messages::ProtocolServer::SetCertificate>(download.id(), move(certificate), move(key))->success();
|
||||
}
|
||||
|
||||
void Client::handle(const Messages::ProtocolClient::DownloadFinished& message)
|
||||
{
|
||||
RefPtr<Download> download;
|
||||
@ -85,4 +92,13 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message)
|
||||
}
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message)
|
||||
{
|
||||
if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
|
||||
download->did_request_certificates({});
|
||||
}
|
||||
|
||||
return make<Messages::ProtocolClient::CertificateRequestedResponse>();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -46,14 +46,15 @@ public:
|
||||
bool is_supported_protocol(const String&);
|
||||
RefPtr<Download> start_download(const String& url, const HashMap<String, String>& request_headers = {});
|
||||
|
||||
|
||||
bool stop_download(Badge<Download>, Download&);
|
||||
bool set_certificate(Badge<Download>, Download&, String, String);
|
||||
|
||||
private:
|
||||
Client();
|
||||
|
||||
virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override;
|
||||
virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override;
|
||||
virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override;
|
||||
|
||||
HashMap<i32, RefPtr<Download>> m_downloads;
|
||||
};
|
||||
|
@ -67,4 +67,14 @@ void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloa
|
||||
if (on_progress)
|
||||
on_progress(total_size, downloaded_size);
|
||||
}
|
||||
|
||||
void Download::did_request_certificates(Badge<Client>)
|
||||
{
|
||||
if (on_certificate_requested) {
|
||||
auto result = on_certificate_requested();
|
||||
if (!m_client->set_certificate({}, *this, result.certificate, result.key)) {
|
||||
dbg() << "Download: set_certificate failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -40,6 +40,11 @@ class Client;
|
||||
|
||||
class Download : public RefCounted<Download> {
|
||||
public:
|
||||
struct CertificateAndKey {
|
||||
String certificate;
|
||||
String key;
|
||||
};
|
||||
|
||||
static NonnullRefPtr<Download> create_from_id(Badge<Client>, Client& client, i32 download_id)
|
||||
{
|
||||
return adopt(*new Download(client, download_id));
|
||||
@ -50,9 +55,11 @@ public:
|
||||
|
||||
Function<void(bool success, const ByteBuffer& payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> status_code)> on_finish;
|
||||
Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress;
|
||||
Function<CertificateAndKey()> on_certificate_requested;
|
||||
|
||||
void did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers);
|
||||
void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size);
|
||||
void did_request_certificates(Badge<Client>);
|
||||
|
||||
private:
|
||||
explicit Download(Client&, i32 download_id);
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include <LibCore/DateTime.h>
|
||||
#include <LibCore/Timer.h>
|
||||
#include <LibCrypto/ASN1/DER.h>
|
||||
#include <LibCrypto/ASN1/PEM.h>
|
||||
#include <LibCrypto/PK/Code/EMSA_PSS.h>
|
||||
#include <LibTLS/TLSv12.h>
|
||||
|
||||
@ -721,4 +722,28 @@ TLSv12::TLSv12(Core::Object* parent, Version version)
|
||||
}
|
||||
}
|
||||
|
||||
bool TLSv12::add_client_key(const ByteBuffer& certificate_pem_buffer, const ByteBuffer& rsa_key) // FIXME: This should not be bound to RSA
|
||||
{
|
||||
if (certificate_pem_buffer.is_empty() || rsa_key.is_empty()) {
|
||||
return true;
|
||||
}
|
||||
auto decoded_certificate = decode_pem(certificate_pem_buffer.span(), 0);
|
||||
if (decoded_certificate.is_empty()) {
|
||||
dbg() << "Certificate not PEM";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto maybe_certificate = parse_asn1(decoded_certificate);
|
||||
if (!maybe_certificate.has_value()) {
|
||||
dbg() << "Invalid certificate";
|
||||
return false;
|
||||
}
|
||||
|
||||
Crypto::PK::RSA rsa(rsa_key);
|
||||
auto certificate = maybe_certificate.value();
|
||||
certificate.private_key = rsa.private_key();
|
||||
|
||||
return add_client_key(certificate);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -206,6 +206,7 @@ struct Certificate {
|
||||
CertificateKeyAlgorithm ec_algorithm;
|
||||
ByteBuffer exponent;
|
||||
Crypto::PK::RSAPublicKey<Crypto::UnsignedBigInteger> public_key;
|
||||
Crypto::PK::RSAPrivateKey<Crypto::UnsignedBigInteger> private_key;
|
||||
String issuer_country;
|
||||
String issuer_state;
|
||||
String issuer_location;
|
||||
@ -318,6 +319,13 @@ public:
|
||||
bool load_certificates(const ByteBuffer& pem_buffer);
|
||||
bool load_private_key(const ByteBuffer& pem_buffer);
|
||||
|
||||
bool add_client_key(const ByteBuffer& certificate_pem_buffer, const ByteBuffer& key_pem_buffer);
|
||||
bool add_client_key(Certificate certificate)
|
||||
{
|
||||
m_context.client_certificates.append(move(certificate));
|
||||
return true;
|
||||
}
|
||||
|
||||
ByteBuffer finish_build();
|
||||
|
||||
const StringView& alpn() const { return m_context.negotiated_alpn; }
|
||||
@ -349,6 +357,7 @@ public:
|
||||
Function<void(AlertDescription)> on_tls_error;
|
||||
Function<void()> on_tls_connected;
|
||||
Function<void()> on_tls_finished;
|
||||
Function<void(TLSv12&)> on_tls_certificate_request;
|
||||
|
||||
private:
|
||||
explicit TLSv12(Core::Object* parent, Version version = Version::V12);
|
||||
|
@ -179,6 +179,9 @@ void ResourceLoader::load(const URL& url, Function<void(const ByteBuffer&, const
|
||||
}
|
||||
success_callback(ByteBuffer::copy(payload.data(), payload.size()), response_headers);
|
||||
};
|
||||
download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey {
|
||||
return {};
|
||||
};
|
||||
++m_pending_loads;
|
||||
if (on_load_counter_change)
|
||||
on_load_counter_change();
|
||||
|
@ -111,6 +111,11 @@ void ClientConnection::did_progress_download(Badge<Download>, Download& download
|
||||
post_message(Messages::ProtocolClient::DownloadProgress(download.id(), download.total_size(), download.downloaded_size()));
|
||||
}
|
||||
|
||||
void ClientConnection::did_request_certificates(Badge<Download>, Download& download)
|
||||
{
|
||||
post_message(Messages::ProtocolClient::CertificateRequested(download.id()));
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolServer::GreetResponse> ClientConnection::handle(const Messages::ProtocolServer::Greet&)
|
||||
{
|
||||
return make<Messages::ProtocolServer::GreetResponse>(client_id());
|
||||
@ -122,4 +127,15 @@ OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> ClientConnection::h
|
||||
return make<Messages::ProtocolServer::DisownSharedBufferResponse>();
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message)
|
||||
{
|
||||
auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));
|
||||
bool success = false;
|
||||
if (download) {
|
||||
download->set_certificate(message.certificate(), message.key());
|
||||
success = true;
|
||||
}
|
||||
return make<Messages::ProtocolServer::SetCertificateResponse>(success);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -28,8 +28,8 @@
|
||||
|
||||
#include <AK/HashMap.h>
|
||||
#include <LibIPC/ClientConnection.h>
|
||||
#include <ProtocolServer/ProtocolServerEndpoint.h>
|
||||
#include <ProtocolServer/Forward.h>
|
||||
#include <ProtocolServer/ProtocolServerEndpoint.h>
|
||||
|
||||
namespace ProtocolServer {
|
||||
|
||||
@ -46,6 +46,7 @@ public:
|
||||
|
||||
void did_finish_download(Badge<Download>, Download&, bool success);
|
||||
void did_progress_download(Badge<Download>, Download&);
|
||||
void did_request_certificates(Badge<Download>, Download&);
|
||||
|
||||
private:
|
||||
virtual OwnPtr<Messages::ProtocolServer::GreetResponse> handle(const Messages::ProtocolServer::Greet&) override;
|
||||
@ -53,6 +54,7 @@ private:
|
||||
virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&);
|
||||
|
||||
HashMap<i32, OwnPtr<Download>> m_downloads;
|
||||
HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;
|
||||
|
@ -25,8 +25,8 @@
|
||||
*/
|
||||
|
||||
#include <AK/Badge.h>
|
||||
#include <ProtocolServer/Download.h>
|
||||
#include <ProtocolServer/ClientConnection.h>
|
||||
#include <ProtocolServer/Download.h>
|
||||
|
||||
namespace ProtocolServer {
|
||||
|
||||
@ -59,6 +59,10 @@ void Download::set_response_headers(const HashMap<String, String, CaseInsensitiv
|
||||
m_response_headers = response_headers;
|
||||
}
|
||||
|
||||
void Download::set_certificate(String, String)
|
||||
{
|
||||
}
|
||||
|
||||
void Download::did_finish(bool success)
|
||||
{
|
||||
m_client.did_finish_download({}, *this, success);
|
||||
@ -71,4 +75,9 @@ void Download::did_progress(Optional<u32> total_size, u32 downloaded_size)
|
||||
m_client.did_progress_download({}, *this);
|
||||
}
|
||||
|
||||
void Download::did_request_certificates()
|
||||
{
|
||||
m_client.did_request_certificates({}, *this);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -49,6 +49,7 @@ public:
|
||||
const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; }
|
||||
|
||||
void stop();
|
||||
virtual void set_certificate(String, String);
|
||||
|
||||
protected:
|
||||
explicit Download(ClientConnection&);
|
||||
@ -56,6 +57,7 @@ protected:
|
||||
void did_finish(bool success);
|
||||
void did_progress(Optional<u32> total_size, u32 downloaded_size);
|
||||
void set_status_code(u32 status_code) { m_status_code = status_code; }
|
||||
void did_request_certificates();
|
||||
void set_payload(const ByteBuffer&);
|
||||
void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);
|
||||
|
||||
|
@ -31,7 +31,9 @@ namespace ProtocolServer {
|
||||
class ClientConnection;
|
||||
class Download;
|
||||
class GeminiProtocol;
|
||||
class HttpDownload;
|
||||
class HttpProtocol;
|
||||
class HttpsDownload;
|
||||
class HttpsProtocol;
|
||||
class Protocol;
|
||||
|
||||
|
@ -59,6 +59,14 @@ GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::G
|
||||
m_job->on_progress = [this](Optional<u32> total, u32 current) {
|
||||
did_progress(total, current);
|
||||
};
|
||||
m_job->on_certificate_requested = [this](auto&) {
|
||||
did_request_certificates();
|
||||
};
|
||||
}
|
||||
|
||||
void GeminiDownload::set_certificate(String certificate, String key)
|
||||
{
|
||||
m_job->set_certificate(move(certificate), move(key));
|
||||
}
|
||||
|
||||
GeminiDownload::~GeminiDownload()
|
||||
|
@ -41,6 +41,8 @@ public:
|
||||
private:
|
||||
explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
|
||||
|
||||
virtual void set_certificate(String certificate, String key) override;
|
||||
|
||||
NonnullRefPtr<Gemini::GeminiJob> m_job;
|
||||
};
|
||||
|
||||
|
@ -51,6 +51,14 @@ HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::Https
|
||||
m_job->on_progress = [this](Optional<u32> total, u32 current) {
|
||||
did_progress(total, current);
|
||||
};
|
||||
m_job->on_certificate_requested = [this](auto&) {
|
||||
did_request_certificates();
|
||||
};
|
||||
}
|
||||
|
||||
void HttpsDownload::set_certificate(String certificate, String key)
|
||||
{
|
||||
m_job->set_certificate(move(certificate), move(key));
|
||||
}
|
||||
|
||||
HttpsDownload::~HttpsDownload()
|
||||
|
@ -41,6 +41,8 @@ public:
|
||||
private:
|
||||
explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
|
||||
|
||||
virtual void set_certificate(String certificate, String key) override;
|
||||
|
||||
NonnullRefPtr<HTTP::HttpsJob> m_job;
|
||||
};
|
||||
|
||||
|
@ -3,4 +3,7 @@ endpoint ProtocolClient = 13
|
||||
// Download notifications
|
||||
DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =|
|
||||
DownloadFinished(i32 download_id, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =|
|
||||
|
||||
// Certificate requests
|
||||
CertificateRequested(i32 download_id) => ()
|
||||
}
|
||||
|
@ -12,4 +12,5 @@ endpoint ProtocolServer = 9
|
||||
// Download API
|
||||
StartDownload(URL url, IPC::Dictionary request_headers) => (i32 download_id)
|
||||
StopDownload(i32 download_id) => (bool success)
|
||||
SetCertificate(i32 download_id, String certificate, String key) => (bool success)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user