summaryrefslogblamecommitdiffstats
path: root/src/Bindings/LuaTCPLink.cpp
blob: 466240a7d7350b6ba1d7d7fbf50bd805f4043fe0 (plain) (tree)
1
2
3
4
5
6
7
8






                                                                                                                
                            




 

                                                                 
 





 

                                                                                                      

                                           








                                               

                            
         
                              


                     







                                              






                                                                              
                                          

                            




                             
                                  








                                           

                            




                              
                                  








                                            

                            




                         
                                    








                                            

                                  




                              
                                   








                                             

                                  




                         
                                     







                                
                                                           

                                  
         





                                                    
                                 
         







                             
                                               

                                  
         





                                                    
                              








                     



















                                                                                         
                                                       




















                                                                                                                       



















































                                                                                                               









                                  


                                  
                                   
         
                                     






                                                                                

         
                                               
         

                                    
                 
                                      

                                       
         


                                                    

 



 

                                                                               
                             
                                                                                      





 

                                                
                             
                                                      







                                                                      
                             
                                                                           
 
                                                
                     

















                                                                      
                                                                             

                                       
         
                                                                


                       
                             
                                                                                    







                                      






                                                                                      
                             
                                                         
 
                                                
                     




 












                                                                                
                                                                      
















                                                  

                                                                                            


                                                            














                                                           


                                                            










                                                                                    




                                                                        








                                                             


                                                            



















                                                                           


                                                            


















                                                                                               


                                                            

























                                                                                                  

// LuaTCPLink.cpp

// Implements the cLuaTCPLink class representing a Lua wrapper for the cTCPLink class and the callbacks it needs

#include "Globals.h"
#include "LuaTCPLink.h"
#include "LuaServerHandle.h"





cLuaTCPLink::cLuaTCPLink(cLuaState::cTableRefPtr && a_Callbacks):
	m_Callbacks(std::move(a_Callbacks))
{
}





cLuaTCPLink::cLuaTCPLink(cLuaState::cTableRefPtr && a_Callbacks, cLuaServerHandleWPtr a_ServerHandle):
	m_Callbacks(std::move(a_Callbacks)),
	m_Server(std::move(a_ServerHandle))
{
}





cLuaTCPLink::~cLuaTCPLink()
{
	// If the link is still open, close it:
	auto link = m_Link;
	if (link != nullptr)
	{
		link->Close();
	}

	Terminated();
}





bool cLuaTCPLink::Send(const AString & a_Data)
{
	// If running in SSL mode, push the data into the SSL context instead:
	if (m_SslContext != nullptr)
	{
		m_SslContext->Send(a_Data);
		return true;
	}

	// Safely grab a copy of the link:
	auto link = m_Link;
	if (link == nullptr)
	{
		return false;
	}

	// Send the data:
	return link->Send(a_Data);
}





AString cLuaTCPLink::GetLocalIP(void) const
{
	// Safely grab a copy of the link:
	auto link = m_Link;
	if (link == nullptr)
	{
		return "";
	}

	// Get the IP address:
	return link->GetLocalIP();
}





UInt16 cLuaTCPLink::GetLocalPort(void) const
{
	// Safely grab a copy of the link:
	auto link = m_Link;
	if (link == nullptr)
	{
		return 0;
	}

	// Get the port:
	return link->GetLocalPort();
}





AString cLuaTCPLink::GetRemoteIP(void) const
{
	// Safely grab a copy of the link:
	cTCPLinkPtr link = m_Link;
	if (link == nullptr)
	{
		return "";
	}

	// Get the IP address:
	return link->GetRemoteIP();
}





UInt16 cLuaTCPLink::GetRemotePort(void) const
{
	// Safely grab a copy of the link:
	cTCPLinkPtr link = m_Link;
	if (link == nullptr)
	{
		return 0;
	}

	// Get the port:
	return link->GetRemotePort();
}





void cLuaTCPLink::Shutdown(void)
{
	// Safely grab a copy of the link and shut it down:
	cTCPLinkPtr link = m_Link;
	if (link != nullptr)
	{
		if (m_SslContext != nullptr)
		{
			m_SslContext->NotifyClose();
			m_SslContext->ResetSelf();
			m_SslContext.reset();
		}
		link->Shutdown();
	}
}





void cLuaTCPLink::Close(void)
{
	// If the link is still open, close it:
	cTCPLinkPtr link = m_Link;
	if (link != nullptr)
	{
		if (m_SslContext != nullptr)
		{
			m_SslContext->NotifyClose();
			m_SslContext->ResetSelf();
			m_SslContext.reset();
		}
		link->Close();
	}

	Terminated();
}





AString cLuaTCPLink::StartTLSClient(
	const AString & a_OwnCertData,
	const AString & a_OwnPrivKeyData,
	const AString & a_OwnPrivKeyPassword
)
{
	// Check preconditions:
	if (m_SslContext != nullptr)
	{
		return "TLS is already active on this link";
	}
	if (
		(a_OwnCertData.empty()  && !a_OwnPrivKeyData.empty()) ||
		(!a_OwnCertData.empty() && a_OwnPrivKeyData.empty())
	)
	{
		return "Either provide both the certificate and private key, or neither";
	}

	// Create the SSL context:
	m_SslContext.reset(new cLinkSslContext(*this));
	m_SslContext->Initialize(true);

	// Create the peer cert, if required:
	if (!a_OwnCertData.empty() && !a_OwnPrivKeyData.empty())
	{
		auto OwnCert = std::make_shared<cX509Cert>();
		int res = OwnCert->Parse(a_OwnCertData.data(), a_OwnCertData.size());
		if (res != 0)
		{
			m_SslContext.reset();
			return Printf("Cannot parse peer certificate: -0x%x", res);
		}
		auto OwnPrivKey = std::make_shared<cCryptoKey>();
		res = OwnPrivKey->ParsePrivate(a_OwnPrivKeyData.data(), a_OwnPrivKeyData.size(), a_OwnPrivKeyPassword);
		if (res != 0)
		{
			m_SslContext.reset();
			return Printf("Cannot parse peer private key: -0x%x", res);
		}
		m_SslContext->SetOwnCert(OwnCert, OwnPrivKey);
	}
	m_SslContext->SetSelf(cLinkSslContextWPtr(m_SslContext));

	// Start the handshake:
	m_SslContext->Handshake();
	return "";
}





AString cLuaTCPLink::StartTLSServer(
	const AString & a_OwnCertData,
	const AString & a_OwnPrivKeyData,
	const AString & a_OwnPrivKeyPassword,
	const AString & a_StartTLSData
)
{
	// Check preconditions:
	if (m_SslContext != nullptr)
	{
		return "TLS is already active on this link";
	}
	if (a_OwnCertData.empty()  || a_OwnPrivKeyData.empty())
	{
		return "Provide the server certificate and private key";
	}

	// Create the SSL context:
	m_SslContext.reset(new cLinkSslContext(*this));
	m_SslContext->Initialize(false);

	// Create the peer cert:
	auto OwnCert = std::make_shared<cX509Cert>();
	int res = OwnCert->Parse(a_OwnCertData.data(), a_OwnCertData.size());
	if (res != 0)
	{
		m_SslContext.reset();
		return Printf("Cannot parse server certificate: -0x%x", res);
	}
	auto OwnPrivKey = std::make_shared<cCryptoKey>();
	res = OwnPrivKey->ParsePrivate(a_OwnPrivKeyData.data(), a_OwnPrivKeyData.size(), a_OwnPrivKeyPassword);
	if (res != 0)
	{
		m_SslContext.reset();
		return Printf("Cannot parse server private key: -0x%x", res);
	}
	m_SslContext->SetOwnCert(OwnCert, OwnPrivKey);
	m_SslContext->SetSelf(cLinkSslContextWPtr(m_SslContext));

	// Push the initial data:
	m_SslContext->StoreReceivedData(a_StartTLSData.data(), a_StartTLSData.size());

	// Start the handshake:
	m_SslContext->Handshake();
	return "";
}





void cLuaTCPLink::Terminated(void)
{
	// Disable the callbacks:
	if (m_Callbacks->IsValid())
	{
		m_Callbacks->Clear();
	}

	// If the managing server is still alive, let it know we're terminating:
	auto Server = m_Server.lock();
	if (Server != nullptr)
	{
		Server->RemoveLink(this);
	}

	// If the link is still open, close it:
	{
		auto link= m_Link;
		if (link != nullptr)
		{
			link->Close();
			m_Link.reset();
		}
	}

	// If the SSL context still exists, free it:
	m_SslContext.reset();
}





void cLuaTCPLink::ReceivedCleartextData(const char * a_Data, size_t a_NumBytes)
{
	// Call the callback:
	m_Callbacks->CallTableFn("OnReceivedData", this, AString(a_Data, a_NumBytes));
}





void cLuaTCPLink::OnConnected(cTCPLink & a_Link)
{
	// Call the callback:
	m_Callbacks->CallTableFn("OnConnected", this);
}





void cLuaTCPLink::OnError(int a_ErrorCode, const AString & a_ErrorMsg)
{
	// Call the callback:
	m_Callbacks->CallTableFn("OnError", this, a_ErrorCode, a_ErrorMsg);

	// Terminate all processing on the link:
	Terminated();
}





void cLuaTCPLink::OnLinkCreated(cTCPLinkPtr a_Link)
{
	// Store the cTCPLink for later use:
	m_Link = a_Link;
}





void cLuaTCPLink::OnReceivedData(const char * a_Data, size_t a_Length)
{
	// If we're running in SSL mode, put the data into the SSL decryptor:
	auto sslContext = m_SslContext;
	if (sslContext != nullptr)
	{
		sslContext->StoreReceivedData(a_Data, a_Length);
		return;
	}

	// Call the callback:
	m_Callbacks->CallTableFn("OnReceivedData", this, AString(a_Data, a_Length));
}





void cLuaTCPLink::OnRemoteClosed(void)
{
	// If running in SSL mode and there's data left in the SSL contect, report it:
	auto sslContext = m_SslContext;
	if (sslContext != nullptr)
	{
		sslContext->FlushBuffers();
	}

	// Call the callback:
	m_Callbacks->CallTableFn("OnRemoteClosed", this);

	// Terminate all processing on the link:
	Terminated();
}





////////////////////////////////////////////////////////////////////////////////
// cLuaTCPLink::cLinkSslContext:

cLuaTCPLink::cLinkSslContext::cLinkSslContext(cLuaTCPLink & a_Link):
	m_Link(a_Link)
{
}





void cLuaTCPLink::cLinkSslContext::SetSelf(cLinkSslContextWPtr a_Self)
{
	m_Self = a_Self;
}





void cLuaTCPLink::cLinkSslContext::ResetSelf(void)
{
	m_Self.reset();
}





void cLuaTCPLink::cLinkSslContext::StoreReceivedData(const char * a_Data, size_t a_NumBytes)
{
	// Hold self alive for the duration of this function
	cLinkSslContextPtr Self(m_Self);

	m_EncryptedData.append(a_Data, a_NumBytes);

	// Try to finish a pending handshake:
	TryFinishHandshaking();

	// Flush any cleartext data that can be "received":
	FlushBuffers();
}





void cLuaTCPLink::cLinkSslContext::FlushBuffers(void)
{
	// Hold self alive for the duration of this function
	cLinkSslContextPtr Self(m_Self);

	// If the handshake didn't complete yet, bail out:
	if (!HasHandshaken())
	{
		return;
	}

	char Buffer[1024];
	int NumBytes;
	while ((NumBytes = ReadPlain(Buffer, sizeof(Buffer))) > 0)
	{
		m_Link.ReceivedCleartextData(Buffer, static_cast<size_t>(NumBytes));
		if (m_Self.expired())
		{
			// The callback closed the SSL context, bail out
			return;
		}
	}
}





void cLuaTCPLink::cLinkSslContext::TryFinishHandshaking(void)
{
	// Hold self alive for the duration of this function
	cLinkSslContextPtr Self(m_Self);

	// If the handshake hasn't finished yet, retry:
	if (!HasHandshaken())
	{
		Handshake();
	}

	// If the handshake succeeded, write all the queued plaintext data:
	if (HasHandshaken())
	{
		WritePlain(m_CleartextData.data(), m_CleartextData.size());
		m_CleartextData.clear();
	}
}





void cLuaTCPLink::cLinkSslContext::Send(const AString & a_Data)
{
	// Hold self alive for the duration of this function
	cLinkSslContextPtr Self(m_Self);

	// If the handshake hasn't completed yet, queue the data:
	if (!HasHandshaken())
	{
		m_CleartextData.append(a_Data);
		TryFinishHandshaking();
		return;
	}

	// The connection is all set up, write the cleartext data into the SSL context:
	WritePlain(a_Data.data(), a_Data.size());
	FlushBuffers();
}





int cLuaTCPLink::cLinkSslContext::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
{
	// Hold self alive for the duration of this function
	cLinkSslContextPtr Self(m_Self);

	// If there's nothing queued in the buffer, report empty buffer:
	if (m_EncryptedData.empty())
	{
		return POLARSSL_ERR_NET_WANT_READ;
	}

	// Copy as much data as possible to the provided buffer:
	size_t BytesToCopy = std::min(a_NumBytes, m_EncryptedData.size());
	memcpy(a_Buffer, m_EncryptedData.data(), BytesToCopy);
	m_EncryptedData.erase(0, BytesToCopy);
	return static_cast<int>(BytesToCopy);
}





int cLuaTCPLink::cLinkSslContext::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
{
	m_Link.m_Link->Send(a_Buffer, a_NumBytes);
	return static_cast<int>(a_NumBytes);
}