diff options
Diffstat (limited to '')
-rw-r--r-- | src/OSSupport/TCPLinkImpl.cpp | 291 |
1 files changed, 289 insertions, 2 deletions
diff --git a/src/OSSupport/TCPLinkImpl.cpp b/src/OSSupport/TCPLinkImpl.cpp index b15b6282f..df70f3f72 100644 --- a/src/OSSupport/TCPLinkImpl.cpp +++ b/src/OSSupport/TCPLinkImpl.cpp @@ -48,6 +48,9 @@ cTCPLinkImpl::cTCPLinkImpl(evutil_socket_t a_Socket, cTCPLink::cCallbacksPtr a_L cTCPLinkImpl::~cTCPLinkImpl() { + // If the TLS context still exists, free it: + m_TlsContext.reset(); + bufferevent_free(m_BufferEvent); } @@ -129,7 +132,16 @@ bool cTCPLinkImpl::Send(const void * a_Data, size_t a_Length) LOGD("%s: Cannot send data, the link is already shut down.", __FUNCTION__); return false; } - return (bufferevent_write(m_BufferEvent, a_Data, a_Length) == 0); + + // If running in TLS mode, push the data into the TLS context instead: + if (m_TlsContext != nullptr) + { + m_TlsContext->Send(a_Data, a_Length); + return true; + } + + // Send the data: + return SendRaw(a_Data, a_Length); } @@ -138,6 +150,14 @@ bool cTCPLinkImpl::Send(const void * a_Data, size_t a_Length) void cTCPLinkImpl::Shutdown(void) { + // If running in TLS mode, notify the TLS layer: + if (m_TlsContext != nullptr) + { + m_TlsContext->NotifyClose(); + m_TlsContext->ResetSelf(); + m_TlsContext.reset(); + } + // If there's no outgoing data, shutdown the socket directly: if (evbuffer_get_length(bufferevent_get_output(m_BufferEvent)) == 0) { @@ -155,6 +175,14 @@ void cTCPLinkImpl::Shutdown(void) void cTCPLinkImpl::Close(void) { + // If running in TLS mode, notify the TLS layer: + if (m_TlsContext != nullptr) + { + m_TlsContext->NotifyClose(); + m_TlsContext->ResetSelf(); + m_TlsContext.reset(); + } + // Disable all events on the socket, but keep it alive: bufferevent_disable(m_BufferEvent, EV_READ | EV_WRITE); if (m_Server == nullptr) @@ -173,18 +201,98 @@ void cTCPLinkImpl::Close(void) +AString cTCPLinkImpl::StartTLSClient( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey +) +{ + // Check preconditions: + if (m_TlsContext != nullptr) + { + return "TLS is already active on this link"; + } + if ( + ((a_OwnCert == nullptr) && (a_OwnPrivKey != nullptr)) || + ((a_OwnCert != nullptr) && (a_OwnPrivKey != nullptr)) + ) + { + return "Either provide both the certificate and private key, or neither"; + } + + // Create the TLS context: + m_TlsContext.reset(new cLinkTlsContext(*this)); + m_TlsContext->Initialize(true); + if (a_OwnCert != nullptr) + { + m_TlsContext->SetOwnCert(a_OwnCert, a_OwnPrivKey); + } + m_TlsContext->SetSelf(cLinkTlsContextWPtr(m_TlsContext)); + + // Start the handshake: + m_TlsContext->Handshake(); + return ""; +} + + + + + +AString cTCPLinkImpl::StartTLSServer( + cX509CertPtr a_OwnCert, + cCryptoKeyPtr a_OwnPrivKey, + const AString & a_StartTLSData +) +{ + // Check preconditions: + if (m_TlsContext != nullptr) + { + return "TLS is already active on this link"; + } + if ((a_OwnCert == nullptr) || (a_OwnPrivKey == nullptr)) + { + return "Provide the server certificate and private key"; + } + + // Create the TLS context: + m_TlsContext.reset(new cLinkTlsContext(*this)); + m_TlsContext->Initialize(false); + m_TlsContext->SetOwnCert(a_OwnCert, a_OwnPrivKey); + m_TlsContext->SetSelf(cLinkTlsContextWPtr(m_TlsContext)); + + // Push the initial data: + m_TlsContext->StoreReceivedData(a_StartTLSData.data(), a_StartTLSData.size()); + + // Start the handshake: + m_TlsContext->Handshake(); + return ""; +} + + + + + void cTCPLinkImpl::ReadCallback(bufferevent * a_BufferEvent, void * a_Self) { ASSERT(a_Self != nullptr); cTCPLinkImpl * Self = static_cast<cTCPLinkImpl *>(a_Self); + ASSERT(Self->m_BufferEvent == a_BufferEvent); ASSERT(Self->m_Callbacks != nullptr); // Read all the incoming data, in 1024-byte chunks: char data[1024]; size_t length; + auto tlsContext = Self->m_TlsContext; while ((length = bufferevent_read(a_BufferEvent, data, sizeof(data))) > 0) { - Self->m_Callbacks->OnReceivedData(data, length); + if (tlsContext != nullptr) + { + ASSERT(tlsContext->IsLink(Self)); + tlsContext->StoreReceivedData(data, length); + } + else + { + Self->ReceivedCleartextData(data, length); + } } } @@ -262,6 +370,13 @@ void cTCPLinkImpl::EventCallback(bufferevent * a_BufferEvent, short a_What, void // If the connection has been closed, call the link callback and remove the connection: if (a_What & BEV_EVENT_EOF) { + // If running in TLS mode and there's data left in the TLS contect, report it: + auto tlsContext = Self->m_TlsContext; + if (tlsContext != nullptr) + { + tlsContext->FlushBuffers(); + } + Self->m_Callbacks->OnRemoteClosed(); if (Self->m_Server != nullptr) { @@ -357,6 +472,178 @@ void cTCPLinkImpl::DoActualShutdown(void) +bool cTCPLinkImpl::SendRaw(const void * a_Data, size_t a_Length) +{ + return (bufferevent_write(m_BufferEvent, a_Data, a_Length) == 0); +} + + + + + +void cTCPLinkImpl::ReceivedCleartextData(const char * a_Data, size_t a_Length) +{ + ASSERT(m_Callbacks != nullptr); + m_Callbacks->OnReceivedData(a_Data, a_Length); +} + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cTCPLinkImpl::cLinkTlsContext: + +cTCPLinkImpl::cLinkTlsContext::cLinkTlsContext(cTCPLinkImpl & a_Link): + m_Link(a_Link) +{ +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::SetSelf(cLinkTlsContextWPtr a_Self) +{ + m_Self = a_Self; +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::ResetSelf(void) +{ + m_Self.reset(); +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::StoreReceivedData(const char * a_Data, size_t a_NumBytes) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + m_EncryptedData.append(a_Data, a_NumBytes); + + // Try to finish a pending handshake: + TryFinishHandshaking(); + + // Flush any cleartext data that can be "received": + FlushBuffers(); +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::FlushBuffers(void) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If the handshake didn't complete yet, bail out: + if (!HasHandshaken()) + { + return; + } + + char Buffer[1024]; + int NumBytes; + while ((NumBytes = ReadPlain(Buffer, sizeof(Buffer))) > 0) + { + m_Link.ReceivedCleartextData(Buffer, static_cast<size_t>(NumBytes)); + if (m_Self.expired()) + { + // The callback closed the SSL context, bail out + return; + } + } +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::TryFinishHandshaking(void) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If the handshake hasn't finished yet, retry: + if (!HasHandshaken()) + { + Handshake(); + // If the handshake succeeded, write all the queued plaintext data: + if (HasHandshaken()) + { + m_Link.GetCallbacks()->OnTlsHandshakeCompleted(); + WritePlain(m_CleartextData.data(), m_CleartextData.size()); + m_CleartextData.clear(); + } + } +} + + + + + +void cTCPLinkImpl::cLinkTlsContext::Send(const void * a_Data, size_t a_Length) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If the handshake hasn't completed yet, queue the data: + if (!HasHandshaken()) + { + m_CleartextData.append(reinterpret_cast<const char *>(a_Data), a_Length); + TryFinishHandshaking(); + return; + } + + // The connection is all set up, write the cleartext data into the SSL context: + WritePlain(a_Data, a_Length); + FlushBuffers(); +} + + + + + +int cTCPLinkImpl::cLinkTlsContext::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) +{ + // Hold self alive for the duration of this function + cLinkTlsContextPtr Self(m_Self); + + // If there's nothing queued in the buffer, report empty buffer: + if (m_EncryptedData.empty()) + { + return POLARSSL_ERR_NET_WANT_READ; + } + + // Copy as much data as possible to the provided buffer: + size_t BytesToCopy = std::min(a_NumBytes, m_EncryptedData.size()); + memcpy(a_Buffer, m_EncryptedData.data(), BytesToCopy); + m_EncryptedData.erase(0, BytesToCopy); + return static_cast<int>(BytesToCopy); +} + + + + + +int cTCPLinkImpl::cLinkTlsContext::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) +{ + m_Link.SendRaw(a_Buffer, a_NumBytes); + return static_cast<int>(a_NumBytes); +} + + + + + //////////////////////////////////////////////////////////////////////////////// // cNetwork API: |