summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp42
1 files changed, 41 insertions, 1 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index 7da4d50ef..15e16956e 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -3,9 +3,9 @@
// Refer to the license.txt file included.
#include <algorithm>
+#include <ranges>
#include <tuple>
#include <type_traits>
-#include <ranges>
#include "common/bit_cast.h"
#include "common/bit_util.h"
@@ -332,6 +332,18 @@ void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
}
}
+void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
+ const IR::Value value{inst.Arg(0)};
+ if (value.IsImmediate()) {
+ return;
+ }
+ IR::Inst* const arg_inst{value.InstRecursive()};
+ if (arg_inst->Opcode() == reverse) {
+ inst.ReplaceUsesWith(arg_inst->Arg(0));
+ return;
+ }
+}
+
template <typename Func, size_t... I>
IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) {
using Traits = LambdaTraits<decltype(func)>;
@@ -372,6 +384,10 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
return FoldBitCast<IR::Opcode::BitCastU32F32, u32, f32>(inst, IR::Opcode::BitCastF32U32);
case IR::Opcode::IAdd64:
return FoldAdd<u64>(block, inst);
+ case IR::Opcode::PackHalf2x16:
+ return FoldInverseFunc(inst, IR::Opcode::UnpackHalf2x16);
+ case IR::Opcode::UnpackHalf2x16:
+ return FoldInverseFunc(inst, IR::Opcode::PackHalf2x16);
case IR::Opcode::SelectU1:
case IR::Opcode::SelectU8:
case IR::Opcode::SelectU16:
@@ -395,6 +411,30 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
case IR::Opcode::ULessThan:
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
return;
+ case IR::Opcode::SLessThanEqual:
+ FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a <= b; });
+ return;
+ case IR::Opcode::ULessThanEqual:
+ FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a <= b; });
+ return;
+ case IR::Opcode::SGreaterThan:
+ FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a > b; });
+ return;
+ case IR::Opcode::UGreaterThan:
+ FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a > b; });
+ return;
+ case IR::Opcode::SGreaterThanEqual:
+ FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a >= b; });
+ return;
+ case IR::Opcode::UGreaterThanEqual:
+ FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a >= b; });
+ return;
+ case IR::Opcode::IEqual:
+ FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a == b; });
+ return;
+ case IR::Opcode::INotEqual:
+ FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a != b; });
+ return;
case IR::Opcode::BitFieldUExtract:
FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {