spirv: Initial bindings support

This commit is contained in:
ReinUsesLisp 2021-02-16 04:10:22 -03:00 committed by ameerj
parent d5d468cf2c
commit b5d7279d87
23 changed files with 679 additions and 300 deletions

View file

@ -77,6 +77,16 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
return true;
}
template <typename Func>
bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
return false;
}
using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>;
inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
return true;
}
void FoldGetRegister(IR::Inst& inst) {
if (inst.Arg(0).Reg() == IR::Reg::RZ) {
inst.ReplaceUsesWith(IR::Value{u32{0}});
@ -103,6 +113,52 @@ void FoldAdd(IR::Inst& inst) {
}
}
void FoldISub32(IR::Inst& inst) {
if (FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a - b; })) {
return;
}
if (inst.Arg(0).IsImmediate() || inst.Arg(1).IsImmediate()) {
return;
}
// ISub32 is generally used to subtract two constant buffers, compare and replace this with
// zero if they equal.
const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) {
return a->Opcode() == IR::Opcode::GetCbuf && b->Opcode() == IR::Opcode::GetCbuf &&
a->Arg(0) == b->Arg(0) && a->Arg(1) == b->Arg(1);
}};
IR::Inst* op_a{inst.Arg(0).InstRecursive()};
IR::Inst* op_b{inst.Arg(1).InstRecursive()};
if (equal_cbuf(op_a, op_b)) {
inst.ReplaceUsesWith(IR::Value{u32{0}});
return;
}
// It's also possible a value is being added to a cbuf and then subtracted
if (op_b->Opcode() == IR::Opcode::IAdd32) {
// Canonicalize local variables to simplify the following logic
std::swap(op_a, op_b);
}
if (op_b->Opcode() != IR::Opcode::GetCbuf) {
return;
}
IR::Inst* const inst_cbuf{op_b};
if (op_a->Opcode() != IR::Opcode::IAdd32) {
return;
}
IR::Value add_op_a{op_a->Arg(0)};
IR::Value add_op_b{op_a->Arg(1)};
if (add_op_b.IsImmediate()) {
// Canonicalize
std::swap(add_op_a, add_op_b);
}
if (add_op_b.IsImmediate()) {
return;
}
IR::Inst* const add_cbuf{add_op_b.InstRecursive()};
if (equal_cbuf(add_cbuf, inst_cbuf)) {
inst.ReplaceUsesWith(add_op_a);
}
}
template <typename T>
void FoldSelect(IR::Inst& inst) {
const IR::Value cond{inst.Arg(0)};
@ -170,15 +226,6 @@ IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<
return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)};
}
template <typename Func>
void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
return;
}
using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>;
inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
}
void FoldBranchConditional(IR::Inst& inst) {
const IR::U1 cond{inst.Arg(0)};
if (cond.IsImmediate()) {
@ -205,6 +252,8 @@ void ConstantPropagation(IR::Inst& inst) {
return FoldGetPred(inst);
case IR::Opcode::IAdd32:
return FoldAdd<u32>(inst);
case IR::Opcode::ISub32:
return FoldISub32(inst);
case IR::Opcode::BitCastF32U32:
return FoldBitCast<f32, u32>(inst, IR::Opcode::BitCastU32F32);
case IR::Opcode::BitCastU32F32:
@ -220,17 +269,20 @@ void ConstantPropagation(IR::Inst& inst) {
case IR::Opcode::LogicalNot:
return FoldLogicalNot(inst);
case IR::Opcode::SLessThan:
return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; });
FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; });
return;
case IR::Opcode::ULessThan:
return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
return;
case IR::Opcode::BitFieldUExtract:
return FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {
throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract,
base, shift, count);
}
return (base >> shift) & ((1U << count) - 1);
});
return;
case IR::Opcode::BranchConditional:
return FoldBranchConditional(inst);
default: