summaryrefslogtreecommitdiffstats
path: root/src/PolarSSL++
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/PolarSSL++/BlockingSslClientSocket.cpp176
-rw-r--r--src/PolarSSL++/BlockingSslClientSocket.h35
-rw-r--r--src/PolarSSL++/SslContext.cpp2
3 files changed, 198 insertions, 15 deletions
diff --git a/src/PolarSSL++/BlockingSslClientSocket.cpp b/src/PolarSSL++/BlockingSslClientSocket.cpp
index 59e1281ac..821125b31 100644
--- a/src/PolarSSL++/BlockingSslClientSocket.cpp
+++ b/src/PolarSSL++/BlockingSslClientSocket.cpp
@@ -10,6 +10,80 @@
+////////////////////////////////////////////////////////////////////////////////
+// cBlockingSslClientSocketConnectCallbacks:
+
+class cBlockingSslClientSocketConnectCallbacks:
+ public cNetwork::cConnectCallbacks
+{
+ /** The socket object that is using this instance of the callbacks. */
+ cBlockingSslClientSocket & m_Socket;
+
+ virtual void OnConnected(cTCPLink & a_Link) override
+ {
+ m_Socket.OnConnected();
+ }
+
+ virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
+ {
+ m_Socket.OnConnectError(a_ErrorMsg);
+ }
+
+public:
+ cBlockingSslClientSocketConnectCallbacks(cBlockingSslClientSocket & a_Socket):
+ m_Socket(a_Socket)
+ {
+ }
+};
+
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
+// cBlockingSslClientSocketLinkCallbacks:
+
+class cBlockingSslClientSocketLinkCallbacks:
+ public cTCPLink::cCallbacks
+{
+ cBlockingSslClientSocket & m_Socket;
+
+ virtual void OnLinkCreated(cTCPLinkPtr a_Link) override
+ {
+ m_Socket.SetLink(a_Link);
+ }
+
+
+ virtual void OnReceivedData(const char * a_Data, size_t a_Length)
+ {
+ m_Socket.OnReceivedData(a_Data, a_Length);
+ }
+
+
+ virtual void OnRemoteClosed(void)
+ {
+ m_Socket.OnDisconnected();
+ }
+
+
+ virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg)
+ {
+ m_Socket.OnDisconnected();
+ }
+public:
+ cBlockingSslClientSocketLinkCallbacks(cBlockingSslClientSocket & a_Socket):
+ m_Socket(a_Socket)
+ {
+ }
+};
+
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
+// cBlockingSslClientSocket:
+
cBlockingSslClientSocket::cBlockingSslClientSocket(void) :
m_Ssl(*this),
m_IsConnected(false)
@@ -32,10 +106,19 @@ bool cBlockingSslClientSocket::Connect(const AString & a_ServerName, UInt16 a_Po
}
// Connect the underlying socket:
- m_Socket.CreateSocket(cSocket::IPv4);
- if (!m_Socket.ConnectIPv4(a_ServerName.c_str(), a_Port))
+ m_ServerName = a_ServerName;
+ if (!cNetwork::Connect(a_ServerName, a_Port,
+ std::make_shared<cBlockingSslClientSocketConnectCallbacks>(*this),
+ std::make_shared<cBlockingSslClientSocketLinkCallbacks>(*this))
+ )
+ {
+ return false;
+ }
+
+ // Wait for the connection to succeed or fail:
+ m_Event.Wait();
+ if (!m_IsConnected)
{
- Printf(m_LastErrorText, "Socket connect failed: %s", m_Socket.GetLastErrorString().c_str());
return false;
}
@@ -102,7 +185,7 @@ bool cBlockingSslClientSocket::Send(const void * a_Data, size_t a_NumBytes)
ASSERT(m_IsConnected);
// Keep sending the data until all of it is sent:
- const char * Data = (const char *)a_Data;
+ const char * Data = reinterpret_cast<const char *>(a_Data);
size_t NumBytes = a_NumBytes;
for (;;)
{
@@ -156,7 +239,8 @@ void cBlockingSslClientSocket::Disconnect(void)
}
m_Ssl.NotifyClose();
- m_Socket.CloseSocket();
+ m_Socket->Close();
+ m_Socket.reset();
m_IsConnected = false;
}
@@ -166,13 +250,25 @@ void cBlockingSslClientSocket::Disconnect(void)
int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
{
- int res = m_Socket.Receive((char *)a_Buffer, a_NumBytes, 0);
- if (res < 0)
+ // Wait for any incoming data, if there is none:
+ cCSLock Lock(m_CSIncomingData);
+ while (m_IsConnected && m_IncomingData.empty())
+ {
+ cCSUnlock Unlock(Lock);
+ m_Event.Wait();
+ }
+
+ // If we got disconnected, report an error after processing all data:
+ if (!m_IsConnected && m_IncomingData.empty())
{
- // PolarSSL's net routines distinguish between connection reset and general failure, we don't need to
return POLARSSL_ERR_NET_RECV_FAILED;
}
- return res;
+
+ // Copy the data from the incoming buffer into the specified space:
+ size_t NumToCopy = std::min(a_NumBytes, m_IncomingData.size());
+ memcpy(a_Buffer, m_IncomingData.data(), NumToCopy);
+ m_IncomingData.erase(0, NumToCopy);
+ return static_cast<int>(NumToCopy);
}
@@ -181,13 +277,69 @@ int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t
int cBlockingSslClientSocket::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
{
- int res = m_Socket.Send((const char *)a_Buffer, a_NumBytes);
- if (res < 0)
+ cTCPLinkPtr Socket(m_Socket); // Make a copy so that multiple threads don't race on deleting the socket.
+ if (Socket == nullptr)
+ {
+ return POLARSSL_ERR_NET_SEND_FAILED;
+ }
+ if (!Socket->Send(a_Buffer, a_NumBytes))
{
// PolarSSL's net routines distinguish between connection reset and general failure, we don't need to
return POLARSSL_ERR_NET_SEND_FAILED;
}
- return res;
+ return static_cast<int>(a_NumBytes);
+}
+
+
+
+
+void cBlockingSslClientSocket::OnConnected(void)
+{
+ m_IsConnected = true;
+ m_Event.Set();
+}
+
+
+
+
+
+void cBlockingSslClientSocket::OnConnectError(const AString & a_ErrorMsg)
+{
+ LOG("Cannot connect to %s: %s", m_ServerName.c_str(), a_ErrorMsg.c_str());
+ m_Event.Set();
+}
+
+
+
+
+
+void cBlockingSslClientSocket::OnReceivedData(const char * a_Data, size_t a_Size)
+{
+ {
+ cCSLock Lock(m_CSIncomingData);
+ m_IncomingData.append(a_Data, a_Size);
+ }
+ m_Event.Set();
+}
+
+
+
+
+
+void cBlockingSslClientSocket::SetLink(cTCPLinkPtr a_Link)
+{
+ m_Socket = a_Link;
+}
+
+
+
+
+
+void cBlockingSslClientSocket::OnDisconnected(void)
+{
+ m_Socket.reset();
+ m_IsConnected = false;
+ m_Event.Set();
}
diff --git a/src/PolarSSL++/BlockingSslClientSocket.h b/src/PolarSSL++/BlockingSslClientSocket.h
index 7af897582..319e82bf2 100644
--- a/src/PolarSSL++/BlockingSslClientSocket.h
+++ b/src/PolarSSL++/BlockingSslClientSocket.h
@@ -9,8 +9,8 @@
#pragma once
+#include "OSSupport/Network.h"
#include "CallbackSslContext.h"
-#include "../OSSupport/Socket.h"
@@ -51,25 +51,56 @@ public:
const AString & GetLastErrorText(void) const { return m_LastErrorText; }
protected:
+ friend class cBlockingSslClientSocketConnectCallbacks;
+ friend class cBlockingSslClientSocketLinkCallbacks;
+
/** The SSL context used for the socket */
cCallbackSslContext m_Ssl;
/** The underlying socket to the SSL server */
- cSocket m_Socket;
+ cTCPLinkPtr m_Socket;
+
+ /** The object used to signal state changes in the socket (the cause of the blocking). */
+ cEvent m_Event;
/** The trusted CA root cert store, if we are to verify the cert strictly. Set by SetTrustedRootCertsFromString(). */
cX509CertPtr m_CACerts;
/** The expected SSL peer's name, if we are to verify the cert strictly. Set by SetTrustedRootCertsFromString(). */
AString m_ExpectedPeerName;
+
+ /** The hostname to which the socket is connecting (stored for error reporting). */
+ AString m_ServerName;
/** Text of the last error that has occurred. */
AString m_LastErrorText;
/** Set to true if the connection established successfully. */
bool m_IsConnected;
+
+ /** Protects m_IncomingData against multithreaded access. */
+ cCriticalSection m_CSIncomingData;
+
+ /** Buffer for the data incoming on the network socket.
+ Protected by m_CSIncomingData. */
+ AString m_IncomingData;
+ /** Called when the connection is established successfully. */
+ void OnConnected(void);
+
+ /** Called when an error occurs while connecting the socket. */
+ void OnConnectError(const AString & a_ErrorMsg);
+
+ /** Called when there's incoming data from the socket. */
+ void OnReceivedData(const char * a_Data, size_t a_Size);
+
+ /** Called when the link for the connection is created. */
+ void SetLink(cTCPLinkPtr a_Link);
+
+ /** Called when the link is disconnected, either gracefully or by an error. */
+ void OnDisconnected(void);
+
// cCallbackSslContext::cDataCallbacks overrides:
virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override;
virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override;
diff --git a/src/PolarSSL++/SslContext.cpp b/src/PolarSSL++/SslContext.cpp
index 902267f90..66dfefc65 100644
--- a/src/PolarSSL++/SslContext.cpp
+++ b/src/PolarSSL++/SslContext.cpp
@@ -70,7 +70,7 @@ int cSslContext::Initialize(bool a_IsClient, const SharedPtr<cCtrDrbgContext> &
// so they're disabled until someone needs them
ssl_set_dbg(&m_Ssl, &SSLDebugMessage, this);
ssl_set_verify(&m_Ssl, &SSLVerifyCert, this);
- */
+ //*/
/*
// Set ciphersuite to the easiest one to decode, so that the connection can be wireshark-decoded: