From 2d58789d66f1b63ad63304584c7ac43284b540b8 Mon Sep 17 00:00:00 2001 From: Mattes D Date: Wed, 6 Jul 2016 20:52:04 +0200 Subject: Converted cLuaState::cTableRef to use cTrackedRef. This makes the table-based callbacks resistent to LuaState unloads and safer to use. --- src/Bindings/LuaState.cpp | 161 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 128 insertions(+), 33 deletions(-) (limited to 'src/Bindings/LuaState.cpp') diff --git a/src/Bindings/LuaState.cpp b/src/Bindings/LuaState.cpp index bc9447cc2..e6a94091e 100644 --- a/src/Bindings/LuaState.cpp +++ b/src/Bindings/LuaState.cpp @@ -122,9 +122,9 @@ cLuaStateTracker & cLuaStateTracker::Get(void) //////////////////////////////////////////////////////////////////////////////// -// cLuaState::cCallback: +// cLuaState::cTrackedRef: -cLuaState::cCallback::cCallback(void): +cLuaState::cTrackedRef::cTrackedRef(void): m_CS(nullptr) { } @@ -133,20 +133,14 @@ cLuaState::cCallback::cCallback(void): -bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos) +bool cLuaState::cTrackedRef::RefStack(cLuaState & a_LuaState, int a_StackPos) { - // Check if the stack contains a function: - if (!lua_isfunction(a_LuaState, a_StackPos)) - { - return false; - } - // Clear any previous callback: Clear(); // Add self to LuaState's callback-tracking: auto canonState = a_LuaState.QueryCanonLuaState(); - canonState->TrackCallback(*this); + canonState->TrackRef(*this); // Store the new callback: m_CS = &(canonState->m_CS); @@ -158,9 +152,9 @@ bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos) -void cLuaState::cCallback::Clear(void) +void cLuaState::cTrackedRef::Clear(void) { - // Free the callback reference: + // Free the reference: lua_State * luaState = nullptr; { auto cs = m_CS; @@ -175,20 +169,21 @@ void cLuaState::cCallback::Clear(void) m_Ref.UnRef(); } } + m_CS = nullptr; // Remove from LuaState's callback-tracking: if (luaState == nullptr) { return; } - cLuaState(luaState).UntrackCallback(*this); + cLuaState(luaState).UntrackRef(*this); } -bool cLuaState::cCallback::IsValid(void) +bool cLuaState::cTrackedRef::IsValid(void) { auto cs = m_CS; if (cs == nullptr) @@ -203,7 +198,7 @@ bool cLuaState::cCallback::IsValid(void) -bool cLuaState::cCallback::IsSameLuaState(cLuaState & a_LuaState) +bool cLuaState::cTrackedRef::IsSameLuaState(cLuaState & a_LuaState) { auto cs = m_CS; if (cs == nullptr) @@ -227,7 +222,7 @@ bool cLuaState::cCallback::IsSameLuaState(cLuaState & a_LuaState) -void cLuaState::cCallback::Invalidate(void) +void cLuaState::cTrackedRef::Invalidate(void) { auto cs = m_CS; if (cs == nullptr) @@ -244,6 +239,43 @@ void cLuaState::cCallback::Invalidate(void) return; } m_Ref.UnRef(); + m_CS = nullptr; +} + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cLuaState::cCallback: + +bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos) +{ + // Check if the stack contains a function: + if (!lua_isfunction(a_LuaState, a_StackPos)) + { + return false; + } + + return Super::RefStack(a_LuaState, a_StackPos); +} + + + + + +//////////////////////////////////////////////////////////////////////////////// +// cLuaState::cTableRef: + +bool cLuaState::cTableRef::RefStack(cLuaState & a_LuaState, int a_StackPos) +{ + // Check if the stack contains a table: + if (!lua_istable(a_LuaState, a_StackPos)) + { + return false; + } + + return Super::RefStack(a_LuaState, a_StackPos); } @@ -365,12 +397,12 @@ void cLuaState::Close(void) return; } - // Invalidate all callbacks: + // Invalidate all tracked refs: { - cCSLock Lock(m_CSTrackedCallbacks); - for (auto & c: m_TrackedCallbacks) + cCSLock Lock(m_CSTrackedRefs); + for (auto & r: m_TrackedRefs) { - c->Invalidate(); + r->Invalidate(); } } @@ -592,22 +624,28 @@ bool cLuaState::PushFunction(const cRef & a_FnRef) -bool cLuaState::PushFunction(const cTableRef & a_TableRef) +bool cLuaState::PushFunction(const cRef & a_TableRef, const char * a_FnName) { ASSERT(IsValid()); ASSERT(m_NumCurrentFunctionArgs == -1); // If not, there's already something pushed onto the stack + if (!a_TableRef.IsValid()) + { + return false; + } + // Push the error handler for lua_pcall() lua_pushcfunction(m_LuaState, &ReportFnCallErrors); - lua_rawgeti(m_LuaState, LUA_REGISTRYINDEX, a_TableRef.GetTableRef()); // Get the table ref + // Get the function from the table: + lua_rawgeti(m_LuaState, LUA_REGISTRYINDEX, static_cast(a_TableRef)); if (!lua_istable(m_LuaState, -1)) { // Not a table, bail out lua_pop(m_LuaState, 2); return false; } - lua_getfield(m_LuaState, -1, a_TableRef.GetFnName()); + lua_getfield(m_LuaState, -1, a_FnName); if (lua_isnil(m_LuaState, -1) || !lua_isfunction(m_LuaState, -1)) { // Not a valid function, bail out @@ -618,7 +656,7 @@ bool cLuaState::PushFunction(const cTableRef & a_TableRef) // Pop the table off the stack: lua_remove(m_LuaState, -2); - Printf(m_CurrentFunctionName, "", a_TableRef.GetFnName()); + Printf(m_CurrentFunctionName, "", a_FnName); m_NumCurrentFunctionArgs = 0; return true; } @@ -1061,6 +1099,28 @@ bool cLuaState::GetStackValue(int a_StackPos, cCallbackSharedPtr & a_Callback) +bool cLuaState::GetStackValue(int a_StackPos, cTableRef & a_TableRef) +{ + return a_TableRef.RefStack(*this, a_StackPos); +} + + + + + +bool cLuaState::GetStackValue(int a_StackPos, cTableRefPtr & a_TableRef) +{ + if (a_TableRef == nullptr) + { + a_TableRef = cpp14::make_unique(); + } + return a_TableRef->RefStack(*this, a_StackPos); +} + + + + + bool cLuaState::GetStackValue(int a_StackPos, cPluginManager::CommandResult & a_Result) { if (lua_isnumber(m_LuaState, a_StackPos)) @@ -1085,6 +1145,41 @@ bool cLuaState::GetStackValue(int a_StackPos, cRef & a_Ref) +bool cLuaState::GetStackValue(int a_StackPos, cTrackedRef & a_Ref) +{ + return a_Ref.RefStack(*this, a_StackPos); +} + + + + + +bool cLuaState::GetStackValue(int a_StackPos, cTrackedRefPtr & a_Ref) +{ + if (a_Ref == nullptr) + { + a_Ref = cpp14::make_unique(); + } + return a_Ref->RefStack(*this, a_StackPos); +} + + + + + +bool cLuaState::GetStackValue(int a_StackPos, cTrackedRefSharedPtr & a_Ref) +{ + if (a_Ref == nullptr) + { + a_Ref = std::make_shared(); + } + return a_Ref->RefStack(*this, a_StackPos); +} + + + + + bool cLuaState::GetStackValue(int a_StackPos, double & a_ReturnedVal) { if (lua_isnumber(m_LuaState, a_StackPos)) @@ -1930,7 +2025,7 @@ int cLuaState::BreakIntoDebugger(lua_State * a_LuaState) -void cLuaState::TrackCallback(cCallback & a_Callback) +void cLuaState::TrackRef(cTrackedRef & a_Ref) { // Get the CanonLuaState global from Lua: auto canonState = QueryCanonLuaState(); @@ -1941,15 +2036,15 @@ void cLuaState::TrackCallback(cCallback & a_Callback) } // Add the callback: - cCSLock Lock(canonState->m_CSTrackedCallbacks); - canonState->m_TrackedCallbacks.push_back(&a_Callback); + cCSLock Lock(canonState->m_CSTrackedRefs); + canonState->m_TrackedRefs.push_back(&a_Ref); } -void cLuaState::UntrackCallback(cCallback & a_Callback) +void cLuaState::UntrackRef(cTrackedRef & a_Ref) { // Get the CanonLuaState global from Lua: auto canonState = QueryCanonLuaState(); @@ -1960,12 +2055,12 @@ void cLuaState::UntrackCallback(cCallback & a_Callback) } // Remove the callback: - cCSLock Lock(canonState->m_CSTrackedCallbacks); - auto & trackedCallbacks = canonState->m_TrackedCallbacks; - trackedCallbacks.erase(std::remove_if(trackedCallbacks.begin(), trackedCallbacks.end(), - [&a_Callback](cCallback * a_StoredCallback) + cCSLock Lock(canonState->m_CSTrackedRefs); + auto & trackedRefs = canonState->m_TrackedRefs; + trackedRefs.erase(std::remove_if(trackedRefs.begin(), trackedRefs.end(), + [&a_Ref](cTrackedRef * a_StoredRef) { - return (a_StoredCallback == &a_Callback); + return (a_StoredRef == &a_Ref); } )); } -- cgit v1.2.3