diff --git a/src/lib/arch/Arch.cpp b/src/lib/arch/Arch.cpp index d06a873a..d76aca91 100644 --- a/src/lib/arch/Arch.cpp +++ b/src/lib/arch/Arch.cpp @@ -30,6 +30,11 @@ Arch::Arch() s_instance = this; } +Arch::Arch(Arch* arch) +{ + s_instance = arch; +} + Arch::~Arch() { #if SYSAPI_WIN32 diff --git a/src/lib/arch/Arch.h b/src/lib/arch/Arch.h index ad0719bf..bc624131 100644 --- a/src/lib/arch/Arch.h +++ b/src/lib/arch/Arch.h @@ -99,6 +99,7 @@ class Arch : public ARCH_CONSOLE, public ARCH_TIME { public: Arch(); + Arch(Arch* arch); virtual ~Arch(); //! Call init on other arch classes. diff --git a/src/lib/arch/IArchPlugin.h b/src/lib/arch/IArchPlugin.h index 68938848..478b66cc 100644 --- a/src/lib/arch/IArchPlugin.h +++ b/src/lib/arch/IArchPlugin.h @@ -42,11 +42,17 @@ public: */ virtual void load() = 0; - //! Init plugins + //! Init the common parts /*! - Initializes loaded plugins. + Initializes common parts like log and arch. */ - virtual void init(void* eventTarget, IEventQueue* events) = 0; + virtual void init(void* log, void* arch) = 0; + + //! Init the event part + /*! + Initializes event parts. + */ + virtual void initEvent(void* eventTarget, IEventQueue* events) = 0; //! Check if exists /*! @@ -60,7 +66,7 @@ public: */ virtual void* invoke(const char* plugin, const char* command, - void* args) = 0; + void** args) = 0; //@} diff --git a/src/lib/arch/unix/ArchPluginUnix.cpp b/src/lib/arch/unix/ArchPluginUnix.cpp index 18d21345..905672c8 100644 --- a/src/lib/arch/unix/ArchPluginUnix.cpp +++ b/src/lib/arch/unix/ArchPluginUnix.cpp @@ -28,7 +28,8 @@ #include #include -typedef int (*initFunc)(void (*sendEvent)(const char*, void*), void (*log)(const char*)); +typedef void (*initFunc)(void*, void*); +typedef int (*initEventFunc)(void (*sendEvent)(const char*, void*)); typedef void* (*invokeFunc)(const char*, void*); void* g_eventTarget = NULL; @@ -84,15 +85,25 @@ ArchPluginUnix::load() } void -ArchPluginUnix::init(void* eventTarget, IEventQueue* events) +ArchPluginUnix::init(void* log, void* arch) +{ + PluginTable::iterator it; + for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) { + initFunc initPlugin = (initFunc)dlsym(it->second, "init"); + initPlugin(log, arch); + } +} + +void +ArchPluginUnix::initEvent(void* eventTarget, IEventQueue* events) { g_eventTarget = eventTarget; g_events = events; PluginTable::iterator it; for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) { - initFunc initPlugin = (initFunc)dlsym(it->second, "init"); - initPlugin(&sendEvent, &log); + initEventFunc initEventPlugin = (initEventFunc)dlsym(it->second, "initEvent"); + initEventPlugin(&sendEvent); } } @@ -108,7 +119,7 @@ void* ArchPluginUnix::invoke( const char* plugin, const char* command, - void* args) + void** args) { PluginTable::iterator it; it = m_pluginTable.find(plugin); diff --git a/src/lib/arch/unix/ArchPluginUnix.h b/src/lib/arch/unix/ArchPluginUnix.h index d9b30922..15e9b1a6 100644 --- a/src/lib/arch/unix/ArchPluginUnix.h +++ b/src/lib/arch/unix/ArchPluginUnix.h @@ -32,11 +32,12 @@ public: // IArchPlugin overrides void load(); - void init(void* eventTarget, IEventQueue* events); + void init(void* log, void* arch); + void initEvent(void* eventTarget, IEventQueue* events); bool exists(const char* name); virtual void* invoke(const char* pluginName, const char* functionName, - void* args); + void** args); private: String getPluginsDir(); diff --git a/src/lib/arch/win32/ArchPluginWindows.cpp b/src/lib/arch/win32/ArchPluginWindows.cpp index c9f8cee8..1f6aecfd 100644 --- a/src/lib/arch/win32/ArchPluginWindows.cpp +++ b/src/lib/arch/win32/ArchPluginWindows.cpp @@ -27,8 +27,9 @@ #include #include -typedef int (*initFunc)(void (*sendEvent)(const char*, void*), void (*log)(const char*)); -typedef void* (*invokeFunc)(const char*, void*); +typedef int (*initFunc)(void*, void*); +typedef int (*initEventFunc)(void (*sendEvent)(const char*, void*)); +typedef void* (*invokeFunc)(const char*, void**); void* g_eventTarget = NULL; IEventQueue* g_events = NULL; @@ -68,7 +69,19 @@ ArchPluginWindows::load() } void -ArchPluginWindows::init(void* eventTarget, IEventQueue* events) +ArchPluginWindows::init(void* log, void* arch) +{ + PluginTable::iterator it; + HINSTANCE lib; + for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) { + lib = reinterpret_cast(it->second); + initFunc initPlugin = (initFunc)GetProcAddress(lib, "init"); + initPlugin(log, arch); + } +} + +void +ArchPluginWindows::initEvent(void* eventTarget, IEventQueue* events) { g_eventTarget = eventTarget; g_events = events; @@ -77,8 +90,8 @@ ArchPluginWindows::init(void* eventTarget, IEventQueue* events) HINSTANCE lib; for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) { lib = reinterpret_cast(it->second); - initFunc initPlugin = (initFunc)GetProcAddress(lib, "init"); - initPlugin(&sendEvent, &log); + initEventFunc initEventPlugin = (initEventFunc)GetProcAddress(lib, "initEvent"); + initEventPlugin(&sendEvent); } } @@ -95,7 +108,7 @@ void* ArchPluginWindows::invoke( const char* plugin, const char* command, - void* args) + void** args) { PluginTable::iterator it; it = m_pluginTable.find(plugin); diff --git a/src/lib/arch/win32/ArchPluginWindows.h b/src/lib/arch/win32/ArchPluginWindows.h index 53490506..7a252e60 100644 --- a/src/lib/arch/win32/ArchPluginWindows.h +++ b/src/lib/arch/win32/ArchPluginWindows.h @@ -35,11 +35,12 @@ public: // IArchPlugin overrides void load(); - void init(void* eventTarget, IEventQueue* events); + void init(void* log, void* arch); + void initEvent(void* eventTarget, IEventQueue* events); bool exists(const char* name); void* invoke(const char* pluginName, const char* functionName, - void* args); + void** args); private: String getModuleDir(); diff --git a/src/lib/base/Log.cpp b/src/lib/base/Log.cpp index 2ea91ffb..7017d972 100644 --- a/src/lib/base/Log.cpp +++ b/src/lib/base/Log.cpp @@ -74,6 +74,11 @@ Log::Log() s_log = this; } +Log::Log(Log* src) +{ + s_log = src; +} + Log::~Log() { // clean up diff --git a/src/lib/base/Log.h b/src/lib/base/Log.h index 3c8974ad..fd5273b9 100644 --- a/src/lib/base/Log.h +++ b/src/lib/base/Log.h @@ -41,6 +41,7 @@ LOGC() provide convenient access. class Log { public: Log(); + Log(Log* src); ~Log(); //! @name manipulators diff --git a/src/lib/base/String.cpp b/src/lib/base/String.cpp index c981eae9..4ce38899 100644 --- a/src/lib/base/String.cpp +++ b/src/lib/base/String.cpp @@ -171,7 +171,7 @@ findReplaceAll( String removeFileExt(String filename) { - unsigned dot = filename.find_last_of('.'); + size_t dot = filename.find_last_of('.'); if (dot == String::npos) { return filename; diff --git a/src/lib/client/Client.cpp b/src/lib/client/Client.cpp index 77e38c3d..9227834b 100644 --- a/src/lib/client/Client.cpp +++ b/src/lib/client/Client.cpp @@ -82,8 +82,7 @@ Client::Client( m_crypto(crypto), m_sendFileThread(NULL), m_writeToDropDirThread(NULL), - m_enableDragDrop(enableDragDrop), - m_secureSocket(NULL) + m_enableDragDrop(enableDragDrop) { assert(m_socketFactory != NULL); assert(m_screen != NULL); @@ -108,11 +107,6 @@ Client::Client( new TMethodEventJob(this, &Client::handleFileRecieveCompleted)); } - - if (ARCH->plugin().exists(s_networkSecurity)) { - m_secureSocket = static_cast( - ARCH->plugin().invoke("ns", "getSecureSocket", NULL)); - } } Client::~Client() @@ -163,14 +157,16 @@ Client::connect() } // create the socket - IDataSocket* socket = m_socketFactory->create(); + bool useSecureSocket = ARCH->plugin().exists(s_networkSecurity); + IDataSocket* socket = m_socketFactory->create(useSecureSocket); // filter socket messages, including a packetizing filter m_stream = socket; + bool adopt = !useSecureSocket; if (m_streamFilterFactory != NULL) { - m_stream = m_streamFilterFactory->create(m_stream, true); + m_stream = m_streamFilterFactory->create(m_stream, adopt); } - m_stream = new PacketStreamFilter(m_events, m_stream, true); + m_stream = new PacketStreamFilter(m_events, m_stream, adopt); if (m_crypto.m_mode != kDisabled) { m_cryptoStream = new CryptoStream( @@ -187,8 +183,7 @@ Client::connect() catch (XBase& e) { cleanupTimer(); cleanupConnecting(); - delete m_stream; - m_stream = NULL; + cleanupStream(); LOG((CLOG_DEBUG1 "connection failed")); sendConnectionFailedEvent(e.what()); return; @@ -545,8 +540,7 @@ Client::cleanupConnection() m_stream->getEventTarget()); m_events->removeHandler(m_events->forISocket().disconnected(), m_stream->getEventTarget()); - delete m_stream; - m_stream = NULL; + cleanupStream(); } } @@ -577,6 +571,16 @@ Client::cleanupTimer() } } +void +Client::cleanupStream() +{ + bool useSecureSocket = ARCH->plugin().exists(s_networkSecurity); + if (!useSecureSocket) { + delete m_stream; + m_stream = NULL; + } +} + void Client::handleConnected(const Event&, void*) { @@ -600,8 +604,7 @@ Client::handleConnectionFailed(const Event& event, void*) cleanupTimer(); cleanupConnecting(); - delete m_stream; - m_stream = NULL; + cleanupStream(); LOG((CLOG_DEBUG1 "connection failed")); sendConnectionFailedEvent(info->m_what.c_str()); delete info; @@ -613,8 +616,7 @@ Client::handleConnectTimeout(const Event&, void*) cleanupTimer(); cleanupConnecting(); cleanupConnection(); - delete m_stream; - m_stream = NULL; + cleanupStream(); LOG((CLOG_DEBUG1 "connection timed out")); sendConnectionFailedEvent("Timed out"); } diff --git a/src/lib/client/Client.h b/src/lib/client/Client.h index 3145a948..37c7f770 100644 --- a/src/lib/client/Client.h +++ b/src/lib/client/Client.h @@ -190,6 +190,7 @@ private: void cleanupConnection(); void cleanupScreen(); void cleanupTimer(); + void cleanupStream(); void handleConnected(const Event&, void*); void handleConnectionFailed(const Event&, void*); void handleConnectTimeout(const Event&, void*); @@ -205,34 +206,34 @@ private: void onFileRecieveCompleted(); public: - bool m_mock; + bool m_mock; private: - String m_name; - NetworkAddress m_serverAddress; - ISocketFactory* m_socketFactory; - IStreamFilterFactory* m_streamFilterFactory; - synergy::Screen* m_screen; - synergy::IStream* m_stream; - EventQueueTimer* m_timer; - ServerProxy* m_server; - bool m_ready; - bool m_active; - bool m_suspended; - bool m_connectOnResume; - bool m_ownClipboard[kClipboardEnd]; - bool m_sentClipboard[kClipboardEnd]; - IClipboard::Time m_timeClipboard[kClipboardEnd]; - String m_dataClipboard[kClipboardEnd]; - IEventQueue* m_events; - CryptoStream* m_cryptoStream; - CryptoOptions m_crypto; - std::size_t m_expectedFileSize; - String m_receivedFileData; - DragFileList m_dragFileList; - String m_dragFileExt; + String m_name; + NetworkAddress m_serverAddress; + ISocketFactory* m_socketFactory; + IStreamFilterFactory* + m_streamFilterFactory; + synergy::Screen* m_screen; + synergy::IStream* m_stream; + EventQueueTimer* m_timer; + ServerProxy* m_server; + bool m_ready; + bool m_active; + bool m_suspended; + bool m_connectOnResume; + bool m_ownClipboard[kClipboardEnd]; + bool m_sentClipboard[kClipboardEnd]; + IClipboard::Time m_timeClipboard[kClipboardEnd]; + String m_dataClipboard[kClipboardEnd]; + IEventQueue* m_events; + CryptoStream* m_cryptoStream; + CryptoOptions m_crypto; + std::size_t m_expectedFileSize; + String m_receivedFileData; + DragFileList m_dragFileList; + String m_dragFileExt; Thread* m_sendFileThread; Thread* m_writeToDropDirThread; - bool m_enableDragDrop; - SecureSocket* m_secureSocket; + bool m_enableDragDrop; }; diff --git a/src/lib/net/ISocketFactory.h b/src/lib/net/ISocketFactory.h index aca45d54..5c86a0bc 100644 --- a/src/lib/net/ISocketFactory.h +++ b/src/lib/net/ISocketFactory.h @@ -34,10 +34,10 @@ public: //@{ //! Create data socket - virtual IDataSocket* create() const = 0; + virtual IDataSocket* create(bool secure) const = 0; //! Create listen socket - virtual IListenSocket* createListen() const = 0; + virtual IListenSocket* createListen(bool secure) const = 0; //@} }; diff --git a/src/lib/net/TCPListenSocket.h b/src/lib/net/TCPListenSocket.h index e93a0e8e..b81b38fb 100644 --- a/src/lib/net/TCPListenSocket.h +++ b/src/lib/net/TCPListenSocket.h @@ -33,7 +33,7 @@ A listen socket using TCP. class TCPListenSocket : public IListenSocket { public: TCPListenSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer); - ~TCPListenSocket(); + virtual ~TCPListenSocket(); // ISocket overrides virtual void bind(const NetworkAddress&); @@ -44,12 +44,14 @@ public: virtual IDataSocket* accept(); -private: + ArchSocket& getSocket() { return m_socket; } + +public: ISocketMultiplexerJob* serviceListening(ISocketMultiplexerJob*, bool, bool, bool); -private: +protected: ArchSocket m_socket; Mutex* m_mutex; IEventQueue* m_events; diff --git a/src/lib/net/TCPSocket.h b/src/lib/net/TCPSocket.h index a83c5756..a026b608 100644 --- a/src/lib/net/TCPSocket.h +++ b/src/lib/net/TCPSocket.h @@ -38,7 +38,7 @@ class TCPSocket : public IDataSocket { public: TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer); TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer, ArchSocket socket); - ~TCPSocket(); + virtual ~TCPSocket(); // ISocket overrides virtual void bind(const NetworkAddress&); @@ -57,15 +57,19 @@ public: // IDataSocket overrides virtual void connect(const NetworkAddress&); +protected: + virtual void onConnected(); + ArchSocket getSocket() { return m_socket; } + private: void init(); void setJob(ISocketMultiplexerJob*); - ISocketMultiplexerJob* newJob(); + ISocketMultiplexerJob* + newJob(); void sendConnectionFailedEvent(const char*); void sendEvent(Event::Type); - void onConnected(); void onInputShutdown(); void onOutputShutdown(); void onDisconnected(); diff --git a/src/lib/net/TCPSocketFactory.cpp b/src/lib/net/TCPSocketFactory.cpp index 74bd467c..2f272d94 100644 --- a/src/lib/net/TCPSocketFactory.cpp +++ b/src/lib/net/TCPSocketFactory.cpp @@ -20,11 +20,19 @@ #include "net/TCPSocket.h" #include "net/TCPListenSocket.h" +#include "arch/Arch.h" +#include "base/Log.h" // // TCPSocketFactory // +#if defined _WIN32 +static const char s_networkSecurity[] = { "ns" }; +#else +static const char s_networkSecurity[] = { "libns" }; +#endif + TCPSocketFactory::TCPSocketFactory(IEventQueue* events, SocketMultiplexer* socketMultiplexer) : m_events(events), m_socketMultiplexer(socketMultiplexer) @@ -38,13 +46,43 @@ TCPSocketFactory::~TCPSocketFactory() } IDataSocket* -TCPSocketFactory::create() const +TCPSocketFactory::create(bool secure) const { - return new TCPSocket(m_events, m_socketMultiplexer); + IDataSocket* socket = NULL; + if (secure) { + void* args[4] = { + m_events, + m_socketMultiplexer, + Log::getInstance(), + Arch::getInstance() + }; + socket = static_cast( + ARCH->plugin().invoke(s_networkSecurity, "getSecureSocket", args)); + } + else { + socket = new TCPSocket(m_events, m_socketMultiplexer); + } + + return socket; } IListenSocket* -TCPSocketFactory::createListen() const +TCPSocketFactory::createListen(bool secure) const { - return new TCPListenSocket(m_events, m_socketMultiplexer); + IListenSocket* socket = NULL; + if (secure) { + void* args[4] = { + m_events, + m_socketMultiplexer, + Log::getInstance(), + Arch::getInstance() + }; + socket = static_cast( + ARCH->plugin().invoke(s_networkSecurity, "getSecureListenSocket", args)); + } + else { + socket = new TCPListenSocket(m_events, m_socketMultiplexer); + } + + return socket; } diff --git a/src/lib/net/TCPSocketFactory.h b/src/lib/net/TCPSocketFactory.h index 827a74ac..714ded5b 100644 --- a/src/lib/net/TCPSocketFactory.h +++ b/src/lib/net/TCPSocketFactory.h @@ -31,9 +31,9 @@ public: // ISocketFactory overrides virtual IDataSocket* - create() const; + create(bool secure) const; virtual IListenSocket* - createListen() const; + createListen(bool secure) const; private: IEventQueue* m_events; diff --git a/src/lib/server/ClientListener.cpp b/src/lib/server/ClientListener.cpp index a486774d..a4fd7638 100644 --- a/src/lib/server/ClientListener.cpp +++ b/src/lib/server/ClientListener.cpp @@ -36,6 +36,12 @@ // ClientListener // +#if defined _WIN32 +static const char s_networkSecurity[] = { "ns" }; +#else +static const char s_networkSecurity[] = { "libns" }; +#endif + ClientListener::ClientListener(const NetworkAddress& address, ISocketFactory* socketFactory, IStreamFilterFactory* streamFilterFactory, @@ -51,7 +57,8 @@ ClientListener::ClientListener(const NetworkAddress& address, try { // create listen socket - m_listen = m_socketFactory->createListen(); + bool useSecureSocket = ARCH->plugin().exists(s_networkSecurity); + m_listen = m_socketFactory->createListen(useSecureSocket); // bind listen address LOG((CLOG_DEBUG1 "binding listen socket")); diff --git a/src/lib/synergy/ClientApp.cpp b/src/lib/synergy/ClientApp.cpp index 0d684efa..fe10b697 100644 --- a/src/lib/synergy/ClientApp.cpp +++ b/src/lib/synergy/ClientApp.cpp @@ -460,6 +460,8 @@ ClientApp::mainLoop() // load all available plugins. ARCH->plugin().load(); + // pass log and arch into plugins. + ARCH->plugin().init(Log::getInstance(), Arch::getInstance()); // start client, etc appUtil().startNode(); @@ -470,8 +472,8 @@ ClientApp::mainLoop() initIpcClient(); } - // init all available plugins. - ARCH->plugin().init(m_clientScreen->getEventTarget(), m_events); + // init event for all available plugins. + ARCH->plugin().initEvent(m_clientScreen->getEventTarget(), m_events); // run event loop. if startClient() failed we're supposed to retry // later. the timer installed by startClient() will take care of diff --git a/src/lib/synergy/ServerApp.cpp b/src/lib/synergy/ServerApp.cpp index 1d6e4443..3adf0bfe 100644 --- a/src/lib/synergy/ServerApp.cpp +++ b/src/lib/synergy/ServerApp.cpp @@ -707,6 +707,11 @@ ServerApp::mainLoop() return kExitFailed; } + // load all available plugins. + ARCH->plugin().load(); + // pass log and arch into plugins. + ARCH->plugin().init(Log::getInstance(), Arch::getInstance()); + // start server, etc appUtil().startNode(); @@ -716,8 +721,8 @@ ServerApp::mainLoop() initIpcClient(); } - // load all available plugins. - ARCH->plugin().init(m_serverScreen->getEventTarget(), m_events); + // init event for all available plugins. + ARCH->plugin().initEvent(m_serverScreen->getEventTarget(), m_events); // handle hangup signal by reloading the server's configuration ARCH->setSignalHandler(Arch::kHANGUP, &reloadSignalHandler, NULL); diff --git a/src/plugin/ns/SecureListenSocket.cpp b/src/plugin/ns/SecureListenSocket.cpp new file mode 100644 index 00000000..e8c5d0c9 --- /dev/null +++ b/src/plugin/ns/SecureListenSocket.cpp @@ -0,0 +1,77 @@ +/* + * synergy -- mouse and keyboard sharing utility + * Copyright (C) 2015 Synergy Si Ltd. + * + * This package is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * found in the file COPYING that should have accompanied this file. + * + * This package is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "SecureListenSocket.h" + +#include "SecureSocket.h" +#include "net/NetworkAddress.h" +#include "net/SocketMultiplexer.h" +#include "net/TSocketMultiplexerMethodJob.h" +#include "arch/XArch.h" +#include "base/Log.h" + +// +// SecureListenSocket +// + +SecureListenSocket::SecureListenSocket( + IEventQueue* events, + SocketMultiplexer* socketMultiplexer) : + TCPListenSocket(events, socketMultiplexer) +{ +} + +SecureListenSocket::~SecureListenSocket() +{ +} + +IDataSocket* +SecureListenSocket::accept() +{ + SecureSocket* socket = NULL; + try { + socket = new SecureSocket( + m_events, + m_socketMultiplexer, + ARCH->acceptSocket(m_socket, NULL)); + socket->initSsl(true); + // TODO: customized certificate path + socket->loadCertificates("C:\\Temp\\synergy.pem"); + + if (socket != NULL) { + m_socketMultiplexer->addSocket(this, + new TSocketMultiplexerMethodJob( + this, &TCPListenSocket::serviceListening, + m_socket, true, false)); + + socket->acceptSecureSocket(); + } + return dynamic_cast(socket); + } + catch (XArchNetwork&) { + if (socket != NULL) { + delete socket; + } + return NULL; + } + catch (std::exception &ex) { + if (socket != NULL) { + delete socket; + } + throw ex; + } +} diff --git a/src/plugin/ns/SecureListenSocket.h b/src/plugin/ns/SecureListenSocket.h new file mode 100644 index 00000000..4a4c1496 --- /dev/null +++ b/src/plugin/ns/SecureListenSocket.h @@ -0,0 +1,34 @@ +/* + * synergy -- mouse and keyboard sharing utility + * Copyright (C) 2015 Synergy Si Ltd. + * + * This package is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * found in the file COPYING that should have accompanied this file. + * + * This package is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#include "net/TCPListenSocket.h" + +class IEventQueue; +class SocketMultiplexer; + +class SecureListenSocket : public TCPListenSocket{ +public: + SecureListenSocket(IEventQueue* events, + SocketMultiplexer* socketMultiplexer); + ~SecureListenSocket(); + + // IListenSocket overrides + virtual IDataSocket* + accept(); +}; diff --git a/src/plugin/ns/SecureSocket.cpp b/src/plugin/ns/SecureSocket.cpp index b0bb2629..8c3d61c6 100644 --- a/src/plugin/ns/SecureSocket.cpp +++ b/src/plugin/ns/SecureSocket.cpp @@ -17,26 +17,42 @@ #include "SecureSocket.h" -#include "base/String.h" +#include "net/TCPSocket.h" +#include "arch/XArch.h" +#include "base/Log.h" + #include #include +#include +#include +#include // // SecureSocket // + +#define MAX_ERROR_SIZE 65535 + struct Ssl { SSL_CTX* m_context; SSL* m_ssl; }; -SecureSocket::SecureSocket() : - m_ready(false), - m_errorSize(65535) +SecureSocket::SecureSocket( + IEventQueue* events, + SocketMultiplexer* socketMultiplexer) : + TCPSocket(events, socketMultiplexer), + m_ready(false) +{ +} + +SecureSocket::SecureSocket( + IEventQueue* events, + SocketMultiplexer* socketMultiplexer, + ArchSocket socket) : + TCPSocket(events, socketMultiplexer, socket), + m_ready(false) { - m_ssl = new Ssl(); - m_ssl->m_context = NULL; - m_ssl->m_ssl = NULL; - m_error = new char[m_errorSize]; } SecureSocket::~SecureSocket() @@ -51,6 +67,81 @@ SecureSocket::~SecureSocket() delete[] m_error; } +UInt32 +SecureSocket::read(void* buffer, UInt32 n) +{ + bool retry = false; + int r = 0; + if (m_ssl != NULL) { + r = SSL_read(m_ssl->m_ssl, buffer, n); + retry = checkResult(r); + if (retry) { + r = 0; + } + } + + return r > 0 ? (UInt32)r : 0; +} + +void +SecureSocket::write(const void* buffer, UInt32 n) +{ + bool retry = false; + int r = 0; + if (m_ssl != NULL) { + r = SSL_write(m_ssl->m_ssl, buffer, n); + retry = checkResult(r); + if (retry) { + r = 0; + } + } +} + +bool +SecureSocket::isReady() const +{ + return m_ready; +} + +void +SecureSocket::connectSecureSocket() +{ +#ifdef SYSAPI_WIN32 + secureConnect(static_cast(getSocket()->m_socket)); +#elif SYSAPI_UNIX + secureConnect(getSocket()->m_fd); +#endif +} + +void +SecureSocket::acceptSecureSocket() +{ +#ifdef SYSAPI_WIN32 + secureAccept(static_cast(getSocket()->m_socket)); +#elif SYSAPI_UNIX + secureAccept(getSocket()->m_fd); +#endif +} + +void +SecureSocket::initSsl(bool server) +{ + m_ssl = new Ssl(); + m_ssl->m_context = NULL; + m_ssl->m_ssl = NULL; + m_error = new char[MAX_ERROR_SIZE]; + + initContext(server); +} + +void +SecureSocket::onConnected() +{ + TCPSocket::onConnected(); + + connectSecureSocket(); +} + void SecureSocket::initContext(bool server) { @@ -111,7 +202,7 @@ SecureSocket::createSSL() } void -SecureSocket::accept(int socket) +SecureSocket::secureAccept(int socket) { createSSL(); @@ -124,6 +215,7 @@ SecureSocket::accept(int socket) bool retry = checkResult(r); while (retry) { ARCH->sleep(.5f); + LOG((CLOG_INFO "secureAccept sleep .5s")); r = SSL_accept(m_ssl->m_ssl); retry = checkResult(r); } @@ -132,7 +224,7 @@ SecureSocket::accept(int socket) } void -SecureSocket::connect(int socket) +SecureSocket::secureConnect(int socket) { createSSL(); @@ -152,38 +244,6 @@ SecureSocket::connect(int socket) showCertificate(); } -size_t -SecureSocket::write(const void* buffer, int size) -{ - bool retry = false; - int n = 0; - if (m_ssl != NULL) { - n = SSL_write(m_ssl->m_ssl, buffer, size); - retry = checkResult(n); - if (retry) { - n = 0; - } - } - - return n > 0 ? n : 0; -} - -size_t -SecureSocket::read(void* buffer, int size) -{ - bool retry = false; - int n = 0; - if (m_ssl != NULL) { - n = SSL_read(m_ssl->m_ssl, buffer, size); - retry = checkResult(n); - if (retry) { - n = 0; - } - } - - return n > 0 ? n : 0; -} - void SecureSocket::showCertificate() { @@ -277,7 +337,7 @@ SecureSocket::getError() bool errorUpdated = false; if (e != 0) { - ERR_error_string_n(e, m_error, m_errorSize); + ERR_error_string_n(e, m_error, MAX_ERROR_SIZE); errorUpdated = true; } else { @@ -286,9 +346,3 @@ SecureSocket::getError() return errorUpdated; } - -bool -SecureSocket::isReady() -{ - return m_ready; -} diff --git a/src/plugin/ns/SecureSocket.h b/src/plugin/ns/SecureSocket.h index 73b3a0e0..e6aea016 100644 --- a/src/plugin/ns/SecureSocket.h +++ b/src/plugin/ns/SecureSocket.h @@ -17,34 +17,48 @@ #pragma once +#include "net/TCPSocket.h" #include "base/XBase.h" -#include "base/Log.h" + +class IEventQueue; +class SocketMultiplexer; + +struct Ssl; //! Generic socket exception XBASE_SUBCLASS(XSecureSocket, XBase); -//! SSL + +//! Secure socket /*! -Secure socket layer using OpenSSL. +A secure socket using SSL. */ - -struct Ssl; - -class SecureSocket { +class SecureSocket : public TCPSocket { public: - SecureSocket(); + SecureSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer); + SecureSocket(IEventQueue* events, + SocketMultiplexer* socketMultiplexer, + ArchSocket socket); ~SecureSocket(); - void initContext(bool server); + // IStream overrides + virtual UInt32 read(void* buffer, UInt32 n); + virtual void write(const void* buffer, UInt32 n); + virtual bool isReady() const; + + void connectSecureSocket(); + void acceptSecureSocket(); + void initSsl(bool server); void loadCertificates(const char* CertFile); - void createSSL(); - void accept(int s); - void connect(int s); - size_t write(const void* buffer, int size); - size_t read(void* buffer, int size); - bool isReady(); private: + void onConnected(); + + // SSL + void initContext(bool server); + void createSSL(); + void secureAccept(int s); + void secureConnect(int s); void showCertificate(); bool checkResult(int n); void showError(); @@ -55,5 +69,4 @@ private: Ssl* m_ssl; bool m_ready; char* m_error; - const size_t m_errorSize; }; diff --git a/src/plugin/ns/ns.cpp b/src/plugin/ns/ns.cpp index 16eb872d..44a5ac36 100644 --- a/src/plugin/ns/ns.cpp +++ b/src/plugin/ns/ns.cpp @@ -18,28 +18,62 @@ #include "ns.h" #include "SecureSocket.h" +#include "SecureListenSocket.h" +#include "arch/Arch.h" +#include "base/Log.h" #include SecureSocket* g_secureSocket = NULL; +SecureListenSocket* g_secureListenSocket = NULL; +Arch* g_arch = NULL; +Log* g_log = NULL; extern "C" { +void +init(void* log, void* arch) +{ + if (g_log == NULL) { + g_log = new Log(reinterpret_cast(log)); + } + + if (g_arch == NULL) { + g_arch = new Arch(reinterpret_cast(arch)); + } +} + int -init(void (*sendEvent)(const char*, void*), void (*log)(const char*)) +initEvent(void (*sendEvent)(const char*, void*)) { return 0; } void* -invoke(const char* command, void* args) +invoke(const char* command, void** args) { + IEventQueue* arg1 = NULL; + SocketMultiplexer* arg2 = NULL; + if (args != NULL) { + arg1 = reinterpret_cast(args[0]); + arg2 = reinterpret_cast(args[1]); + } + if (strcmp(command, "getSecureSocket") == 0) { - if (g_secureSocket == NULL) { - g_secureSocket = new SecureSocket(); + if (g_secureSocket != NULL) { + delete g_secureSocket; } + g_secureSocket = new SecureSocket(arg1, arg2); + g_secureSocket->initSsl(false); return g_secureSocket; } + else if (strcmp(command, "getSecureListenSocket") == 0) { + if (g_secureListenSocket != NULL) { + delete g_secureListenSocket; + } + g_secureListenSocket = new SecureListenSocket(arg1, arg2); + return g_secureListenSocket; + } else { return NULL; } @@ -52,6 +86,10 @@ cleanup() delete g_secureSocket; } + if (g_secureListenSocket != NULL) { + delete g_secureListenSocket; + } + return 0; } diff --git a/src/plugin/ns/ns.h b/src/plugin/ns/ns.h index 582275e4..fa5a444f 100644 --- a/src/plugin/ns/ns.h +++ b/src/plugin/ns/ns.h @@ -33,8 +33,9 @@ extern "C" { -NS_API int init(void (*sendEvent)(const char*, void*), void (*log)(const char*)); -NS_API void* invoke(const char* command, void* args); +NS_API void init(void* log, void* arch); +NS_API int initEvent(void (*sendEvent)(const char*, void*)); +NS_API void* invoke(const char* command, void** args); NS_API int cleanup(); } diff --git a/src/plugin/winmmjoy/winmmjoy.cpp b/src/plugin/winmmjoy/winmmjoy.cpp index 96fc6670..0ae535e7 100644 --- a/src/plugin/winmmjoy/winmmjoy.cpp +++ b/src/plugin/winmmjoy/winmmjoy.cpp @@ -36,12 +36,15 @@ static void (*s_log)(const char*) = NULL; extern "C" { +void +init(void* log, void* arch) +{ +} + int -init(void (*sendEvent)(const char*, void*), void (*log)(const char*)) +initEvent(void (*sendEvent)(const char*, void*)) { s_sendEvent = sendEvent; - s_log = log; - LOG("init"); CreateThread(NULL, 0, mainLoop, NULL, 0, NULL); return 0; } diff --git a/src/plugin/winmmjoy/winmmjoy.h b/src/plugin/winmmjoy/winmmjoy.h index 9addac0f..38e8518c 100644 --- a/src/plugin/winmmjoy/winmmjoy.h +++ b/src/plugin/winmmjoy/winmmjoy.h @@ -29,7 +29,8 @@ extern "C" { -WINMMJOY_API int init(void (*sendEvent)(const char*, void*), void (*log)(const char*)); +WINMMJOY_API void init(void* log, void* arch); +WINMMJOY_API int initEvent(void (*sendEvent)(const char*, void*)); WINMMJOY_API int cleanup(); }