diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Bindings/LuaTCPLink.cpp | 270 | ||||
-rw-r--r-- | src/Bindings/LuaTCPLink.h | 54 | ||||
-rw-r--r-- | src/HTTP/UrlClient.cpp | 74 | ||||
-rw-r--r-- | src/HTTP/UrlClient.h | 28 | ||||
-rw-r--r-- | src/OSSupport/Network.h | 33 | ||||
-rw-r--r-- | src/OSSupport/TCPLinkImpl.cpp | 291 | ||||
-rw-r--r-- | src/OSSupport/TCPLinkImpl.h | 75 |
7 files changed, 510 insertions, 315 deletions
diff --git a/src/Bindings/LuaTCPLink.cpp b/src/Bindings/LuaTCPLink.cpp index 466240a7d..905dfc5ac 100644 --- a/src/Bindings/LuaTCPLink.cpp +++ b/src/Bindings/LuaTCPLink.cpp @@ -6,6 +6,8 @@ #include "Globals.h" #include "LuaTCPLink.h" #include "LuaServerHandle.h" +#include "../PolarSSL++/X509Cert.h" +#include "../PolarSSL++/CryptoKey.h" @@ -48,13 +50,6 @@ cLuaTCPLink::~cLuaTCPLink() bool cLuaTCPLink::Send(const AString & a_Data) { - // If running in SSL mode, push the data into the SSL context instead: - if (m_SslContext != nullptr) - { - m_SslContext->Send(a_Data); - return true; - } - // Safely grab a copy of the link: auto link = m_Link; if (link == nullptr) @@ -144,12 +139,6 @@ void cLuaTCPLink::Shutdown(void) cTCPLinkPtr link = m_Link; if (link != nullptr) { - if (m_SslContext != nullptr) - { - m_SslContext->NotifyClose(); - m_SslContext->ResetSelf(); - m_SslContext.reset(); - } link->Shutdown(); } } @@ -164,12 +153,6 @@ void cLuaTCPLink::Close(void) cTCPLinkPtr link = m_Link; if (link != nullptr) { - if (m_SslContext != nullptr) - { - m_SslContext->NotifyClose(); - m_SslContext->ResetSelf(); - m_SslContext.reset(); - } link->Close(); } @@ -186,46 +169,31 @@ AString cLuaTCPLink::StartTLSClient( const AString & a_OwnPrivKeyPassword ) { - // Check preconditions: - if (m_SslContext != nullptr) - { - return "TLS is already active on this link"; - } - if ( - (a_OwnCertData.empty() && !a_OwnPrivKeyData.empty()) || - (!a_OwnCertData.empty() && a_OwnPrivKeyData.empty()) - ) - { - return "Either provide both the certificate and private key, or neither"; - } - - // Create the SSL context: - m_SslContext.reset(new cLinkSslContext(*this)); - m_SslContext->Initialize(true); - - // Create the peer cert, if required: - if (!a_OwnCertData.empty() && !a_OwnPrivKeyData.empty()) + auto link = m_Link; + if (link != nullptr) { - auto OwnCert = std::make_shared<cX509Cert>(); - int res = OwnCert->Parse(a_OwnCertData.data(), a_OwnCertData.size()); - if (res != 0) + cX509CertPtr ownCert; + if (!a_OwnCertData.empty()) { - m_SslContext.reset(); - return Printf("Cannot parse peer certificate: -0x%x", res); + ownCert = std::make_shared<cX509Cert>(); + auto res = ownCert->Parse(a_OwnCertData.data(), a_OwnCertData.size()); + if (res != 0) + { + return Printf("Cannot parse client certificate: -0x%x", res); + } } - auto OwnPrivKey = std::make_shared<cCryptoKey>(); - res = OwnPrivKey->ParsePrivate(a_OwnPrivKeyData.data(), a_OwnPrivKeyData.size(), a_OwnPrivKeyPassword); - if (res != 0) + cCryptoKeyPtr ownPrivKey; + if (!a_OwnPrivKeyData.empty()) { - m_SslContext.reset(); - return Printf("Cannot parse peer private key: -0x%x", res); + ownPrivKey = std::make_shared<cCryptoKey>(); + auto res = ownPrivKey->ParsePrivate(a_OwnPrivKeyData.data(), a_OwnPrivKeyData.size(), a_OwnPrivKeyPassword); + if (res != 0) + { + return Printf("Cannot parse client private key: -0x%x", res); + } } - m_SslContext->SetOwnCert(OwnCert, OwnPrivKey); + return link->StartTLSClient(ownCert, ownPrivKey); } - m_SslContext->SetSelf(cLinkSslContextWPtr(m_SslContext)); - - // Start the handshake: - m_SslContext->Handshake(); return ""; } @@ -240,43 +208,25 @@ AString cLuaTCPLink::StartTLSServer( const AString & a_StartTLSData ) { - // Check preconditions: - if (m_SslContext != nullptr) - { - return "TLS is already active on this link"; - } - if (a_OwnCertData.empty() || a_OwnPrivKeyData.empty()) + auto link = m_Link; + if (link != nullptr) { - return "Provide the server certificate and private key"; - } - - // Create the SSL context: - m_SslContext.reset(new cLinkSslContext(*this)); - m_SslContext->Initialize(false); - // Create the peer cert: auto OwnCert = std::make_shared<cX509Cert>(); int res = OwnCert->Parse(a_OwnCertData.data(), a_OwnCertData.size()); if (res != 0) { - m_SslContext.reset(); return Printf("Cannot parse server certificate: -0x%x", res); } auto OwnPrivKey = std::make_shared<cCryptoKey>(); res = OwnPrivKey->ParsePrivate(a_OwnPrivKeyData.data(), a_OwnPrivKeyData.size(), a_OwnPrivKeyPassword); if (res != 0) { - m_SslContext.reset(); return Printf("Cannot parse server private key: -0x%x", res); } - m_SslContext->SetOwnCert(OwnCert, OwnPrivKey); - m_SslContext->SetSelf(cLinkSslContextWPtr(m_SslContext)); - // Push the initial data: - m_SslContext->StoreReceivedData(a_StartTLSData.data(), a_StartTLSData.size()); - - // Start the handshake: - m_SslContext->Handshake(); + return link->StartTLSServer(OwnCert, OwnPrivKey, a_StartTLSData); + } return ""; } @@ -308,9 +258,6 @@ void cLuaTCPLink::Terminated(void) m_Link.reset(); } } - - // If the SSL context still exists, free it: - m_SslContext.reset(); } @@ -362,14 +309,6 @@ void cLuaTCPLink::OnLinkCreated(cTCPLinkPtr a_Link) void cLuaTCPLink::OnReceivedData(const char * a_Data, size_t a_Length) { - // If we're running in SSL mode, put the data into the SSL decryptor: - auto sslContext = m_SslContext; - if (sslContext != nullptr) - { - sslContext->StoreReceivedData(a_Data, a_Length); - return; - } - // Call the callback: m_Callbacks->CallTableFn("OnReceivedData", this, AString(a_Data, a_Length)); } @@ -380,13 +319,6 @@ void cLuaTCPLink::OnReceivedData(const char * a_Data, size_t a_Length) void cLuaTCPLink::OnRemoteClosed(void) { - // If running in SSL mode and there's data left in the SSL contect, report it: - auto sslContext = m_SslContext; - if (sslContext != nullptr) - { - sslContext->FlushBuffers(); - } - // Call the callback: m_Callbacks->CallTableFn("OnRemoteClosed", this); @@ -398,155 +330,3 @@ void cLuaTCPLink::OnRemoteClosed(void) -//////////////////////////////////////////////////////////////////////////////// -// cLuaTCPLink::cLinkSslContext: - -cLuaTCPLink::cLinkSslContext::cLinkSslContext(cLuaTCPLink & a_Link): - m_Link(a_Link) -{ -} - - - - - -void cLuaTCPLink::cLinkSslContext::SetSelf(cLinkSslContextWPtr a_Self) -{ - m_Self = a_Self; -} - - - - - -void cLuaTCPLink::cLinkSslContext::ResetSelf(void) -{ - m_Self.reset(); -} - - - - - -void cLuaTCPLink::cLinkSslContext::StoreReceivedData(const char * a_Data, size_t a_NumBytes) -{ - // Hold self alive for the duration of this function - cLinkSslContextPtr Self(m_Self); - - m_EncryptedData.append(a_Data, a_NumBytes); - - // Try to finish a pending handshake: - TryFinishHandshaking(); - - // Flush any cleartext data that can be "received": - FlushBuffers(); -} - - - - - -void cLuaTCPLink::cLinkSslContext::FlushBuffers(void) -{ - // Hold self alive for the duration of this function - cLinkSslContextPtr Self(m_Self); - - // If the handshake didn't complete yet, bail out: - if (!HasHandshaken()) - { - return; - } - - char Buffer[1024]; - int NumBytes; - while ((NumBytes = ReadPlain(Buffer, sizeof(Buffer))) > 0) - { - m_Link.ReceivedCleartextData(Buffer, static_cast<size_t>(NumBytes)); - if (m_Self.expired()) - { - // The callback closed the SSL context, bail out - return; - } - } -} - - - - - -void cLuaTCPLink::cLinkSslContext::TryFinishHandshaking(void) -{ - // Hold self alive for the duration of this function - cLinkSslContextPtr Self(m_Self); - - // If the handshake hasn't finished yet, retry: - if (!HasHandshaken()) - { - Handshake(); - } - - // If the handshake succeeded, write all the queued plaintext data: - if (HasHandshaken()) - { - WritePlain(m_CleartextData.data(), m_CleartextData.size()); - m_CleartextData.clear(); - } -} - - - - - -void cLuaTCPLink::cLinkSslContext::Send(const AString & a_Data) -{ - // Hold self alive for the duration of this function - cLinkSslContextPtr Self(m_Self); - - // If the handshake hasn't completed yet, queue the data: - if (!HasHandshaken()) - { - m_CleartextData.append(a_Data); - TryFinishHandshaking(); - return; - } - - // The connection is all set up, write the cleartext data into the SSL context: - WritePlain(a_Data.data(), a_Data.size()); - FlushBuffers(); -} - - - - - -int cLuaTCPLink::cLinkSslContext::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) -{ - // Hold self alive for the duration of this function - cLinkSslContextPtr Self(m_Self); - - // If there's nothing queued in the buffer, report empty buffer: - if (m_EncryptedData.empty()) - { - return POLARSSL_ERR_NET_WANT_READ; - } - - // Copy as much data as possible to the provided buffer: - size_t BytesToCopy = std::min(a_NumBytes, m_EncryptedData.size()); - memcpy(a_Buffer, m_EncryptedData.data(), BytesToCopy); - m_EncryptedData.erase(0, BytesToCopy); - return static_cast<int>(BytesToCopy); -} - - - - - -int cLuaTCPLink::cLinkSslContext::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) -{ - m_Link.m_Link->Send(a_Buffer, a_NumBytes); - return static_cast<int>(a_NumBytes); -} - - - - diff --git a/src/Bindings/LuaTCPLink.h b/src/Bindings/LuaTCPLink.h index b8c886ef8..f4ca67018 100644 --- a/src/Bindings/LuaTCPLink.h +++ b/src/Bindings/LuaTCPLink.h @@ -10,7 +10,6 @@ #pragma once #include "../OSSupport/Network.h" -#include "../PolarSSL++/SslContext.h" #include "LuaState.h" @@ -90,54 +89,6 @@ public: ); protected: - // fwd: - class cLinkSslContext; - typedef SharedPtr<cLinkSslContext> cLinkSslContextPtr; - typedef WeakPtr<cLinkSslContext> cLinkSslContextWPtr; - - /** Wrapper around cSslContext that is used when this link is being encrypted by SSL. */ - class cLinkSslContext : - public cSslContext - { - cLuaTCPLink & m_Link; - - /** Buffer for storing the incoming encrypted data until it is requested by the SSL decryptor. */ - AString m_EncryptedData; - - /** Buffer for storing the outgoing cleartext data until the link has finished handshaking. */ - AString m_CleartextData; - - /** Shared ownership of self, so that this object can keep itself alive for as long as it needs. */ - cLinkSslContextWPtr m_Self; - - public: - cLinkSslContext(cLuaTCPLink & a_Link); - - /** Shares ownership of self, so that this object can keep itself alive for as long as it needs. */ - void SetSelf(cLinkSslContextWPtr a_Self); - - /** Removes the self ownership so that we can detect the SSL closure. */ - void ResetSelf(void); - - /** Stores the specified block of data into the buffer of the data to be decrypted (incoming from remote). - Also flushes the SSL buffers by attempting to read any data through the SSL context. */ - void StoreReceivedData(const char * a_Data, size_t a_NumBytes); - - /** Tries to read any cleartext data available through the SSL, reports it in the link. */ - void FlushBuffers(void); - - /** Tries to finish handshaking the SSL. */ - void TryFinishHandshaking(void); - - /** Sends the specified cleartext data over the SSL to the remote peer. - If the handshake hasn't been completed yet, queues the data for sending when it completes. */ - void Send(const AString & a_Data); - - // cSslContext overrides: - virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override; - virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override; - }; - /** The Lua table that holds the callbacks to be invoked. */ cLuaState::cTableRefPtr m_Callbacks; @@ -149,11 +100,6 @@ protected: /** The server that is responsible for this link, if any. */ cLuaServerHandleWPtr m_Server; - /** The SSL context used for encryption, if this link uses SSL. - If valid, the link uses encryption through this context. */ - cLinkSslContextPtr m_SslContext; - - /** Common code called when the link is considered as terminated. Releases m_Link, m_Callbacks and this from m_Server, each when applicable. */ void Terminated(void); diff --git a/src/HTTP/UrlClient.cpp b/src/HTTP/UrlClient.cpp index f9e642b22..9346882f1 100644 --- a/src/HTTP/UrlClient.cpp +++ b/src/HTTP/UrlClient.cpp @@ -7,6 +7,8 @@ #include "UrlClient.h" #include "UrlParser.h" #include "HTTPMessageParser.h" +#include "../PolarSSL++/X509Cert.h" +#include "../PolarSSL++/CryptoKey.h" @@ -67,6 +69,38 @@ public: bool ShouldAllowRedirects() const; + cX509CertPtr GetOwnCert() const + { + auto itr = m_Options.find("OwnCert"); + if (itr == m_Options.end()) + { + return nullptr; + } + cX509CertPtr cert; + if (!cert->Parse(itr->second.data(), itr->second.size())) + { + LOGD("OwnCert failed to parse"); + return nullptr; + } + return cert; + } + + cCryptoKeyPtr GetOwnPrivKey() const + { + auto itr = m_Options.find("OwnPrivKey"); + if (itr == m_Options.end()) + { + return nullptr; + } + cCryptoKeyPtr key; + auto passItr = m_Options.find("OwnPrivKeyPassword"); + auto pass = (passItr == m_Options.end()) ? AString() : passItr->second; + if (!key->ParsePrivate(itr->second.data(), itr->second.size(), pass)) + { + return nullptr; + } + return key; + } protected: @@ -148,6 +182,9 @@ protected: } + // cTCPLink::cCallbacks override: TLS handshake completed on the link: + virtual void OnTlsHandshakeCompleted(void) override; + /** Called when there's data incoming from the remote peer. */ virtual void OnReceivedData(const char * a_Data, size_t a_Length) override; @@ -188,6 +225,9 @@ public: /** Called when there's data incoming from the remote peer. */ virtual void OnReceivedData(const char * a_Data, size_t a_Length) = 0; + /** Called when the TLS handshake has completed on the underlying link. */ + virtual void OnTlsHandshakeCompleted(void) = 0; + /** Called when the remote end closes the connection. The link is still available for connection information query (IP / port). Sending data on the link is not an error, but the data won't be delivered. */ @@ -223,7 +263,7 @@ public: m_Link = &a_Link; if (m_IsTls) { - // TODO: Start TLS + m_Link->StartTLSClient(m_ParentRequest.GetOwnCert(), m_ParentRequest.GetOwnPrivKey()); } else { @@ -231,9 +271,12 @@ public: } } + + /** Sends the HTTP request over the link. + Common code for both HTTP and HTTPS. */ void SendRequest() { - // Send the request: + // Compose the request line: auto requestLine = m_ParentRequest.m_UrlPath; if (requestLine.empty()) { @@ -245,6 +288,8 @@ public: requestLine.append(m_ParentRequest.m_UrlQuery); } m_Link->Send(Printf("%s %s HTTP/1.1\r\n", m_ParentRequest.m_Method.c_str(), requestLine.c_str())); + + // Send the headers: m_Link->Send(Printf("Host: %s\r\n", m_ParentRequest.m_UrlHost.c_str())); m_Link->Send(Printf("Content-Length: %u\r\n", static_cast<unsigned>(m_ParentRequest.m_Body.size()))); for (auto itr = m_ParentRequest.m_Headers.cbegin(), end = m_ParentRequest.m_Headers.cend(); itr != end; ++itr) @@ -252,6 +297,8 @@ public: m_Link->Send(Printf("%s: %s\r\n", itr->first.c_str(), itr->second.c_str())); } // for itr - m_Headers[] m_Link->Send("\r\n", 2); + + // Send the body: m_Link->Send(m_ParentRequest.m_Body); // Notify the callbacks that the request has been sent: @@ -270,6 +317,12 @@ public: } + virtual void OnTlsHandshakeCompleted(void) override + { + SendRequest(); + } + + virtual void OnRemoteClosed(void) override { m_Link = nullptr; @@ -385,12 +438,12 @@ protected: /** The network link. */ cTCPLink * m_Link; - /** If true, the TLS should be started on the link before sending the request (used for https). */ - bool m_IsTls; - /** Parser of the HTTP response message. */ cHTTPMessageParser m_Parser; + /** If true, the TLS should be started on the link before sending the request (used for https). */ + bool m_IsTls; + /** Set to true if the first line contains a redirecting HTTP status code and the options specify to follow redirects. If true, and the parent request allows redirects, neither headers not the body contents are reported through the callbacks, and after the entire request is parsed, the redirect is attempted. */ @@ -475,6 +528,17 @@ void cUrlClientRequest::OnConnected(cTCPLink & a_Link) +void cUrlClientRequest::OnTlsHandshakeCompleted(void) +{ + // Notify the scheme handler and the callbacks: + m_SchemeHandler->OnTlsHandshakeCompleted(); + m_Callbacks.OnTlsHandshakeCompleted(); +} + + + + + void cUrlClientRequest::OnReceivedData(const char * a_Data, size_t a_Length) { auto handler = m_SchemeHandler; diff --git a/src/HTTP/UrlClient.h b/src/HTTP/UrlClient.h index 42086a4f1..652cc76f7 100644 --- a/src/HTTP/UrlClient.h +++ b/src/HTTP/UrlClient.h @@ -5,7 +5,10 @@ /* Options that can be set via the Options parameter to the cUrlClient calls: -"MaxRedirects": The maximum number of allowed redirects before the client refuses a redirect with an error +"MaxRedirects": The maximum number of allowed redirects before the client refuses a redirect with an error +"OwnCert": The client certificate to use, if requested by the server. Any string that can be parsed by cX509Cert. +"OwnPrivKey": The private key appropriate for OwnCert. Any string that can be parsed by cCryptoKey. +"OwnPrivKeyPassword": The password for OwnPrivKey. If not present or empty, no password is assumed. Behavior: - If a redirect is received, and redirection is allowed, the redirection is reported via OnRedirecting() callback @@ -34,8 +37,11 @@ public: class cCallbacks { public: + // Force a virtual destructor in descendants: + virtual ~cCallbacks() {} + /** Called when the TCP connection is established. */ - virtual void OnConnected(cTCPLink & a_Link) {}; + virtual void OnConnected(cTCPLink & a_Link) {} /** Called for TLS connections, when the server certificate is received. Return true to continue with the request, false to abort. @@ -43,30 +49,34 @@ public: TODO: The certificate parameter needs a representation! */ virtual bool OnCertificateReceived() { return true; } + /** Called for TLS connections, when the TLS handshake has been completed. + An empty default implementation is provided so that clients don't need to reimplement it unless they are interested in the event. */ + virtual void OnTlsHandshakeCompleted() { } + /** Called after the entire request has been sent to the remote peer. */ - virtual void OnRequestSent() {}; + virtual void OnRequestSent() {} /** Called after the first line of the response is parsed, unless the response is an allowed redirect. */ virtual void OnStatusLine(const AString & a_HttpVersion, int a_StatusCode, const AString & a_Rest) {} /** Called when a single HTTP header is received and parsed, unless the response is an allowed redirect Called once for each incoming header. */ - virtual void OnHeader(const AString & a_Key, const AString & a_Value) {}; + virtual void OnHeader(const AString & a_Key, const AString & a_Value) {} /** Called when the HTTP headers have been fully parsed, unless the response is an allowed redirect. There will be no more OnHeader() calls. */ - virtual void OnHeadersFinished() {}; + virtual void OnHeadersFinished() {} /** Called when the next fragment of the response body is received, unless the response is an allowed redirect. This can be called multiple times, as data arrives over the network. */ - virtual void OnBodyData(const void * a_Data, size_t a_Size) {}; + virtual void OnBodyData(const void * a_Data, size_t a_Size) {} /** Called after the response body has been fully reported by OnBody() calls, unless the response is an allowed redirect. There will be no more OnBody() calls. */ - virtual void OnBodyFinished() {}; + virtual void OnBodyFinished() {} /** Called when an asynchronous error is encountered. */ - virtual void OnError(const AString & a_ErrorMsg) {}; + virtual void OnError(const AString & a_ErrorMsg) {} /** Called when a redirect is to be followed. This is called even if the redirecting is prohibited by the options; in such an event, this call will be @@ -74,7 +84,7 @@ public: If a response indicates a redirect (and the request allows redirecting), the regular callbacks OnStatusLine(), OnHeader(), OnHeadersFinished(), OnBodyData() and OnBodyFinished() are not called for such a response; instead, the redirect is silently attempted. */ - virtual void OnRedirecting(const AString & a_NewLocation) {}; + virtual void OnRedirecting(const AString & a_NewLocation) {} }; diff --git a/src/OSSupport/Network.h b/src/OSSupport/Network.h index 1162d7fc6..78c5e92f0 100644 --- a/src/OSSupport/Network.h +++ b/src/OSSupport/Network.h @@ -20,6 +20,11 @@ typedef std::vector<cTCPLinkPtr> cTCPLinkPtrs; class cServerHandle; typedef SharedPtr<cServerHandle> cServerHandlePtr; typedef std::vector<cServerHandlePtr> cServerHandlePtrs; +class cCryptoKey; +typedef SharedPtr<cCryptoKey> cCryptoKeyPtr; +class cX509Cert; +typedef SharedPtr<cX509Cert> cX509CertPtr; + @@ -49,6 +54,10 @@ public: Sending data on the link is not an error, but the data won't be delivered. */ virtual void OnRemoteClosed(void) = 0; + /** Called when the TLS handshake has been completed and communication can continue regularly. + Has an empty default implementation, so that link callback descendants don't need to specify TLS handlers when they don't use TLS at all. */ + virtual void OnTlsHandshakeCompleted(void) {} + /** Called when an error is detected on the connection. */ virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) = 0; }; @@ -90,6 +99,30 @@ public: Sends the RST packet, queued outgoing and incoming data is lost. */ virtual void Close(void) = 0; + /** Starts a TLS handshake as a client connection. + If a client certificate should be used for the connection, set the certificate into a_OwnCertData and + its corresponding private key to a_OwnPrivKeyData. If both are empty, no client cert is presented. + a_OwnPrivKeyPassword is the password to be used for decoding PrivKey, empty if not passworded. + Returns empty string on success, non-empty error description on failure. */ + virtual AString StartTLSClient( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey + ) = 0; + + /** Starts a TLS handshake as a server connection. + Set the server certificate into a_CertData and its corresponding private key to a_OwnPrivKeyData. + a_OwnPrivKeyPassword is the password to be used for decoding PrivKey, empty if not passworded. + a_StartTLSData is any data that should be pushed into the TLS before reading more data from the remote. + This is used mainly for protocols starting TLS in the middle of communication, when the TLS start command + can be received together with the TLS Client Hello message in one OnReceivedData() call, to re-queue the + Client Hello message into the TLS handshake buffer. + Returns empty string on success, non-empty error description on failure. */ + virtual AString StartTLSServer( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey, + const AString & a_StartTLSData + ) = 0; + /** Returns the callbacks that are used. */ cCallbacksPtr GetCallbacks(void) const { return m_Callbacks; } diff --git a/src/OSSupport/TCPLinkImpl.cpp b/src/OSSupport/TCPLinkImpl.cpp index b15b6282f..df70f3f72 100644 --- a/src/OSSupport/TCPLinkImpl.cpp +++ b/src/OSSupport/TCPLinkImpl.cpp @@ -48,6 +48,9 @@ cTCPLinkImpl::cTCPLinkImpl(evutil_socket_t a_Socket, cTCPLink::cCallbacksPtr a_L cTCPLinkImpl::~cTCPLinkImpl() { + // If the TLS context still exists, free it: + m_TlsContext.reset(); + bufferevent_free(m_BufferEvent); } @@ -129,7 +132,16 @@ bool cTCPLinkImpl::Send(const void * a_Data, size_t a_Length) LOGD("%s: Cannot send data, the link is already shut down.", __FUNCTION__); return false; } - return (bufferevent_write(m_BufferEvent, a_Data, a_Length) == 0); + + // If running in TLS mode, push the data into the TLS context instead: + if (m_TlsContext != nullptr) + { + m_TlsContext->Send(a_Data, a_Length); + return true; + } + + // Send the data: + return SendRaw(a_Data, a_Length); } @@ -138,6 +150,14 @@ bool cTCPLinkImpl::Send(const void * a_Data, size_t a_Length) void cTCPLinkImpl::Shutdown(void) { + // If running in TLS mode, notify the TLS layer: + if (m_TlsContext != nullptr) + { + m_TlsContext->NotifyClose(); + m_TlsContext->ResetSelf(); + m_TlsContext.reset(); + } + // If there's no outgoing data, shutdown the socket directly: if (evbuffer_get_length(bufferevent_get_output(m_BufferEvent)) == 0) { @@ -155,6 +175,14 @@ void cTCPLinkImpl::Shutdown(void) void cTCPLinkImpl::Close(void) { + // If running in TLS mode, notify the TLS layer: + if (m_TlsContext != nullptr) + { + m_TlsContext->NotifyClose(); + m_TlsContext->ResetSelf(); + m_TlsContext.reset(); + } + // Disable all events on the socket, but keep it alive: bufferevent_disable(m_BufferEvent, EV_READ | EV_WRITE); if (m_Server == nullptr) @@ -173,18 +201,98 @@ void cTCPLinkImpl::Close(void) +AString cTCPLinkImpl::StartTLSClient( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey +) +{ + // Check preconditions: + if (m_TlsContext != nullptr) + { + return "TLS is already active on this link"; + } + if ( + ((a_OwnCert == nullptr) && (a_OwnPrivKey != nullptr)) || + ((a_OwnCert != nullptr) && (a_OwnPrivKey != nullptr)) + ) + { + return "Either provide both the certificate and private key, or neither"; + } + + // Create the TLS context: + m_TlsContext.reset(new cLinkTlsContext(*this)); + m_TlsContext->Initialize(true); + if (a_OwnCert != nullptr) + { + m_TlsContext->SetOwnCert(a_OwnCert, a_OwnPrivKey); + } + m_TlsContext->SetSelf(cLinkTlsContextWPtr(m_TlsContext)); + + // Start the handshake: + m_TlsContext->Handshake(); + return ""; +} + + + + + +AString cTCPLinkImpl::StartTLSServer( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey, + const AString & a_StartTLSData +) +{ + // Check preconditions: + if (m_TlsContext != nullptr) + { + return "TLS is already active on this link"; + } + if ((a_OwnCert == nullptr) || (a_OwnPrivKey == nullptr)) + { + return "Provide the server certificate and private key"; + } + + // Create the TLS context: + m_TlsContext.reset(new cLinkTlsContext(*this)); + m_TlsContext->Initialize(false); + m_TlsContext->SetOwnCert(a_OwnCert, a_OwnPrivKey); + m_TlsContext->SetSelf(cLinkTlsContextWPtr(m_TlsContext)); + + // Push the initial data: + m_TlsContext->StoreReceivedData(a_StartTLSData.data(), a_StartTLSData.size()); + + // Start the handshake: + m_TlsContext->Handshake(); + return ""; +} + + + + + void cTCPLinkImpl::ReadCallback(bufferevent * a_BufferEvent, void * a_Self) { ASSERT(a_Self != nullptr); cTCPLinkImpl * Self = static_cast<cTCPLinkImpl *>(a_Self); + ASSERT(Self->m_BufferEvent == a_BufferEvent); ASSERT(Self->m_Callbacks != nullptr); // Read all the incoming data, in 1024-byte chunks: char data[1024]; size_t length; + auto tlsContext = Self->m_TlsContext; while ((length = bufferevent_read(a_BufferEvent, data, sizeof(data))) > 0) { - Self->m_Callbacks->OnReceivedData(data, length); + if (tlsContext != nullptr) + { + ASSERT(tlsContext->IsLink(Self)); + tlsContext->StoreReceivedData(data, length); + } + else + { + Self->ReceivedCleartextData(data, length); + } } } @@ -262,6 +370,13 @@ void cTCPLinkImpl::EventCallback(bufferevent * a_BufferEvent, short a_What, void // If the connection has been closed, call the link callback and remove the connection: if (a_What & BEV_EVENT_EOF) { + // If running in TLS mode and there's data left in the TLS contect, report it: + auto tlsContext = Self->m_TlsContext; + if (tlsContext != nullptr) + { + tlsContext->FlushBuffers(); + } + Self->m_Callbacks->OnRemoteClosed(); if (Self->m_Server != nullptr) { @@ -357,6 +472,178 @@ void cTCPLinkImpl::DoActualShutdown(void) +bool cTCPLinkImpl::SendRaw(const void * a_Data, size_t a_Length) +{ + return (bufferevent_write(m_BufferEvent, a_Data, a_Length) == 0); +} + + + + + +void cTCPLinkImpl::ReceivedCleartextData(const char * a_Data, size_t a_Length) +{ + ASSERT(m_Callbacks != nullptr); + m_Callbacks->OnReceivedData(a_Data, a_Length); +} + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cTCPLinkImpl::cLinkTlsContext: + +cTCPLinkImpl::cLinkTlsContext::cLinkTlsContext(cTCPLinkImpl & a_Link): + m_Link(a_Link) +{ +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::SetSelf(cLinkTlsContextWPtr a_Self) +{ + m_Self = a_Self; +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::ResetSelf(void) +{ + m_Self.reset(); +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::StoreReceivedData(const char * a_Data, size_t a_NumBytes) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + m_EncryptedData.append(a_Data, a_NumBytes); + + // Try to finish a pending handshake: + TryFinishHandshaking(); + + // Flush any cleartext data that can be "received": + FlushBuffers(); +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::FlushBuffers(void) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If the handshake didn't complete yet, bail out: + if (!HasHandshaken()) + { + return; + } + + char Buffer[1024]; + int NumBytes; + while ((NumBytes = ReadPlain(Buffer, sizeof(Buffer))) > 0) + { + m_Link.ReceivedCleartextData(Buffer, static_cast<size_t>(NumBytes)); + if (m_Self.expired()) + { + // The callback closed the SSL context, bail out + return; + } + } +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::TryFinishHandshaking(void) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If the handshake hasn't finished yet, retry: + if (!HasHandshaken()) + { + Handshake(); + // If the handshake succeeded, write all the queued plaintext data: + if (HasHandshaken()) + { + m_Link.GetCallbacks()->OnTlsHandshakeCompleted(); + WritePlain(m_CleartextData.data(), m_CleartextData.size()); + m_CleartextData.clear(); + } + } +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::Send(const void * a_Data, size_t a_Length) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If the handshake hasn't completed yet, queue the data: + if (!HasHandshaken()) + { + m_CleartextData.append(reinterpret_cast<const char *>(a_Data), a_Length); + TryFinishHandshaking(); + return; + } + + // The connection is all set up, write the cleartext data into the SSL context: + WritePlain(a_Data, a_Length); + FlushBuffers(); +} + + + + + +int cTCPLinkImpl::cLinkTlsContext::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If there's nothing queued in the buffer, report empty buffer: + if (m_EncryptedData.empty()) + { + return POLARSSL_ERR_NET_WANT_READ; + } + + // Copy as much data as possible to the provided buffer: + size_t BytesToCopy = std::min(a_NumBytes, m_EncryptedData.size()); + memcpy(a_Buffer, m_EncryptedData.data(), BytesToCopy); + m_EncryptedData.erase(0, BytesToCopy); + return static_cast<int>(BytesToCopy); +} + + + + + +int cTCPLinkImpl::cLinkTlsContext::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) +{ + m_Link.SendRaw(a_Buffer, a_NumBytes); + return static_cast<int>(a_NumBytes); +} + + + + + //////////////////////////////////////////////////////////////////////////////// // cNetwork API: diff --git a/src/OSSupport/TCPLinkImpl.h b/src/OSSupport/TCPLinkImpl.h index bea21aeff..b54c1a2cc 100644 --- a/src/OSSupport/TCPLinkImpl.h +++ b/src/OSSupport/TCPLinkImpl.h @@ -14,6 +14,7 @@ #include "Network.h" #include <event2/event.h> #include <event2/bufferevent.h> +#include "../PolarSSL++/SslContext.h" @@ -64,9 +65,73 @@ public: virtual UInt16 GetRemotePort(void) const override { return m_RemotePort; } virtual void Shutdown(void) override; virtual void Close(void) override; + virtual AString StartTLSClient( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey + ) override; + virtual AString StartTLSServer( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey, + const AString & a_StartTLSData + ) override; protected: + // fwd: + class cLinkTlsContext; + typedef SharedPtr<cLinkTlsContext> cLinkTlsContextPtr; + typedef WeakPtr<cLinkTlsContext> cLinkTlsContextWPtr; + + /** Wrapper around cSslContext that is used when this link is being encrypted by SSL. */ + class cLinkTlsContext : + public cSslContext + { + cTCPLinkImpl & m_Link; + + /** Buffer for storing the incoming encrypted data until it is requested by the SSL decryptor. */ + AString m_EncryptedData; + + /** Buffer for storing the outgoing cleartext data until the link has finished handshaking. */ + AString m_CleartextData; + + /** Shared ownership of self, so that this object can keep itself alive for as long as it needs. */ + cLinkTlsContextWPtr m_Self; + + public: + cLinkTlsContext(cTCPLinkImpl & a_Link); + + /** Shares ownership of self, so that this object can keep itself alive for as long as it needs. */ + void SetSelf(cLinkTlsContextWPtr a_Self); + + /** Removes the self ownership so that we can detect the SSL closure. */ + void ResetSelf(void); + + /** Stores the specified block of data into the buffer of the data to be decrypted (incoming from remote). + Also flushes the SSL buffers by attempting to read any data through the SSL context. */ + void StoreReceivedData(const char * a_Data, size_t a_NumBytes); + + /** Tries to read any cleartext data available through the SSL, reports it in the link. */ + void FlushBuffers(void); + + /** Tries to finish handshaking the SSL. */ + void TryFinishHandshaking(void); + + /** Sends the specified cleartext data over the SSL to the remote peer. + If the handshake hasn't been completed yet, queues the data for sending when it completes. */ + void Send(const void * a_Data, size_t a_Length); + + // cSslContext overrides: + virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override; + virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override; + + /** Returns true if the context's associated TCP link is the same link as a_Link. */ + bool IsLink(cTCPLinkImpl * a_Link) + { + return (a_Link == &m_Link); + } + }; + + /** Callbacks to call when the connection is established. May be NULL if not used. Only used for outgoing connections (cNetwork::Connect()). */ cNetwork::cConnectCallbacksPtr m_ConnectCallbacks; @@ -99,6 +164,10 @@ protected: data is sent to the OS TCP stack, the socket gets shut down. */ bool m_ShouldShutdown; + /** The SSL context used for encryption, if this link uses SSL. + If valid, the link uses encryption through this context. */ + cLinkTlsContextPtr m_TlsContext; + /** Creates a new link to be queued to connect to a specified host:port. Used for outgoing connections created using cNetwork::Connect(). @@ -127,6 +196,12 @@ protected: /** Calls shutdown on the link and disables LibEvent writing. Called after all data from LibEvent buffers is sent to the OS TCP stack and shutdown() has been called before. */ void DoActualShutdown(void); + + /** Sends the data directly to the socket (without the optional TLS). */ + bool SendRaw(const void * a_Data, size_t a_Length); + + /** Called by the TLS when it has decoded a piece of incoming cleartext data from the socket. */ + void ReceivedCleartextData(const char * a_Data, size_t a_Length); }; |