Implemented load_buffer_format_* conversions (#295)

* Implemented load_buffer_format_* conversions

* clang-format insists on ugly things
This commit is contained in:
Vladislav Mikhalin 2024-07-16 15:03:07 +03:00 committed by GitHub
parent c6cdfcfb0b
commit f9e96793cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 475 additions and 91 deletions

View file

@ -327,6 +327,22 @@ Value IREmitter::LoadBuffer(int num_dwords, const Value& handle, const Value& ad
}
}
Value IREmitter::LoadBufferFormat(int num_dwords, const Value& handle, const Value& address,
BufferInstInfo info) {
switch (num_dwords) {
case 1:
return Inst(Opcode::LoadBufferFormatF32, Flags{info}, handle, address);
case 2:
return Inst(Opcode::LoadBufferFormatF32x2, Flags{info}, handle, address);
case 3:
return Inst(Opcode::LoadBufferFormatF32x3, Flags{info}, handle, address);
case 4:
return Inst(Opcode::LoadBufferFormatF32x4, Flags{info}, handle, address);
default:
UNREACHABLE_MSG("Invalid number of dwords {}", num_dwords);
}
}
void IREmitter::StoreBuffer(int num_dwords, const Value& handle, const Value& address,
const Value& data, BufferInstInfo info) {
switch (num_dwords) {

View file

@ -89,6 +89,8 @@ public:
[[nodiscard]] Value LoadBuffer(int num_dwords, const Value& handle, const Value& address,
BufferInstInfo info);
[[nodiscard]] Value LoadBufferFormat(int num_dwords, const Value& handle, const Value& address,
BufferInstInfo info);
void StoreBuffer(int num_dwords, const Value& handle, const Value& address, const Value& data,
BufferInstInfo info);

View file

@ -79,6 +79,10 @@ OPCODE(LoadBufferF32, F32, Opaq
OPCODE(LoadBufferF32x2, F32x2, Opaque, Opaque, )
OPCODE(LoadBufferF32x3, F32x3, Opaque, Opaque, )
OPCODE(LoadBufferF32x4, F32x4, Opaque, Opaque, )
OPCODE(LoadBufferFormatF32, F32, Opaque, Opaque, )
OPCODE(LoadBufferFormatF32x2, F32x2, Opaque, Opaque, )
OPCODE(LoadBufferFormatF32x3, F32x3, Opaque, Opaque, )
OPCODE(LoadBufferFormatF32x4, F32x4, Opaque, Opaque, )
OPCODE(LoadBufferU32, U32, Opaque, Opaque, )
OPCODE(StoreBufferF32, Void, Opaque, Opaque, F32, )
OPCODE(StoreBufferF32x2, Void, Opaque, Opaque, F32x2, )

View file

@ -27,6 +27,10 @@ bool IsBufferInstruction(const IR::Inst& inst) {
case IR::Opcode::LoadBufferF32x2:
case IR::Opcode::LoadBufferF32x3:
case IR::Opcode::LoadBufferF32x4:
case IR::Opcode::LoadBufferFormatF32:
case IR::Opcode::LoadBufferFormatF32x2:
case IR::Opcode::LoadBufferFormatF32x3:
case IR::Opcode::LoadBufferFormatF32x4:
case IR::Opcode::LoadBufferU32:
case IR::Opcode::ReadConstBuffer:
case IR::Opcode::ReadConstBufferU32:
@ -41,8 +45,49 @@ bool IsBufferInstruction(const IR::Inst& inst) {
}
}
IR::Type BufferDataType(const IR::Inst& inst) {
static bool UseFP16(AmdGpu::DataFormat data_format, AmdGpu::NumberFormat num_format) {
switch (num_format) {
case AmdGpu::NumberFormat::Float:
switch (data_format) {
case AmdGpu::DataFormat::Format16:
case AmdGpu::DataFormat::Format16_16:
case AmdGpu::DataFormat::Format16_16_16_16:
return true;
default:
return false;
}
case AmdGpu::NumberFormat::Unorm:
case AmdGpu::NumberFormat::Snorm:
case AmdGpu::NumberFormat::Uscaled:
case AmdGpu::NumberFormat::Sscaled:
case AmdGpu::NumberFormat::Uint:
case AmdGpu::NumberFormat::Sint:
case AmdGpu::NumberFormat::SnormNz:
default:
return false;
}
}
IR::Type BufferDataType(const IR::Inst& inst, AmdGpu::NumberFormat num_format) {
switch (inst.GetOpcode()) {
case IR::Opcode::LoadBufferFormatF32:
case IR::Opcode::LoadBufferFormatF32x2:
case IR::Opcode::LoadBufferFormatF32x3:
case IR::Opcode::LoadBufferFormatF32x4:
switch (num_format) {
case AmdGpu::NumberFormat::Unorm:
case AmdGpu::NumberFormat::Snorm:
case AmdGpu::NumberFormat::Uscaled:
case AmdGpu::NumberFormat::Sscaled:
case AmdGpu::NumberFormat::Uint:
case AmdGpu::NumberFormat::Sint:
case AmdGpu::NumberFormat::SnormNz:
return IR::Type::U32;
case AmdGpu::NumberFormat::Float:
return IR::Type::F32;
default:
UNREACHABLE();
}
case IR::Opcode::LoadBufferF32:
case IR::Opcode::LoadBufferF32x2:
case IR::Opcode::LoadBufferF32x3:
@ -141,7 +186,7 @@ public:
desc.inline_cbuf == existing.inline_cbuf;
})};
auto& buffer = buffer_resources[index];
ASSERT(buffer.stride == desc.stride && buffer.num_records == desc.num_records);
ASSERT(buffer.length == desc.length);
buffer.is_storage |= desc.is_storage;
buffer.used_types |= desc.used_types;
return index;
@ -263,6 +308,41 @@ SharpLocation TrackSharp(const IR::Inst* inst) {
static constexpr size_t MaxUboSize = 65536;
static bool IsLoadBufferFormat(const IR::Inst& inst) {
switch (inst.GetOpcode()) {
case IR::Opcode::LoadBufferFormatF32:
case IR::Opcode::LoadBufferFormatF32x2:
case IR::Opcode::LoadBufferFormatF32x3:
case IR::Opcode::LoadBufferFormatF32x4:
return true;
default:
return false;
}
}
static bool IsReadConstBuffer(const IR::Inst& inst) {
switch (inst.GetOpcode()) {
case IR::Opcode::ReadConstBuffer:
case IR::Opcode::ReadConstBufferU32:
return true;
default:
return false;
}
}
static u32 BufferLength(const AmdGpu::Buffer& buffer) {
const auto stride = buffer.GetStride();
if (stride < sizeof(f32)) {
ASSERT(sizeof(f32) % stride == 0);
return (((buffer.num_records - 1) / sizeof(f32)) + 1) * stride;
} else if (stride == sizeof(f32)) {
return buffer.num_records;
} else {
ASSERT(stride % sizeof(f32) == 0);
return buffer.num_records * (stride / sizeof(f32));
}
}
s32 TryHandleInlineCbuf(IR::Inst& inst, Info& info, Descriptors& descriptors,
AmdGpu::Buffer& cbuf) {
@ -298,9 +378,8 @@ s32 TryHandleInlineCbuf(IR::Inst& inst, Info& info, Descriptors& descriptors,
return descriptors.Add(BufferResource{
.sgpr_base = std::numeric_limits<u32>::max(),
.dword_offset = 0,
.stride = cbuf.GetStride(),
.num_records = u32(cbuf.num_records),
.used_types = BufferDataType(inst),
.length = BufferLength(cbuf),
.used_types = BufferDataType(inst, cbuf.GetNumberFmt()),
.inline_cbuf = cbuf,
.is_storage = IsBufferStore(inst) || cbuf.GetSize() > MaxUboSize,
});
@ -318,9 +397,8 @@ void PatchBufferInstruction(IR::Block& block, IR::Inst& inst, Info& info,
binding = descriptors.Add(BufferResource{
.sgpr_base = sharp.sgpr_base,
.dword_offset = sharp.dword_offset,
.stride = buffer.GetStride(),
.num_records = u32(buffer.num_records),
.used_types = BufferDataType(inst),
.length = BufferLength(buffer),
.used_types = BufferDataType(inst, buffer.GetNumberFmt()),
.is_storage = IsBufferStore(inst) || buffer.GetSize() > MaxUboSize,
});
}
@ -337,24 +415,31 @@ void PatchBufferInstruction(IR::Block& block, IR::Inst& inst, Info& info,
inst_info.dmft == AmdGpu::DataFormat::Format32_32 ||
inst_info.dmft == AmdGpu::DataFormat::Format32));
}
if (inst.GetOpcode() == IR::Opcode::ReadConstBuffer ||
inst.GetOpcode() == IR::Opcode::ReadConstBufferU32) {
if (IsReadConstBuffer(inst)) {
return;
}
// Calculate buffer address.
const u32 dword_stride = buffer.GetStrideElements(sizeof(u32));
const u32 dword_offset = inst_info.inst_offset.Value() / sizeof(u32);
IR::U32 address = ir.Imm32(dword_offset);
if (inst_info.index_enable && inst_info.offset_enable) {
const IR::U32 offset{ir.CompositeExtract(inst.Arg(1), 1)};
const IR::U32 index{ir.CompositeExtract(inst.Arg(1), 0)};
address = ir.IAdd(ir.IMul(index, ir.Imm32(dword_stride)), address);
address = ir.IAdd(address, ir.ShiftRightLogical(offset, ir.Imm32(2)));
} else if (inst_info.index_enable) {
const IR::U32 index{inst.Arg(1)};
address = ir.IAdd(ir.IMul(index, ir.Imm32(dword_stride)), address);
} else if (inst_info.offset_enable) {
const IR::U32 offset{inst.Arg(1)};
if (IsLoadBufferFormat(inst)) {
if (UseFP16(buffer.GetDataFmt(), buffer.GetNumberFmt())) {
info.uses_fp16 = true;
}
} else {
const u32 stride = buffer.GetStride();
ASSERT_MSG(stride >= 4, "non-formatting load_buffer_* is not implemented for stride {}",
stride);
}
IR::U32 address = ir.Imm32(inst_info.inst_offset.Value());
if (inst_info.index_enable) {
const IR::U32 index = inst_info.offset_enable ? IR::U32{ir.CompositeExtract(inst.Arg(1), 0)}
: IR::U32{inst.Arg(1)};
address = ir.IAdd(address, ir.IMul(index, ir.Imm32(buffer.GetStride())));
}
if (inst_info.offset_enable) {
const IR::U32 offset = inst_info.index_enable ? IR::U32{ir.CompositeExtract(inst.Arg(1), 1)}
: IR::U32{inst.Arg(1)};
address = ir.IAdd(address, offset);
}
inst.SetArg(1, address);
}