summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp82
1 files changed, 77 insertions, 5 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index f1ad16d60..9eb61b54c 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -9,6 +9,7 @@
#include "common/bit_cast.h"
#include "common/bit_util.h"
#include "shader_recompiler/exception.h"
+#include "shader_recompiler/frontend/ir/ir_emitter.h"
#include "shader_recompiler/frontend/ir/microinstruction.h"
#include "shader_recompiler/ir_opt/passes.h"
@@ -99,8 +100,71 @@ void FoldGetPred(IR::Inst& inst) {
}
}
+/// Replaces the pattern generated by two XMAD multiplications
+bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
+ /*
+ * We are looking for this pattern:
+ * %rhs_bfe = BitFieldUExtract %factor_a, #0, #16 (uses: 1)
+ * %rhs_mul = IMul32 %rhs_bfe, %factor_b (uses: 1)
+ * %lhs_bfe = BitFieldUExtract %factor_a, #16, #16 (uses: 1)
+ * %rhs_mul = IMul32 %lhs_bfe, %factor_b (uses: 1)
+ * %lhs_shl = ShiftLeftLogical32 %rhs_mul, #16 (uses: 1)
+ * %result = IAdd32 %lhs_shl, %rhs_mul (uses: 10)
+ *
+ * And replacing it with
+ * %result = IMul32 %factor_a, %factor_b
+ *
+ * This optimization has been proven safe by LLVM and MSVC.
+ */
+ const IR::Value lhs_arg{inst.Arg(0)};
+ const IR::Value rhs_arg{inst.Arg(1)};
+ if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) {
+ return false;
+ }
+ IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
+ if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) {
+ return false;
+ }
+ if (lhs_shl->Arg(0).IsImmediate()) {
+ return false;
+ }
+ IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
+ IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
+ if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) {
+ return false;
+ }
+ if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
+ return false;
+ }
+ const IR::U32 factor_b{lhs_mul->Arg(1)};
+ if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) {
+ return false;
+ }
+ IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
+ IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
+ if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
+ return false;
+ }
+ if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
+ return false;
+ }
+ if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) {
+ return false;
+ }
+ if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
+ return false;
+ }
+ if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) {
+ return false;
+ }
+ const IR::U32 factor_a{lhs_bfe->Arg(0)};
+ IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
+ inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
+ return true;
+}
+
template <typename T>
-void FoldAdd(IR::Inst& inst) {
+void FoldAdd(IR::Block& block, IR::Inst& inst) {
if (inst.HasAssociatedPseudoOperation()) {
return;
}
@@ -110,6 +174,12 @@ void FoldAdd(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
inst.ReplaceUsesWith(inst.Arg(0));
+ return;
+ }
+ if constexpr (std::is_same_v<T, u32>) {
+ if (FoldXmadMultiply(block, inst)) {
+ return;
+ }
}
}
@@ -244,14 +314,14 @@ void FoldBranchConditional(IR::Inst& inst) {
}
}
-void ConstantPropagation(IR::Inst& inst) {
+void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
switch (inst.Opcode()) {
case IR::Opcode::GetRegister:
return FoldGetRegister(inst);
case IR::Opcode::GetPred:
return FoldGetPred(inst);
case IR::Opcode::IAdd32:
- return FoldAdd<u32>(inst);
+ return FoldAdd<u32>(block, inst);
case IR::Opcode::ISub32:
return FoldISub32(inst);
case IR::Opcode::BitCastF32U32:
@@ -259,7 +329,7 @@ void ConstantPropagation(IR::Inst& inst) {
case IR::Opcode::BitCastU32F32:
return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32);
case IR::Opcode::IAdd64:
- return FoldAdd<u64>(inst);
+ return FoldAdd<u64>(block, inst);
case IR::Opcode::Select32:
return FoldSelect<u32>(inst);
case IR::Opcode::LogicalAnd:
@@ -292,7 +362,9 @@ void ConstantPropagation(IR::Inst& inst) {
} // Anonymous namespace
void ConstantPropagationPass(IR::Block& block) {
- std::ranges::for_each(block, ConstantPropagation);
+ for (IR::Inst& inst : block) {
+ ConstantPropagation(block, inst);
+ }
}
} // namespace Shader::Optimization