diff options
-rw-r--r-- | src/video_core/renderer_opengl/gl_shader_decompiler.cpp | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp index a5cc1a86f..c399fab0f 100644 --- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp +++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp @@ -325,6 +325,7 @@ public: DeclareRegisters(); DeclarePredicates(); DeclareLocalMemory(); + DeclareSharedMemory(); DeclareInternalFlags(); DeclareInputAttributes(); DeclareOutputAttributes(); @@ -500,6 +501,13 @@ private: code.AddNewLine(); } + void DeclareSharedMemory() { + if (stage != ProgramType::Compute) { + return; + } + code.AddLine("shared uint {}[];", GetSharedMemory()); + } + void DeclareInternalFlags() { for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) { const auto flag_code = static_cast<InternalFlag>(flag); @@ -858,6 +866,12 @@ private: Type::Uint}; } + if (const auto smem = std::get_if<SmemNode>(&*node)) { + return { + fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), + Type::Uint}; + } + if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) { return {GetInternalFlag(internal_flag->GetFlag()), Type::Bool}; } @@ -1195,6 +1209,11 @@ private: target = { fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()), Type::Uint}; + } else if (const auto smem = std::get_if<SmemNode>(&*dest)) { + ASSERT(stage == ProgramType::Compute); + target = { + fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), + Type::Uint}; } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) { const std::string real = Visit(gmem->GetRealAddress()).AsUint(); const std::string base = Visit(gmem->GetBaseAddress()).AsUint(); @@ -2076,6 +2095,10 @@ private: return "lmem_" + suffix; } + std::string GetSharedMemory() const { + return fmt::format("smem_{}", suffix); + } + std::string GetInternalFlag(InternalFlag flag) const { constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag", "overflow_flag"}; |