From 2d58789d66f1b63ad63304584c7ac43284b540b8 Mon Sep 17 00:00:00 2001 From: Mattes D Date: Wed, 6 Jul 2016 20:52:04 +0200 Subject: Converted cLuaState::cTableRef to use cTrackedRef. This makes the table-based callbacks resistent to LuaState unloads and safer to use. --- src/Bindings/LuaNameLookup.cpp | 50 +------- src/Bindings/LuaNameLookup.h | 11 +- src/Bindings/LuaServerHandle.cpp | 54 ++------- src/Bindings/LuaServerHandle.h | 11 +- src/Bindings/LuaState.cpp | 161 +++++++++++++++++++------ src/Bindings/LuaState.h | 201 ++++++++++++++++++++++---------- src/Bindings/LuaTCPLink.cpp | 146 +++++++---------------- src/Bindings/LuaTCPLink.h | 13 +-- src/Bindings/LuaUDPEndpoint.cpp | 89 +++++--------- src/Bindings/LuaUDPEndpoint.h | 11 +- src/Bindings/ManualBindings.cpp | 52 +++++---- src/Bindings/ManualBindings_Network.cpp | 121 ++++++++++--------- 12 files changed, 457 insertions(+), 463 deletions(-) diff --git a/src/Bindings/LuaNameLookup.cpp b/src/Bindings/LuaNameLookup.cpp index 3cbdbb5cf..3f55e7bc8 100644 --- a/src/Bindings/LuaNameLookup.cpp +++ b/src/Bindings/LuaNameLookup.cpp @@ -10,9 +10,8 @@ -cLuaNameLookup::cLuaNameLookup(const AString & a_Query, cPluginLua & a_Plugin, int a_CallbacksTableStackPos): - m_Plugin(a_Plugin), - m_Callbacks(cPluginLua::cOperation(a_Plugin)(), a_CallbacksTableStackPos), +cLuaNameLookup::cLuaNameLookup(const AString & a_Query, cLuaState::cTableRefPtr && a_Callbacks): + m_Callbacks(std::move(a_Callbacks)), m_Query(a_Query) { } @@ -23,20 +22,7 @@ cLuaNameLookup::cLuaNameLookup(const AString & a_Query, cPluginLua & a_Plugin, i void cLuaNameLookup::OnNameResolved(const AString & a_Name, const AString & a_IP) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnNameResolved"), a_Name, a_IP)) - { - LOGINFO("cNetwork name lookup OnNameResolved callback failed in plugin %s looking up %s. %s resolves to %s.", - m_Plugin.GetName().c_str(), m_Query.c_str(), a_Name.c_str(), a_IP.c_str() - ); - } + m_Callbacks->CallTableFn("OnNameResolved", a_Name, a_IP); } @@ -45,20 +31,7 @@ void cLuaNameLookup::OnNameResolved(const AString & a_Name, const AString & a_IP void cLuaNameLookup::OnError(int a_ErrorCode, const AString & a_ErrorMsg) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnError"), m_Query, a_ErrorCode, a_ErrorMsg)) - { - LOGINFO("cNetwork name lookup OnError callback failed in plugin %s looking up %s. The error is %d (%s)", - m_Plugin.GetName().c_str(), m_Query.c_str(), a_ErrorCode, a_ErrorMsg.c_str() - ); - } + m_Callbacks->CallTableFn("OnError", m_Query, a_ErrorCode, a_ErrorMsg); } @@ -67,20 +40,7 @@ void cLuaNameLookup::OnError(int a_ErrorCode, const AString & a_ErrorMsg) void cLuaNameLookup::OnFinished(void) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnFinished"), m_Query)) - { - LOGINFO("cNetwork name lookup OnFinished callback failed in plugin %s, looking up %s.", - m_Plugin.GetName().c_str(), m_Query.c_str() - ); - } + m_Callbacks->CallTableFn("OnFinished", m_Query); } diff --git a/src/Bindings/LuaNameLookup.h b/src/Bindings/LuaNameLookup.h index e4cdb9f53..0eef108c8 100644 --- a/src/Bindings/LuaNameLookup.h +++ b/src/Bindings/LuaNameLookup.h @@ -10,7 +10,7 @@ #pragma once #include "../OSSupport/Network.h" -#include "PluginLua.h" +#include "LuaState.h" @@ -21,15 +21,12 @@ class cLuaNameLookup: { public: /** Creates a new instance of the lookup callbacks for the specified query, - attached to the specified lua plugin and wrapping the callbacks that are in a table at the specified stack pos. */ - cLuaNameLookup(const AString & a_Query, cPluginLua & a_Plugin, int a_CallbacksTableStackPos); + using the callbacks that are in the specified table. */ + cLuaNameLookup(const AString & a_Query, cLuaState::cTableRefPtr && a_Callbacks); protected: - /** The plugin for which the query is created. */ - cPluginLua & m_Plugin; - /** The Lua table that holds the callbacks to be invoked. */ - cLuaState::cRef m_Callbacks; + cLuaState::cTableRefPtr m_Callbacks; /** The query used to start the lookup (either hostname or IP). */ AString m_Query; diff --git a/src/Bindings/LuaServerHandle.cpp b/src/Bindings/LuaServerHandle.cpp index 9cc8ad350..d32b8fe03 100644 --- a/src/Bindings/LuaServerHandle.cpp +++ b/src/Bindings/LuaServerHandle.cpp @@ -12,9 +12,8 @@ -cLuaServerHandle::cLuaServerHandle(UInt16 a_Port, cPluginLua & a_Plugin, int a_CallbacksTableStackPos): - m_Plugin(a_Plugin), - m_Callbacks(cPluginLua::cOperation(a_Plugin)(), a_CallbacksTableStackPos), +cLuaServerHandle::cLuaServerHandle(UInt16 a_Port, cLuaState::cTableRefPtr && a_Callbacks): + m_Callbacks(std::move(a_Callbacks)), m_Port(a_Port) { } @@ -127,28 +126,19 @@ void cLuaServerHandle::Release(void) cTCPLink::cCallbacksPtr cLuaServerHandle::OnIncomingConnection(const AString & a_RemoteIPAddress, UInt16 a_RemotePort) { - // If not valid anymore, drop the connection: - if (!m_Callbacks.IsValid()) - { - return nullptr; - } - // Ask the plugin for link callbacks: - cPluginLua::cOperation Op(m_Plugin); - cLuaState::cRef LinkCallbacks; + cLuaState::cTableRefPtr LinkCallbacks; if ( - !Op().Call(cLuaState::cTableRef(m_Callbacks, "OnIncomingConnection"), a_RemoteIPAddress, a_RemotePort, m_Port, cLuaState::Return, LinkCallbacks) || - !LinkCallbacks.IsValid() + !m_Callbacks->CallTableFn("OnIncomingConnection", a_RemoteIPAddress, a_RemotePort, m_Port, cLuaState::Return, LinkCallbacks) || + !LinkCallbacks->IsValid() ) { - LOGINFO("cNetwork server (port %d) OnIncomingConnection callback failed in plugin %s. Dropping connection.", - m_Port, m_Plugin.GetName().c_str() - ); + LOGINFO("cNetwork server (port %d) OnIncomingConnection callback failed. Dropping connection.", m_Port); return nullptr; } // Create the link wrapper to use with the callbacks: - auto res = std::make_shared(m_Plugin, std::move(LinkCallbacks), m_Self); + auto res = std::make_shared(std::move(LinkCallbacks), m_Self); // Add the link to the list of our connections: cCSLock Lock(m_CSConnections); @@ -163,21 +153,8 @@ cTCPLink::cCallbacksPtr cLuaServerHandle::OnIncomingConnection(const AString & a void cLuaServerHandle::OnAccepted(cTCPLink & a_Link) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - // Notify the plugin: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnAccepted"), static_cast(a_Link.GetCallbacks().get()))) - { - LOGINFO("cNetwork server (port %d) OnAccepted callback failed in plugin %s, connection to %s:%d.", - m_Port, m_Plugin.GetName().c_str(), a_Link.GetRemoteIP().c_str(), a_Link.GetRemotePort() - ); - return; - } + m_Callbacks->CallTableFn("OnAccepted", static_cast(a_Link.GetCallbacks().get())); } @@ -186,21 +163,8 @@ void cLuaServerHandle::OnAccepted(cTCPLink & a_Link) void cLuaServerHandle::OnError(int a_ErrorCode, const AString & a_ErrorMsg) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - // Notify the plugin: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnError"), a_ErrorCode, a_ErrorMsg)) - { - LOGINFO("cNetwork server (port %d) OnError callback failed in plugin %s. The error is %d (%s).", - m_Port, m_Plugin.GetName().c_str(), a_ErrorCode, a_ErrorMsg.c_str() - ); - return; - } + m_Callbacks->CallTableFn("OnError", a_ErrorCode, a_ErrorMsg); } diff --git a/src/Bindings/LuaServerHandle.h b/src/Bindings/LuaServerHandle.h index 9325bca3e..7632669ff 100644 --- a/src/Bindings/LuaServerHandle.h +++ b/src/Bindings/LuaServerHandle.h @@ -10,7 +10,7 @@ #pragma once #include "../OSSupport/Network.h" -#include "PluginLua.h" +#include "LuaState.h" @@ -31,8 +31,8 @@ class cLuaServerHandle: { public: /** Creates a new instance of the server handle, - attached to the specified lua plugin and wrapping the (listen-) callbacks that are in a table at the specified stack pos. */ - cLuaServerHandle(UInt16 a_Port, cPluginLua & a_Plugin, int a_CallbacksTableStackPos); + wrapping the (listen-) callbacks that are in the specified table. */ + cLuaServerHandle(UInt16 a_Port, cLuaState::cTableRefPtr && a_Callbacks); ~cLuaServerHandle(); @@ -54,11 +54,8 @@ public: void Release(void); protected: - /** The plugin for which the server is created. */ - cPluginLua & m_Plugin; - /** The Lua table that holds the callbacks to be invoked. */ - cLuaState::cRef m_Callbacks; + cLuaState::cTableRefPtr m_Callbacks; /** The port on which the server is listening. Used mainly for better error reporting. */ diff --git a/src/Bindings/LuaState.cpp b/src/Bindings/LuaState.cpp index bc9447cc2..e6a94091e 100644 --- a/src/Bindings/LuaState.cpp +++ b/src/Bindings/LuaState.cpp @@ -122,9 +122,9 @@ cLuaStateTracker & cLuaStateTracker::Get(void) //////////////////////////////////////////////////////////////////////////////// -// cLuaState::cCallback: +// cLuaState::cTrackedRef: -cLuaState::cCallback::cCallback(void): +cLuaState::cTrackedRef::cTrackedRef(void): m_CS(nullptr) { } @@ -133,20 +133,14 @@ cLuaState::cCallback::cCallback(void): -bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos) +bool cLuaState::cTrackedRef::RefStack(cLuaState & a_LuaState, int a_StackPos) { - // Check if the stack contains a function: - if (!lua_isfunction(a_LuaState, a_StackPos)) - { - return false; - } - // Clear any previous callback: Clear(); // Add self to LuaState's callback-tracking: auto canonState = a_LuaState.QueryCanonLuaState(); - canonState->TrackCallback(*this); + canonState->TrackRef(*this); // Store the new callback: m_CS = &(canonState->m_CS); @@ -158,9 +152,9 @@ bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos) -void cLuaState::cCallback::Clear(void) +void cLuaState::cTrackedRef::Clear(void) { - // Free the callback reference: + // Free the reference: lua_State * luaState = nullptr; { auto cs = m_CS; @@ -175,20 +169,21 @@ void cLuaState::cCallback::Clear(void) m_Ref.UnRef(); } } + m_CS = nullptr; // Remove from LuaState's callback-tracking: if (luaState == nullptr) { return; } - cLuaState(luaState).UntrackCallback(*this); + cLuaState(luaState).UntrackRef(*this); } -bool cLuaState::cCallback::IsValid(void) +bool cLuaState::cTrackedRef::IsValid(void) { auto cs = m_CS; if (cs == nullptr) @@ -203,7 +198,7 @@ bool cLuaState::cCallback::IsValid(void) -bool cLuaState::cCallback::IsSameLuaState(cLuaState & a_LuaState) +bool cLuaState::cTrackedRef::IsSameLuaState(cLuaState & a_LuaState) { auto cs = m_CS; if (cs == nullptr) @@ -227,7 +222,7 @@ bool cLuaState::cCallback::IsSameLuaState(cLuaState & a_LuaState) -void cLuaState::cCallback::Invalidate(void) +void cLuaState::cTrackedRef::Invalidate(void) { auto cs = m_CS; if (cs == nullptr) @@ -244,6 +239,43 @@ void cLuaState::cCallback::Invalidate(void) return; } m_Ref.UnRef(); + m_CS = nullptr; +} + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cLuaState::cCallback: + +bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos) +{ + // Check if the stack contains a function: + if (!lua_isfunction(a_LuaState, a_StackPos)) + { + return false; + } + + return Super::RefStack(a_LuaState, a_StackPos); +} + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cLuaState::cTableRef: + +bool cLuaState::cTableRef::RefStack(cLuaState & a_LuaState, int a_StackPos) +{ + // Check if the stack contains a table: + if (!lua_istable(a_LuaState, a_StackPos)) + { + return false; + } + + return Super::RefStack(a_LuaState, a_StackPos); } @@ -365,12 +397,12 @@ void cLuaState::Close(void) return; } - // Invalidate all callbacks: + // Invalidate all tracked refs: { - cCSLock Lock(m_CSTrackedCallbacks); - for (auto & c: m_TrackedCallbacks) + cCSLock Lock(m_CSTrackedRefs); + for (auto & r: m_TrackedRefs) { - c->Invalidate(); + r->Invalidate(); } } @@ -592,22 +624,28 @@ bool cLuaState::PushFunction(const cRef & a_FnRef) -bool cLuaState::PushFunction(const cTableRef & a_TableRef) +bool cLuaState::PushFunction(const cRef & a_TableRef, const char * a_FnName) { ASSERT(IsValid()); ASSERT(m_NumCurrentFunctionArgs == -1); // If not, there's already something pushed onto the stack + if (!a_TableRef.IsValid()) + { + return false; + } + // Push the error handler for lua_pcall() lua_pushcfunction(m_LuaState, &ReportFnCallErrors); - lua_rawgeti(m_LuaState, LUA_REGISTRYINDEX, a_TableRef.GetTableRef()); // Get the table ref + // Get the function from the table: + lua_rawgeti(m_LuaState, LUA_REGISTRYINDEX, static_cast(a_TableRef)); if (!lua_istable(m_LuaState, -1)) { // Not a table, bail out lua_pop(m_LuaState, 2); return false; } - lua_getfield(m_LuaState, -1, a_TableRef.GetFnName()); + lua_getfield(m_LuaState, -1, a_FnName); if (lua_isnil(m_LuaState, -1) || !lua_isfunction(m_LuaState, -1)) { // Not a valid function, bail out @@ -618,7 +656,7 @@ bool cLuaState::PushFunction(const cTableRef & a_TableRef) // Pop the table off the stack: lua_remove(m_LuaState, -2); - Printf(m_CurrentFunctionName, "", a_TableRef.GetFnName()); + Printf(m_CurrentFunctionName, "", a_FnName); m_NumCurrentFunctionArgs = 0; return true; } @@ -1061,6 +1099,28 @@ bool cLuaState::GetStackValue(int a_StackPos, cCallbackSharedPtr & a_Callback) +bool cLuaState::GetStackValue(int a_StackPos, cTableRef & a_TableRef) +{ + return a_TableRef.RefStack(*this, a_StackPos); +} + + + + + +bool cLuaState::GetStackValue(int a_StackPos, cTableRefPtr & a_TableRef) +{ + if (a_TableRef == nullptr) + { + a_TableRef = cpp14::make_unique(); + } + return a_TableRef->RefStack(*this, a_StackPos); +} + + + + + bool cLuaState::GetStackValue(int a_StackPos, cPluginManager::CommandResult & a_Result) { if (lua_isnumber(m_LuaState, a_StackPos)) @@ -1085,6 +1145,41 @@ bool cLuaState::GetStackValue(int a_StackPos, cRef & a_Ref) +bool cLuaState::GetStackValue(int a_StackPos, cTrackedRef & a_Ref) +{ + return a_Ref.RefStack(*this, a_StackPos); +} + + + + + +bool cLuaState::GetStackValue(int a_StackPos, cTrackedRefPtr & a_Ref) +{ + if (a_Ref == nullptr) + { + a_Ref = cpp14::make_unique(); + } + return a_Ref->RefStack(*this, a_StackPos); +} + + + + + +bool cLuaState::GetStackValue(int a_StackPos, cTrackedRefSharedPtr & a_Ref) +{ + if (a_Ref == nullptr) + { + a_Ref = std::make_shared(); + } + return a_Ref->RefStack(*this, a_StackPos); +} + + + + + bool cLuaState::GetStackValue(int a_StackPos, double & a_ReturnedVal) { if (lua_isnumber(m_LuaState, a_StackPos)) @@ -1930,7 +2025,7 @@ int cLuaState::BreakIntoDebugger(lua_State * a_LuaState) -void cLuaState::TrackCallback(cCallback & a_Callback) +void cLuaState::TrackRef(cTrackedRef & a_Ref) { // Get the CanonLuaState global from Lua: auto canonState = QueryCanonLuaState(); @@ -1941,15 +2036,15 @@ void cLuaState::TrackCallback(cCallback & a_Callback) } // Add the callback: - cCSLock Lock(canonState->m_CSTrackedCallbacks); - canonState->m_TrackedCallbacks.push_back(&a_Callback); + cCSLock Lock(canonState->m_CSTrackedRefs); + canonState->m_TrackedRefs.push_back(&a_Ref); } -void cLuaState::UntrackCallback(cCallback & a_Callback) +void cLuaState::UntrackRef(cTrackedRef & a_Ref) { // Get the CanonLuaState global from Lua: auto canonState = QueryCanonLuaState(); @@ -1960,12 +2055,12 @@ void cLuaState::UntrackCallback(cCallback & a_Callback) } // Remove the callback: - cCSLock Lock(canonState->m_CSTrackedCallbacks); - auto & trackedCallbacks = canonState->m_TrackedCallbacks; - trackedCallbacks.erase(std::remove_if(trackedCallbacks.begin(), trackedCallbacks.end(), - [&a_Callback](cCallback * a_StoredCallback) + cCSLock Lock(canonState->m_CSTrackedRefs); + auto & trackedRefs = canonState->m_TrackedRefs; + trackedRefs.erase(std::remove_if(trackedRefs.begin(), trackedRefs.end(), + [&a_Ref](cTrackedRef * a_StoredRef) { - return (a_StoredCallback == &a_Callback); + return (a_StoredRef == &a_Ref); } )); } diff --git a/src/Bindings/LuaState.h b/src/Bindings/LuaState.h index c34feca9d..bc88fbf1b 100644 --- a/src/Bindings/LuaState.h +++ b/src/Bindings/LuaState.h @@ -120,46 +120,84 @@ public: } ; - /** Used for calling functions stored in a reference-stored table */ - class cTableRef + /** Represents a reference to a Lua object that has a tracked lifetime - + - when the Lua state to which it belongs is closed, the object is kept alive, but invalidated. + Is thread-safe and unload-safe. + To receive the cTrackedRef instance from the Lua side, use RefStack() or (better) cLuaState::GetStackValue(). + Note that instances of this class are tracked in the canon LuaState instance, so that + they can be invalidated when the LuaState is unloaded; due to multithreading issues they can only be tracked + by-ptr, which has an unfortunate effect of disabling the copy and move constructors. */ + class cTrackedRef { - int m_TableRef; - const char * m_FnName; + friend class ::cLuaState; public: - cTableRef(int a_TableRef, const char * a_FnName) : - m_TableRef(a_TableRef), - m_FnName(a_FnName) - { - } + /** Creates an unbound ref instance. */ + cTrackedRef(void); - cTableRef(const cRef & a_TableRef, const char * a_FnName) : - m_TableRef(static_cast(a_TableRef)), - m_FnName(a_FnName) + ~cTrackedRef() { + Clear(); } - int GetTableRef(void) const { return m_TableRef; } - const char * GetFnName(void) const { return m_FnName; } - } ; + /** Set the contained reference to the object at the specified Lua state's stack position. + If another reference has been previously contained, it is freed first. */ + bool RefStack(cLuaState & a_LuaState, int a_StackPos); + + /** Frees the contained reference, if any. */ + void Clear(void); + + /** Returns true if the contained reference is valid. */ + bool IsValid(void); + + /** Returns true if the reference resides in the specified Lua state. + Internally, compares the reference's canon Lua state. */ + bool IsSameLuaState(cLuaState & a_LuaState); + + protected: + friend class cLuaState; + + /** The mutex protecting m_Ref against multithreaded access */ + cCriticalSection * m_CS; + + /** Reference to the Lua callback */ + cRef m_Ref; + + /** Invalidates the callback, without untracking it from the cLuaState. + Called only from cLuaState when closing the Lua state. */ + void Invalidate(void); + + /** Returns the internal reference. + Only to be used when the cLuaState's CS is held and the cLuaState is known to be valid. */ + cRef & GetRef() { return m_Ref; } + + /** This class cannot be copied, because it is tracked in the LuaState by-ptr. + Use a smart pointer for a copyable object. */ + cTrackedRef(const cTrackedRef &) = delete; - /** Represents a callback to Lua that C++ code can call. + /** This class cannot be moved, because it is tracked in the LuaState by-ptr. + Use a smart pointer for a copyable object. */ + cTrackedRef(cTrackedRef &&) = delete; + }; + typedef UniquePtr cTrackedRefPtr; + typedef SharedPtr cTrackedRefSharedPtr; + + + /** Represents a stored callback to Lua that C++ code can call. Is thread-safe and unload-safe. When the Lua state is unloaded, the callback returns an error instead of calling into non-existent code. To receive the callback instance from the Lua side, use RefStack() or (better) cLuaState::GetStackValue() with a cCallbackPtr. Note that instances of this class are tracked in the canon LuaState instance, so that they can be invalidated when the LuaState is unloaded; due to multithreading issues they can only be tracked by-ptr, which has an unfortunate effect of disabling the copy and move constructors. */ - class cCallback + class cCallback: + public cTrackedRef { + typedef cTrackedRef Super; + public: - /** Creates an unbound callback instance. */ - cCallback(void); - ~cCallback() - { - Clear(); - } + cCallback(void) {} /** Calls the Lua callback, if still available. Returns true if callback has been called. @@ -181,32 +219,11 @@ public: } /** Set the contained callback to the function in the specified Lua state's stack position. - If a callback has been previously contained, it is freed first. */ + If a callback has been previously contained, it is unreferenced first. + Returns true on success, false on failure (not a function at the specified stack pos). */ bool RefStack(cLuaState & a_LuaState, int a_StackPos); - /** Frees the contained callback, if any. */ - void Clear(void); - - /** Returns true if the contained callback is valid. */ - bool IsValid(void); - - /** Returns true if the callback resides in the specified Lua state. - Internally, compares the callback's canon Lua state. */ - bool IsSameLuaState(cLuaState & a_LuaState); - protected: - friend class cLuaState; - - /** The mutex protecting m_Ref against multithreaded access */ - cCriticalSection * m_CS; - - /** Reference to the Lua callback */ - cRef m_Ref; - - - /** Invalidates the callback, without untracking it from the cLuaState. - Called only from cLuaState when closing the Lua state. */ - void Invalidate(void); /** This class cannot be copied, because it is tracked in the LuaState by-ptr. Use cCallbackPtr for a copyable object. */ @@ -220,6 +237,47 @@ public: typedef SharedPtr cCallbackSharedPtr; + /** Represents a stored Lua table with callback functions that C++ code can call. + Is thread-safe and unload-safe. + When Lua state is unloaded, the CallFn() will return failure instead of calling into non-existent code. + To receive the callback instance from the Lua side, use RefStack() or (better) cLuaState::GetStackValue() + with a cCallbackPtr. Note that instances of this class are tracked in the canon LuaState instance, so that + they can be invalidated when the LuaState is unloaded; due to multithreading issues they can only be tracked + by-ptr, which has an unfortunate effect of disabling the copy and move constructors. */ + class cTableRef: + public cTrackedRef + { + typedef cTrackedRef Super; + public: + cTableRef(void) {} + + /** Calls the Lua function stored under the specified name in the referenced table, if still available. + Returns true if callback has been called. + Returns false if the Lua state isn't valid anymore, or the function doesn't exist. */ + template + bool CallTableFn(const char * a_FnName, Args &&... args) + { + auto cs = m_CS; + if (cs == nullptr) + { + return false; + } + cCSLock Lock(*cs); + if (!m_Ref.IsValid()) + { + return false; + } + return cLuaState(m_Ref.GetLuaState()).CallTableFn(m_Ref, a_FnName, std::forward(args)...); + } + + /** Set the contained reference to the table in the specified Lua state's stack position. + If another table has been previously contained, it is unreferenced first. + Returns true on success, false on failure (not a table at the specified stack pos). */ + bool RefStack(cLuaState & a_LuaState, int a_StackPos); + }; + typedef UniquePtr cTableRefPtr; + + /** A dummy class that's used only to delimit function args from return values for cLuaState::Call() */ class cRet { @@ -381,8 +439,13 @@ public: bool GetStackValue(int a_StackPos, cCallback & a_Callback); bool GetStackValue(int a_StackPos, cCallbackPtr & a_Callback); bool GetStackValue(int a_StackPos, cCallbackSharedPtr & a_Callback); + bool GetStackValue(int a_StackPos, cTableRef & a_TableRef); + bool GetStackValue(int a_StackPos, cTableRefPtr & a_TableRef); bool GetStackValue(int a_StackPos, cPluginManager::CommandResult & a_Result); bool GetStackValue(int a_StackPos, cRef & a_Ref); + bool GetStackValue(int a_StackPos, cTrackedRef & a_Ref); + bool GetStackValue(int a_StackPos, cTrackedRefPtr & a_Ref); + bool GetStackValue(int a_StackPos, cTrackedRefSharedPtr & a_Ref); bool GetStackValue(int a_StackPos, double & a_Value); bool GetStackValue(int a_StackPos, eBlockFace & a_Value); bool GetStackValue(int a_StackPos, eWeather & a_Value); @@ -583,15 +646,30 @@ protected: /** Number of arguments currently pushed (for the Push / Call chain) */ int m_NumCurrentFunctionArgs; - /** The tracked callbacks. - This object will invalidate all of these when it is about to be closed. - Protected against multithreaded access by m_CSTrackedCallbacks. */ - std::vector m_TrackedCallbacks; + /** The tracked references. + The cLuaState will invalidate all of these when it is about to be closed. + Protected against multithreaded access by m_CSTrackedRefs. */ + std::vector m_TrackedRefs; - /** Protects m_TrackedTallbacks against multithreaded access. */ - cCriticalSection m_CSTrackedCallbacks; + /** Protects m_TrackedRefs against multithreaded access. */ + cCriticalSection m_CSTrackedRefs; + /** Call the Lua function specified by name in the table stored as a reference. + Returns true if call succeeded, false if there was an error (not a table ref, function name not found). + A special param of cRet & signifies the end of param list and the start of return values. + Example call: CallTableFn(TableRef, "FnName", Param1, Param2, Param3, cLuaState::Return, Ret1, Ret2) */ + template + bool CallTableFn(const cRef & a_TableRef, const char * a_FnName, Args &&... args) + { + if (!PushFunction(a_TableRef, a_FnName)) + { + // Pushing the function failed + return false; + } + return PushCallPop(std::forward(args)...); + } + /** Variadic template terminator: If there's nothing more to push / pop, just call the function. Note that there are no return values either, because those are prefixed by a cRet value, so the arg list is never empty. */ bool PushCallPop(void) @@ -646,10 +724,9 @@ protected: */ bool PushFunction(const cRef & a_FnRef); - /** Pushes a function that is stored in a referenced table by name - Returns true if successful. Logs a warning on failure - */ - bool PushFunction(const cTableRef & a_TableRef); + /** Pushes a function that is stored under the specified name in a table that has been saved as a reference. + Returns true if successful. */ + bool PushFunction(const cRef & a_TableRef, const char * a_FnName); /** Pushes a usertype of the specified class type onto the stack */ // void PushUserType(void * a_Object, const char * a_Type); @@ -667,13 +744,13 @@ protected: /** Tries to break into the MobDebug debugger, if it is installed. */ static int BreakIntoDebugger(lua_State * a_LuaState); - /** Adds the specified callback to tracking. - The callback will be invalidated when this Lua state is about to be closed. */ - void TrackCallback(cCallback & a_Callback); + /** Adds the specified reference to tracking. + The reference will be invalidated when this Lua state is about to be closed. */ + void TrackRef(cTrackedRef & a_Ref); - /** Removes the specified callback from tracking. - The callback will no longer be invalidated when this Lua state is about to be closed. */ - void UntrackCallback(cCallback & a_Callback); + /** Removes the specified reference from tracking. + The reference will no longer be invalidated when this Lua state is about to be closed. */ + void UntrackRef(cTrackedRef & a_Ref); } ; diff --git a/src/Bindings/LuaTCPLink.cpp b/src/Bindings/LuaTCPLink.cpp index 4b04a1c02..466240a7d 100644 --- a/src/Bindings/LuaTCPLink.cpp +++ b/src/Bindings/LuaTCPLink.cpp @@ -11,35 +11,19 @@ -cLuaTCPLink::cLuaTCPLink(cPluginLua & a_Plugin, int a_CallbacksTableStackPos): - m_Plugin(a_Plugin), - m_Callbacks(cPluginLua::cOperation(a_Plugin)(), a_CallbacksTableStackPos) +cLuaTCPLink::cLuaTCPLink(cLuaState::cTableRefPtr && a_Callbacks): + m_Callbacks(std::move(a_Callbacks)) { - // Warn if the callbacks aren't valid: - if (!m_Callbacks.IsValid()) - { - LOGWARNING("cTCPLink in plugin %s: callbacks could not be retrieved", m_Plugin.GetName().c_str()); - cPluginLua::cOperation Op(m_Plugin); - Op().LogStackTrace(); - } } -cLuaTCPLink::cLuaTCPLink(cPluginLua & a_Plugin, cLuaState::cRef && a_CallbacksTableRef, cLuaServerHandleWPtr a_ServerHandle): - m_Plugin(a_Plugin), - m_Callbacks(std::move(a_CallbacksTableRef)), +cLuaTCPLink::cLuaTCPLink(cLuaState::cTableRefPtr && a_Callbacks, cLuaServerHandleWPtr a_ServerHandle): + m_Callbacks(std::move(a_Callbacks)), m_Server(std::move(a_ServerHandle)) { - // Warn if the callbacks aren't valid: - if (!m_Callbacks.IsValid()) - { - LOGWARNING("cTCPLink in plugin %s: callbacks could not be retrieved", m_Plugin.GetName().c_str()); - cPluginLua::cOperation Op(m_Plugin); - Op().LogStackTrace(); - } } @@ -49,10 +33,10 @@ cLuaTCPLink::cLuaTCPLink(cPluginLua & a_Plugin, cLuaState::cRef && a_CallbacksTa cLuaTCPLink::~cLuaTCPLink() { // If the link is still open, close it: - cTCPLinkPtr Link = m_Link; - if (Link != nullptr) + auto link = m_Link; + if (link != nullptr) { - Link->Close(); + link->Close(); } Terminated(); @@ -72,14 +56,14 @@ bool cLuaTCPLink::Send(const AString & a_Data) } // Safely grab a copy of the link: - cTCPLinkPtr Link = m_Link; - if (Link == nullptr) + auto link = m_Link; + if (link == nullptr) { return false; } // Send the data: - return Link->Send(a_Data); + return link->Send(a_Data); } @@ -89,14 +73,14 @@ bool cLuaTCPLink::Send(const AString & a_Data) AString cLuaTCPLink::GetLocalIP(void) const { // Safely grab a copy of the link: - cTCPLinkPtr Link = m_Link; - if (Link == nullptr) + auto link = m_Link; + if (link == nullptr) { return ""; } // Get the IP address: - return Link->GetLocalIP(); + return link->GetLocalIP(); } @@ -106,14 +90,14 @@ AString cLuaTCPLink::GetLocalIP(void) const UInt16 cLuaTCPLink::GetLocalPort(void) const { // Safely grab a copy of the link: - cTCPLinkPtr Link = m_Link; - if (Link == nullptr) + auto link = m_Link; + if (link == nullptr) { return 0; } // Get the port: - return Link->GetLocalPort(); + return link->GetLocalPort(); } @@ -123,14 +107,14 @@ UInt16 cLuaTCPLink::GetLocalPort(void) const AString cLuaTCPLink::GetRemoteIP(void) const { // Safely grab a copy of the link: - cTCPLinkPtr Link = m_Link; - if (Link == nullptr) + cTCPLinkPtr link = m_Link; + if (link == nullptr) { return ""; } // Get the IP address: - return Link->GetRemoteIP(); + return link->GetRemoteIP(); } @@ -140,14 +124,14 @@ AString cLuaTCPLink::GetRemoteIP(void) const UInt16 cLuaTCPLink::GetRemotePort(void) const { // Safely grab a copy of the link: - cTCPLinkPtr Link = m_Link; - if (Link == nullptr) + cTCPLinkPtr link = m_Link; + if (link == nullptr) { return 0; } // Get the port: - return Link->GetRemotePort(); + return link->GetRemotePort(); } @@ -157,8 +141,8 @@ UInt16 cLuaTCPLink::GetRemotePort(void) const void cLuaTCPLink::Shutdown(void) { // Safely grab a copy of the link and shut it down: - cTCPLinkPtr Link = m_Link; - if (Link != nullptr) + cTCPLinkPtr link = m_Link; + if (link != nullptr) { if (m_SslContext != nullptr) { @@ -166,7 +150,7 @@ void cLuaTCPLink::Shutdown(void) m_SslContext->ResetSelf(); m_SslContext.reset(); } - Link->Shutdown(); + link->Shutdown(); } } @@ -177,8 +161,8 @@ void cLuaTCPLink::Shutdown(void) void cLuaTCPLink::Close(void) { // If the link is still open, close it: - cTCPLinkPtr Link = m_Link; - if (Link != nullptr) + cTCPLinkPtr link = m_Link; + if (link != nullptr) { if (m_SslContext != nullptr) { @@ -186,7 +170,7 @@ void cLuaTCPLink::Close(void) m_SslContext->ResetSelf(); m_SslContext.reset(); } - Link->Close(); + link->Close(); } Terminated(); @@ -303,9 +287,9 @@ AString cLuaTCPLink::StartTLSServer( void cLuaTCPLink::Terminated(void) { // Disable the callbacks: - if (m_Callbacks.IsValid()) + if (m_Callbacks->IsValid()) { - m_Callbacks.UnRef(); + m_Callbacks->Clear(); } // If the managing server is still alive, let it know we're terminating: @@ -317,10 +301,10 @@ void cLuaTCPLink::Terminated(void) // If the link is still open, close it: { - cTCPLinkPtr Link = m_Link; - if (Link != nullptr) + auto link= m_Link; + if (link != nullptr) { - Link->Close(); + link->Close(); m_Link.reset(); } } @@ -335,18 +319,8 @@ void cLuaTCPLink::Terminated(void) void cLuaTCPLink::ReceivedCleartextData(const char * a_Data, size_t a_NumBytes) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnReceivedData"), this, AString(a_Data, a_NumBytes))) - { - LOGINFO("cTCPLink OnReceivedData callback failed in plugin %s.", m_Plugin.GetName().c_str()); - } + m_Callbacks->CallTableFn("OnReceivedData", this, AString(a_Data, a_NumBytes)); } @@ -355,18 +329,8 @@ void cLuaTCPLink::ReceivedCleartextData(const char * a_Data, size_t a_NumBytes) void cLuaTCPLink::OnConnected(cTCPLink & a_Link) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnConnected"), this)) - { - LOGINFO("cTCPLink OnConnected() callback failed in plugin %s.", m_Plugin.GetName().c_str()); - } + m_Callbacks->CallTableFn("OnConnected", this); } @@ -375,21 +339,10 @@ void cLuaTCPLink::OnConnected(cTCPLink & a_Link) void cLuaTCPLink::OnError(int a_ErrorCode, const AString & a_ErrorMsg) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnError"), this, a_ErrorCode, a_ErrorMsg)) - { - LOGINFO("cTCPLink OnError() callback failed in plugin %s; the link error is %d (%s).", - m_Plugin.GetName().c_str(), a_ErrorCode, a_ErrorMsg.c_str() - ); - } + m_Callbacks->CallTableFn("OnError", this, a_ErrorCode, a_ErrorMsg); + // Terminate all processing on the link: Terminated(); } @@ -409,12 +362,6 @@ void cLuaTCPLink::OnLinkCreated(cTCPLinkPtr a_Link) void cLuaTCPLink::OnReceivedData(const char * a_Data, size_t a_Length) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - // If we're running in SSL mode, put the data into the SSL decryptor: auto sslContext = m_SslContext; if (sslContext != nullptr) @@ -424,11 +371,7 @@ void cLuaTCPLink::OnReceivedData(const char * a_Data, size_t a_Length) } // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnReceivedData"), this, AString(a_Data, a_Length))) - { - LOGINFO("cTCPLink OnReceivedData callback failed in plugin %s.", m_Plugin.GetName().c_str()); - } + m_Callbacks->CallTableFn("OnReceivedData", this, AString(a_Data, a_Length)); } @@ -437,12 +380,6 @@ void cLuaTCPLink::OnReceivedData(const char * a_Data, size_t a_Length) void cLuaTCPLink::OnRemoteClosed(void) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - // If running in SSL mode and there's data left in the SSL contect, report it: auto sslContext = m_SslContext; if (sslContext != nullptr) @@ -451,12 +388,9 @@ void cLuaTCPLink::OnRemoteClosed(void) } // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnRemoteClosed"), this)) - { - LOGINFO("cTCPLink OnRemoteClosed() callback failed in plugin %s.", m_Plugin.GetName().c_str()); - } + m_Callbacks->CallTableFn("OnRemoteClosed", this); + // Terminate all processing on the link: Terminated(); } diff --git a/src/Bindings/LuaTCPLink.h b/src/Bindings/LuaTCPLink.h index c8ae776fe..b8c886ef8 100644 --- a/src/Bindings/LuaTCPLink.h +++ b/src/Bindings/LuaTCPLink.h @@ -10,8 +10,8 @@ #pragma once #include "../OSSupport/Network.h" -#include "PluginLua.h" #include "../PolarSSL++/SslContext.h" +#include "LuaState.h" @@ -30,11 +30,11 @@ class cLuaTCPLink: public cTCPLink::cCallbacks { public: - /** Creates a new instance of the link, attached to the specified plugin and wrapping the callbacks that are in a table at the specified stack pos. */ - cLuaTCPLink(cPluginLua & a_Plugin, int a_CallbacksTableStackPos); + /** Creates a new instance of the link, wrapping the callbacks that are in the specified table. */ + cLuaTCPLink(cLuaState::cTableRefPtr && a_Callbacks); /** Creates a new instance of the link, attached to the specified plugin and wrapping the callbacks that are in the specified referenced table. */ - cLuaTCPLink(cPluginLua & a_Plugin, cLuaState::cRef && a_CallbacksTableRef, cLuaServerHandleWPtr a_Server); + cLuaTCPLink(cLuaState::cTableRefPtr && a_Callbacks, cLuaServerHandleWPtr a_Server); ~cLuaTCPLink(); @@ -139,11 +139,8 @@ protected: }; - /** The plugin for which the link is created. */ - cPluginLua & m_Plugin; - /** The Lua table that holds the callbacks to be invoked. */ - cLuaState::cRef m_Callbacks; + cLuaState::cTableRefPtr m_Callbacks; /** The underlying link representing the connection. May be nullptr. */ diff --git a/src/Bindings/LuaUDPEndpoint.cpp b/src/Bindings/LuaUDPEndpoint.cpp index ed8f4e87f..779efcbce 100644 --- a/src/Bindings/LuaUDPEndpoint.cpp +++ b/src/Bindings/LuaUDPEndpoint.cpp @@ -10,17 +10,9 @@ -cLuaUDPEndpoint::cLuaUDPEndpoint(cPluginLua & a_Plugin, int a_CallbacksTableStackPos): - m_Plugin(a_Plugin), - m_Callbacks(cPluginLua::cOperation(a_Plugin)(), a_CallbacksTableStackPos) +cLuaUDPEndpoint::cLuaUDPEndpoint(cLuaState::cTableRefPtr && a_Callbacks): + m_Callbacks(std::move(a_Callbacks)) { - // Warn if the callbacks aren't valid: - if (!m_Callbacks.IsValid()) - { - LOGWARNING("cLuaUDPEndpoint in plugin %s: callbacks could not be retrieved", m_Plugin.GetName().c_str()); - cPluginLua::cOperation Op(m_Plugin); - Op().LogStackTrace(); - } } @@ -30,10 +22,10 @@ cLuaUDPEndpoint::cLuaUDPEndpoint(cPluginLua & a_Plugin, int a_CallbacksTableStac cLuaUDPEndpoint::~cLuaUDPEndpoint() { // If the endpoint is still open, close it: - cUDPEndpointPtr Endpoint = m_Endpoint; - if (Endpoint != nullptr) + auto endpoint = m_Endpoint; + if (endpoint != nullptr) { - Endpoint->Close(); + endpoint->Close(); } Terminated(); @@ -60,14 +52,14 @@ bool cLuaUDPEndpoint::Open(UInt16 a_Port, cLuaUDPEndpointPtr a_Self) bool cLuaUDPEndpoint::Send(const AString & a_Data, const AString & a_RemotePeer, UInt16 a_RemotePort) { // Safely grab a copy of the endpoint: - cUDPEndpointPtr Endpoint = m_Endpoint; - if (Endpoint == nullptr) + auto endpoint = m_Endpoint; + if (endpoint == nullptr) { return false; } // Send the data: - return Endpoint->Send(a_Data, a_RemotePeer, a_RemotePort); + return endpoint->Send(a_Data, a_RemotePeer, a_RemotePort); } @@ -77,14 +69,14 @@ bool cLuaUDPEndpoint::Send(const AString & a_Data, const AString & a_RemotePeer, UInt16 cLuaUDPEndpoint::GetPort(void) const { // Safely grab a copy of the endpoint: - cUDPEndpointPtr Endpoint = m_Endpoint; - if (Endpoint == nullptr) + auto endpoint = m_Endpoint; + if (endpoint == nullptr) { return 0; } // Get the port: - return Endpoint->GetPort(); + return endpoint->GetPort(); } @@ -94,15 +86,15 @@ UInt16 cLuaUDPEndpoint::GetPort(void) const bool cLuaUDPEndpoint::IsOpen(void) const { // Safely grab a copy of the endpoint: - cUDPEndpointPtr Endpoint = m_Endpoint; - if (Endpoint == nullptr) + auto endpoint = m_Endpoint; + if (endpoint == nullptr) { // No endpoint means that we're not open return false; } // Get the state: - return Endpoint->IsOpen(); + return endpoint->IsOpen(); } @@ -112,10 +104,10 @@ bool cLuaUDPEndpoint::IsOpen(void) const void cLuaUDPEndpoint::Close(void) { // If the endpoint is still open, close it: - cUDPEndpointPtr Endpoint = m_Endpoint; - if (Endpoint != nullptr) + auto endpoint = m_Endpoint; + if (endpoint != nullptr) { - Endpoint->Close(); + endpoint->Close(); m_Endpoint.reset(); } @@ -129,10 +121,10 @@ void cLuaUDPEndpoint::Close(void) void cLuaUDPEndpoint::EnableBroadcasts(void) { // Safely grab a copy of the endpoint: - cUDPEndpointPtr Endpoint = m_Endpoint; - if (Endpoint != nullptr) + auto endpoint = m_Endpoint; + if (endpoint != nullptr) { - Endpoint->EnableBroadcasts(); + endpoint->EnableBroadcasts(); } } @@ -156,17 +148,14 @@ void cLuaUDPEndpoint::Release(void) void cLuaUDPEndpoint::Terminated(void) { // Disable the callbacks: - if (m_Callbacks.IsValid()) - { - m_Callbacks.UnRef(); - } + m_Callbacks.reset(); // If the endpoint is still open, close it: { - cUDPEndpointPtr Endpoint = m_Endpoint; - if (Endpoint != nullptr) + auto endpoint = m_Endpoint; + if (endpoint != nullptr) { - Endpoint->Close(); + endpoint->Close(); m_Endpoint.reset(); } } @@ -178,18 +167,7 @@ void cLuaUDPEndpoint::Terminated(void) void cLuaUDPEndpoint::OnReceivedData(const char * a_Data, size_t a_NumBytes, const AString & a_RemotePeer, UInt16 a_RemotePort) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnReceivedData"), this, AString(a_Data, a_NumBytes), a_RemotePeer, a_RemotePort)) - { - LOGINFO("cUDPEndpoint OnReceivedData callback failed in plugin %s.", m_Plugin.GetName().c_str()); - } + m_Callbacks->CallTableFn("OnReceivedData", this, AString(a_Data, a_NumBytes), a_RemotePeer, a_RemotePort); } @@ -198,21 +176,10 @@ void cLuaUDPEndpoint::OnReceivedData(const char * a_Data, size_t a_NumBytes, con void cLuaUDPEndpoint::OnError(int a_ErrorCode, const AString & a_ErrorMsg) { - // Check if we're still valid: - if (!m_Callbacks.IsValid()) - { - return; - } - - // Call the callback: - cPluginLua::cOperation Op(m_Plugin); - if (!Op().Call(cLuaState::cTableRef(m_Callbacks, "OnError"), a_ErrorCode, a_ErrorMsg)) - { - LOGINFO("cUDPEndpoint OnError() callback failed in plugin %s; the endpoint error is %d (%s).", - m_Plugin.GetName().c_str(), a_ErrorCode, a_ErrorMsg.c_str() - ); - } + // Notify the plugin: + m_Callbacks->CallTableFn("OnError", a_ErrorCode, a_ErrorMsg); + // Terminate all processing on the endpoint: Terminated(); } diff --git a/src/Bindings/LuaUDPEndpoint.h b/src/Bindings/LuaUDPEndpoint.h index 0587491ab..338ea6648 100644 --- a/src/Bindings/LuaUDPEndpoint.h +++ b/src/Bindings/LuaUDPEndpoint.h @@ -10,7 +10,7 @@ #pragma once #include "../OSSupport/Network.h" -#include "PluginLua.h" +#include "LuaState.h" @@ -28,8 +28,8 @@ class cLuaUDPEndpoint: public cUDPEndpoint::cCallbacks { public: - /** Creates a new instance of the endpoint, attached to the specified plugin and wrapping the callbacks that are in a table at the specified stack pos. */ - cLuaUDPEndpoint(cPluginLua & a_Plugin, int a_CallbacksTableStackPos); + /** Creates a new instance of the endpoint, wrapping the callbacks that are in the specified table. */ + cLuaUDPEndpoint(cLuaState::cTableRefPtr && a_Callbacks); ~cLuaUDPEndpoint(); @@ -58,11 +58,8 @@ public: void Release(void); protected: - /** The plugin for which the link is created. */ - cPluginLua & m_Plugin; - /** The Lua table that holds the callbacks to be invoked. */ - cLuaState::cRef m_Callbacks; + cLuaState::cTableRefPtr m_Callbacks; /** SharedPtr to self, so that the object can keep itself alive for as long as it needs (for Lua). */ cLuaUDPEndpointPtr m_Self; diff --git a/src/Bindings/ManualBindings.cpp b/src/Bindings/ManualBindings.cpp index bc5352a84..b6a10adf3 100644 --- a/src/Bindings/ManualBindings.cpp +++ b/src/Bindings/ManualBindings.cpp @@ -2647,17 +2647,16 @@ class cLuaBlockTracerCallbacks : public cBlockTracer::cCallbacks { public: - cLuaBlockTracerCallbacks(cLuaState & a_LuaState, int a_ParamNum) : - m_LuaState(a_LuaState), - m_TableRef(a_LuaState, a_ParamNum) + cLuaBlockTracerCallbacks(cLuaState::cTableRefPtr && a_Callbacks): + m_Callbacks(std::move(a_Callbacks)) { } virtual bool OnNextBlock(int a_BlockX, int a_BlockY, int a_BlockZ, BLOCKTYPE a_BlockType, NIBBLETYPE a_BlockMeta, char a_EntryFace) override { bool res = false; - if (!m_LuaState.Call( - cLuaState::cTableRef(m_TableRef, "OnNextBlock"), + if (!m_Callbacks->CallTableFn( + "OnNextBlock", a_BlockX, a_BlockY, a_BlockZ, a_BlockType, a_BlockMeta, a_EntryFace, cLuaState::Return, res )) @@ -2671,8 +2670,8 @@ public: virtual bool OnNextBlockNoData(int a_BlockX, int a_BlockY, int a_BlockZ, char a_EntryFace) override { bool res = false; - if (!m_LuaState.Call( - cLuaState::cTableRef(m_TableRef, "OnNextBlockNoData"), + if (!m_Callbacks->CallTableFn( + "OnNextBlockNoData", a_BlockX, a_BlockY, a_BlockZ, a_EntryFace, cLuaState::Return, res )) @@ -2686,8 +2685,8 @@ public: virtual bool OnOutOfWorld(double a_BlockX, double a_BlockY, double a_BlockZ) override { bool res = false; - if (!m_LuaState.Call( - cLuaState::cTableRef(m_TableRef, "OnOutOfWorld"), + if (!m_Callbacks->CallTableFn( + "OnOutOfWorld", a_BlockX, a_BlockY, a_BlockZ, cLuaState::Return, res )) @@ -2701,8 +2700,8 @@ public: virtual bool OnIntoWorld(double a_BlockX, double a_BlockY, double a_BlockZ) override { bool res = false; - if (!m_LuaState.Call( - cLuaState::cTableRef(m_TableRef, "OnIntoWorld"), + if (!m_Callbacks->CallTableFn( + "OnIntoWorld", a_BlockX, a_BlockY, a_BlockZ, cLuaState::Return, res )) @@ -2715,17 +2714,16 @@ public: virtual void OnNoMoreHits(void) override { - m_LuaState.Call(cLuaState::cTableRef(m_TableRef, "OnNoMoreHits")); + m_Callbacks->CallTableFn("OnNoMoreHits"); } virtual void OnNoChunk(void) override { - m_LuaState.Call(cLuaState::cTableRef(m_TableRef, "OnNoChunk")); + m_Callbacks->CallTableFn("OnNoChunk"); } protected: - cLuaState & m_LuaState; - cLuaState::cRef m_TableRef; + cLuaState::cTableRefPtr m_Callbacks; } ; @@ -2759,16 +2757,22 @@ static int tolua_cLineBlockTracer_Trace(lua_State * tolua_S) return 0; } + // Get the params: + cWorld * world; + double startX, startY, startZ; + double endX, endY, endZ; + cLuaState::cTableRefPtr callbacks; + if (!L.GetStackValues(idx, world, callbacks, startX, startY, startZ, endX, endY, endZ)) + { + LOGWARNING("cLineBlockTracer:Trace(): Cannot read parameters (starting at idx %d), aborting the trace.", idx); + L.LogStackTrace(); + L.LogStackValues("Values on the stack"); + return 0; + } + // Trace: - cWorld * World = reinterpret_cast(tolua_tousertype(L, idx, nullptr)); - cLuaBlockTracerCallbacks Callbacks(L, idx + 1); - double StartX = tolua_tonumber(L, idx + 2, 0); - double StartY = tolua_tonumber(L, idx + 3, 0); - double StartZ = tolua_tonumber(L, idx + 4, 0); - double EndX = tolua_tonumber(L, idx + 5, 0); - double EndY = tolua_tonumber(L, idx + 6, 0); - double EndZ = tolua_tonumber(L, idx + 7, 0); - bool res = cLineBlockTracer::Trace(*World, Callbacks, StartX, StartY, StartZ, EndX, EndY, EndZ); + cLuaBlockTracerCallbacks tracerCallbacks(std::move(callbacks)); + bool res = cLineBlockTracer::Trace(*world, tracerCallbacks, startX, startY, startZ, endX, endY, endZ); tolua_pushboolean(L, res ? 1 : 0); return 1; } diff --git a/src/Bindings/ManualBindings_Network.cpp b/src/Bindings/ManualBindings_Network.cpp index 576fe94b7..68eba5870 100644 --- a/src/Bindings/ManualBindings_Network.cpp +++ b/src/Bindings/ManualBindings_Network.cpp @@ -38,33 +38,34 @@ static int tolua_cNetwork_Connect(lua_State * L) return 0; } - // Get the plugin instance: - cPluginLua * Plugin = cManualBindings::GetLuaPlugin(L); - if (Plugin == nullptr) + // Read the params: + AString host; + int port = 0; + cLuaState::cTableRefPtr callbacks; + if (!S.GetStackValues(2, host, port, callbacks)) { - // An error message has been already printed in GetLuaPlugin() + LOGWARNING("cNetwork::Connect() cannot read its parameters, failing the request."); + S.LogStackTrace(); + S.LogStackValues("Values on the stack"); S.Push(false); return 1; } - // Read the params: - AString Host; - int Port = 0; - S.GetStackValues(2, Host, Port); - // Check validity: - if ((Port < 0) || (Port > 65535)) + if ((port < 0) || (port > 65535)) { - LOGWARNING("cNetwork:Connect() called with invalid port (%d), failing the request.", Port); + LOGWARNING("cNetwork:Connect() called with invalid port (%d), failing the request.", port); + S.LogStackTrace(); S.Push(false); return 1; } + ASSERT(callbacks != nullptr); // Invalid callbacks would have resulted in GetStackValues() returning false // Create the LuaTCPLink glue class: - auto Link = std::make_shared(*Plugin, 4); + auto link = std::make_shared(std::move(callbacks)); // Try to connect: - bool res = cNetwork::Connect(Host, static_cast(Port), Link, Link); + bool res = cNetwork::Connect(host, static_cast(port), link, link); S.Push(res); return 1; @@ -91,36 +92,38 @@ static int tolua_cNetwork_CreateUDPEndpoint(lua_State * L) return 0; } - // Get the plugin instance: - cPluginLua * Plugin = cManualBindings::GetLuaPlugin(L); - if (Plugin == nullptr) + // Read the params: + UInt16 port; + cLuaState::cTableRefPtr callbacks; + if (!S.GetStackValues(2, port, callbacks)) { - // An error message has been already printed in GetLuaPlugin() + LOGWARNING("cNetwork:CreateUDPEndpoint() cannot read its parameters, failing the request."); + S.LogStackTrace(); + S.LogStackValues("Values on the stack"); S.Push(false); return 1; } - // Read the params: - UInt16 Port; - // Check validity: - if (!S.GetStackValues(2, Port)) + if ((port < 0) || (port > 65535)) { - LOGWARNING("cNetwork:CreateUDPEndpoint() called with invalid port, failing the request."); + LOGWARNING("cNetwork:CreateUDPEndpoint() called with invalid port (%d), failing the request.", port); + S.LogStackTrace(); S.Push(false); return 1; } + ASSERT(callbacks != nullptr); // Invalid callbacks would have resulted in GetStackValues() returning false // Create the LuaUDPEndpoint glue class: - auto Endpoint = std::make_shared(*Plugin, 3); - Endpoint->Open(Port, Endpoint); + auto endpoint = std::make_shared(std::move(callbacks)); + endpoint->Open(port, endpoint); // Register the endpoint to be garbage-collected by Lua: - tolua_pushusertype(L, Endpoint.get(), "cUDPEndpoint"); + tolua_pushusertype(L, endpoint.get(), "cUDPEndpoint"); tolua_register_gc(L, lua_gettop(L)); // Return the endpoint object: - S.Push(Endpoint.get()); + S.Push(endpoint.get()); return 1; } @@ -169,21 +172,21 @@ static int tolua_cNetwork_HostnameToIP(lua_State * L) return 0; } - // Get the plugin instance: - cPluginLua * Plugin = cManualBindings::GetLuaPlugin(L); - if (Plugin == nullptr) + // Read the params: + AString host; + cLuaState::cTableRefPtr callbacks; + if (!S.GetStackValues(2, host, callbacks)) { - // An error message has been already printed in GetLuaPlugin() + LOGWARNING("cNetwork::HostnameToIP() cannot read its parameters, failing the request."); + S.LogStackTrace(); + S.LogStackValues("Values on the stack"); S.Push(false); return 1; } - - // Read the params: - AString Host; - S.GetStackValue(2, Host); + ASSERT(callbacks != nullptr); // Invalid callbacks would have resulted in GetStackValues() returning false // Try to look up: - bool res = cNetwork::HostnameToIP(Host, std::make_shared(Host, *Plugin, 3)); + bool res = cNetwork::HostnameToIP(host, std::make_shared(host, std::move(callbacks))); S.Push(res); return 1; @@ -210,21 +213,21 @@ static int tolua_cNetwork_IPToHostname(lua_State * L) return 0; } - // Get the plugin instance: - cPluginLua * Plugin = cManualBindings::GetLuaPlugin(L); - if (Plugin == nullptr) + // Read the params: + AString ip; + cLuaState::cTableRefPtr callbacks; + if (!S.GetStackValues(2, ip, callbacks)) { - // An error message has been already printed in GetLuaPlugin() + LOGWARNING("cNetwork::IPToHostname() cannot read its parameters, failing the request."); + S.LogStackTrace(); + S.LogStackValues("Values on the stack"); S.Push(false); return 1; } - - // Read the params: - AString Host; - S.GetStackValue(2, Host); + ASSERT(callbacks != nullptr); // Invalid callbacks would have resulted in GetStackValues() returning false // Try to look up: - bool res = cNetwork::IPToHostName(Host, std::make_shared(Host, *Plugin, 3)); + bool res = cNetwork::IPToHostName(ip, std::make_shared(ip, std::move(callbacks))); S.Push(res); return 1; @@ -251,38 +254,40 @@ static int tolua_cNetwork_Listen(lua_State * L) return 0; } - // Get the plugin instance: - cPluginLua * Plugin = cManualBindings::GetLuaPlugin(L); - if (Plugin == nullptr) + // Read the params: + int port = 0; + cLuaState::cTableRefPtr callbacks; + if (!S.GetStackValues(2, port, callbacks)) { - // An error message has been already printed in GetLuaPlugin() + LOGWARNING("cNetwork::Listen() cannot read its parameters, failing the request."); + S.LogStackTrace(); + S.LogStackValues("Values on the stack"); S.Push(false); return 1; } - // Read the params: - int Port = 0; - S.GetStackValues(2, Port); - if ((Port < 0) || (Port > 65535)) + // Check the validity: + if ((port < 0) || (port > 65535)) { - LOGWARNING("cNetwork:Listen() called with invalid port (%d), failing the request.", Port); + LOGWARNING("cNetwork:Listen() called with invalid port (%d), failing the request.", port); + S.LogStackTrace(); S.Push(false); return 1; } - UInt16 Port16 = static_cast(Port); + auto port16 = static_cast(port); // Create the LuaTCPLink glue class: - auto Srv = std::make_shared(Port16, *Plugin, 3); + auto srv = std::make_shared(port16, std::move(callbacks)); // Listen: - Srv->SetServerHandle(cNetwork::Listen(Port16, Srv), Srv); + srv->SetServerHandle(cNetwork::Listen(port16, srv), srv); // Register the server to be garbage-collected by Lua: - tolua_pushusertype(L, Srv.get(), "cServerHandle"); + tolua_pushusertype(L, srv.get(), "cServerHandle"); tolua_register_gc(L, lua_gettop(L)); // Return the server handle wrapper: - S.Push(Srv.get()); + S.Push(srv.get()); return 1; } -- cgit v1.2.3 From dd5567a90c47e52b19bacae5eddce88b5c3d4cd6 Mon Sep 17 00:00:00 2001 From: Mattes D Date: Wed, 6 Jul 2016 20:52:43 +0200 Subject: IPLookup: Fixed a soft memory leak when looking up invalid IPs. --- src/OSSupport/IPLookup.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/OSSupport/IPLookup.cpp b/src/OSSupport/IPLookup.cpp index 8cdc5132d..2722d4722 100644 --- a/src/OSSupport/IPLookup.cpp +++ b/src/OSSupport/IPLookup.cpp @@ -103,7 +103,13 @@ bool cNetwork::IPToHostName( { auto res = std::make_shared(a_Callbacks); cNetworkSingleton::Get().AddIPLookup(res); - return res->Lookup(a_IP); + if (!res->Lookup(a_IP)) + { + // Lookup failed early on, remove the object completely: + cNetworkSingleton::Get().RemoveIPLookup(res.get()); + return false; + } + return true; } -- cgit v1.2.3