diff --git a/Kernel/FileSystem.cpp b/Kernel/FileSystem.cpp index 75e6a83e045..85c23f5389c 100644 --- a/Kernel/FileSystem.cpp +++ b/Kernel/FileSystem.cpp @@ -4,6 +4,7 @@ #include #include "FileSystem.h" #include "MemoryManager.h" +#include static dword s_lastFileSystemID; static HashMap* s_fs_map; @@ -152,3 +153,17 @@ void Inode::set_vmo(VMObject& vmo) { m_vmo = vmo.make_weak_ptr(); } + +bool Inode::bind_socket(LocalSocket& socket) +{ + ASSERT(!m_socket); + m_socket = socket; + return true; +} + +bool Inode::unbind_socket() +{ + ASSERT(m_socket); + m_socket = nullptr; + return true; +} diff --git a/Kernel/FileSystem.h b/Kernel/FileSystem.h index a6564b0de9c..1bbd1d27acb 100644 --- a/Kernel/FileSystem.h +++ b/Kernel/FileSystem.h @@ -20,6 +20,7 @@ static const dword mepoch = 476763780; class Inode; class FileDescriptor; +class LocalSocket; class VMObject; class FS : public Retainable { @@ -92,6 +93,11 @@ public: virtual size_t directory_entry_count() const = 0; virtual bool chmod(mode_t, int& error) = 0; + LocalSocket* socket() { return m_socket.ptr(); } + const LocalSocket* socket() const { return m_socket.ptr(); } + bool bind_socket(LocalSocket&); + bool unbind_socket(); + bool is_metadata_dirty() const { return m_metadata_dirty; } virtual int set_atime(time_t); @@ -120,6 +126,7 @@ private: FS& m_fs; unsigned m_index { 0 }; WeakPtr m_vmo; + RetainPtr m_socket; bool m_metadata_dirty { false }; }; diff --git a/Kernel/LocalSocket.cpp b/Kernel/LocalSocket.cpp index abd3201cc89..2fb0f3de1cb 100644 --- a/Kernel/LocalSocket.cpp +++ b/Kernel/LocalSocket.cpp @@ -31,11 +31,11 @@ bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size) bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error) { + ASSERT(!m_connected); if (address_size != sizeof(sockaddr_un)) { error = -EINVAL; return false; } - if (address->sa_family != AF_LOCAL) { error = -EINVAL; return false; @@ -51,10 +51,46 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err if (!m_file) { if (error == -EEXIST) error = -EADDRINUSE; - return error; + return false; } + ASSERT(m_file->inode()); + m_file->inode()->bind_socket(*this); + m_address = local_address; m_bound = true; return true; } + +RetainPtr LocalSocket::connect(const sockaddr* address, socklen_t address_size, int& error) +{ + ASSERT(!m_bound); + if (address_size != sizeof(sockaddr_un)) { + error = -EINVAL; + return nullptr; + } + if (address->sa_family != AF_LOCAL) { + error = -EINVAL; + return nullptr; + } + + const sockaddr_un& local_address = *reinterpret_cast(address); + char safe_address[sizeof(local_address.sun_path) + 1]; + memcpy(safe_address, local_address.sun_path, sizeof(local_address.sun_path)); + + kprintf("%s(%u) LocalSocket{%p} connect(%s)\n", current->name().characters(), current->pid(), safe_address); + + m_file = VFS::the().open(safe_address, error, 0, 0, *current->cwd_inode()); + if (!m_file) { + error = -ECONNREFUSED; + return nullptr; + } + + ASSERT(m_file->inode()); + ASSERT(m_file->inode()->socket()); + + m_peer = m_file->inode()->socket(); + m_address = local_address; + m_connected = true; + return m_peer; +} diff --git a/Kernel/LocalSocket.h b/Kernel/LocalSocket.h index ea5440694fa..7ec70b97418 100644 --- a/Kernel/LocalSocket.h +++ b/Kernel/LocalSocket.h @@ -11,14 +11,18 @@ public: virtual ~LocalSocket() override; virtual bool bind(const sockaddr*, socklen_t, int& error) override; + virtual RetainPtr connect(const sockaddr*, socklen_t, int& error) override; virtual bool get_address(sockaddr*, socklen_t*) override; private: explicit LocalSocket(int type); + virtual bool is_local() const override { return true; } RetainPtr m_file; + RetainPtr m_peer; bool m_bound { false }; + bool m_connected { false }; sockaddr_un m_address; DoubleBuffer m_for_client; diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index 2ddf71d6bd7..234a92de271 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -2329,7 +2329,28 @@ int Process::sys$accept(int sockfd, sockaddr* address, socklen_t* address_size) return fd; } -int Process::sys$connect(int sockfd, const sockaddr*, socklen_t) +int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_size) { - return -ENOTIMPL; + if (!validate_read(address, address_size)) + return -EFAULT; + if (number_of_open_file_descriptors() >= m_max_open_file_descriptors) + return -EMFILE; + int fd = 0; + for (; fd < (int)m_max_open_file_descriptors; ++fd) { + if (!m_fds[fd]) + break; + } + auto* descriptor = file_descriptor(sockfd); + if (!descriptor) + return -EBADF; + if (!descriptor->is_socket()) + return -ENOTSOCK; + auto& socket = *descriptor->socket(); + int error; + auto server = socket.connect(address, address_size, error); + if (!server) + return error; + auto server_descriptor = FileDescriptor::create(move(server), SocketRole::Connected); + m_fds[fd].set(move(server_descriptor)); + return fd; } diff --git a/Kernel/Socket.h b/Kernel/Socket.h index d5e91cfb525..2afd6a53f27 100644 --- a/Kernel/Socket.h +++ b/Kernel/Socket.h @@ -23,7 +23,9 @@ public: bool listen(int backlog, int& error); virtual bool bind(const sockaddr*, socklen_t, int& error) = 0; + virtual RetainPtr connect(const sockaddr*, socklen_t, int& error) = 0; virtual bool get_address(sockaddr*, socklen_t*) = 0; + virtual bool is_local() const { return false; } protected: Socket(int domain, int type, int protocol); diff --git a/LibC/errno_numbers.h b/LibC/errno_numbers.h index a787dd75aae..b296dddf494 100644 --- a/LibC/errno_numbers.h +++ b/LibC/errno_numbers.h @@ -46,6 +46,7 @@ __ERROR(EADDRINUSE, "Address in use") \ __ERROR(EWHYTHO, "Failed without setting an error code (Bug!)") \ __ERROR(ENOTEMPTY, "Directory not empty") \ + __ERROR(ECONNREFUSED, "Connection refused") \ enum __errno_values {