Kernel/Net: Rework ephemeral port allocation

Currently, ephemeral port allocation is handled by the
allocate_local_port_if_needed() and protocol_allocate_local_port()
methods. Actually binding the socket to an address (which means
inserting the socket/address pair into a global map) is performed either
in protocol_allocate_local_port() (for ephemeral ports) or in
protocol_listen() (for non-ephemeral ports); the latter will fail with
EADDRINUSE if the address is already used by an existing pair present in
the map.

There used to be a bug where for listen() without an explicit bind(),
the port allocation would conflict with itself: first an ephemeral port
would get allocated and inserted into the map, and then
protocol_listen() would check again for the port being free, find the
just-created map entry, and error out. This was fixed in commit
01e5af487f by passing an additional flag
did_allocate_port into protocol_listen() which specifies whether the
port was just allocated, and skipping the check in protocol_listen() if
the flag is set.

However, this only helps if the socket is bound to an ephemeral port
inside of this very listen() call. But calling bind(sin_port = 0) from
userspace should succeed and bind to an allocated ephemeral port, in the
same was as using an unbound socket for connect() does. The port number
can then be retrieved from userspace by calling getsockname (), and it
should be possible to either connect() or listen() on this socket,
keeping the allocated port number. Also, calling bind() when already
bound (either explicitly or implicitly) should always result in EINVAL.

To untangle this, introduce an explicit m_bound state in IPv4Socket,
just like LocalSocket has already. Once a socket is bound, further
attempt to bind it fail. Some operations cause the socket to implicitly
get bound to an (ephemeral) address; this is implemented by the new
ensure_bound() method. The protocol_allocate_local_port() method is
gone; it is now up to a protocol to assign a port to the socket inside
protocol_bind() if it finds that the socket has local_port() == 0.

protocol_bind() is now called in more cases, such as inside listen() if
the socket wasn't bound before that.
This commit is contained in:
Sergey Bugaev 2023-07-23 15:43:45 +03:00 committed by Andrew Kaster
parent cd4298ae04
commit 95bcffd713
Notes: sideshowbarker 2024-07-17 18:38:54 +09:00
6 changed files with 96 additions and 96 deletions

View File

@ -95,8 +95,23 @@ void IPv4Socket::get_peer_address(sockaddr* address, socklen_t* address_size)
*address_size = sizeof(sockaddr_in);
}
ErrorOr<void> IPv4Socket::ensure_bound()
{
dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket::ensure_bound() m_bound {}", m_bound);
if (m_bound)
return {};
auto result = protocol_bind();
if (!result.is_error())
m_bound = true;
return result;
}
ErrorOr<void> IPv4Socket::bind(Credentials const& credentials, Userspace<sockaddr const*> user_address, socklen_t address_size)
{
if (m_bound)
return set_so_error(EINVAL);
VERIFY(setup_state() == SetupState::Unstarted);
if (address_size != sizeof(sockaddr_in))
return set_so_error(EINVAL);
@ -120,23 +135,20 @@ ErrorOr<void> IPv4Socket::bind(Credentials const& credentials, Userspace<sockadd
dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket::bind {}({}) to {}:{}", class_name(), this, m_local_address, m_local_port);
return protocol_bind();
return ensure_bound();
}
ErrorOr<void> IPv4Socket::listen(size_t backlog)
{
MutexLocker locker(mutex());
auto result = allocate_local_port_if_needed();
if (result.error_or_port.is_error() && result.error_or_port.error().code() != ENOPROTOOPT)
return result.error_or_port.release_error();
TRY(ensure_bound());
set_backlog(backlog);
set_role(Role::Listener);
evaluate_block_conditions();
dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) listening with backlog={}", this, backlog);
return protocol_listen(result.did_allocate);
return protocol_listen();
}
ErrorOr<void> IPv4Socket::connect(Credentials const&, OpenFileDescription& description, Userspace<sockaddr const*> address, socklen_t address_size)
@ -176,18 +188,6 @@ bool IPv4Socket::can_write(OpenFileDescription const&, u64) const
return true;
}
PortAllocationResult IPv4Socket::allocate_local_port_if_needed()
{
MutexLocker locker(mutex());
if (m_local_port)
return { m_local_port, false };
auto port_or_error = protocol_allocate_local_port();
if (port_or_error.is_error())
return { port_or_error.release_error(), false };
m_local_port = port_or_error.release_value();
return { m_local_port, true };
}
ErrorOr<size_t> IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer const& data, size_t data_length, [[maybe_unused]] int flags, Userspace<sockaddr const*> addr, socklen_t addr_length)
{
MutexLocker locker(mutex());
@ -220,8 +220,7 @@ ErrorOr<size_t> IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons
if (m_local_address.to_u32() == 0)
m_local_address = routing_decision.adapter->ipv4_address();
if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error() && result.error_or_port.error().code() != ENOPROTOOPT)
return result.error_or_port.release_error();
TRY(ensure_bound());
dbgln_if(IPV4_SOCKET_DEBUG, "sendto: destination={}:{}", m_peer_address, m_peer_port);

View File

@ -21,11 +21,6 @@ class NetworkAdapter;
class TCPPacket;
class TCPSocket;
struct PortAllocationResult {
ErrorOr<u16> error_or_port;
bool did_allocate;
};
class IPv4Socket : public Socket {
public:
static ErrorOr<NonnullRefPtr<Socket>> create(int type, int protocol);
@ -76,14 +71,14 @@ protected:
IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, OwnPtr<KBuffer> optional_scratch_buffer);
virtual StringView class_name() const override { return "IPv4Socket"sv; }
PortAllocationResult allocate_local_port_if_needed();
void set_bound(bool bound) { m_bound = bound; }
ErrorOr<void> ensure_bound();
virtual ErrorOr<void> protocol_bind() { return {}; }
virtual ErrorOr<void> protocol_listen([[maybe_unused]] bool did_allocate_port) { return {}; }
virtual ErrorOr<void> protocol_listen() { return {}; }
virtual ErrorOr<size_t> protocol_receive(ReadonlyBytes /* raw_ipv4_packet */, UserOrKernelBuffer&, size_t, int) { return ENOTIMPL; }
virtual ErrorOr<size_t> protocol_send(UserOrKernelBuffer const&, size_t) { return ENOTIMPL; }
virtual ErrorOr<void> protocol_connect(OpenFileDescription&) { return {}; }
virtual ErrorOr<u16> protocol_allocate_local_port() { return ENOPROTOOPT; }
virtual ErrorOr<size_t> protocol_size(ReadonlyBytes /* raw_ipv4_packet */) { return ENOTIMPL; }
virtual bool protocol_is_disconnected() const { return false; }
@ -108,6 +103,7 @@ private:
Vector<IPv4Address> m_multicast_memberships;
bool m_multicast_loop { true };
bool m_bound { false };
struct ReceivedPacket {
IPv4Address peer_address;

View File

@ -137,6 +137,7 @@ ErrorOr<NonnullRefPtr<TCPSocket>> TCPSocket::try_create_client(IPv4Address const
client->set_local_port(new_local_port);
client->set_peer_address(new_peer_address);
client->set_peer_port(new_peer_port);
client->set_bound(true);
client->set_direction(Direction::Incoming);
client->set_originator(*this);
@ -414,19 +415,46 @@ NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(IPv4Address const& source, I
ErrorOr<void> TCPSocket::protocol_bind()
{
return m_adapter.with([this](auto& adapter) -> ErrorOr<void> {
dbgln_if(TCP_SOCKET_DEBUG, "TCPSocket::protocol_bind(), local_port() is {}", local_port());
// Check that we do have the address we're trying to bind to.
TRY(m_adapter.with([this](auto& adapter) -> ErrorOr<void> {
if (has_specific_local_address() && !adapter) {
adapter = NetworkingManagement::the().from_ipv4_address(local_address());
if (!adapter)
return set_so_error(EADDRNOTAVAIL);
}
return {};
});
}
}));
ErrorOr<void> TCPSocket::protocol_listen(bool did_allocate_port)
{
if (!did_allocate_port) {
if (local_port() == 0) {
// Allocate an unused ephemeral port.
constexpr u16 first_ephemeral_port = 32768;
constexpr u16 last_ephemeral_port = 60999;
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
return sockets_by_tuple().with_exclusive([&](auto& table) -> ErrorOr<void> {
u16 port = first_scan_port;
while (true) {
IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port());
auto it = table.find(proposed_tuple);
if (it == table.end()) {
set_local_port(port);
table.set(proposed_tuple, this);
dbgln_if(TCP_SOCKET_DEBUG, "...allocated port {}, tuple {}", port, proposed_tuple.to_string());
return {};
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
return set_so_error(EADDRINUSE);
});
} else {
// Verify that the user-supplied port is not already used by someone else.
bool ok = sockets_by_tuple().with_exclusive([&](auto& table) -> bool {
if (table.contains(tuple()))
return false;
@ -435,8 +463,12 @@ ErrorOr<void> TCPSocket::protocol_listen(bool did_allocate_port)
});
if (!ok)
return set_so_error(EADDRINUSE);
return {};
}
}
ErrorOr<void> TCPSocket::protocol_listen()
{
set_direction(Direction::Passive);
set_state(State::Listen);
set_setup_state(SetupState::Completed);
@ -453,8 +485,7 @@ ErrorOr<void> TCPSocket::protocol_connect(OpenFileDescription& description)
if (!has_specific_local_address())
set_local_address(routing_decision.adapter->ipv4_address());
if (auto result = allocate_local_port_if_needed(); result.error_or_port.is_error())
return result.error_or_port.release_error();
TRY(ensure_bound());
m_sequence_number = get_good_random<u32>();
m_ack_number = 0;
@ -487,33 +518,6 @@ ErrorOr<void> TCPSocket::protocol_connect(OpenFileDescription& description)
return set_so_error(EINPROGRESS);
}
ErrorOr<u16> TCPSocket::protocol_allocate_local_port()
{
constexpr u16 first_ephemeral_port = 32768;
constexpr u16 last_ephemeral_port = 60999;
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
return sockets_by_tuple().with_exclusive([&](auto& table) -> ErrorOr<u16> {
for (u16 port = first_scan_port;;) {
IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port());
auto it = table.find(proposed_tuple);
if (it == table.end()) {
set_local_port(port);
table.set(proposed_tuple, this);
return port;
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
return set_so_error(EADDRINUSE);
});
}
bool TCPSocket::protocol_is_disconnected() const
{
switch (m_state) {

View File

@ -176,11 +176,10 @@ private:
virtual ErrorOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
virtual ErrorOr<size_t> protocol_send(UserOrKernelBuffer const&, size_t) override;
virtual ErrorOr<void> protocol_connect(OpenFileDescription&) override;
virtual ErrorOr<u16> protocol_allocate_local_port() override;
virtual ErrorOr<size_t> protocol_size(ReadonlyBytes raw_ipv4_packet) override;
virtual bool protocol_is_disconnected() const override;
virtual ErrorOr<void> protocol_bind() override;
virtual ErrorOr<void> protocol_listen(bool did_allocate_port) override;
virtual ErrorOr<void> protocol_listen() override;
void enqueue_for_retransmit();
void dequeue_for_retransmit();

View File

@ -108,44 +108,47 @@ ErrorOr<size_t> UDPSocket::protocol_send(UserOrKernelBuffer const& data, size_t
ErrorOr<void> UDPSocket::protocol_connect(OpenFileDescription&)
{
TRY(ensure_bound());
set_role(Role::Connected);
set_connected(true);
return {};
}
ErrorOr<u16> UDPSocket::protocol_allocate_local_port()
{
constexpr u16 first_ephemeral_port = 32768;
constexpr u16 last_ephemeral_port = 60999;
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr<u16> {
for (u16 port = first_scan_port;;) {
auto it = table.find(port);
if (it == table.end()) {
set_local_port(port);
table.set(port, this);
return port;
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
return set_so_error(EADDRINUSE);
});
}
ErrorOr<void> UDPSocket::protocol_bind()
{
return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr<void> {
if (table.contains(local_port()))
if (local_port() == 0) {
// Allocate an unused ephemeral port.
constexpr u16 first_ephemeral_port = 32768;
constexpr u16 last_ephemeral_port = 60999;
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr<void> {
u16 port = first_scan_port;
while (true) {
auto it = table.find(port);
if (it == table.end()) {
set_local_port(port);
table.set(port, this);
return {};
}
++port;
if (port > last_ephemeral_port)
port = first_ephemeral_port;
if (port == first_scan_port)
break;
}
return set_so_error(EADDRINUSE);
table.set(local_port(), this);
return {};
});
});
} else {
// Verify that the user-supplied port is not already used by someone else.
return sockets_by_port().with_exclusive([&](auto& table) -> ErrorOr<void> {
if (table.contains(local_port()))
return set_so_error(EADDRINUSE);
table.set(local_port(), this);
return {};
});
}
}
}

View File

@ -30,7 +30,6 @@ private:
virtual ErrorOr<size_t> protocol_send(UserOrKernelBuffer const&, size_t) override;
virtual ErrorOr<size_t> protocol_size(ReadonlyBytes raw_ipv4_packet) override;
virtual ErrorOr<void> protocol_connect(OpenFileDescription&) override;
virtual ErrorOr<u16> protocol_allocate_local_port() override;
virtual ErrorOr<void> protocol_bind() override;
};