summaryrefslogtreecommitdiffstats
path: root/src/mbedTLS++
diff options
context:
space:
mode:
Diffstat (limited to 'src/mbedTLS++')
-rw-r--r--src/mbedTLS++/AesCfb128Decryptor.cpp55
-rw-r--r--src/mbedTLS++/AesCfb128Decryptor.h48
-rw-r--r--src/mbedTLS++/AesCfb128Encryptor.cpp55
-rw-r--r--src/mbedTLS++/AesCfb128Encryptor.h47
-rw-r--r--src/mbedTLS++/BlockingSslClientSocket.cpp377
-rw-r--r--src/mbedTLS++/BlockingSslClientSocket.h119
-rw-r--r--src/mbedTLS++/BufferedSslContext.cpp93
-rw-r--r--src/mbedTLS++/BufferedSslContext.h53
-rw-r--r--src/mbedTLS++/CMakeLists.txt44
-rw-r--r--src/mbedTLS++/CallbackSslContext.cpp60
-rw-r--r--src/mbedTLS++/CallbackSslContext.h65
-rw-r--r--src/mbedTLS++/CryptoKey.cpp149
-rw-r--r--src/mbedTLS++/CryptoKey.h76
-rw-r--r--src/mbedTLS++/CtrDrbgContext.cpp51
-rw-r--r--src/mbedTLS++/CtrDrbgContext.h63
-rw-r--r--src/mbedTLS++/EntropyContext.cpp29
-rw-r--r--src/mbedTLS++/EntropyContext.h31
-rw-r--r--src/mbedTLS++/ErrorCodes.h18
-rw-r--r--src/mbedTLS++/RsaPrivateKey.cpp174
-rw-r--r--src/mbedTLS++/RsaPrivateKey.h67
-rw-r--r--src/mbedTLS++/Sha1Checksum.cpp138
-rw-r--r--src/mbedTLS++/Sha1Checksum.h52
-rw-r--r--src/mbedTLS++/SslConfig.cpp287
-rw-r--r--src/mbedTLS++/SslConfig.h93
-rw-r--r--src/mbedTLS++/SslContext.cpp157
-rw-r--r--src/mbedTLS++/SslContext.h124
-rw-r--r--src/mbedTLS++/X509Cert.cpp38
-rw-r--r--src/mbedTLS++/X509Cert.h41
28 files changed, 2604 insertions, 0 deletions
diff --git a/src/mbedTLS++/AesCfb128Decryptor.cpp b/src/mbedTLS++/AesCfb128Decryptor.cpp
new file mode 100644
index 000000000..78a7ab9c5
--- /dev/null
+++ b/src/mbedTLS++/AesCfb128Decryptor.cpp
@@ -0,0 +1,55 @@
+
+// AesCfb128Decryptor.cpp
+
+// Implements the cAesCfb128Decryptor class decrypting data using AES CFB-128
+
+#include "Globals.h"
+#include "AesCfb128Decryptor.h"
+
+
+
+
+
+cAesCfb128Decryptor::cAesCfb128Decryptor(void):
+ m_IsValid(false)
+{
+ mbedtls_aes_init(&m_Aes);
+}
+
+
+
+
+
+cAesCfb128Decryptor::~cAesCfb128Decryptor()
+{
+ // Clear the leftover in-memory data, so that they can't be accessed by a backdoor
+ mbedtls_aes_free(&m_Aes);
+}
+
+
+
+
+
+void cAesCfb128Decryptor::Init(const Byte a_Key[16], const Byte a_IV[16])
+{
+ ASSERT(!IsValid()); // Cannot Init twice
+
+ memcpy(m_IV, a_IV, 16);
+ mbedtls_aes_setkey_enc(&m_Aes, a_Key, 128);
+ m_IsValid = true;
+}
+
+
+
+
+
+void cAesCfb128Decryptor::ProcessData(Byte * a_DecryptedOut, const Byte * a_EncryptedIn, size_t a_Length)
+{
+ ASSERT(IsValid()); // Must Init() first
+ mbedtls_aes_crypt_cfb8(&m_Aes, MBEDTLS_AES_DECRYPT, a_Length, m_IV, a_EncryptedIn, a_DecryptedOut);
+}
+
+
+
+
+
diff --git a/src/mbedTLS++/AesCfb128Decryptor.h b/src/mbedTLS++/AesCfb128Decryptor.h
new file mode 100644
index 000000000..54c5536ea
--- /dev/null
+++ b/src/mbedTLS++/AesCfb128Decryptor.h
@@ -0,0 +1,48 @@
+
+// AesCfb128Decryptor.h
+
+// Declares the cAesCfb128Decryptor class decrypting data using AES CFB-128
+
+
+
+
+
+#pragma once
+
+#include "mbedtls/aes.h"
+
+
+
+
+
+/** Decrypts data using the AES / CFB 128 algorithm */
+class cAesCfb128Decryptor
+{
+public:
+
+ cAesCfb128Decryptor(void);
+ ~cAesCfb128Decryptor();
+
+ /** Initializes the decryptor with the specified Key / IV */
+ void Init(const Byte a_Key[16], const Byte a_IV[16]);
+
+ /** Decrypts a_Length bytes of the encrypted data; produces a_Length output bytes */
+ void ProcessData(Byte * a_DecryptedOut, const Byte * a_EncryptedIn, size_t a_Length);
+
+ /** Returns true if the object has been initialized with the Key / IV */
+ bool IsValid(void) const { return m_IsValid; }
+
+protected:
+ mbedtls_aes_context m_Aes;
+
+ /** The InitialVector, used by the CFB mode decryption */
+ Byte m_IV[16];
+
+ /** Indicates whether the object has been initialized with the Key / IV */
+ bool m_IsValid;
+} ;
+
+
+
+
+
diff --git a/src/mbedTLS++/AesCfb128Encryptor.cpp b/src/mbedTLS++/AesCfb128Encryptor.cpp
new file mode 100644
index 000000000..11582fc19
--- /dev/null
+++ b/src/mbedTLS++/AesCfb128Encryptor.cpp
@@ -0,0 +1,55 @@
+
+// AesCfb128Encryptor.cpp
+
+// Implements the cAesCfb128Encryptor class encrypting data using AES CFB-128
+
+#include "Globals.h"
+#include "AesCfb128Encryptor.h"
+
+
+
+
+
+cAesCfb128Encryptor::cAesCfb128Encryptor(void):
+ m_IsValid(false)
+{
+ mbedtls_aes_init(&m_Aes);
+}
+
+
+
+
+
+cAesCfb128Encryptor::~cAesCfb128Encryptor()
+{
+ // Clear the leftover in-memory data, so that they can't be accessed by a backdoor
+ mbedtls_aes_free(&m_Aes);
+}
+
+
+
+
+
+void cAesCfb128Encryptor::Init(const Byte a_Key[16], const Byte a_IV[16])
+{
+ ASSERT(!IsValid()); // Cannot Init twice
+
+ memcpy(m_IV, a_IV, 16);
+ mbedtls_aes_setkey_enc(&m_Aes, a_Key, 128);
+ m_IsValid = true;
+}
+
+
+
+
+
+void cAesCfb128Encryptor::ProcessData(Byte * a_EncryptedOut, const Byte * a_PlainIn, size_t a_Length)
+{
+ ASSERT(IsValid()); // Must Init() first
+ mbedtls_aes_crypt_cfb8(&m_Aes, MBEDTLS_AES_ENCRYPT, a_Length, m_IV, a_PlainIn, a_EncryptedOut);
+}
+
+
+
+
+
diff --git a/src/mbedTLS++/AesCfb128Encryptor.h b/src/mbedTLS++/AesCfb128Encryptor.h
new file mode 100644
index 000000000..6bfa6b5c9
--- /dev/null
+++ b/src/mbedTLS++/AesCfb128Encryptor.h
@@ -0,0 +1,47 @@
+
+// AesCfb128Encryptor.h
+
+// Declares the cAesCfb128Encryptor class encrypting data using AES CFB-128
+
+
+
+
+
+#pragma once
+
+#include "mbedtls/aes.h"
+
+
+
+
+
+/** Encrypts data using the AES / CFB (128) algorithm */
+class cAesCfb128Encryptor
+{
+public:
+ cAesCfb128Encryptor(void);
+ ~cAesCfb128Encryptor();
+
+ /** Initializes the decryptor with the specified Key / IV */
+ void Init(const Byte a_Key[16], const Byte a_IV[16]);
+
+ /** Encrypts a_Length bytes of the plain data; produces a_Length output bytes */
+ void ProcessData(Byte * a_EncryptedOut, const Byte * a_PlainIn, size_t a_Length);
+
+ /** Returns true if the object has been initialized with the Key / IV */
+ bool IsValid(void) const { return m_IsValid; }
+
+protected:
+ mbedtls_aes_context m_Aes;
+
+ /** The InitialVector, used by the CFB mode encryption */
+ Byte m_IV[16];
+
+ /** Indicates whether the object has been initialized with the Key / IV */
+ bool m_IsValid;
+} ;
+
+
+
+
+
diff --git a/src/mbedTLS++/BlockingSslClientSocket.cpp b/src/mbedTLS++/BlockingSslClientSocket.cpp
new file mode 100644
index 000000000..6f765f607
--- /dev/null
+++ b/src/mbedTLS++/BlockingSslClientSocket.cpp
@@ -0,0 +1,377 @@
+
+// BlockingSslClientSocket.cpp
+
+// Implements the cBlockingSslClientSocket class representing a blocking TCP socket with client SSL encryption over it
+
+#include "Globals.h"
+#include "BlockingSslClientSocket.h"
+
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
+// cBlockingSslClientSocketConnectCallbacks:
+
+class cBlockingSslClientSocketConnectCallbacks:
+ public cNetwork::cConnectCallbacks
+{
+ /** The socket object that is using this instance of the callbacks. */
+ cBlockingSslClientSocket & m_Socket;
+
+ virtual void OnConnected(cTCPLink & a_Link) override
+ {
+ m_Socket.OnConnected();
+ }
+
+ virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
+ {
+ m_Socket.OnConnectError(a_ErrorMsg);
+ }
+
+public:
+ cBlockingSslClientSocketConnectCallbacks(cBlockingSslClientSocket & a_Socket):
+ m_Socket(a_Socket)
+ {
+ }
+};
+
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
+// cBlockingSslClientSocketLinkCallbacks:
+
+class cBlockingSslClientSocketLinkCallbacks:
+ public cTCPLink::cCallbacks
+{
+ cBlockingSslClientSocket & m_Socket;
+
+ virtual void OnLinkCreated(cTCPLinkPtr a_Link) override
+ {
+ m_Socket.SetLink(a_Link);
+ }
+
+
+ virtual void OnReceivedData(const char * a_Data, size_t a_Length) override
+ {
+ m_Socket.OnReceivedData(a_Data, a_Length);
+ }
+
+
+ virtual void OnRemoteClosed(void) override
+ {
+ m_Socket.OnDisconnected();
+ }
+
+
+ virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
+ {
+ m_Socket.OnDisconnected();
+ }
+
+public:
+
+ cBlockingSslClientSocketLinkCallbacks(cBlockingSslClientSocket & a_Socket):
+ m_Socket(a_Socket)
+ {
+ }
+};
+
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
+// cBlockingSslClientSocket:
+
+cBlockingSslClientSocket::cBlockingSslClientSocket(void) :
+ m_Ssl(*this),
+ m_IsConnected(false)
+{
+ // Nothing needed yet
+}
+
+
+
+
+
+bool cBlockingSslClientSocket::Connect(const AString & a_ServerName, UInt16 a_Port)
+{
+ // If already connected, report an error:
+ if (m_IsConnected)
+ {
+ // TODO: Handle this better - if connected to the same server and port, and the socket is alive, return success
+ m_LastErrorText = "Already connected";
+ return false;
+ }
+
+ // Connect the underlying socket:
+ m_ServerName = a_ServerName;
+ if (!cNetwork::Connect(a_ServerName, a_Port,
+ std::make_shared<cBlockingSslClientSocketConnectCallbacks>(*this),
+ std::make_shared<cBlockingSslClientSocketLinkCallbacks>(*this))
+ )
+ {
+ return false;
+ }
+
+ // Wait for the connection to succeed or fail:
+ m_Event.Wait();
+ if (!m_IsConnected)
+ {
+ return false;
+ }
+
+ // Initialize the SSL:
+ int ret = 0;
+ if (m_Config != nullptr)
+ {
+ ret = m_Ssl.Initialize(m_Config);
+ }
+ else
+ {
+ ret = m_Ssl.Initialize(true);
+ }
+
+ if (ret != 0)
+ {
+ Printf(m_LastErrorText, "SSL initialization failed: -0x%x", -ret);
+ return false;
+ }
+
+ // If we have been assigned a trusted CA root cert store, push it into the SSL context:
+ if (!m_ExpectedPeerName.empty())
+ {
+ m_Ssl.SetExpectedPeerName(m_ExpectedPeerName);
+ }
+
+ ret = m_Ssl.Handshake();
+ if (ret != 0)
+ {
+ Printf(m_LastErrorText, "SSL handshake failed: -0x%x", -ret);
+ return false;
+ }
+
+ return true;
+}
+
+
+
+
+
+
+void cBlockingSslClientSocket::SetExpectedPeerName(AString a_ExpectedPeerName)
+{
+ ASSERT(!m_IsConnected); // Must be called before connect
+
+ // Warn if used multiple times, but don't signal an error:
+ if (!m_ExpectedPeerName.empty())
+ {
+ LOGWARNING(
+ "SSL: Trying to set multiple expected peer names, only the last one will be used. Name: %s",
+ a_ExpectedPeerName.c_str()
+ );
+ }
+
+ m_ExpectedPeerName = std::move(a_ExpectedPeerName);
+}
+
+
+
+
+
+void cBlockingSslClientSocket::SetSslConfig(std::shared_ptr<const cSslConfig> a_Config)
+{
+ ASSERT(!m_IsConnected); // Must be called before connect
+
+ // Warn if used multiple times, but don't signal an error:
+ if (m_Config != nullptr)
+ {
+ LOGWARNING("SSL: Trying to set multiple configurations, only the last one will be used.");
+ }
+
+ m_Config = std::move(a_Config);
+}
+
+
+
+
+
+bool cBlockingSslClientSocket::Send(const void * a_Data, size_t a_NumBytes)
+{
+ if (!m_IsConnected)
+ {
+ m_LastErrorText = "Socket is closed";
+ return false;
+ }
+
+ // Keep sending the data until all of it is sent:
+ const char * Data = reinterpret_cast<const char *>(a_Data);
+ size_t NumBytes = a_NumBytes;
+ for (;;)
+ {
+ int res = m_Ssl.WritePlain(Data, a_NumBytes);
+ if (res < 0)
+ {
+ ASSERT(res != MBEDTLS_ERR_SSL_WANT_READ); // This should never happen with callback-based SSL
+ ASSERT(res != MBEDTLS_ERR_SSL_WANT_WRITE); // This should never happen with callback-based SSL
+ Printf(m_LastErrorText, "Data cannot be written to SSL context: -0x%x", -res);
+ return false;
+ }
+ else
+ {
+ Data += res;
+ NumBytes -= static_cast<size_t>(res);
+ if (NumBytes == 0)
+ {
+ return true;
+ }
+ }
+ }
+}
+
+
+
+
+
+
+int cBlockingSslClientSocket::Receive(void * a_Data, size_t a_MaxBytes)
+{
+ // Even if m_IsConnected is false (socket disconnected), the SSL context may have more data in the queue
+ int res = m_Ssl.ReadPlain(a_Data, a_MaxBytes);
+ if (res < 0)
+ {
+ Printf(m_LastErrorText, "Data cannot be read form SSL context: -0x%x", -res);
+ }
+ return res;
+}
+
+
+
+
+
+void cBlockingSslClientSocket::Disconnect(void)
+{
+ // Ignore if not connected
+ if (!m_IsConnected)
+ {
+ return;
+ }
+
+ m_Ssl.NotifyClose();
+ m_IsConnected = false;
+
+ // Grab a copy of the socket so that we know it doesn't change under our hands:
+ auto socket = m_Socket;
+ if (socket != nullptr)
+ {
+ socket->Close();
+ }
+
+ m_Socket.reset();
+}
+
+
+
+
+
+int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
+{
+ // Wait for any incoming data, if there is none:
+ cCSLock Lock(m_CSIncomingData);
+ while (m_IsConnected && m_IncomingData.empty())
+ {
+ cCSUnlock Unlock(Lock);
+ m_Event.Wait();
+ }
+
+ // If we got disconnected, report an error after processing all data:
+ if (!m_IsConnected && m_IncomingData.empty())
+ {
+ return MBEDTLS_ERR_NET_RECV_FAILED;
+ }
+
+ // Copy the data from the incoming buffer into the specified space:
+ size_t NumToCopy = std::min(a_NumBytes, m_IncomingData.size());
+ memcpy(a_Buffer, m_IncomingData.data(), NumToCopy);
+ m_IncomingData.erase(0, NumToCopy);
+ return static_cast<int>(NumToCopy);
+}
+
+
+
+
+
+int cBlockingSslClientSocket::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
+{
+ cTCPLinkPtr Socket(m_Socket); // Make a copy so that multiple threads don't race on deleting the socket.
+ if (Socket == nullptr)
+ {
+ return MBEDTLS_ERR_NET_SEND_FAILED;
+ }
+ if (!Socket->Send(a_Buffer, a_NumBytes))
+ {
+ // mbedTLS's net routines distinguish between connection reset and general failure, we don't need to
+ return MBEDTLS_ERR_NET_SEND_FAILED;
+ }
+ return static_cast<int>(a_NumBytes);
+}
+
+
+
+
+
+void cBlockingSslClientSocket::OnConnected(void)
+{
+ m_IsConnected = true;
+ m_Event.Set();
+}
+
+
+
+
+
+void cBlockingSslClientSocket::OnConnectError(const AString & a_ErrorMsg)
+{
+ LOG("Cannot connect to %s: \"%s\"", m_ServerName.c_str(), a_ErrorMsg.c_str());
+ m_Event.Set();
+}
+
+
+
+
+
+void cBlockingSslClientSocket::OnReceivedData(const char * a_Data, size_t a_Size)
+{
+ {
+ cCSLock Lock(m_CSIncomingData);
+ m_IncomingData.append(a_Data, a_Size);
+ }
+ m_Event.Set();
+}
+
+
+
+
+
+void cBlockingSslClientSocket::SetLink(cTCPLinkPtr a_Link)
+{
+ m_Socket = a_Link;
+}
+
+
+
+
+
+void cBlockingSslClientSocket::OnDisconnected(void)
+{
+ m_IsConnected = false;
+ m_Socket.reset();
+ m_Event.Set();
+}
+
+
+
+
diff --git a/src/mbedTLS++/BlockingSslClientSocket.h b/src/mbedTLS++/BlockingSslClientSocket.h
new file mode 100644
index 000000000..24ee32680
--- /dev/null
+++ b/src/mbedTLS++/BlockingSslClientSocket.h
@@ -0,0 +1,119 @@
+
+// BlockingSslClientSocket.h
+
+// Declares the cBlockingSslClientSocket class representing a blocking TCP socket with client SSL encryption over it
+
+
+
+
+
+#pragma once
+
+#include "OSSupport/Network.h"
+#include "CallbackSslContext.h"
+
+
+
+
+
+class cBlockingSslClientSocket :
+ protected cCallbackSslContext::cDataCallbacks
+{
+public:
+ cBlockingSslClientSocket(void);
+
+ virtual ~cBlockingSslClientSocket(void) override
+ {
+ Disconnect();
+ }
+
+ /** Connects to the specified server and performs SSL handshake.
+ Returns true if successful, false on failure. Sets internal error text on failure. */
+ bool Connect(const AString & a_ServerName, UInt16 a_Port);
+
+ /** Sends the specified data over the connection.
+ Returns true if successful, false on failure. Sets the internal error text on failure. */
+ bool Send(const void * a_Data, size_t a_NumBytes);
+
+ /** Receives data from the connection.
+ Blocks until there is any data available, then returns as much as possible.
+ Returns the number of bytes actually received, negative number on failure.
+ Sets the internal error text on failure. */
+ int Receive(void * a_Data, size_t a_MaxBytes);
+
+ /** Disconnects the connection gracefully, if possible.
+ Note that this also frees the internal SSL context, so all the certificates etc. are lost. */
+ void Disconnect(void);
+
+ /** Sets the Expected peer name.
+ Needs to be used before calling Connect().
+ \param a_ExpectedPeerName Name that we expect to receive in the SSL peer's cert; verification will fail if
+ the presented name is different (possible MITM). */
+ void SetExpectedPeerName(AString a_ExpectedPeerName);
+
+ /** Set the config to be used by the SSL context.
+ Config must not be modified after calling connect. */
+ void SetSslConfig(std::shared_ptr<const cSslConfig> a_Config);
+
+ /** Returns the text of the last error that has occurred in this instance. */
+ const AString & GetLastErrorText(void) const { return m_LastErrorText; }
+
+protected:
+ friend class cBlockingSslClientSocketConnectCallbacks;
+ friend class cBlockingSslClientSocketLinkCallbacks;
+
+ /** The SSL context used for the socket */
+ cCallbackSslContext m_Ssl;
+
+ /** The underlying socket to the SSL server */
+ cTCPLinkPtr m_Socket;
+
+ /** The object used to signal state changes in the socket (the cause of the blocking). */
+ cEvent m_Event;
+
+ /** The configuration to be used by the SSL context. Set by SetSslConfig(). */
+ std::shared_ptr<const cSslConfig> m_Config;
+
+ /** The expected SSL peer's name, if we are to verify the cert strictly. Set by SetExpectedPeerName(). */
+ AString m_ExpectedPeerName;
+
+ /** The hostname to which the socket is connecting (stored for error reporting). */
+ AString m_ServerName;
+
+ /** Text of the last error that has occurred. */
+ AString m_LastErrorText;
+
+ /** Set to true if the connection established successfully. */
+ std::atomic<bool> m_IsConnected;
+
+ /** Protects m_IncomingData against multithreaded access. */
+ cCriticalSection m_CSIncomingData;
+
+ /** Buffer for the data incoming on the network socket.
+ Protected by m_CSIncomingData. */
+ AString m_IncomingData;
+
+
+ /** Called when the connection is established successfully. */
+ void OnConnected(void);
+
+ /** Called when an error occurs while connecting the socket. */
+ void OnConnectError(const AString & a_ErrorMsg);
+
+ /** Called when there's incoming data from the socket. */
+ void OnReceivedData(const char * a_Data, size_t a_Size);
+
+ /** Called when the link for the connection is created. */
+ void SetLink(cTCPLinkPtr a_Link);
+
+ /** Called when the link is disconnected, either gracefully or by an error. */
+ void OnDisconnected(void);
+
+ // cCallbackSslContext::cDataCallbacks overrides:
+ virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override;
+ virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override;
+} ;
+
+
+
+
diff --git a/src/mbedTLS++/BufferedSslContext.cpp b/src/mbedTLS++/BufferedSslContext.cpp
new file mode 100644
index 000000000..5cdf04323
--- /dev/null
+++ b/src/mbedTLS++/BufferedSslContext.cpp
@@ -0,0 +1,93 @@
+
+// BufferedSslContext.cpp
+
+// Implements the cBufferedSslContext class representing a SSL context with the SSL peer data backed by a cByteBuffer
+
+#include "Globals.h"
+#include "BufferedSslContext.h"
+
+
+
+
+
+cBufferedSslContext::cBufferedSslContext(size_t a_BufferSize):
+ m_OutgoingData(a_BufferSize),
+ m_IncomingData(a_BufferSize)
+{
+}
+
+
+
+
+
+size_t cBufferedSslContext::WriteIncoming(const void * a_Data, size_t a_NumBytes)
+{
+ size_t NumBytes = std::min(m_IncomingData.GetFreeSpace(), a_NumBytes);
+ if (NumBytes > 0)
+ {
+ m_IncomingData.Write(a_Data, NumBytes);
+ return NumBytes;
+ }
+ return 0;
+}
+
+
+
+
+
+size_t cBufferedSslContext::ReadOutgoing(void * a_Data, size_t a_DataMaxSize)
+{
+ size_t NumBytes = std::min(m_OutgoingData.GetReadableSpace(), a_DataMaxSize);
+ if (NumBytes > 0)
+ {
+ m_OutgoingData.ReadBuf(a_Data, NumBytes);
+ m_OutgoingData.CommitRead();
+ return NumBytes;
+ }
+ return 0;
+}
+
+
+
+
+
+int cBufferedSslContext::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
+{
+ // Called when mbedTLS wants to read encrypted data from the SSL peer
+ // Read the data from the buffer inside this object, where the owner has stored them using WriteIncoming():
+ size_t NumBytes = std::min(a_NumBytes, m_IncomingData.GetReadableSpace());
+ if (NumBytes == 0)
+ {
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ }
+ if (!m_IncomingData.ReadBuf(a_Buffer, NumBytes))
+ {
+ m_IncomingData.ResetRead();
+ return MBEDTLS_ERR_NET_RECV_FAILED;
+ }
+ m_IncomingData.CommitRead();
+ return static_cast<int>(NumBytes);
+}
+
+
+
+
+
+int cBufferedSslContext::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
+{
+ // Called when mbedTLS wants to write encrypted data to the SSL peer
+ // Write the data into the buffer inside this object, where the owner can later read them using ReadOutgoing():
+ if (!m_OutgoingData.CanWriteBytes(a_NumBytes))
+ {
+ return MBEDTLS_ERR_SSL_WANT_WRITE;
+ }
+ if (!m_OutgoingData.Write(reinterpret_cast<const char *>(a_Buffer), a_NumBytes))
+ {
+ return MBEDTLS_ERR_NET_SEND_FAILED;
+ }
+ return static_cast<int>(a_NumBytes);
+}
+
+
+
+
diff --git a/src/mbedTLS++/BufferedSslContext.h b/src/mbedTLS++/BufferedSslContext.h
new file mode 100644
index 000000000..9c9dd8f73
--- /dev/null
+++ b/src/mbedTLS++/BufferedSslContext.h
@@ -0,0 +1,53 @@
+
+// BufferedSslContext.h
+
+// Declares the cBufferedSslContext class representing a SSL context with the SSL peer data backed by a cByteBuffer
+
+
+
+
+
+#pragma once
+
+#include "SslContext.h"
+#include "ErrorCodes.h"
+
+
+
+
+
+class cBufferedSslContext :
+ public cSslContext
+{
+ typedef cSslContext super;
+
+public:
+ /** Creates a new context with the buffers of specified size for the encrypted / decrypted data. */
+ cBufferedSslContext(size_t a_BufferSize = 64000);
+
+ /** Stores the specified data in the "incoming" buffer, to be process by the SSL decryptor.
+ This is the data received from the SSL peer.
+ Returns the number of bytes actually stored. If 0 is returned, owner should check the error state. */
+ size_t WriteIncoming(const void * a_Data, size_t a_NumBytes);
+
+ /** Retrieves data from the "outgoing" buffer, after being processed by the SSL encryptor.
+ This is the data to be sent to the SSL peer.
+ Returns the number of bytes actually retrieved. */
+ size_t ReadOutgoing(void * a_Data, size_t a_DataMaxSize);
+
+protected:
+ /** Buffer for the data that has been encrypted into the SSL stream and should be sent out. */
+ cByteBuffer m_OutgoingData;
+
+ /** Buffer for the data that has come in and needs to be decrypted from the SSL stream. */
+ cByteBuffer m_IncomingData;
+
+
+ // cSslContext overrides:
+ virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override;
+ virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override;
+} ;
+
+
+
+
diff --git a/src/mbedTLS++/CMakeLists.txt b/src/mbedTLS++/CMakeLists.txt
new file mode 100644
index 000000000..18ef22312
--- /dev/null
+++ b/src/mbedTLS++/CMakeLists.txt
@@ -0,0 +1,44 @@
+project (Cuberite)
+
+include_directories ("${PROJECT_SOURCE_DIR}/../")
+
+set(SRCS
+ AesCfb128Decryptor.cpp
+ AesCfb128Encryptor.cpp
+ BlockingSslClientSocket.cpp
+ BufferedSslContext.cpp
+ CallbackSslContext.cpp
+ CtrDrbgContext.cpp
+ CryptoKey.cpp
+ EntropyContext.cpp
+ RsaPrivateKey.cpp
+ Sha1Checksum.cpp
+ SslConfig.cpp
+ SslContext.cpp
+ X509Cert.cpp
+)
+
+set(HDRS
+ AesCfb128Decryptor.h
+ AesCfb128Encryptor.h
+ BlockingSslClientSocket.h
+ BufferedSslContext.h
+ CallbackSslContext.h
+ CtrDrbgContext.h
+ CryptoKey.h
+ EntropyContext.h
+ ErrorCodes.h
+ RsaPrivateKey.h
+ SslConfig.h
+ SslContext.h
+ Sha1Checksum.h
+ X509Cert.h
+)
+
+if(NOT MSVC)
+ add_library(mbedTLS++ ${SRCS} ${HDRS})
+
+ if (UNIX)
+ target_link_libraries(mbedTLS++ mbedtls)
+ endif()
+endif()
diff --git a/src/mbedTLS++/CallbackSslContext.cpp b/src/mbedTLS++/CallbackSslContext.cpp
new file mode 100644
index 000000000..26bcec2ff
--- /dev/null
+++ b/src/mbedTLS++/CallbackSslContext.cpp
@@ -0,0 +1,60 @@
+
+// CallbackSslContext.cpp
+
+// Declares the cCallbackSslContext class representing a SSL context wrapper that uses callbacks to read and write SSL peer data
+
+#include "Globals.h"
+#include "CallbackSslContext.h"
+
+
+
+
+
+
+cCallbackSslContext::cCallbackSslContext(void) :
+ m_Callbacks(nullptr)
+{
+ // Nothing needed, but the constructor needs to exist so
+}
+
+
+
+
+
+cCallbackSslContext::cCallbackSslContext(cCallbackSslContext::cDataCallbacks & a_Callbacks) :
+ m_Callbacks(&a_Callbacks)
+{
+}
+
+
+
+
+
+int cCallbackSslContext::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
+{
+ if (m_Callbacks == nullptr)
+ {
+ LOGWARNING("SSL: Trying to receive data with no callbacks, aborting.");
+ return MBEDTLS_ERR_NET_RECV_FAILED;
+ }
+ return m_Callbacks->ReceiveEncrypted(a_Buffer, a_NumBytes);
+}
+
+
+
+
+
+int cCallbackSslContext::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
+{
+ if (m_Callbacks == nullptr)
+ {
+ LOGWARNING("SSL: Trying to send data with no callbacks, aborting.");
+ return MBEDTLS_ERR_NET_SEND_FAILED;
+ }
+ return m_Callbacks->SendEncrypted(a_Buffer, a_NumBytes);
+}
+
+
+
+
+
diff --git a/src/mbedTLS++/CallbackSslContext.h b/src/mbedTLS++/CallbackSslContext.h
new file mode 100644
index 000000000..da1abb707
--- /dev/null
+++ b/src/mbedTLS++/CallbackSslContext.h
@@ -0,0 +1,65 @@
+
+// CallbackSslContext.h
+
+// Declares the cCallbackSslContext class representing a SSL context wrapper that uses callbacks to read and write SSL peer data
+
+
+
+
+
+#pragma once
+
+#include "SslContext.h"
+#include "ErrorCodes.h"
+
+
+
+
+
+class cCallbackSslContext :
+ public cSslContext
+{
+public:
+ /** Interface used as a data sink for the SSL peer data. */
+ class cDataCallbacks
+ {
+ public:
+ // Force a virtual destructor in descendants:
+ virtual ~cDataCallbacks() {}
+
+ /** Called when mbedTLS wants to read encrypted data from the SSL peer.
+ The returned value is the number of bytes received, or a mbedTLS error on failure.
+ The implementation can return MBEDTLS_ERR_SSL_WANT_READ or MBEDTLS_ERR_SSL_WANT_WRITE to indicate
+ that there's currently no more data and that there might be more data in the future. In such cases the
+ SSL operation that invoked this call will terminate with the same return value, so that the owner is
+ notified of this condition and can potentially restart the operation later on. */
+ virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) = 0;
+
+ /** Called when mbedTLS wants to write encrypted data to the SSL peer.
+ The returned value is the number of bytes sent, or a mbedTLS error on failure.
+ The implementation can return MBEDTLS_ERR_SSL_WANT_READ or MBEDTLS_ERR_SSL_WANT_WRITE to indicate
+ that there's currently no more data and that there might be more data in the future. In such cases the
+ SSL operation that invoked this call will terminate with the same return value, so that the owner is
+ notified of this condition and can potentially restart the operation later on. */
+ virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) = 0;
+ } ;
+
+
+ /** Creates a new SSL context with no callbacks assigned */
+ cCallbackSslContext(void);
+
+ /** Creates a new SSL context with the specified callbacks */
+ cCallbackSslContext(cDataCallbacks & a_Callbacks);
+
+protected:
+ /** The callbacks to use to send and receive SSL peer data */
+ cDataCallbacks * m_Callbacks;
+
+ // cSslContext overrides:
+ virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override;
+ virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override;
+};
+
+
+
+
diff --git a/src/mbedTLS++/CryptoKey.cpp b/src/mbedTLS++/CryptoKey.cpp
new file mode 100644
index 000000000..4ebb0f300
--- /dev/null
+++ b/src/mbedTLS++/CryptoKey.cpp
@@ -0,0 +1,149 @@
+
+// CryptoKey.cpp
+
+// Implements the cCryptoKey class representing a RSA public key in mbedTLS
+
+#include "Globals.h"
+#include "CryptoKey.h"
+
+
+
+
+
+cCryptoKey::cCryptoKey(void)
+{
+ mbedtls_pk_init(&m_Pk);
+ m_CtrDrbg.Initialize("rsa_pubkey", 10);
+}
+
+
+
+
+
+cCryptoKey::cCryptoKey(const AString & a_PublicKeyData)
+{
+ mbedtls_pk_init(&m_Pk);
+ m_CtrDrbg.Initialize("rsa_pubkey", 10);
+ int res = ParsePublic(a_PublicKeyData.data(), a_PublicKeyData.size());
+ if (res != 0)
+ {
+ LOGWARNING("Failed to parse public key: -0x%x", res);
+ ASSERT(!"Cannot parse PubKey");
+ return;
+ }
+}
+
+
+
+
+
+cCryptoKey::cCryptoKey(const AString & a_PrivateKeyData, const AString & a_Password)
+{
+ mbedtls_pk_init(&m_Pk);
+ m_CtrDrbg.Initialize("rsa_privkey", 11);
+ int res = ParsePrivate(a_PrivateKeyData.data(), a_PrivateKeyData.size(), a_Password);
+ if (res != 0)
+ {
+ LOGWARNING("Failed to parse private key: -0x%x", res);
+ ASSERT(!"Cannot parse PrivKey");
+ return;
+ }
+}
+
+
+
+
+
+cCryptoKey::~cCryptoKey()
+{
+ mbedtls_pk_free(&m_Pk);
+}
+
+
+
+
+
+int cCryptoKey::Decrypt(const Byte * a_EncryptedData, size_t a_EncryptedLength, Byte * a_DecryptedData, size_t a_DecryptedMaxLength)
+{
+ ASSERT(IsValid());
+
+ size_t DecryptedLen = a_DecryptedMaxLength;
+ int res = mbedtls_pk_decrypt(&m_Pk,
+ a_EncryptedData, a_EncryptedLength,
+ a_DecryptedData, &DecryptedLen, a_DecryptedMaxLength,
+ mbedtls_ctr_drbg_random, m_CtrDrbg.GetInternal()
+ );
+ if (res != 0)
+ {
+ return res;
+ }
+ return static_cast<int>(DecryptedLen);
+}
+
+
+
+
+
+int cCryptoKey::Encrypt(const Byte * a_PlainData, size_t a_PlainLength, Byte * a_EncryptedData, size_t a_EncryptedMaxLength)
+{
+ ASSERT(IsValid());
+
+ size_t EncryptedLength = a_EncryptedMaxLength;
+ int res = mbedtls_pk_encrypt(&m_Pk,
+ a_PlainData, a_PlainLength, a_EncryptedData, &EncryptedLength, a_EncryptedMaxLength,
+ mbedtls_ctr_drbg_random, m_CtrDrbg.GetInternal()
+ );
+ if (res != 0)
+ {
+ return res;
+ }
+ return static_cast<int>(EncryptedLength);
+}
+
+
+
+
+
+
+int cCryptoKey::ParsePublic(const void * a_Data, size_t a_NumBytes)
+{
+ ASSERT(!IsValid()); // Cannot parse a second key
+
+ return mbedtls_pk_parse_public_key(&m_Pk, reinterpret_cast<const unsigned char *>(a_Data), a_NumBytes);
+}
+
+
+
+
+
+
+int cCryptoKey::ParsePrivate(const void * a_Data, size_t a_NumBytes, const AString & a_Password)
+{
+ ASSERT(!IsValid()); // Cannot parse a second key
+
+ if (a_Password.empty())
+ {
+ return mbedtls_pk_parse_key(&m_Pk, reinterpret_cast<const unsigned char *>(a_Data), a_NumBytes, nullptr, 0);
+ }
+ else
+ {
+ return mbedtls_pk_parse_key(
+ &m_Pk,
+ reinterpret_cast<const unsigned char *>(a_Data), a_NumBytes,
+ reinterpret_cast<const unsigned char *>(a_Password.c_str()), a_Password.size()
+ );
+ }
+}
+
+
+
+
+
+bool cCryptoKey::IsValid(void) const
+{
+ return (mbedtls_pk_get_type(&m_Pk) != MBEDTLS_PK_NONE);
+}
+
+
+
+
diff --git a/src/mbedTLS++/CryptoKey.h b/src/mbedTLS++/CryptoKey.h
new file mode 100644
index 000000000..1a74090ac
--- /dev/null
+++ b/src/mbedTLS++/CryptoKey.h
@@ -0,0 +1,76 @@
+
+// CryptoKey.h
+
+// Declares the cCryptoKey class representing a RSA public key in mbedTLS
+
+
+
+
+
+#pragma once
+
+#include "CtrDrbgContext.h"
+#include "mbedtls/pk.h"
+
+
+
+
+
+class cCryptoKey
+{
+ friend class cSslConfig;
+
+public:
+ /** Constructs an empty key instance. Before use, it needs to be filled by ParsePublic() or ParsePrivate() */
+ cCryptoKey(void);
+
+ /** Constructs the public key out of the DER- or PEM-encoded pubkey data */
+ cCryptoKey(const AString & a_PublicKeyData);
+
+ /** Constructs the private key out of the DER- or PEM-encoded privkey data, with the specified password.
+ If a_Password is empty, no password is assumed. */
+ cCryptoKey(const AString & a_PrivateKeyData, const AString & a_Password);
+
+ ~cCryptoKey();
+
+ /** Decrypts the data using the stored public key
+ Both a_EncryptedData and a_DecryptedData must be at least <KeySizeBytes> bytes large.
+ Returns the number of bytes decrypted, or negative number for error. */
+ int Decrypt(const Byte * a_EncryptedData, size_t a_EncryptedLength, Byte * a_DecryptedData, size_t a_DecryptedMaxLength);
+
+ /** Encrypts the data using the stored public key
+ Both a_EncryptedData and a_DecryptedData must be at least <KeySizeBytes> bytes large.
+ Returns the number of bytes decrypted, or negative number for error. */
+ int Encrypt(const Byte * a_PlainData, size_t a_PlainLength, Byte * a_EncryptedData, size_t a_EncryptedMaxLength);
+
+ /** Parses the specified data into a public key representation.
+ The key can be DER- or PEM-encoded.
+ Returns 0 on success, mbedTLS error code on failure. */
+ int ParsePublic(const void * a_Data, size_t a_NumBytes);
+
+ /** Parses the specified data into a private key representation.
+ If a_Password is empty, no password is assumed.
+ The key can be DER- or PEM-encoded.
+ Returns 0 on success, mbedTLS error code on failure. */
+ int ParsePrivate(const void * a_Data, size_t a_NumBytes, const AString & a_Password);
+
+ /** Returns true if the contained key is valid. */
+ bool IsValid(void) const;
+
+protected:
+ /** The mbedTLS representation of the key data */
+ mbedtls_pk_context m_Pk;
+
+ /** The random generator used in encryption and decryption */
+ cCtrDrbgContext m_CtrDrbg;
+
+
+ /** Returns the internal context ptr. Only use in mbedTLS API calls. */
+ mbedtls_pk_context * GetInternal(void) { return &m_Pk; }
+} ;
+
+typedef std::shared_ptr<cCryptoKey> cCryptoKeyPtr;
+
+
+
+
diff --git a/src/mbedTLS++/CtrDrbgContext.cpp b/src/mbedTLS++/CtrDrbgContext.cpp
new file mode 100644
index 000000000..bd4a55000
--- /dev/null
+++ b/src/mbedTLS++/CtrDrbgContext.cpp
@@ -0,0 +1,51 @@
+
+// CtrDrbgContext.cpp
+
+// Implements the cCtrDrbgContext class representing a wrapper over CTR-DRBG implementation in mbedTLS
+
+#include "Globals.h"
+#include "CtrDrbgContext.h"
+#include "EntropyContext.h"
+
+
+
+
+
+cCtrDrbgContext::cCtrDrbgContext(void) :
+ m_EntropyContext(std::make_shared<cEntropyContext>()),
+ m_IsValid(false)
+{
+ mbedtls_ctr_drbg_init(&m_CtrDrbg);
+}
+
+
+
+
+
+cCtrDrbgContext::cCtrDrbgContext(const std::shared_ptr<cEntropyContext> & a_EntropyContext) :
+ m_EntropyContext(a_EntropyContext),
+ m_IsValid(false)
+{
+ mbedtls_ctr_drbg_init(&m_CtrDrbg);
+}
+
+
+
+
+
+int cCtrDrbgContext::Initialize(const void * a_Custom, size_t a_CustomSize)
+{
+ if (m_IsValid)
+ {
+ // Already initialized
+ return 0;
+ }
+
+ int res = mbedtls_ctr_drbg_seed(&m_CtrDrbg, mbedtls_entropy_func, &(m_EntropyContext->m_Entropy), reinterpret_cast<const unsigned char *>(a_Custom), a_CustomSize);
+ m_IsValid = (res == 0);
+ return res;
+}
+
+
+
+
diff --git a/src/mbedTLS++/CtrDrbgContext.h b/src/mbedTLS++/CtrDrbgContext.h
new file mode 100644
index 000000000..21d786c2e
--- /dev/null
+++ b/src/mbedTLS++/CtrDrbgContext.h
@@ -0,0 +1,63 @@
+
+// CtrDrbgContext.h
+
+// Declares the cCtrDrbgContext class representing a wrapper over CTR-DRBG implementation in mbedTLS
+
+
+
+
+
+#pragma once
+
+#include "mbedtls/ctr_drbg.h"
+
+
+
+
+
+// fwd: EntropyContext.h
+class cEntropyContext;
+
+
+
+
+
+class cCtrDrbgContext
+{
+ friend class cSslConfig;
+ friend class cRsaPrivateKey;
+ friend class cCryptoKey;
+
+public:
+ /** Constructs the context with a new entropy context. */
+ cCtrDrbgContext(void);
+
+ /** Constructs the context with the specified entropy context. */
+ cCtrDrbgContext(const std::shared_ptr<cEntropyContext> & a_EntropyContext);
+
+ /** Initializes the context.
+ a_Custom is optional additional data to use for entropy, nullptr is accepted.
+ Returns 0 if successful, mbedTLS error code on failure. */
+ int Initialize(const void * a_Custom, size_t a_CustomSize);
+
+ /** Returns true if the object is valid (has been initialized properly) */
+ bool IsValid(void) const { return m_IsValid; }
+
+protected:
+ /** The entropy source used for generating the random */
+ std::shared_ptr<cEntropyContext> m_EntropyContext;
+
+ /** The random generator context */
+ mbedtls_ctr_drbg_context m_CtrDrbg;
+
+ /** Set to true if the object is valid (has been initialized properly) */
+ bool m_IsValid;
+
+
+ /** Returns the internal context ptr. Only use in mbedTLS API calls. */
+ mbedtls_ctr_drbg_context * GetInternal(void) { return &m_CtrDrbg; }
+} ;
+
+
+
+
diff --git a/src/mbedTLS++/EntropyContext.cpp b/src/mbedTLS++/EntropyContext.cpp
new file mode 100644
index 000000000..aea056f4e
--- /dev/null
+++ b/src/mbedTLS++/EntropyContext.cpp
@@ -0,0 +1,29 @@
+
+// EntropyContext.cpp
+
+// Implements the cEntropyContext class representing a wrapper over entropy contexts in mbedTLS
+
+#include "Globals.h"
+#include "EntropyContext.h"
+
+
+
+
+
+cEntropyContext::cEntropyContext(void)
+{
+ mbedtls_entropy_init(&m_Entropy);
+}
+
+
+
+
+
+cEntropyContext::~cEntropyContext()
+{
+ mbedtls_entropy_free(&m_Entropy);
+}
+
+
+
+
diff --git a/src/mbedTLS++/EntropyContext.h b/src/mbedTLS++/EntropyContext.h
new file mode 100644
index 000000000..37b6f120e
--- /dev/null
+++ b/src/mbedTLS++/EntropyContext.h
@@ -0,0 +1,31 @@
+
+// EntropyContext.h
+
+// Declares the cEntropyContext class representing a wrapper over entropy contexts in mbedTLS
+
+
+
+
+
+#pragma once
+
+#include "mbedtls/entropy.h"
+
+
+
+
+
+class cEntropyContext
+{
+ friend class cCtrDrbgContext;
+public:
+ cEntropyContext(void);
+ ~cEntropyContext();
+
+protected:
+ mbedtls_entropy_context m_Entropy;
+} ;
+
+
+
+
diff --git a/src/mbedTLS++/ErrorCodes.h b/src/mbedTLS++/ErrorCodes.h
new file mode 100644
index 000000000..36ef86fec
--- /dev/null
+++ b/src/mbedTLS++/ErrorCodes.h
@@ -0,0 +1,18 @@
+
+/** Error codes from mbedtls net_sockets.h */
+// TODO: Replace with std::error_code
+
+#define MBEDTLS_ERR_NET_SOCKET_FAILED -0x0042 /**< Failed to open a socket. */
+#define MBEDTLS_ERR_NET_CONNECT_FAILED -0x0044 /**< The connection to the given server / port failed. */
+#define MBEDTLS_ERR_NET_BIND_FAILED -0x0046 /**< Binding of the socket failed. */
+#define MBEDTLS_ERR_NET_LISTEN_FAILED -0x0048 /**< Could not listen on the socket. */
+#define MBEDTLS_ERR_NET_ACCEPT_FAILED -0x004A /**< Could not accept the incoming connection. */
+#define MBEDTLS_ERR_NET_RECV_FAILED -0x004C /**< Reading information from the socket failed. */
+#define MBEDTLS_ERR_NET_SEND_FAILED -0x004E /**< Sending information through the socket failed. */
+#define MBEDTLS_ERR_NET_CONN_RESET -0x0050 /**< Connection was reset by peer. */
+#define MBEDTLS_ERR_NET_UNKNOWN_HOST -0x0052 /**< Failed to get an IP address for the given hostname. */
+#define MBEDTLS_ERR_NET_BUFFER_TOO_SMALL -0x0043 /**< Buffer is too small to hold the data. */
+#define MBEDTLS_ERR_NET_INVALID_CONTEXT -0x0045 /**< The context is invalid, eg because it was free()ed. */
+
+
+
diff --git a/src/mbedTLS++/RsaPrivateKey.cpp b/src/mbedTLS++/RsaPrivateKey.cpp
new file mode 100644
index 000000000..3dfb3bac3
--- /dev/null
+++ b/src/mbedTLS++/RsaPrivateKey.cpp
@@ -0,0 +1,174 @@
+
+// RsaPrivateKey.cpp
+
+#include "Globals.h"
+#include "RsaPrivateKey.h"
+#include "mbedtls/pk.h"
+
+
+
+
+
+cRsaPrivateKey::cRsaPrivateKey(void)
+{
+ mbedtls_rsa_init(&m_Rsa, MBEDTLS_RSA_PKCS_V15, 0);
+ m_CtrDrbg.Initialize("RSA", 3);
+}
+
+
+
+
+
+cRsaPrivateKey::cRsaPrivateKey(const cRsaPrivateKey & a_Other)
+{
+ mbedtls_rsa_init(&m_Rsa, MBEDTLS_RSA_PKCS_V15, 0);
+ mbedtls_rsa_copy(&m_Rsa, &a_Other.m_Rsa);
+ m_CtrDrbg.Initialize("RSA", 3);
+}
+
+
+
+
+
+cRsaPrivateKey::~cRsaPrivateKey()
+{
+ mbedtls_rsa_free(&m_Rsa);
+}
+
+
+
+
+
+bool cRsaPrivateKey::Generate(unsigned a_KeySizeBits)
+{
+ int res = mbedtls_rsa_gen_key(&m_Rsa, mbedtls_ctr_drbg_random, m_CtrDrbg.GetInternal(), a_KeySizeBits, 65537);
+ if (res != 0)
+ {
+ LOG("RSA key generation failed: -0x%x", -res);
+ return false;
+ }
+
+ return true;
+}
+
+
+
+
+
+AString cRsaPrivateKey::GetPubKeyDER(void)
+{
+ class cPubKey
+ {
+ public:
+ cPubKey(mbedtls_rsa_context * a_Rsa) :
+ m_IsValid(false)
+ {
+ mbedtls_pk_init(&m_Key);
+ if (mbedtls_pk_setup(&m_Key, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) != 0)
+ {
+ ASSERT(!"Cannot init PrivKey context");
+ return;
+ }
+ if (mbedtls_rsa_copy(mbedtls_pk_rsa(m_Key), a_Rsa) != 0)
+ {
+ ASSERT(!"Cannot copy PrivKey to PK context");
+ return;
+ }
+ m_IsValid = true;
+ }
+
+ ~cPubKey()
+ {
+ if (m_IsValid)
+ {
+ mbedtls_pk_free(&m_Key);
+ }
+ }
+
+ operator mbedtls_pk_context * (void) { return &m_Key; }
+
+ protected:
+ bool m_IsValid;
+ mbedtls_pk_context m_Key;
+ } PkCtx(&m_Rsa);
+
+ unsigned char buf[3000];
+ int res = mbedtls_pk_write_pubkey_der(PkCtx, buf, sizeof(buf));
+ if (res < 0)
+ {
+ return AString();
+ }
+ return AString(reinterpret_cast<const char *>(buf + sizeof(buf) - res), static_cast<size_t>(res));
+}
+
+
+
+
+
+int cRsaPrivateKey::Decrypt(const Byte * a_EncryptedData, size_t a_EncryptedLength, Byte * a_DecryptedData, size_t a_DecryptedMaxLength)
+{
+ if (a_EncryptedLength < m_Rsa.len)
+ {
+ LOGD("%s: Invalid a_EncryptedLength: got %u, exp at least %u",
+ __FUNCTION__, static_cast<unsigned>(a_EncryptedLength), static_cast<unsigned>(m_Rsa.len)
+ );
+ ASSERT(!"Invalid a_DecryptedMaxLength!");
+ return -1;
+ }
+ if (a_DecryptedMaxLength < m_Rsa.len)
+ {
+ LOGD("%s: Invalid a_DecryptedMaxLength: got %u, exp at least %u",
+ __FUNCTION__, static_cast<unsigned>(a_EncryptedLength), static_cast<unsigned>(m_Rsa.len)
+ );
+ ASSERT(!"Invalid a_DecryptedMaxLength!");
+ return -1;
+ }
+ size_t DecryptedLength;
+ int res = mbedtls_rsa_pkcs1_decrypt(
+ &m_Rsa, mbedtls_ctr_drbg_random, m_CtrDrbg.GetInternal(), MBEDTLS_RSA_PRIVATE, &DecryptedLength,
+ a_EncryptedData, a_DecryptedData, a_DecryptedMaxLength
+ );
+ if (res != 0)
+ {
+ return -1;
+ }
+ return static_cast<int>(DecryptedLength);
+}
+
+
+
+
+
+int cRsaPrivateKey::Encrypt(const Byte * a_PlainData, size_t a_PlainLength, Byte * a_EncryptedData, size_t a_EncryptedMaxLength)
+{
+ if (a_EncryptedMaxLength < m_Rsa.len)
+ {
+ LOGD("%s: Invalid a_EncryptedMaxLength: got %u, exp at least %u",
+ __FUNCTION__, static_cast<unsigned>(a_EncryptedMaxLength), static_cast<unsigned>(m_Rsa.len)
+ );
+ ASSERT(!"Invalid a_DecryptedMaxLength!");
+ return -1;
+ }
+ if (a_PlainLength < m_Rsa.len)
+ {
+ LOGD("%s: Invalid a_PlainLength: got %u, exp at least %u",
+ __FUNCTION__, static_cast<unsigned>(a_PlainLength), static_cast<unsigned>(m_Rsa.len)
+ );
+ ASSERT(!"Invalid a_PlainLength!");
+ return -1;
+ }
+ int res = mbedtls_rsa_pkcs1_encrypt(
+ &m_Rsa, mbedtls_ctr_drbg_random, m_CtrDrbg.GetInternal(), MBEDTLS_RSA_PRIVATE,
+ a_PlainLength, a_PlainData, a_EncryptedData
+ );
+ if (res != 0)
+ {
+ return -1;
+ }
+ return static_cast<int>(m_Rsa.len);
+}
+
+
+
+
+
diff --git a/src/mbedTLS++/RsaPrivateKey.h b/src/mbedTLS++/RsaPrivateKey.h
new file mode 100644
index 000000000..7be0152b7
--- /dev/null
+++ b/src/mbedTLS++/RsaPrivateKey.h
@@ -0,0 +1,67 @@
+
+// RsaPrivateKey.h
+
+// Declares the cRsaPrivateKey class representing a private key for RSA operations.
+
+
+
+
+
+#pragma once
+
+#include "CtrDrbgContext.h"
+#include "mbedtls/rsa.h"
+
+
+
+
+
+/** Encapsulates an RSA private key used in PKI cryptography */
+class cRsaPrivateKey
+{
+ friend class cSslContext;
+
+public:
+ /** Creates a new empty object, the key is not assigned */
+ cRsaPrivateKey(void);
+
+ /** Deep-copies the key from a_Other */
+ cRsaPrivateKey(const cRsaPrivateKey & a_Other);
+
+ ~cRsaPrivateKey();
+
+ /** Generates a new key within this object, with the specified size in bits.
+ Returns true on success, false on failure. */
+ bool Generate(unsigned a_KeySizeBits = 1024);
+
+ /** Returns the public key part encoded in ASN1 DER encoding */
+ AString GetPubKeyDER(void);
+
+ /** Decrypts the data using RSAES-PKCS#1 algorithm.
+ Both a_EncryptedData and a_DecryptedData must be at least <KeySizeBytes> bytes large.
+ Returns the number of bytes decrypted, or negative number for error. */
+ int Decrypt(const Byte * a_EncryptedData, size_t a_EncryptedLength, Byte * a_DecryptedData, size_t a_DecryptedMaxLength);
+
+ /** Encrypts the data using RSAES-PKCS#1 algorithm.
+ Both a_EncryptedData and a_DecryptedData must be at least <KeySizeBytes> bytes large.
+ Returns the number of bytes decrypted, or negative number for error. */
+ int Encrypt(const Byte * a_PlainData, size_t a_PlainLength, Byte * a_EncryptedData, size_t a_EncryptedMaxLength);
+
+protected:
+ /** The mbedTLS key context */
+ mbedtls_rsa_context m_Rsa;
+
+ /** The random generator used for generating the key and encryption / decryption */
+ cCtrDrbgContext m_CtrDrbg;
+
+
+ /** Returns the internal context ptr. Only use in mbedTLS API calls. */
+ mbedtls_rsa_context * GetInternal(void) { return &m_Rsa; }
+} ;
+
+typedef std::shared_ptr<cRsaPrivateKey> cRsaPrivateKeyPtr;
+
+
+
+
+
diff --git a/src/mbedTLS++/Sha1Checksum.cpp b/src/mbedTLS++/Sha1Checksum.cpp
new file mode 100644
index 000000000..9c82d92fe
--- /dev/null
+++ b/src/mbedTLS++/Sha1Checksum.cpp
@@ -0,0 +1,138 @@
+
+// Sha1Checksum.cpp
+
+// Declares the cSha1Checksum class representing the SHA-1 checksum calculator
+
+#include "Globals.h"
+#include "Sha1Checksum.h"
+
+
+
+
+
+/*
+// Self-test the hash formatting for known values:
+// sha1(Notch) : 4ed1f46bbe04bc756bcb17c0c7ce3e4632f06a48
+// sha1(jeb_) : -7c9d5b0044c130109a5d7b5fb5c317c02b4e28c1
+// sha1(simon) : 88e16a1019277b15d58faf0541e11910eb756f6
+
+static class Test
+{
+public:
+ Test(void)
+ {
+ AString DigestNotch, DigestJeb, DigestSimon;
+ Byte Digest[20];
+ cSha1Checksum Checksum;
+ Checksum.Update((const Byte *)"Notch", 5);
+ Checksum.Finalize(Digest);
+ cSha1Checksum::DigestToJava(Digest, DigestNotch);
+ Checksum.Restart();
+ Checksum.Update((const Byte *)"jeb_", 4);
+ Checksum.Finalize(Digest);
+ cSha1Checksum::DigestToJava(Digest, DigestJeb);
+ Checksum.Restart();
+ Checksum.Update((const Byte *)"simon", 5);
+ Checksum.Finalize(Digest);
+ cSha1Checksum::DigestToJava(Digest, DigestSimon);
+ printf("Notch: \"%s\"\n", DigestNotch.c_str());
+ printf("jeb_: \"%s\"\n", DigestJeb.c_str());
+ printf("simon: \"%s\"\n", DigestSimon.c_str());
+ assert(DigestNotch == "4ed1f46bbe04bc756bcb17c0c7ce3e4632f06a48");
+ assert(DigestJeb == "-7c9d5b0044c130109a5d7b5fb5c317c02b4e28c1");
+ assert(DigestSimon == "88e16a1019277b15d58faf0541e11910eb756f6");
+ }
+} test;
+*/
+
+
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
+// cSha1Checksum:
+
+cSha1Checksum::cSha1Checksum(void) :
+ m_DoesAcceptInput(true)
+{
+ mbedtls_sha1_starts(&m_Sha1);
+}
+
+
+
+
+
+void cSha1Checksum::Update(const Byte * a_Data, size_t a_Length)
+{
+ ASSERT(m_DoesAcceptInput); // Not Finalize()-d yet, or Restart()-ed
+
+ mbedtls_sha1_update(&m_Sha1, a_Data, a_Length);
+}
+
+
+
+
+
+void cSha1Checksum::Finalize(cSha1Checksum::Checksum & a_Output)
+{
+ ASSERT(m_DoesAcceptInput); // Not Finalize()-d yet, or Restart()-ed
+
+ mbedtls_sha1_finish(&m_Sha1, a_Output);
+ m_DoesAcceptInput = false;
+}
+
+
+
+
+
+void cSha1Checksum::DigestToJava(const Checksum & a_Digest, AString & a_Out)
+{
+ Checksum Digest;
+ memcpy(Digest, a_Digest, sizeof(Digest));
+
+ bool IsNegative = (Digest[0] >= 0x80);
+ if (IsNegative)
+ {
+ // Two's complement:
+ bool carry = true; // Add one to the whole number
+ for (int i = 19; i >= 0; i--)
+ {
+ Digest[i] = ~Digest[i];
+ if (carry)
+ {
+ carry = (Digest[i] == 0xff);
+ Digest[i]++;
+ }
+ }
+ }
+ a_Out.clear();
+ a_Out.reserve(40);
+ for (int i = 0; i < 20; i++)
+ {
+ AppendPrintf(a_Out, "%02x", Digest[i]);
+ }
+ while ((a_Out.length() > 0) && (a_Out[0] == '0'))
+ {
+ a_Out.erase(0, 1);
+ }
+ if (IsNegative)
+ {
+ a_Out.insert(0, "-");
+ }
+}
+
+
+
+
+
+
+void cSha1Checksum::Restart(void)
+{
+ mbedtls_sha1_starts(&m_Sha1);
+ m_DoesAcceptInput = true;
+}
+
+
+
+
diff --git a/src/mbedTLS++/Sha1Checksum.h b/src/mbedTLS++/Sha1Checksum.h
new file mode 100644
index 000000000..43180e531
--- /dev/null
+++ b/src/mbedTLS++/Sha1Checksum.h
@@ -0,0 +1,52 @@
+
+// Sha1Checksum.h
+
+// Declares the cSha1Checksum class representing the SHA-1 checksum calculator
+
+
+
+
+
+#pragma once
+
+#include "mbedtls/sha1.h"
+
+
+
+
+
+/** Calculates a SHA1 checksum for data stream */
+class cSha1Checksum
+{
+public:
+ typedef Byte Checksum[20]; // The type used for storing the checksum
+
+ cSha1Checksum(void);
+
+ /** Adds the specified data to the checksum */
+ void Update(const Byte * a_Data, size_t a_Length);
+
+ /** Calculates and returns the final checksum */
+ void Finalize(Checksum & a_Output);
+
+ /** Returns true if the object is accepts more input data, false if Finalize()-d (need to Restart()) */
+ bool DoesAcceptInput(void) const { return m_DoesAcceptInput; }
+
+ /** Converts a raw 160-bit SHA1 digest into a Java Hex representation
+ According to http://wiki.vg/Protocol_Encryption
+ */
+ static void DigestToJava(const Checksum & a_Digest, AString & a_JavaOut);
+
+ /** Clears the current context and start a new checksum calculation */
+ void Restart(void);
+
+protected:
+ /** True if the object is accepts more input data, false if Finalize()-d (need to Restart()) */
+ bool m_DoesAcceptInput;
+
+ mbedtls_sha1_context m_Sha1;
+} ;
+
+
+
+
diff --git a/src/mbedTLS++/SslConfig.cpp b/src/mbedTLS++/SslConfig.cpp
new file mode 100644
index 000000000..9dec49776
--- /dev/null
+++ b/src/mbedTLS++/SslConfig.cpp
@@ -0,0 +1,287 @@
+
+#include "Globals.h"
+
+#include "mbedTLS++/SslConfig.h"
+#include "EntropyContext.h"
+#include "CtrDrbgContext.h"
+#include "CryptoKey.h"
+#include "X509Cert.h"
+
+
+// This allows us to debug SSL and certificate problems, but produce way too much output,
+// so it's disabled until someone needs it
+// #define ENABLE_SSL_DEBUG_MSG
+
+
+#if defined(_DEBUG) && defined(ENABLE_SSL_DEBUG_MSG)
+ #include "mbedtls/debug.h"
+
+
+ namespace
+ {
+ void SSLDebugMessage(void * a_UserParam, int a_Level, const char * a_Filename, int a_LineNo, const char * a_Text)
+ {
+ if (a_Level > 3)
+ {
+ // Don't want the trace messages
+ return;
+ }
+
+ // Remove the terminating LF:
+ size_t len = strlen(a_Text) - 1;
+ while ((len > 0) && (a_Text[len] <= 32))
+ {
+ len--;
+ }
+ AString Text(a_Text, len + 1);
+
+ LOGD("SSL (%d): %s", a_Level, Text.c_str());
+ }
+
+
+
+
+
+ int SSLVerifyCert(void * a_This, mbedtls_x509_crt * a_Crt, int a_Depth, uint32_t * a_Flags)
+ {
+ char buf[1024];
+ UNUSED(a_This);
+
+ LOG("Verify requested for (Depth %d):", a_Depth);
+ mbedtls_x509_crt_info(buf, sizeof(buf) - 1, "", a_Crt);
+ LOG("%s", buf);
+
+ uint32_t Flags = *a_Flags;
+ if ((Flags & MBEDTLS_X509_BADCERT_EXPIRED) != 0)
+ {
+ LOG(" ! server certificate has expired");
+ }
+
+ if ((Flags & MBEDTLS_X509_BADCERT_REVOKED) != 0)
+ {
+ LOG(" ! server certificate has been revoked");
+ }
+
+ if ((Flags & MBEDTLS_X509_BADCERT_CN_MISMATCH) != 0)
+ {
+ LOG(" ! CN mismatch");
+ }
+
+ if ((Flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED) != 0)
+ {
+ LOG(" ! self-signed or not signed by a trusted CA");
+ }
+
+ if ((Flags & MBEDTLS_X509_BADCRL_NOT_TRUSTED) != 0)
+ {
+ LOG(" ! CRL not trusted");
+ }
+
+ if ((Flags & MBEDTLS_X509_BADCRL_EXPIRED) != 0)
+ {
+ LOG(" ! CRL expired");
+ }
+
+ if ((Flags & MBEDTLS_X509_BADCERT_OTHER) != 0)
+ {
+ LOG(" ! other (unknown) flag");
+ }
+
+ if (Flags == 0)
+ {
+ LOG(" This certificate has no flags");
+ }
+
+ return 0;
+ }
+ }
+#endif // defined(_DEBUG) && defined(ENABLE_SSL_DEBUG_MSG)
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
+// cSslConfig:
+
+cSslConfig::cSslConfig()
+{
+ mbedtls_ssl_config_init(&m_Config);
+}
+
+
+
+
+
+cSslConfig::~cSslConfig()
+{
+ mbedtls_ssl_config_free(&m_Config);
+}
+
+
+
+
+
+int cSslConfig::InitDefaults(const bool a_IsClient)
+{
+ return mbedtls_ssl_config_defaults(
+ &m_Config,
+ a_IsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
+ MBEDTLS_SSL_TRANSPORT_STREAM,
+ MBEDTLS_SSL_PRESET_DEFAULT
+ );
+}
+
+
+
+
+
+void cSslConfig::SetAuthMode(const eSslAuthMode a_AuthMode)
+{
+ const int Mode = [=]()
+ {
+ switch (a_AuthMode)
+ {
+ case eSslAuthMode::None: return MBEDTLS_SSL_VERIFY_NONE;
+ case eSslAuthMode::Optional: return MBEDTLS_SSL_VERIFY_OPTIONAL;
+ case eSslAuthMode::Required: return MBEDTLS_SSL_VERIFY_REQUIRED;
+ case eSslAuthMode::Unset: return MBEDTLS_SSL_VERIFY_UNSET;
+ #ifndef __clang__
+ default: return MBEDTLS_SSL_VERIFY_OPTIONAL;
+ #endif
+ }
+ }();
+
+ mbedtls_ssl_conf_authmode(&m_Config, Mode);
+}
+
+
+
+
+
+void cSslConfig::SetRng(cCtrDrbgContextPtr a_CtrDrbg)
+{
+ ASSERT(a_CtrDrbg != nullptr);
+ m_CtrDrbg = std::move(a_CtrDrbg);
+ mbedtls_ssl_conf_rng(&m_Config, mbedtls_ctr_drbg_random, &m_CtrDrbg->m_CtrDrbg);
+}
+
+
+
+
+
+void cSslConfig::SetDebugCallback(cDebugCallback a_CallbackFun, void * a_CallbackData)
+{
+ mbedtls_ssl_conf_dbg(&m_Config, a_CallbackFun, a_CallbackData);
+}
+
+
+
+
+
+void cSslConfig::SetOwnCert(cX509CertPtr a_OwnCert, cCryptoKeyPtr a_OwnCertPrivKey)
+{
+ ASSERT(a_OwnCert != nullptr);
+ ASSERT(a_OwnCertPrivKey != nullptr);
+
+ // Make sure we have the cert stored for later, mbedTLS only uses the cert later on
+ m_OwnCert = std::move(a_OwnCert);
+ m_OwnCertPrivKey = std::move(a_OwnCertPrivKey);
+
+ // Set into the config:
+ mbedtls_ssl_conf_own_cert(&m_Config, m_OwnCert->GetInternal(), m_OwnCertPrivKey->GetInternal());
+}
+
+
+
+
+
+void cSslConfig::SetVerifyCallback(cVerifyCallback a_CallbackFun, void * a_CallbackData)
+{
+ mbedtls_ssl_conf_verify(&m_Config, a_CallbackFun, a_CallbackData);
+}
+
+
+
+
+
+void cSslConfig::SetCipherSuites(std::vector<int> a_CipherSuites)
+{
+ m_CipherSuites = std::move(a_CipherSuites);
+ m_CipherSuites.push_back(0); // Must be null terminated
+ mbedtls_ssl_conf_ciphersuites(&m_Config, m_CipherSuites.data());
+}
+
+
+
+
+
+void cSslConfig::SetCACerts(cX509CertPtr a_CACert)
+{
+ m_CACerts = std::move(a_CACert);
+ mbedtls_ssl_conf_ca_chain(&m_Config, m_CACerts->GetInternal(), nullptr);
+}
+
+
+
+
+
+std::shared_ptr<cSslConfig> cSslConfig::MakeDefaultConfig(bool a_IsClient)
+{
+ // TODO: Default CA chain and SetAuthMode(eSslAuthMode::Required)
+ auto Ret = std::make_shared<cSslConfig>();
+
+ Ret->InitDefaults(a_IsClient);
+
+ {
+ auto CtrDrbg = std::make_shared<cCtrDrbgContext>();
+ CtrDrbg->Initialize("Cuberite", 8);
+ Ret->SetRng(std::move(CtrDrbg));
+ }
+
+ Ret->SetAuthMode(eSslAuthMode::None); // We cannot verify because we don't have a CA chain
+
+ #ifdef _DEBUG
+ #ifdef ENABLE_SSL_DEBUG_MSG
+ Ret->SetDebugCallback(&SSLDebugMessage, nullptr);
+ Ret->SetVerifyCallback(SSLVerifyCert, nullptr);
+ mbedtls_debug_set_threshold(2);
+ #endif
+
+ /*
+ // Set ciphersuite to the easiest one to decode, so that the connection can be wireshark-decoded:
+ Ret->SetCipherSuites(
+ {
+ MBEDTLS_TLS_RSA_WITH_RC4_128_MD5,
+ MBEDTLS_TLS_RSA_WITH_RC4_128_SHA,
+ MBEDTLS_TLS_RSA_WITH_AES_128_CBC_SHA
+ }
+ );
+ */
+ #endif
+
+ return Ret;
+}
+
+
+
+
+
+std::shared_ptr<const cSslConfig> cSslConfig::GetDefaultClientConfig()
+{
+ static const std::shared_ptr<const cSslConfig> ClientConfig = MakeDefaultConfig(true);
+ return ClientConfig;
+}
+
+
+
+
+
+std::shared_ptr<const cSslConfig> cSslConfig::GetDefaultServerConfig()
+{
+ static const std::shared_ptr<const cSslConfig> ServerConfig = MakeDefaultConfig(false);
+ return ServerConfig;
+}
+
+
+
+
diff --git a/src/mbedTLS++/SslConfig.h b/src/mbedTLS++/SslConfig.h
new file mode 100644
index 000000000..47a4f7b59
--- /dev/null
+++ b/src/mbedTLS++/SslConfig.h
@@ -0,0 +1,93 @@
+
+#pragma once
+
+#include "mbedtls/ssl.h"
+
+// fwd:
+class cCryptoKey;
+class cCtrDrbgContext;
+class cX509Cert;
+
+using cCryptoKeyPtr = std::shared_ptr<cCryptoKey>;
+using cCtrDrbgContextPtr = std::shared_ptr<cCtrDrbgContext>;
+using cX509CertPtr = std::shared_ptr<cX509Cert>;
+
+enum class eSslAuthMode
+{
+ None = 0, // MBEDTLS_SSL_VERIFY_NONE
+ Optional = 1, // MBEDTLS_SSL_VERIFY_OPTIONAL
+ Required = 2, // MBEDTLS_SSL_VERIFY_REQUIRED
+ Unset = 3, // MBEDTLS_SSL_VERIFY_UNSET
+};
+
+
+
+class cSslConfig
+{
+ friend class cSslContext;
+public:
+ /** Type of the SSL debug callback.
+ Parameters are:
+ void * Opaque context for the callback
+ int Debug level
+ const char * File name
+ int Line number
+ const char * Message */
+ using cDebugCallback = void(*)(void *, int, const char *, int, const char *);
+
+ /** Type of the SSL certificate verify callback.
+ Parameters are:
+ void * Opaque context for the callback
+ mbedtls_x509_crt * Current cert
+ int Cert chain depth
+ uint32_t * Verification flags */
+ using cVerifyCallback = int(*)(void *, mbedtls_x509_crt *, int, uint32_t *);
+
+ cSslConfig();
+ ~cSslConfig();
+
+ /** Initialize with mbedTLS default settings. */
+ int InitDefaults(bool a_IsClient);
+
+ /** Set the authorization mode. */
+ void SetAuthMode(eSslAuthMode a_AuthMode);
+
+ /** Set the random number generator. */
+ void SetRng(cCtrDrbgContextPtr a_CtrDrbg);
+
+ /** Set the debug callback. */
+ void SetDebugCallback(cDebugCallback a_CallbackFun, void * a_CallbackData);
+
+ /** Set the certificate verify callback. */
+ void SetVerifyCallback(cVerifyCallback a_CallbackFun, void * a_CallbackData);
+
+ /** Set the enabled cipher suites. */
+ void SetCipherSuites(std::vector<int> a_CipherSuites);
+
+ /** Set the certificate to use for connections. */
+ void SetOwnCert(cX509CertPtr a_OwnCert, cCryptoKeyPtr a_OwnCertPrivKey);
+
+ /** Set the trusted certificate authority chain. */
+ void SetCACerts(cX509CertPtr a_CACert);
+
+ /** Creates a new config with some sensible defaults on top of mbedTLS basic settings. */
+ static std::shared_ptr<cSslConfig> MakeDefaultConfig(bool a_IsClient);
+
+ /** Returns the default config for client connections. */
+ static std::shared_ptr<const cSslConfig> GetDefaultClientConfig();
+
+ /** Returns the default config for server connections. */
+ static std::shared_ptr<const cSslConfig> GetDefaultServerConfig();
+
+private:
+
+ /** Returns a pointer to the wrapped mbedtls representation. */
+ const mbedtls_ssl_config * GetInternal() const { return &m_Config; }
+
+ mbedtls_ssl_config m_Config;
+ cCtrDrbgContextPtr m_CtrDrbg;
+ cX509CertPtr m_OwnCert;
+ cCryptoKeyPtr m_OwnCertPrivKey;
+ cX509CertPtr m_CACerts;
+ std::vector<int> m_CipherSuites;
+};
diff --git a/src/mbedTLS++/SslContext.cpp b/src/mbedTLS++/SslContext.cpp
new file mode 100644
index 000000000..e86da3fd2
--- /dev/null
+++ b/src/mbedTLS++/SslContext.cpp
@@ -0,0 +1,157 @@
+
+// SslContext.cpp
+
+// Implements the cSslContext class that holds everything a single SSL context needs to function
+
+#include "Globals.h"
+#include "mbedTLS++/SslContext.h"
+#include "mbedTLS++/SslConfig.h"
+
+
+
+
+
+cSslContext::cSslContext(void) :
+ m_IsValid(false),
+ m_HasHandshaken(false)
+{
+ mbedtls_ssl_init(&m_Ssl);
+}
+
+
+
+
+
+cSslContext::~cSslContext()
+{
+ mbedtls_ssl_free(&m_Ssl);
+}
+
+
+
+
+
+int cSslContext::Initialize(std::shared_ptr<const cSslConfig> a_Config)
+{
+ // Check double-initialization:
+ if (m_IsValid)
+ {
+ LOGWARNING("SSL: Double initialization is not supported.");
+ return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; // There is no return value well-suited for this, reuse this one.
+ }
+
+ // Check the Config:
+ m_Config = a_Config;
+ if (m_Config == nullptr)
+ {
+ ASSERT(!"Config must not be nullptr");
+ return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+ }
+
+ // Apply the configuration to the ssl context
+ int res = mbedtls_ssl_setup(&m_Ssl, m_Config->GetInternal());
+ if (res != 0)
+ {
+ return res;
+ }
+
+ // Set the io callbacks
+ mbedtls_ssl_set_bio(&m_Ssl, this, SendEncrypted, ReceiveEncrypted, nullptr);
+
+ m_IsValid = true;
+ return 0;
+}
+
+
+
+
+
+int cSslContext::Initialize(bool a_IsClient)
+{
+ if (a_IsClient)
+ {
+ return Initialize(cSslConfig::GetDefaultClientConfig());
+ }
+ else
+ {
+ return Initialize(cSslConfig::GetDefaultServerConfig());
+ }
+}
+
+
+
+
+
+void cSslContext::SetExpectedPeerName(const AString & a_ExpectedPeerName)
+{
+ ASSERT(m_IsValid); // Call Initialize() first
+ mbedtls_ssl_set_hostname(&m_Ssl, a_ExpectedPeerName.c_str());
+}
+
+
+
+
+
+int cSslContext::WritePlain(const void * a_Data, size_t a_NumBytes)
+{
+ ASSERT(m_IsValid); // Need to call Initialize() first
+ if (!m_HasHandshaken)
+ {
+ int res = Handshake();
+ if (res != 0)
+ {
+ return res;
+ }
+ }
+
+ return mbedtls_ssl_write(&m_Ssl, reinterpret_cast<const unsigned char *>(a_Data), a_NumBytes);
+}
+
+
+
+
+
+int cSslContext::ReadPlain(void * a_Data, size_t a_MaxBytes)
+{
+ ASSERT(m_IsValid); // Need to call Initialize() first
+ if (!m_HasHandshaken)
+ {
+ int res = Handshake();
+ if (res != 0)
+ {
+ return res;
+ }
+ }
+
+ return mbedtls_ssl_read(&m_Ssl, reinterpret_cast<unsigned char *>(a_Data), a_MaxBytes);
+}
+
+
+
+
+
+int cSslContext::Handshake(void)
+{
+ ASSERT(m_IsValid); // Need to call Initialize() first
+ ASSERT(!m_HasHandshaken); // Must not call twice
+
+ int res = mbedtls_ssl_handshake(&m_Ssl);
+ if (res == 0)
+ {
+ m_HasHandshaken = true;
+ }
+ return res;
+}
+
+
+
+
+
+int cSslContext::NotifyClose(void)
+{
+ return mbedtls_ssl_close_notify(&m_Ssl);
+}
+
+
+
+
diff --git a/src/mbedTLS++/SslContext.h b/src/mbedTLS++/SslContext.h
new file mode 100644
index 000000000..c51a9f149
--- /dev/null
+++ b/src/mbedTLS++/SslContext.h
@@ -0,0 +1,124 @@
+
+// SslContext.h
+
+// Declares the cSslContext class that holds everything a single SSL context needs to function
+
+
+
+
+
+#pragma once
+
+#include "mbedtls/ssl.h"
+#include "../ByteBuffer.h"
+
+
+
+
+
+// fwd:
+class cCtrDrbgContext;
+class cSslConfig;
+
+
+
+
+
+/**
+Acts as a generic SSL encryptor / decryptor between the two endpoints. The "owner" of this class is expected
+to create it, initialize it and then provide the means of reading and writing data through the SSL link.
+This is an abstract base class, there are descendants that handle the specific aspects of how the SSL peer
+data comes into the system:
+ - cBufferedSslContext uses a cByteBuffer to read and write the data
+ - cCallbackSslContext uses callbacks to provide the data
+*/
+class cSslContext abstract
+{
+public:
+ /** Creates a new uninitialized context */
+ cSslContext(void);
+
+ virtual ~cSslContext();
+
+ /** Initializes the context for use as a server or client.
+ a_Config must not be nullptr and the config must not be changed after this call.
+ Returns 0 on success, mbedTLS error on failure. */
+ int Initialize(std::shared_ptr<const cSslConfig> a_Config);
+
+ /** Initializes the context using the default config. */
+ int Initialize(bool a_IsClient);
+
+ /** Returns true if the object has been initialized properly. */
+ bool IsValid(void) const { return m_IsValid; }
+
+ /** Sets the SSL peer name expected for this context. Must be called after Initialize().
+ \param a_ExpectedPeerName CommonName that we expect the SSL peer to have in its cert,
+ if it is different, the verification will fail. An empty string will disable the CN check. */
+ void SetExpectedPeerName(const AString & a_ExpectedPeerName);
+
+ /** Writes data to be encrypted and sent to the SSL peer. Will perform SSL handshake, if needed.
+ Returns the number of bytes actually written, or mbedTLS error code.
+ If the return value is MBEDTLS_ERR_SSL_WANT_READ or MBEDTLS_ERR_SSL_WANT_WRITE, the owner should send any
+ cached outgoing data to the SSL peer and write any incoming data received from the SSL peer and then call
+ this function again with the same parameters. Note that this may repeat a few times before the data is
+ actually written, mainly due to initial handshake. */
+ int WritePlain(const void * a_Data, size_t a_NumBytes);
+
+ /** Reads data decrypted from the SSL stream. Will perform SSL handshake, if needed.
+ Returns the number of bytes actually read, or mbedTLS error code.
+ If the return value is MBEDTLS_ERR_SSL_WANT_READ or MBEDTLS_ERR_SSL_WANT_WRITE, the owner should send any
+ cached outgoing data to the SSL peer and write any incoming data received from the SSL peer and then call
+ this function again with the same parameters. Note that this may repeat a few times before the data is
+ actually read, mainly due to initial handshake. */
+ int ReadPlain(void * a_Data, size_t a_MaxBytes);
+
+ /** Performs the SSL handshake.
+ Returns zero on success, mbedTLS error code on failure.
+ If the return value is MBEDTLS_ERR_SSL_WANT_READ or MBEDTLS_ERR_SSL_WANT_WRITE, the owner should send any
+ cached outgoing data to the SSL peer and write any incoming data received from the SSL peer and then call
+ this function again. Note that this may repeat a few times before the handshake is completed. */
+ int Handshake(void);
+
+ /** Returns true if the SSL handshake has been completed. */
+ bool HasHandshaken(void) const { return m_HasHandshaken; }
+
+ /** Notifies the SSL peer that the connection is being closed.
+ Returns 0 on success, mbedTLS error code on failure. */
+ int NotifyClose(void);
+
+protected:
+
+ /** Configuration of the SSL context. */
+ std::shared_ptr<const cSslConfig> m_Config;
+
+ /** The SSL context that mbedTLS uses. */
+ mbedtls_ssl_context m_Ssl;
+
+ /** True if the object has been initialized properly. */
+ bool m_IsValid;
+
+ /** True if the SSL handshake has been completed. */
+ bool m_HasHandshaken;
+
+ /** The callback used by mbedTLS when it wants to read encrypted data. */
+ static int ReceiveEncrypted(void * a_This, unsigned char * a_Buffer, size_t a_NumBytes)
+ {
+ return (reinterpret_cast<cSslContext *>(a_This))->ReceiveEncrypted(a_Buffer, a_NumBytes);
+ }
+
+ /** The callback used by mbedTLS when it wants to write encrypted data. */
+ static int SendEncrypted(void * a_This, const unsigned char * a_Buffer, size_t a_NumBytes)
+ {
+ return (reinterpret_cast<cSslContext *>(a_This))->SendEncrypted(a_Buffer, a_NumBytes);
+ }
+
+ /** Called when mbedTLS wants to read encrypted data. */
+ virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) = 0;
+
+ /** Called when mbedTLS wants to write encrypted data. */
+ virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) = 0;
+} ;
+
+
+
+
diff --git a/src/mbedTLS++/X509Cert.cpp b/src/mbedTLS++/X509Cert.cpp
new file mode 100644
index 000000000..1e51dd2b7
--- /dev/null
+++ b/src/mbedTLS++/X509Cert.cpp
@@ -0,0 +1,38 @@
+
+// X509Cert.cpp
+
+// Implements the cX509Cert class representing a wrapper over X509 certs in mbedTLS
+
+#include "Globals.h"
+#include "X509Cert.h"
+
+
+
+
+
+cX509Cert::cX509Cert(void)
+{
+ mbedtls_x509_crt_init(&m_Cert);
+}
+
+
+
+
+
+cX509Cert::~cX509Cert()
+{
+ mbedtls_x509_crt_free(&m_Cert);
+}
+
+
+
+
+
+int cX509Cert::Parse(const void * a_CertContents, size_t a_Size)
+{
+ return mbedtls_x509_crt_parse(&m_Cert, reinterpret_cast<const unsigned char *>(a_CertContents), a_Size);
+}
+
+
+
+
diff --git a/src/mbedTLS++/X509Cert.h b/src/mbedTLS++/X509Cert.h
new file mode 100644
index 000000000..4234308ff
--- /dev/null
+++ b/src/mbedTLS++/X509Cert.h
@@ -0,0 +1,41 @@
+
+// X509Cert.h
+
+// Declares the cX509Cert class representing a wrapper over X509 certs in mbedTLS
+
+
+
+
+
+#pragma once
+
+#include "mbedtls/x509_crt.h"
+
+
+
+
+
+class cX509Cert
+{
+ friend class cSslConfig;
+
+public:
+ cX509Cert(void);
+ ~cX509Cert(void);
+
+ /** Parses the certificate chain data into the context.
+ Returns 0 on succes, or mbedTLS error code on failure. */
+ int Parse(const void * a_CertContents, size_t a_Size);
+
+protected:
+ mbedtls_x509_crt m_Cert;
+
+ /** Returns the internal cert ptr. Only use in mbedTLS API calls. */
+ mbedtls_x509_crt * GetInternal(void) { return &m_Cert; }
+} ;
+
+typedef std::shared_ptr<cX509Cert> cX509CertPtr;
+
+
+
+