shader_recompilers: Improvements to SSA phi generation and lane instruction elimination (#1667)

* shader_recompiler: Add use tracking for Insts

* ssa_rewrite: Recursively remove phis

* ssa_rewrite: Correct recursive trivial phi elimination

* ir: Improve read lane folding pass

* control_flow: Avoid adding unnecessary divergant blocks

* clang format

* externals: Update ext-boost

---------

Co-authored-by: Frodo Baggins <baggins31084@proton.me>
This commit is contained in:
TheTurtle 2024-12-05 23:14:16 +02:00 committed by GitHub
parent 874508f8c2
commit 22a2741ea0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 175 additions and 76 deletions

View file

@ -43,7 +43,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
if (is_lhs_immediate && is_rhs_immediate) {
const auto result{imm_fn(Arg<T>(lhs), Arg<T>(rhs))};
inst.ReplaceUsesWith(IR::Value{result});
inst.ReplaceUsesWithAndRemove(IR::Value{result});
return false;
}
if (is_lhs_immediate && !is_rhs_immediate) {
@ -75,7 +75,7 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
return false;
}
using Indices = std::make_index_sequence<Common::LambdaTraits<decltype(func)>::NUM_ARGS>;
inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
inst.ReplaceUsesWithAndRemove(EvalImmediates(inst, func, Indices{}));
return true;
}
@ -83,12 +83,12 @@ template <IR::Opcode op, typename Dest, typename Source>
void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
const IR::Value value{inst.Arg(0)};
if (value.IsImmediate()) {
inst.ReplaceUsesWith(IR::Value{std::bit_cast<Dest>(Arg<Source>(value))});
inst.ReplaceUsesWithAndRemove(IR::Value{std::bit_cast<Dest>(Arg<Source>(value))});
return;
}
IR::Inst* const arg_inst{value.InstRecursive()};
if (arg_inst->GetOpcode() == reverse) {
inst.ReplaceUsesWith(arg_inst->Arg(0));
inst.ReplaceUsesWithAndRemove(arg_inst->Arg(0));
return;
}
}
@ -131,7 +131,7 @@ void FoldCompositeExtract(IR::Inst& inst, IR::Opcode construct, IR::Opcode inser
if (!result) {
return;
}
inst.ReplaceUsesWith(*result);
inst.ReplaceUsesWithAndRemove(*result);
}
void FoldConvert(IR::Inst& inst, IR::Opcode opposite) {
@ -141,7 +141,7 @@ void FoldConvert(IR::Inst& inst, IR::Opcode opposite) {
}
IR::Inst* const producer{value.InstRecursive()};
if (producer->GetOpcode() == opposite) {
inst.ReplaceUsesWith(producer->Arg(0));
inst.ReplaceUsesWithAndRemove(producer->Arg(0));
}
}
@ -152,9 +152,9 @@ void FoldLogicalAnd(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate()) {
if (rhs.U1()) {
inst.ReplaceUsesWith(inst.Arg(0));
inst.ReplaceUsesWithAndRemove(inst.Arg(0));
} else {
inst.ReplaceUsesWith(IR::Value{false});
inst.ReplaceUsesWithAndRemove(IR::Value{false});
}
}
}
@ -162,7 +162,7 @@ void FoldLogicalAnd(IR::Inst& inst) {
void FoldSelect(IR::Inst& inst) {
const IR::Value cond{inst.Arg(0)};
if (cond.IsImmediate()) {
inst.ReplaceUsesWith(cond.U1() ? inst.Arg(1) : inst.Arg(2));
inst.ReplaceUsesWithAndRemove(cond.U1() ? inst.Arg(1) : inst.Arg(2));
}
}
@ -173,9 +173,9 @@ void FoldLogicalOr(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate()) {
if (rhs.U1()) {
inst.ReplaceUsesWith(IR::Value{true});
inst.ReplaceUsesWithAndRemove(IR::Value{true});
} else {
inst.ReplaceUsesWith(inst.Arg(0));
inst.ReplaceUsesWithAndRemove(inst.Arg(0));
}
}
}
@ -183,12 +183,12 @@ void FoldLogicalOr(IR::Inst& inst) {
void FoldLogicalNot(IR::Inst& inst) {
const IR::U1 value{inst.Arg(0)};
if (value.IsImmediate()) {
inst.ReplaceUsesWith(IR::Value{!value.U1()});
inst.ReplaceUsesWithAndRemove(IR::Value{!value.U1()});
return;
}
IR::Inst* const arg{value.InstRecursive()};
if (arg->GetOpcode() == IR::Opcode::LogicalNot) {
inst.ReplaceUsesWith(arg->Arg(0));
inst.ReplaceUsesWithAndRemove(arg->Arg(0));
}
}
@ -199,7 +199,7 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
}
IR::Inst* const arg_inst{value.InstRecursive()};
if (arg_inst->GetOpcode() == reverse) {
inst.ReplaceUsesWith(arg_inst->Arg(0));
inst.ReplaceUsesWithAndRemove(arg_inst->Arg(0));
return;
}
}
@ -211,7 +211,7 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) {
}
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
inst.ReplaceUsesWith(inst.Arg(0));
inst.ReplaceUsesWithAndRemove(inst.Arg(0));
return;
}
}
@ -226,21 +226,58 @@ void FoldCmpClass(IR::Block& block, IR::Inst& inst) {
} else if ((class_mask & IR::FloatClassFunc::Finite) == IR::FloatClassFunc::Finite) {
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
const IR::F32 value = IR::F32{inst.Arg(0)};
inst.ReplaceUsesWith(ir.LogicalNot(ir.LogicalOr(ir.FPIsInf(value), ir.FPIsInf(value))));
inst.ReplaceUsesWithAndRemove(
ir.LogicalNot(ir.LogicalOr(ir.FPIsInf(value), ir.FPIsInf(value))));
} else {
UNREACHABLE();
}
}
void FoldReadLane(IR::Inst& inst) {
void FoldReadLane(IR::Block& block, IR::Inst& inst) {
const u32 lane = inst.Arg(1).U32();
IR::Inst* prod = inst.Arg(0).InstRecursive();
while (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (prod->Arg(2).U32() == lane) {
inst.ReplaceUsesWith(prod->Arg(1));
const auto search_chain = [lane](const IR::Inst* prod) -> IR::Value {
while (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (prod->Arg(2).U32() == lane) {
return prod->Arg(1);
}
prod = prod->Arg(0).InstRecursive();
}
return {};
};
if (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (const IR::Value value = search_chain(prod); !value.IsEmpty()) {
inst.ReplaceUsesWith(value);
}
return;
}
if (prod->GetOpcode() == IR::Opcode::Phi) {
boost::container::small_vector<IR::Value, 2> phi_args;
for (size_t arg_index = 0; arg_index < prod->NumArgs(); ++arg_index) {
const IR::Inst* arg{prod->Arg(arg_index).InstRecursive()};
if (arg->GetOpcode() != IR::Opcode::WriteLane) {
return;
}
const IR::Value value = search_chain(arg);
if (value.IsEmpty()) {
continue;
}
phi_args.emplace_back(value);
}
if (std::ranges::all_of(phi_args, [&](IR::Value value) { return value == phi_args[0]; })) {
inst.ReplaceUsesWith(phi_args[0]);
return;
}
prod = prod->Arg(0).InstRecursive();
const auto insert_point = IR::Block::InstructionList::s_iterator_to(*prod);
IR::Inst* const new_phi{&*block.PrependNewInst(insert_point, IR::Opcode::Phi)};
new_phi->SetFlags(IR::Type::U32);
for (size_t arg_index = 0; arg_index < phi_args.size(); arg_index++) {
new_phi->AddPhiOperand(prod->PhiBlock(arg_index), phi_args[arg_index]);
}
inst.ReplaceUsesWith(IR::Value{new_phi});
}
}
@ -290,7 +327,7 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
case IR::Opcode::SelectF64:
return FoldSelect(inst);
case IR::Opcode::ReadLane:
return FoldReadLane(inst);
return FoldReadLane(block, inst);
case IR::Opcode::FPNeg32:
FoldWhenAllImmediates(inst, [](f32 a) { return -a; });
return;

View file

@ -25,7 +25,7 @@ void LowerSharedMemToRegisters(IR::Program& program) {
});
ASSERT(it != ds_writes.end());
// Replace data read with value written.
inst.ReplaceUsesWith((*it)->Arg(1));
inst.ReplaceUsesWithAndRemove((*it)->Arg(1));
}
}
}

View file

@ -596,7 +596,7 @@ void PatchImageSampleInstruction(IR::Block& block, IR::Inst& inst, Info& info,
}
return ir.ImageSampleImplicitLod(handle, coords, bias, offset, inst_info);
}();
inst.ReplaceUsesWith(new_inst);
inst.ReplaceUsesWithAndRemove(new_inst);
}
void PatchImageInstruction(IR::Block& block, IR::Inst& inst, Info& info, Descriptors& descriptors) {

View file

@ -164,7 +164,6 @@ IR::Opcode UndefOpcode(const FlagTag) noexcept {
enum class Status {
Start,
SetValue,
PreparePhiArgument,
PushPhiArgument,
};
@ -253,12 +252,10 @@ public:
IR::Inst* const phi{stack.back().phi};
phi->AddPhiOperand(*stack.back().pred_it, stack.back().result);
++stack.back().pred_it;
}
[[fallthrough]];
case Status::PreparePhiArgument:
prepare_phi_operand();
break;
}
}
} while (stack.size() > 1);
return stack.back().result;
}
@ -266,9 +263,7 @@ public:
void SealBlock(IR::Block* block) {
const auto it{incomplete_phis.find(block)};
if (it != incomplete_phis.end()) {
for (auto& pair : it->second) {
auto& variant{pair.first};
auto& phi{pair.second};
for (auto& [variant, phi] : it->second) {
std::visit([&](auto& variable) { AddPhiOperands(variable, *phi, block); }, variant);
}
}
@ -289,7 +284,7 @@ private:
const size_t num_args{phi.NumArgs()};
for (size_t arg_index = 0; arg_index < num_args; ++arg_index) {
const IR::Value& op{phi.Arg(arg_index)};
if (op.Resolve() == same.Resolve() || op == IR::Value{&phi}) {
if (op.Resolve() == same.Resolve() || op.Resolve() == IR::Value{&phi}) {
// Unique value or self-reference
continue;
}
@ -314,9 +309,15 @@ private:
++reinsert_point;
}
// Reinsert the phi node and reroute all its uses to the "same" value
const auto users = phi.Uses();
list.insert(reinsert_point, phi);
phi.ReplaceUsesWith(same);
// TODO: Try to recursively remove all phi users, which might have become trivial
// Try to recursively remove all phi users, which might have become trivial
for (const auto& [user, arg_index] : users) {
if (user->GetOpcode() == IR::Opcode::Phi) {
TryRemoveTrivialPhi(*user, user->GetParent(), undef_opcode);
}
}
return same;
}