diff options
Diffstat (limited to 'src/shader_recompiler/ir_opt/constant_propagation_pass.cpp')
-rw-r--r-- | src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | 76 |
1 files changed, 64 insertions, 12 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index cbde65b9b..f1ad16d60 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp @@ -77,6 +77,16 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { return true; } +template <typename Func> +bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { + if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) { + return false; + } + using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>; + inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{})); + return true; +} + void FoldGetRegister(IR::Inst& inst) { if (inst.Arg(0).Reg() == IR::Reg::RZ) { inst.ReplaceUsesWith(IR::Value{u32{0}}); @@ -103,6 +113,52 @@ void FoldAdd(IR::Inst& inst) { } } +void FoldISub32(IR::Inst& inst) { + if (FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a - b; })) { + return; + } + if (inst.Arg(0).IsImmediate() || inst.Arg(1).IsImmediate()) { + return; + } + // ISub32 is generally used to subtract two constant buffers, compare and replace this with + // zero if they equal. + const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) { + return a->Opcode() == IR::Opcode::GetCbuf && b->Opcode() == IR::Opcode::GetCbuf && + a->Arg(0) == b->Arg(0) && a->Arg(1) == b->Arg(1); + }}; + IR::Inst* op_a{inst.Arg(0).InstRecursive()}; + IR::Inst* op_b{inst.Arg(1).InstRecursive()}; + if (equal_cbuf(op_a, op_b)) { + inst.ReplaceUsesWith(IR::Value{u32{0}}); + return; + } + // It's also possible a value is being added to a cbuf and then subtracted + if (op_b->Opcode() == IR::Opcode::IAdd32) { + // Canonicalize local variables to simplify the following logic + std::swap(op_a, op_b); + } + if (op_b->Opcode() != IR::Opcode::GetCbuf) { + return; + } + IR::Inst* const inst_cbuf{op_b}; + if (op_a->Opcode() != IR::Opcode::IAdd32) { + return; + } + IR::Value add_op_a{op_a->Arg(0)}; + IR::Value add_op_b{op_a->Arg(1)}; + if (add_op_b.IsImmediate()) { + // Canonicalize + std::swap(add_op_a, add_op_b); + } + if (add_op_b.IsImmediate()) { + return; + } + IR::Inst* const add_cbuf{add_op_b.InstRecursive()}; + if (equal_cbuf(add_cbuf, inst_cbuf)) { + inst.ReplaceUsesWith(add_op_a); + } +} + template <typename T> void FoldSelect(IR::Inst& inst) { const IR::Value cond{inst.Arg(0)}; @@ -170,15 +226,6 @@ IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence< return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)}; } -template <typename Func> -void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { - if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) { - return; - } - using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>; - inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{})); -} - void FoldBranchConditional(IR::Inst& inst) { const IR::U1 cond{inst.Arg(0)}; if (cond.IsImmediate()) { @@ -205,6 +252,8 @@ void ConstantPropagation(IR::Inst& inst) { return FoldGetPred(inst); case IR::Opcode::IAdd32: return FoldAdd<u32>(inst); + case IR::Opcode::ISub32: + return FoldISub32(inst); case IR::Opcode::BitCastF32U32: return FoldBitCast<f32, u32>(inst, IR::Opcode::BitCastU32F32); case IR::Opcode::BitCastU32F32: @@ -220,17 +269,20 @@ void ConstantPropagation(IR::Inst& inst) { case IR::Opcode::LogicalNot: return FoldLogicalNot(inst); case IR::Opcode::SLessThan: - return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); + FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); + return; case IR::Opcode::ULessThan: - return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); + FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); + return; case IR::Opcode::BitFieldUExtract: - return FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) { + FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) { if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) { throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract, base, shift, count); } return (base >> shift) & ((1U << count) - 1); }); + return; case IR::Opcode::BranchConditional: return FoldBranchConditional(inst); default: |