shader_recompiler: Various fixes to shared memory and atomics. (#3075)

* shader_recompiler: Various fixes to shared memory and atomics.

* shader_recompiler: Re-type non-32bit load/stores.
This commit is contained in:
squidbus 2025-06-10 15:41:58 -07:00 committed by GitHub
parent b49340dff8
commit ca92e72efe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 391 additions and 227 deletions

View file

@ -27,6 +27,19 @@ Id SharedAtomicU32(EmitContext& ctx, Id offset, Id value,
}); });
} }
Id SharedAtomicU32IncDec(EmitContext& ctx, Id offset,
Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id)) {
const Id shift_id{ctx.ConstU32(2U)};
const Id index{ctx.OpShiftRightLogical(ctx.U32[1], offset, shift_id)};
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)};
const Id pointer{
ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index)};
const auto [scope, semantics]{AtomicArgs(ctx)};
return AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] {
return (ctx.*atomic_func)(ctx.U32[1], pointer, scope, semantics);
});
}
Id SharedAtomicU64(EmitContext& ctx, Id offset, Id value, Id SharedAtomicU64(EmitContext& ctx, Id offset, Id value,
Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id)) { Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id)) {
const Id shift_id{ctx.ConstU32(3U)}; const Id shift_id{ctx.ConstU32(3U)};
@ -40,19 +53,6 @@ Id SharedAtomicU64(EmitContext& ctx, Id offset, Id value,
}); });
} }
Id SharedAtomicU32_IncDec(EmitContext& ctx, Id offset,
Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id)) {
const Id shift_id{ctx.ConstU32(2U)};
const Id index{ctx.OpShiftRightLogical(ctx.U32[1], offset, shift_id)};
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)};
const Id pointer{
ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index)};
const auto [scope, semantics]{AtomicArgs(ctx)};
return AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] {
return (ctx.*atomic_func)(ctx.U32[1], pointer, scope, semantics);
});
}
Id BufferAtomicU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value, Id BufferAtomicU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value,
Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id)) { Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id)) {
const auto& buffer = ctx.buffers[handle]; const auto& buffer = ctx.buffers[handle];
@ -68,6 +68,21 @@ Id BufferAtomicU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id
}); });
} }
Id BufferAtomicU32IncDec(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address,
Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id)) {
const auto& buffer = ctx.buffers[handle];
if (Sirit::ValidId(buffer.offset)) {
address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset);
}
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
const auto [id, pointer_type] = buffer[EmitContext::PointerType::U32];
const Id ptr = ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index);
const auto [scope, semantics]{AtomicArgs(ctx)};
return AccessBoundsCheck<32>(ctx, index, buffer.size_dwords, [&] {
return (ctx.*atomic_func)(ctx.U32[1], ptr, scope, semantics);
});
}
Id BufferAtomicU32CmpSwap(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value, Id BufferAtomicU32CmpSwap(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value,
Id cmp_value, Id cmp_value,
Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id, Id, Id)) { Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id, Id, Id)) {
@ -156,12 +171,12 @@ Id EmitSharedAtomicISub32(EmitContext& ctx, Id offset, Id value) {
return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicISub); return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicISub);
} }
Id EmitSharedAtomicIIncrement32(EmitContext& ctx, Id offset) { Id EmitSharedAtomicInc32(EmitContext& ctx, Id offset) {
return SharedAtomicU32_IncDec(ctx, offset, &Sirit::Module::OpAtomicIIncrement); return SharedAtomicU32IncDec(ctx, offset, &Sirit::Module::OpAtomicIIncrement);
} }
Id EmitSharedAtomicIDecrement32(EmitContext& ctx, Id offset) { Id EmitSharedAtomicDec32(EmitContext& ctx, Id offset) {
return SharedAtomicU32_IncDec(ctx, offset, &Sirit::Module::OpAtomicIDecrement); return SharedAtomicU32IncDec(ctx, offset, &Sirit::Module::OpAtomicIDecrement);
} }
Id EmitBufferAtomicIAdd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { Id EmitBufferAtomicIAdd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) {
@ -172,6 +187,10 @@ Id EmitBufferAtomicIAdd64(EmitContext& ctx, IR::Inst* inst, u32 handle, Id addre
return BufferAtomicU64(ctx, inst, handle, address, value, &Sirit::Module::OpAtomicIAdd); return BufferAtomicU64(ctx, inst, handle, address, value, &Sirit::Module::OpAtomicIAdd);
} }
Id EmitBufferAtomicISub32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) {
return BufferAtomicU32(ctx, inst, handle, address, value, &Sirit::Module::OpAtomicISub);
}
Id EmitBufferAtomicSMin32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { Id EmitBufferAtomicSMin32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) {
return BufferAtomicU32(ctx, inst, handle, address, value, &Sirit::Module::OpAtomicSMin); return BufferAtomicU32(ctx, inst, handle, address, value, &Sirit::Module::OpAtomicSMin);
} }
@ -188,14 +207,12 @@ Id EmitBufferAtomicUMax32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id addre
return BufferAtomicU32(ctx, inst, handle, address, value, &Sirit::Module::OpAtomicUMax); return BufferAtomicU32(ctx, inst, handle, address, value, &Sirit::Module::OpAtomicUMax);
} }
Id EmitBufferAtomicInc32(EmitContext&, IR::Inst*, u32, Id, Id) { Id EmitBufferAtomicInc32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
// TODO return BufferAtomicU32IncDec(ctx, inst, handle, address, &Sirit::Module::OpAtomicIIncrement);
UNREACHABLE_MSG("Unsupported BUFFER_ATOMIC opcode: ", IR::Opcode::BufferAtomicInc32);
} }
Id EmitBufferAtomicDec32(EmitContext&, IR::Inst*, u32, Id, Id) { Id EmitBufferAtomicDec32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
// TODO return BufferAtomicU32IncDec(ctx, inst, handle, address, &Sirit::Module::OpAtomicIDecrement);
UNREACHABLE_MSG("Unsupported BUFFER_ATOMIC opcode: ", IR::Opcode::BufferAtomicDec32);
} }
Id EmitBufferAtomicAnd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { Id EmitBufferAtomicAnd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) {

View file

@ -1,31 +1,54 @@
// SPDX-FileCopyrightText: Copyright 2025 shadPS4 Emulator Project // SPDX-FileCopyrightText: Copyright 2025 shadPS4 Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later // SPDX-License-Identifier: GPL-2.0-or-later
#include "shader_recompiler/backend/spirv/emit_spirv_instructions.h" #pragma once
#include "shader_recompiler/backend/spirv/spirv_emit_context.h" #include "shader_recompiler/backend/spirv/spirv_emit_context.h"
namespace Shader::Backend::SPIRV { namespace Shader::Backend::SPIRV {
template <u32 bit_size> template <u32 bit_size, u32 num_components = 1, bool is_float = false>
auto AccessBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, auto emit_func) { std::tuple<Id, Id> ResolveTypeAndZero(EmitContext& ctx) {
Id zero_value{};
Id result_type{}; Id result_type{};
if constexpr (bit_size == 64) { Id zero_value{};
zero_value = ctx.u64_zero_value; if constexpr (bit_size == 64 && num_components == 1 && !is_float) {
result_type = ctx.U64; result_type = ctx.U64;
zero_value = ctx.u64_zero_value;
} else if constexpr (bit_size == 32) { } else if constexpr (bit_size == 32) {
zero_value = ctx.u32_zero_value; if (is_float) {
result_type = ctx.U32[1]; result_type = ctx.F32[num_components];
} else if constexpr (bit_size == 16) { zero_value = ctx.f32_zero_value;
zero_value = ctx.u16_zero_value; } else {
result_type = ctx.U32[num_components];
zero_value = ctx.u32_zero_value;
}
} else if constexpr (bit_size == 16 && num_components == 1 && !is_float) {
result_type = ctx.U16; result_type = ctx.U16;
zero_value = ctx.u16_zero_value;
} else if constexpr (bit_size == 8 && num_components == 1 && !is_float) {
result_type = ctx.U8;
zero_value = ctx.u8_zero_value;
} else { } else {
static_assert(false, "type not supported"); static_assert(false, "Type not supported.");
} }
if (num_components > 1) {
std::array<Id, num_components> zero_ids;
zero_ids.fill(zero_value);
zero_value = ctx.ConstantComposite(result_type, zero_ids);
}
return {result_type, zero_value};
}
template <u32 bit_size, u32 num_components = 1, bool is_float = false>
auto AccessBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, auto emit_func) {
if (Sirit::ValidId(buffer_size)) { if (Sirit::ValidId(buffer_size)) {
// Bounds checking enabled, wrap in a conditional branch to make sure that // Bounds checking enabled, wrap in a conditional branch to make sure that
// the atomic is not mistakenly executed when the index is out of bounds. // the atomic is not mistakenly executed when the index is out of bounds.
const Id in_bounds = ctx.OpULessThan(ctx.U1[1], index, buffer_size); auto compare_index = index;
if (num_components > 1) {
compare_index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(num_components - 1));
}
const Id in_bounds = ctx.OpULessThan(ctx.U1[1], compare_index, buffer_size);
const Id ib_label = ctx.OpLabel(); const Id ib_label = ctx.OpLabel();
const Id end_label = ctx.OpLabel(); const Id end_label = ctx.OpLabel();
ctx.OpSelectionMerge(end_label, spv::SelectionControlMask::MaskNone); ctx.OpSelectionMerge(end_label, spv::SelectionControlMask::MaskNone);
@ -36,6 +59,8 @@ auto AccessBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, auto emit_fun
ctx.OpBranch(end_label); ctx.OpBranch(end_label);
ctx.AddLabel(end_label); ctx.AddLabel(end_label);
if (Sirit::ValidId(ib_result)) { if (Sirit::ValidId(ib_result)) {
const auto [result_type, zero_value] =
ResolveTypeAndZero<bit_size, num_components, is_float>(ctx);
return ctx.OpPhi(result_type, ib_result, ib_label, zero_value, last_label); return ctx.OpPhi(result_type, ib_result, ib_label, zero_value, last_label);
} else { } else {
return Id{0}; return Id{0};
@ -45,4 +70,21 @@ auto AccessBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, auto emit_fun
return emit_func(); return emit_func();
} }
template <u32 bit_size, u32 num_components = 1, bool is_float = false>
static Id LoadAccessBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, Id result) {
if (Sirit::ValidId(buffer_size)) {
// Bounds checking enabled, wrap in a select.
auto compare_index = index;
if (num_components > 1) {
compare_index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(num_components - 1));
}
const Id in_bounds = ctx.OpULessThan(ctx.U1[1], compare_index, buffer_size);
const auto [result_type, zero_value] =
ResolveTypeAndZero<bit_size, num_components, is_float>(ctx);
return ctx.OpSelect(result_type, in_bounds, result, zero_value);
}
// Bounds checking not enabled, just return the plain value.
return result;
}
} // namespace Shader::Backend::SPIRV } // namespace Shader::Backend::SPIRV

View file

@ -11,6 +11,8 @@
#include <magic_enum/magic_enum.hpp> #include <magic_enum/magic_enum.hpp>
#include "emit_spirv_bounds.h"
namespace Shader::Backend::SPIRV { namespace Shader::Backend::SPIRV {
namespace { namespace {
@ -239,8 +241,8 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, Id index) {
} }
if (IR::IsParam(attr)) { if (IR::IsParam(attr)) {
const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; const u32 param_index{u32(attr) - u32(IR::Attribute::Param0)};
const auto& param{ctx.input_params.at(index)}; const auto& param{ctx.input_params.at(param_index)};
if (param.buffer_handle >= 0) { if (param.buffer_handle >= 0) {
const auto step_rate = EmitReadStepRate(ctx, param.id.value); const auto step_rate = EmitReadStepRate(ctx, param.id.value);
const auto offset = ctx.OpIAdd( const auto offset = ctx.OpIAdd(
@ -415,27 +417,6 @@ void EmitSetPatch(EmitContext& ctx, IR::Patch patch, Id value) {
ctx.OpStore(pointer, value); ctx.OpStore(pointer, value);
} }
template <u32 N>
static Id EmitLoadBufferBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, Id result,
bool is_float) {
if (Sirit::ValidId(buffer_size)) {
// Bounds checking enabled, wrap in a select.
const auto result_type = is_float ? ctx.F32[N] : ctx.U32[N];
auto compare_index = index;
auto zero_value = is_float ? ctx.f32_zero_value : ctx.u32_zero_value;
if (N > 1) {
compare_index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(N - 1));
std::array<Id, N> zero_ids;
zero_ids.fill(zero_value);
zero_value = ctx.ConstantComposite(result_type, zero_ids);
}
const Id in_bounds = ctx.OpULessThan(ctx.U1[1], compare_index, buffer_size);
return ctx.OpSelect(result_type, in_bounds, result, zero_value);
}
// Bounds checking not enabled, just return the plain value.
return result;
}
template <u32 N, PointerType alias> template <u32 N, PointerType alias>
static Id EmitLoadBufferB32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { static Id EmitLoadBufferB32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
const auto flags = inst->Flags<IR::BufferInstInfo>(); const auto flags = inst->Flags<IR::BufferInstInfo>();
@ -454,8 +435,9 @@ static Id EmitLoadBufferB32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id a
const Id result_i = ctx.OpLoad(data_types[1], ptr_i); const Id result_i = ctx.OpLoad(data_types[1], ptr_i);
if (!flags.typed) { if (!flags.typed) {
// Untyped loads have bounds checking per-component. // Untyped loads have bounds checking per-component.
ids.push_back(EmitLoadBufferBoundsCheck<1>(ctx, index_i, spv_buffer.size_dwords, ids.push_back(LoadAccessBoundsCheck < 32, 1,
result_i, alias == PointerType::F32)); alias ==
PointerType::F32 > (ctx, index_i, spv_buffer.size_dwords, result_i));
} else { } else {
ids.push_back(result_i); ids.push_back(result_i);
} }
@ -464,8 +446,8 @@ static Id EmitLoadBufferB32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id a
const Id result = N == 1 ? ids[0] : ctx.OpCompositeConstruct(data_types[N], ids); const Id result = N == 1 ? ids[0] : ctx.OpCompositeConstruct(data_types[N], ids);
if (flags.typed) { if (flags.typed) {
// Typed loads have single bounds check for the whole load. // Typed loads have single bounds check for the whole load.
return EmitLoadBufferBoundsCheck<N>(ctx, index, spv_buffer.size_dwords, result, return LoadAccessBoundsCheck < 32, N,
alias == PointerType::F32); alias == PointerType::F32 > (ctx, index, spv_buffer.size_dwords, result);
} }
return result; return result;
} }
@ -477,8 +459,8 @@ Id EmitLoadBufferU8(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
} }
const auto [id, pointer_type] = spv_buffer[PointerType::U8]; const auto [id, pointer_type] = spv_buffer[PointerType::U8];
const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, address)}; const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, address)};
const Id result{ctx.OpUConvert(ctx.U32[1], ctx.OpLoad(ctx.U8, ptr))}; const Id result{ctx.OpLoad(ctx.U8, ptr)};
return EmitLoadBufferBoundsCheck<1>(ctx, address, spv_buffer.size, result, false); return LoadAccessBoundsCheck<8>(ctx, address, spv_buffer.size, result);
} }
Id EmitLoadBufferU16(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { Id EmitLoadBufferU16(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
@ -489,8 +471,8 @@ Id EmitLoadBufferU16(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
const auto [id, pointer_type] = spv_buffer[PointerType::U16]; const auto [id, pointer_type] = spv_buffer[PointerType::U16];
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(1u)); const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(1u));
const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)}; const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)};
const Id result{ctx.OpUConvert(ctx.U32[1], ctx.OpLoad(ctx.U16, ptr))}; const Id result{ctx.OpLoad(ctx.U16, ptr)};
return EmitLoadBufferBoundsCheck<1>(ctx, index, spv_buffer.size_shorts, result, false); return LoadAccessBoundsCheck<16>(ctx, index, spv_buffer.size_shorts, result);
} }
Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
@ -509,6 +491,18 @@ Id EmitLoadBufferU32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address)
return EmitLoadBufferB32xN<4, PointerType::U32>(ctx, inst, handle, address); return EmitLoadBufferB32xN<4, PointerType::U32>(ctx, inst, handle, address);
} }
Id EmitLoadBufferU64(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
const auto& spv_buffer = ctx.buffers[handle];
if (Sirit::ValidId(spv_buffer.offset)) {
address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset);
}
const auto [id, pointer_type] = spv_buffer[PointerType::U64];
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(3u));
const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u64_zero_value, index)};
const Id result{ctx.OpLoad(ctx.U64, ptr)};
return LoadAccessBoundsCheck<64>(ctx, index, spv_buffer.size_qwords, result);
}
Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferB32xN<1, PointerType::F32>(ctx, inst, handle, address); return EmitLoadBufferB32xN<1, PointerType::F32>(ctx, inst, handle, address);
} }
@ -529,29 +523,6 @@ Id EmitLoadBufferFormatF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id addr
UNREACHABLE_MSG("SPIR-V instruction"); UNREACHABLE_MSG("SPIR-V instruction");
} }
template <u32 N>
void EmitStoreBufferBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, auto emit_func) {
if (Sirit::ValidId(buffer_size)) {
// Bounds checking enabled, wrap in a conditional branch.
auto compare_index = index;
if (N > 1) {
compare_index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(N - 1));
}
const Id in_bounds = ctx.OpULessThan(ctx.U1[1], compare_index, buffer_size);
const Id in_bounds_label = ctx.OpLabel();
const Id merge_label = ctx.OpLabel();
ctx.OpSelectionMerge(merge_label, spv::SelectionControlMask::MaskNone);
ctx.OpBranchConditional(in_bounds, in_bounds_label, merge_label);
ctx.AddLabel(in_bounds_label);
emit_func();
ctx.OpBranch(merge_label);
ctx.AddLabel(merge_label);
return;
}
// Bounds checking not enabled, just perform the store.
emit_func();
}
template <u32 N, PointerType alias> template <u32 N, PointerType alias>
static void EmitStoreBufferB32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, static void EmitStoreBufferB32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address,
Id value) { Id value) {
@ -569,19 +540,25 @@ static void EmitStoreBufferB32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, I
const Id index_i = i == 0 ? index : ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i)); const Id index_i = i == 0 ? index : ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id ptr_i = ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index_i); const Id ptr_i = ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index_i);
const Id value_i = N == 1 ? value : ctx.OpCompositeExtract(data_types[1], value, i); const Id value_i = N == 1 ? value : ctx.OpCompositeExtract(data_types[1], value, i);
auto store_i = [&]() { ctx.OpStore(ptr_i, value_i); }; auto store_i = [&] {
ctx.OpStore(ptr_i, value_i);
return Id{};
};
if (!flags.typed) { if (!flags.typed) {
// Untyped stores have bounds checking per-component. // Untyped stores have bounds checking per-component.
EmitStoreBufferBoundsCheck<1>(ctx, index_i, spv_buffer.size_dwords, store_i); AccessBoundsCheck<32, 1, alias == PointerType::F32>(
ctx, index_i, spv_buffer.size_dwords, store_i);
} else { } else {
store_i(); store_i();
} }
} }
return Id{};
}; };
if (flags.typed) { if (flags.typed) {
// Typed stores have single bounds check for the whole store. // Typed stores have single bounds check for the whole store.
EmitStoreBufferBoundsCheck<N>(ctx, index, spv_buffer.size_dwords, store); AccessBoundsCheck<32, N, alias == PointerType::F32>(ctx, index, spv_buffer.size_dwords,
store);
} else { } else {
store(); store();
} }
@ -594,8 +571,10 @@ void EmitStoreBufferU8(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id v
} }
const auto [id, pointer_type] = spv_buffer[PointerType::U8]; const auto [id, pointer_type] = spv_buffer[PointerType::U8];
const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, address)}; const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, address)};
const Id result{ctx.OpUConvert(ctx.U8, value)}; AccessBoundsCheck<8>(ctx, address, spv_buffer.size, [&] {
EmitStoreBufferBoundsCheck<1>(ctx, address, spv_buffer.size, [&] { ctx.OpStore(ptr, result); }); ctx.OpStore(ptr, value);
return Id{};
});
} }
void EmitStoreBufferU16(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) { void EmitStoreBufferU16(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) {
@ -606,9 +585,10 @@ void EmitStoreBufferU16(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id
const auto [id, pointer_type] = spv_buffer[PointerType::U16]; const auto [id, pointer_type] = spv_buffer[PointerType::U16];
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(1u)); const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(1u));
const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)}; const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)};
const Id result{ctx.OpUConvert(ctx.U16, value)}; AccessBoundsCheck<16>(ctx, index, spv_buffer.size_shorts, [&] {
EmitStoreBufferBoundsCheck<1>(ctx, index, spv_buffer.size_shorts, ctx.OpStore(ptr, value);
[&] { ctx.OpStore(ptr, result); }); return Id{};
});
} }
void EmitStoreBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { void EmitStoreBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) {
@ -627,6 +607,20 @@ void EmitStoreBufferU32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id addre
EmitStoreBufferB32xN<4, PointerType::U32>(ctx, inst, handle, address, value); EmitStoreBufferB32xN<4, PointerType::U32>(ctx, inst, handle, address, value);
} }
void EmitStoreBufferU64(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) {
const auto& spv_buffer = ctx.buffers[handle];
if (Sirit::ValidId(spv_buffer.offset)) {
address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset);
}
const auto [id, pointer_type] = spv_buffer[PointerType::U64];
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(3u));
const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u64_zero_value, index)};
AccessBoundsCheck<64>(ctx, index, spv_buffer.size_qwords, [&] {
ctx.OpStore(ptr, value);
return Id{};
});
}
void EmitStoreBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { void EmitStoreBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) {
EmitStoreBufferB32xN<1, PointerType::F32>(ctx, inst, handle, address, value); EmitStoreBufferB32xN<1, PointerType::F32>(ctx, inst, handle, address, value);
} }

View file

@ -263,4 +263,12 @@ Id EmitConvertU32U16(EmitContext& ctx, Id value) {
return ctx.OpUConvert(ctx.U32[1], value); return ctx.OpUConvert(ctx.U32[1], value);
} }
Id EmitConvertU8U32(EmitContext& ctx, Id value) {
return ctx.OpUConvert(ctx.U8, value);
}
Id EmitConvertU32U8(EmitContext& ctx, Id value) {
return ctx.OpUConvert(ctx.U32[1], value);
}
} // namespace Shader::Backend::SPIRV } // namespace Shader::Backend::SPIRV

View file

@ -69,6 +69,7 @@ Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferU32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address); Id EmitLoadBufferU32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferU32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address); Id EmitLoadBufferU32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferU32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address); Id EmitLoadBufferU32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferU64(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address); Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address); Id EmitLoadBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address); Id EmitLoadBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
@ -80,6 +81,7 @@ void EmitStoreBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address
void EmitStoreBufferU32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); void EmitStoreBufferU32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
void EmitStoreBufferU32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); void EmitStoreBufferU32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
void EmitStoreBufferU32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); void EmitStoreBufferU32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
void EmitStoreBufferU64(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
void EmitStoreBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); void EmitStoreBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
void EmitStoreBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); void EmitStoreBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
void EmitStoreBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); void EmitStoreBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
@ -87,12 +89,13 @@ void EmitStoreBufferF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id addre
void EmitStoreBufferFormatF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); void EmitStoreBufferFormatF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicIAdd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicIAdd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicIAdd64(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicIAdd64(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicISub32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicSMin32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicSMin32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicUMin32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicUMin32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicSMax32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicSMax32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicUMax32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicUMax32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicInc32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicInc32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitBufferAtomicDec32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicDec32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitBufferAtomicAnd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicAnd32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicOr32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicOr32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Id EmitBufferAtomicXor32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value); Id EmitBufferAtomicXor32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
@ -136,8 +139,8 @@ Id EmitSharedAtomicSMin32(EmitContext& ctx, Id offset, Id value);
Id EmitSharedAtomicAnd32(EmitContext& ctx, Id offset, Id value); Id EmitSharedAtomicAnd32(EmitContext& ctx, Id offset, Id value);
Id EmitSharedAtomicOr32(EmitContext& ctx, Id offset, Id value); Id EmitSharedAtomicOr32(EmitContext& ctx, Id offset, Id value);
Id EmitSharedAtomicXor32(EmitContext& ctx, Id offset, Id value); Id EmitSharedAtomicXor32(EmitContext& ctx, Id offset, Id value);
Id EmitSharedAtomicIIncrement32(EmitContext& ctx, Id offset); Id EmitSharedAtomicInc32(EmitContext& ctx, Id offset);
Id EmitSharedAtomicIDecrement32(EmitContext& ctx, Id offset); Id EmitSharedAtomicDec32(EmitContext& ctx, Id offset);
Id EmitSharedAtomicISub32(EmitContext& ctx, Id offset, Id value); Id EmitSharedAtomicISub32(EmitContext& ctx, Id offset, Id value);
Id EmitCompositeConstructU32x2(EmitContext& ctx, IR::Inst* inst, Id e1, Id e2); Id EmitCompositeConstructU32x2(EmitContext& ctx, IR::Inst* inst, Id e1, Id e2);
@ -461,6 +464,8 @@ Id EmitConvertF64U32(EmitContext& ctx, Id value);
Id EmitConvertF64U64(EmitContext& ctx, Id value); Id EmitConvertF64U64(EmitContext& ctx, Id value);
Id EmitConvertU16U32(EmitContext& ctx, Id value); Id EmitConvertU16U32(EmitContext& ctx, Id value);
Id EmitConvertU32U16(EmitContext& ctx, Id value); Id EmitConvertU32U16(EmitContext& ctx, Id value);
Id EmitConvertU8U32(EmitContext& ctx, Id value);
Id EmitConvertU32U8(EmitContext& ctx, Id value);
Id EmitImageSampleRaw(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address1, Id address2, Id EmitImageSampleRaw(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address1, Id address2,
Id address3, Id address4); Id address3, Id address4);

View file

@ -216,34 +216,38 @@ void Translator::DS_WRITE(int bit_size, bool is_signed, bool is_pair, bool strid
if (is_pair) { if (is_pair) {
const u32 adj = (bit_size == 32 ? 4 : 8) * (stride64 ? 64 : 1); const u32 adj = (bit_size == 32 ? 4 : 8) * (stride64 ? 64 : 1);
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset0 * adj))); const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset0 * adj)));
if (bit_size == 32) { if (bit_size == 64) {
ir.WriteShared(32, ir.GetVectorReg(data0), addr0);
} else {
ir.WriteShared(64, ir.WriteShared(64,
ir.PackUint2x32(ir.CompositeConstruct(ir.GetVectorReg(data0), ir.PackUint2x32(ir.CompositeConstruct(ir.GetVectorReg(data0),
ir.GetVectorReg(data0 + 1))), ir.GetVectorReg(data0 + 1))),
addr0); addr0);
} else if (bit_size == 32) {
ir.WriteShared(32, ir.GetVectorReg(data0), addr0);
} else if (bit_size == 16) {
ir.WriteShared(16, ir.UConvert(16, ir.GetVectorReg(data0)), addr0);
} }
const IR::U32 addr1 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset1 * adj))); const IR::U32 addr1 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset1 * adj)));
if (bit_size == 32) { if (bit_size == 64) {
ir.WriteShared(32, ir.GetVectorReg(data1), addr1);
} else {
ir.WriteShared(64, ir.WriteShared(64,
ir.PackUint2x32(ir.CompositeConstruct(ir.GetVectorReg(data1), ir.PackUint2x32(ir.CompositeConstruct(ir.GetVectorReg(data1),
ir.GetVectorReg(data1 + 1))), ir.GetVectorReg(data1 + 1))),
addr1); addr1);
} else if (bit_size == 32) {
ir.WriteShared(32, ir.GetVectorReg(data1), addr1);
} else if (bit_size == 16) {
ir.WriteShared(16, ir.UConvert(16, ir.GetVectorReg(data1)), addr1);
} }
} else if (bit_size == 64) {
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset));
const IR::Value data =
ir.CompositeConstruct(ir.GetVectorReg(data0), ir.GetVectorReg(data0 + 1));
ir.WriteShared(bit_size, ir.PackUint2x32(data), addr0);
} else if (bit_size == 16) {
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset));
ir.WriteShared(bit_size, ir.GetVectorReg(data0), addr0);
} else { } else {
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset)); const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset));
ir.WriteShared(bit_size, ir.GetVectorReg(data0), addr0); if (bit_size == 64) {
const IR::Value data =
ir.CompositeConstruct(ir.GetVectorReg(data0), ir.GetVectorReg(data0 + 1));
ir.WriteShared(bit_size, ir.PackUint2x32(data), addr0);
} else if (bit_size == 32) {
ir.WriteShared(bit_size, ir.GetVectorReg(data0), addr0);
} else if (bit_size == 16) {
ir.WriteShared(bit_size, ir.UConvert(16, ir.GetVectorReg(data0)), addr0);
}
} }
} }
@ -264,7 +268,7 @@ void Translator::DS_INC_U32(const GcnInst& inst, bool rtn) {
const IR::U32 offset = const IR::U32 offset =
ir.Imm32((u32(inst.control.ds.offset1) << 8u) + u32(inst.control.ds.offset0)); ir.Imm32((u32(inst.control.ds.offset1) << 8u) + u32(inst.control.ds.offset0));
const IR::U32 addr_offset = ir.IAdd(addr, offset); const IR::U32 addr_offset = ir.IAdd(addr, offset);
const IR::Value original_val = ir.SharedAtomicIIncrement(addr_offset); const IR::Value original_val = ir.SharedAtomicInc(addr_offset);
if (rtn) { if (rtn) {
SetDst(inst.dst[0], IR::U32{original_val}); SetDst(inst.dst[0], IR::U32{original_val});
} }
@ -275,7 +279,7 @@ void Translator::DS_DEC_U32(const GcnInst& inst, bool rtn) {
const IR::U32 offset = const IR::U32 offset =
ir.Imm32((u32(inst.control.ds.offset1) << 8u) + u32(inst.control.ds.offset0)); ir.Imm32((u32(inst.control.ds.offset1) << 8u) + u32(inst.control.ds.offset0));
const IR::U32 addr_offset = ir.IAdd(addr, offset); const IR::U32 addr_offset = ir.IAdd(addr, offset);
const IR::Value original_val = ir.SharedAtomicIDecrement(addr_offset); const IR::Value original_val = ir.SharedAtomicDec(addr_offset);
if (rtn) { if (rtn) {
SetDst(inst.dst[0], IR::U32{original_val}); SetDst(inst.dst[0], IR::U32{original_val});
} }
@ -309,36 +313,38 @@ void Translator::DS_READ(int bit_size, bool is_signed, bool is_pair, bool stride
const u32 adj = (bit_size == 32 ? 4 : 8) * (stride64 ? 64 : 1); const u32 adj = (bit_size == 32 ? 4 : 8) * (stride64 ? 64 : 1);
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset0 * adj))); const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset0 * adj)));
const IR::Value data0 = ir.LoadShared(bit_size, is_signed, addr0); const IR::Value data0 = ir.LoadShared(bit_size, is_signed, addr0);
if (bit_size == 32) { if (bit_size == 64) {
ir.SetVectorReg(dst_reg++, IR::U32{data0});
} else {
const auto vector = ir.UnpackUint2x32(IR::U64{data0}); const auto vector = ir.UnpackUint2x32(IR::U64{data0});
ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 0)}); ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 0)});
ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 1)}); ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 1)});
} else if (bit_size == 32) {
ir.SetVectorReg(dst_reg++, IR::U32{data0});
} else if (bit_size == 16) {
ir.SetVectorReg(dst_reg++, IR::U32{ir.UConvert(32, IR::U16{data0})});
} }
const IR::U32 addr1 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset1 * adj))); const IR::U32 addr1 = ir.IAdd(addr, ir.Imm32(u32(inst.control.ds.offset1 * adj)));
const IR::Value data1 = ir.LoadShared(bit_size, is_signed, addr1); const IR::Value data1 = ir.LoadShared(bit_size, is_signed, addr1);
if (bit_size == 32) { if (bit_size == 64) {
ir.SetVectorReg(dst_reg++, IR::U32{data1});
} else {
const auto vector = ir.UnpackUint2x32(IR::U64{data1}); const auto vector = ir.UnpackUint2x32(IR::U64{data1});
ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 0)}); ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 0)});
ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 1)}); ir.SetVectorReg(dst_reg++, IR::U32{ir.CompositeExtract(vector, 1)});
} else if (bit_size == 32) {
ir.SetVectorReg(dst_reg++, IR::U32{data1});
} else if (bit_size == 16) {
ir.SetVectorReg(dst_reg++, IR::U32{ir.UConvert(32, IR::U16{data1})});
} }
} else if (bit_size == 64) {
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset));
const IR::Value data = ir.LoadShared(bit_size, is_signed, addr0);
const auto vector = ir.UnpackUint2x32(IR::U64{data});
ir.SetVectorReg(dst_reg, IR::U32{ir.CompositeExtract(vector, 0)});
ir.SetVectorReg(dst_reg + 1, IR::U32{ir.CompositeExtract(vector, 1)});
} else if (bit_size == 16) {
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset));
const IR::U16 data = IR::U16{ir.LoadShared(bit_size, is_signed, addr0)};
ir.SetVectorReg(dst_reg, ir.UConvert(32, data));
} else { } else {
const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset)); const IR::U32 addr0 = ir.IAdd(addr, ir.Imm32(offset));
const IR::U32 data = IR::U32{ir.LoadShared(bit_size, is_signed, addr0)}; const IR::Value data = ir.LoadShared(bit_size, is_signed, addr0);
ir.SetVectorReg(dst_reg, data); if (bit_size == 64) {
const auto vector = ir.UnpackUint2x32(IR::U64{data});
ir.SetVectorReg(dst_reg, IR::U32{ir.CompositeExtract(vector, 0)});
ir.SetVectorReg(dst_reg + 1, IR::U32{ir.CompositeExtract(vector, 1)});
} else if (bit_size == 32) {
ir.SetVectorReg(dst_reg, IR::U32{data});
} else if (bit_size == 16) {
ir.SetVectorReg(dst_reg++, IR::U32{ir.UConvert(32, IR::U16{data})});
}
} }
} }

View file

@ -354,9 +354,9 @@ void Translator::BUFFER_ATOMIC(AtomicOp op, const GcnInst& inst) {
case AtomicOp::Xor: case AtomicOp::Xor:
return ir.BufferAtomicXor(handle, address, vdata_val, buffer_info); return ir.BufferAtomicXor(handle, address, vdata_val, buffer_info);
case AtomicOp::Inc: case AtomicOp::Inc:
return ir.BufferAtomicInc(handle, address, vdata_val, buffer_info); return ir.BufferAtomicInc(handle, address, buffer_info);
case AtomicOp::Dec: case AtomicOp::Dec:
return ir.BufferAtomicDec(handle, address, vdata_val, buffer_info); return ir.BufferAtomicDec(handle, address, buffer_info);
default: default:
UNREACHABLE(); UNREACHABLE();
} }

View file

@ -353,12 +353,12 @@ U32 IREmitter::SharedAtomicXor(const U32& address, const U32& data) {
return Inst<U32>(Opcode::SharedAtomicXor32, address, data); return Inst<U32>(Opcode::SharedAtomicXor32, address, data);
} }
U32 IREmitter::SharedAtomicIIncrement(const U32& address) { U32 IREmitter::SharedAtomicInc(const U32& address) {
return Inst<U32>(Opcode::SharedAtomicIIncrement32, address); return Inst<U32>(Opcode::SharedAtomicInc32, address);
} }
U32 IREmitter::SharedAtomicIDecrement(const U32& address) { U32 IREmitter::SharedAtomicDec(const U32& address) {
return Inst<U32>(Opcode::SharedAtomicIDecrement32, address); return Inst<U32>(Opcode::SharedAtomicDec32, address);
} }
U32 IREmitter::SharedAtomicISub(const U32& address, const U32& data) { U32 IREmitter::SharedAtomicISub(const U32& address, const U32& data) {
@ -373,12 +373,12 @@ U32 IREmitter::ReadConstBuffer(const Value& handle, const U32& index) {
return Inst<U32>(Opcode::ReadConstBuffer, handle, index); return Inst<U32>(Opcode::ReadConstBuffer, handle, index);
} }
U32 IREmitter::LoadBufferU8(const Value& handle, const Value& address, BufferInstInfo info) { U8 IREmitter::LoadBufferU8(const Value& handle, const Value& address, BufferInstInfo info) {
return Inst<U32>(Opcode::LoadBufferU8, Flags{info}, handle, address); return Inst<U8>(Opcode::LoadBufferU8, Flags{info}, handle, address);
} }
U32 IREmitter::LoadBufferU16(const Value& handle, const Value& address, BufferInstInfo info) { U16 IREmitter::LoadBufferU16(const Value& handle, const Value& address, BufferInstInfo info) {
return Inst<U32>(Opcode::LoadBufferU16, Flags{info}, handle, address); return Inst<U16>(Opcode::LoadBufferU16, Flags{info}, handle, address);
} }
Value IREmitter::LoadBufferU32(int num_dwords, const Value& handle, const Value& address, Value IREmitter::LoadBufferU32(int num_dwords, const Value& handle, const Value& address,
@ -397,6 +397,10 @@ Value IREmitter::LoadBufferU32(int num_dwords, const Value& handle, const Value&
} }
} }
U64 IREmitter::LoadBufferU64(const Value& handle, const Value& address, BufferInstInfo info) {
return Inst<U64>(Opcode::LoadBufferU64, Flags{info}, handle, address);
}
Value IREmitter::LoadBufferF32(int num_dwords, const Value& handle, const Value& address, Value IREmitter::LoadBufferF32(int num_dwords, const Value& handle, const Value& address,
BufferInstInfo info) { BufferInstInfo info) {
switch (num_dwords) { switch (num_dwords) {
@ -417,12 +421,12 @@ Value IREmitter::LoadBufferFormat(const Value& handle, const Value& address, Buf
return Inst(Opcode::LoadBufferFormatF32, Flags{info}, handle, address); return Inst(Opcode::LoadBufferFormatF32, Flags{info}, handle, address);
} }
void IREmitter::StoreBufferU8(const Value& handle, const Value& address, const U32& data, void IREmitter::StoreBufferU8(const Value& handle, const Value& address, const U8& data,
BufferInstInfo info) { BufferInstInfo info) {
Inst(Opcode::StoreBufferU8, Flags{info}, handle, address, data); Inst(Opcode::StoreBufferU8, Flags{info}, handle, address, data);
} }
void IREmitter::StoreBufferU16(const Value& handle, const Value& address, const U32& data, void IREmitter::StoreBufferU16(const Value& handle, const Value& address, const U16& data,
BufferInstInfo info) { BufferInstInfo info) {
Inst(Opcode::StoreBufferU16, Flags{info}, handle, address, data); Inst(Opcode::StoreBufferU16, Flags{info}, handle, address, data);
} }
@ -447,6 +451,11 @@ void IREmitter::StoreBufferU32(int num_dwords, const Value& handle, const Value&
} }
} }
void IREmitter::StoreBufferU64(const Value& handle, const Value& address, const U64& data,
BufferInstInfo info) {
Inst(Opcode::StoreBufferU64, Flags{info}, handle, address, data);
}
void IREmitter::StoreBufferF32(int num_dwords, const Value& handle, const Value& address, void IREmitter::StoreBufferF32(int num_dwords, const Value& handle, const Value& address,
const Value& data, BufferInstInfo info) { const Value& data, BufferInstInfo info) {
switch (num_dwords) { switch (num_dwords) {
@ -474,7 +483,19 @@ void IREmitter::StoreBufferFormat(const Value& handle, const Value& address, con
Value IREmitter::BufferAtomicIAdd(const Value& handle, const Value& address, const Value& value, Value IREmitter::BufferAtomicIAdd(const Value& handle, const Value& address, const Value& value,
BufferInstInfo info) { BufferInstInfo info) {
return Inst(Opcode::BufferAtomicIAdd32, Flags{info}, handle, address, value); switch (value.Type()) {
case Type::U32:
return Inst(Opcode::BufferAtomicIAdd32, Flags{info}, handle, address, value);
case Type::U64:
return Inst(Opcode::BufferAtomicIAdd64, Flags{info}, handle, address, value);
default:
ThrowInvalidType(value.Type());
}
}
Value IREmitter::BufferAtomicISub(const Value& handle, const Value& address, const Value& value,
BufferInstInfo info) {
return Inst(Opcode::BufferAtomicISub32, Flags{info}, handle, address, value);
} }
Value IREmitter::BufferAtomicIMin(const Value& handle, const Value& address, const Value& value, Value IREmitter::BufferAtomicIMin(const Value& handle, const Value& address, const Value& value,
@ -489,14 +510,12 @@ Value IREmitter::BufferAtomicIMax(const Value& handle, const Value& address, con
: Inst(Opcode::BufferAtomicUMax32, Flags{info}, handle, address, value); : Inst(Opcode::BufferAtomicUMax32, Flags{info}, handle, address, value);
} }
Value IREmitter::BufferAtomicInc(const Value& handle, const Value& address, const Value& value, Value IREmitter::BufferAtomicInc(const Value& handle, const Value& address, BufferInstInfo info) {
BufferInstInfo info) { return Inst(Opcode::BufferAtomicInc32, Flags{info}, handle, address);
return Inst(Opcode::BufferAtomicInc32, Flags{info}, handle, address, value);
} }
Value IREmitter::BufferAtomicDec(const Value& handle, const Value& address, const Value& value, Value IREmitter::BufferAtomicDec(const Value& handle, const Value& address, BufferInstInfo info) {
BufferInstInfo info) { return Inst(Opcode::BufferAtomicDec32, Flags{info}, handle, address);
return Inst(Opcode::BufferAtomicDec32, Flags{info}, handle, address, value);
} }
Value IREmitter::BufferAtomicAnd(const Value& handle, const Value& address, const Value& value, Value IREmitter::BufferAtomicAnd(const Value& handle, const Value& address, const Value& value,
@ -1804,8 +1823,15 @@ F32F64 IREmitter::ConvertIToF(size_t dest_bitsize, size_t src_bitsize, bool is_s
: ConvertUToF(dest_bitsize, src_bitsize, value); : ConvertUToF(dest_bitsize, src_bitsize, value);
} }
U16U32U64 IREmitter::UConvert(size_t result_bitsize, const U16U32U64& value) { U8U16U32U64 IREmitter::UConvert(size_t result_bitsize, const U8U16U32U64& value) {
switch (result_bitsize) { switch (result_bitsize) {
case 8:
switch (value.Type()) {
case Type::U32:
return Inst<U8>(Opcode::ConvertU8U32, value);
default:
break;
}
case 16: case 16:
switch (value.Type()) { switch (value.Type()) {
case Type::U32: case Type::U32:
@ -1815,6 +1841,8 @@ U16U32U64 IREmitter::UConvert(size_t result_bitsize, const U16U32U64& value) {
} }
case 32: case 32:
switch (value.Type()) { switch (value.Type()) {
case Type::U8:
return Inst<U32>(Opcode::ConvertU32U8, value);
case Type::U16: case Type::U16:
return Inst<U32>(Opcode::ConvertU32U16, value); return Inst<U32>(Opcode::ConvertU32U16, value);
default: default:

View file

@ -100,33 +100,35 @@ public:
void WriteShared(int bit_size, const Value& value, const U32& offset); void WriteShared(int bit_size, const Value& value, const U32& offset);
[[nodiscard]] U32U64 SharedAtomicIAdd(const U32& address, const U32U64& data); [[nodiscard]] U32U64 SharedAtomicIAdd(const U32& address, const U32U64& data);
[[nodiscard]] U32 SharedAtomicISub(const U32& address, const U32& data);
[[nodiscard]] U32 SharedAtomicIMin(const U32& address, const U32& data, bool is_signed); [[nodiscard]] U32 SharedAtomicIMin(const U32& address, const U32& data, bool is_signed);
[[nodiscard]] U32 SharedAtomicIMax(const U32& address, const U32& data, bool is_signed); [[nodiscard]] U32 SharedAtomicIMax(const U32& address, const U32& data, bool is_signed);
[[nodiscard]] U32 SharedAtomicInc(const U32& address);
[[nodiscard]] U32 SharedAtomicDec(const U32& address);
[[nodiscard]] U32 SharedAtomicAnd(const U32& address, const U32& data); [[nodiscard]] U32 SharedAtomicAnd(const U32& address, const U32& data);
[[nodiscard]] U32 SharedAtomicOr(const U32& address, const U32& data); [[nodiscard]] U32 SharedAtomicOr(const U32& address, const U32& data);
[[nodiscard]] U32 SharedAtomicXor(const U32& address, const U32& data); [[nodiscard]] U32 SharedAtomicXor(const U32& address, const U32& data);
[[nodiscard]] U32 SharedAtomicIIncrement(const U32& address);
[[nodiscard]] U32 SharedAtomicIDecrement(const U32& address);
[[nodiscard]] U32 SharedAtomicISub(const U32& address, const U32& data);
[[nodiscard]] U32 ReadConst(const Value& base, const U32& offset); [[nodiscard]] U32 ReadConst(const Value& base, const U32& offset);
[[nodiscard]] U32 ReadConstBuffer(const Value& handle, const U32& index); [[nodiscard]] U32 ReadConstBuffer(const Value& handle, const U32& index);
[[nodiscard]] U32 LoadBufferU8(const Value& handle, const Value& address, BufferInstInfo info); [[nodiscard]] U8 LoadBufferU8(const Value& handle, const Value& address, BufferInstInfo info);
[[nodiscard]] U32 LoadBufferU16(const Value& handle, const Value& address, BufferInstInfo info); [[nodiscard]] U16 LoadBufferU16(const Value& handle, const Value& address, BufferInstInfo info);
[[nodiscard]] Value LoadBufferU32(int num_dwords, const Value& handle, const Value& address, [[nodiscard]] Value LoadBufferU32(int num_dwords, const Value& handle, const Value& address,
BufferInstInfo info); BufferInstInfo info);
[[nodiscard]] U64 LoadBufferU64(const Value& handle, const Value& address, BufferInstInfo info);
[[nodiscard]] Value LoadBufferF32(int num_dwords, const Value& handle, const Value& address, [[nodiscard]] Value LoadBufferF32(int num_dwords, const Value& handle, const Value& address,
BufferInstInfo info); BufferInstInfo info);
[[nodiscard]] Value LoadBufferFormat(const Value& handle, const Value& address, [[nodiscard]] Value LoadBufferFormat(const Value& handle, const Value& address,
BufferInstInfo info); BufferInstInfo info);
void StoreBufferU8(const Value& handle, const Value& address, const U32& data, void StoreBufferU8(const Value& handle, const Value& address, const U8& data,
BufferInstInfo info); BufferInstInfo info);
void StoreBufferU16(const Value& handle, const Value& address, const U32& data, void StoreBufferU16(const Value& handle, const Value& address, const U16& data,
BufferInstInfo info); BufferInstInfo info);
void StoreBufferU32(int num_dwords, const Value& handle, const Value& address, void StoreBufferU32(int num_dwords, const Value& handle, const Value& address,
const Value& data, BufferInstInfo info); const Value& data, BufferInstInfo info);
void StoreBufferU64(const Value& handle, const Value& address, const U64& data,
BufferInstInfo info);
void StoreBufferF32(int num_dwords, const Value& handle, const Value& address, void StoreBufferF32(int num_dwords, const Value& handle, const Value& address,
const Value& data, BufferInstInfo info); const Value& data, BufferInstInfo info);
void StoreBufferFormat(const Value& handle, const Value& address, const Value& data, void StoreBufferFormat(const Value& handle, const Value& address, const Value& data,
@ -134,14 +136,16 @@ public:
[[nodiscard]] Value BufferAtomicIAdd(const Value& handle, const Value& address, [[nodiscard]] Value BufferAtomicIAdd(const Value& handle, const Value& address,
const Value& value, BufferInstInfo info); const Value& value, BufferInstInfo info);
[[nodiscard]] Value BufferAtomicISub(const Value& handle, const Value& address,
const Value& value, BufferInstInfo info);
[[nodiscard]] Value BufferAtomicIMin(const Value& handle, const Value& address, [[nodiscard]] Value BufferAtomicIMin(const Value& handle, const Value& address,
const Value& value, bool is_signed, BufferInstInfo info); const Value& value, bool is_signed, BufferInstInfo info);
[[nodiscard]] Value BufferAtomicIMax(const Value& handle, const Value& address, [[nodiscard]] Value BufferAtomicIMax(const Value& handle, const Value& address,
const Value& value, bool is_signed, BufferInstInfo info); const Value& value, bool is_signed, BufferInstInfo info);
[[nodiscard]] Value BufferAtomicInc(const Value& handle, const Value& address, [[nodiscard]] Value BufferAtomicInc(const Value& handle, const Value& address,
const Value& value, BufferInstInfo info); BufferInstInfo info);
[[nodiscard]] Value BufferAtomicDec(const Value& handle, const Value& address, [[nodiscard]] Value BufferAtomicDec(const Value& handle, const Value& address,
const Value& value, BufferInstInfo info); BufferInstInfo info);
[[nodiscard]] Value BufferAtomicAnd(const Value& handle, const Value& address, [[nodiscard]] Value BufferAtomicAnd(const Value& handle, const Value& address,
const Value& value, BufferInstInfo info); const Value& value, BufferInstInfo info);
[[nodiscard]] Value BufferAtomicOr(const Value& handle, const Value& address, [[nodiscard]] Value BufferAtomicOr(const Value& handle, const Value& address,
@ -309,7 +313,7 @@ public:
[[nodiscard]] F32F64 ConvertIToF(size_t dest_bitsize, size_t src_bitsize, bool is_signed, [[nodiscard]] F32F64 ConvertIToF(size_t dest_bitsize, size_t src_bitsize, bool is_signed,
const Value& value); const Value& value);
[[nodiscard]] U16U32U64 UConvert(size_t result_bitsize, const U16U32U64& value); [[nodiscard]] U8U16U32U64 UConvert(size_t result_bitsize, const U8U16U32U64& value);
[[nodiscard]] F16F32F64 FPConvert(size_t result_bitsize, const F16F32F64& value); [[nodiscard]] F16F32F64 FPConvert(size_t result_bitsize, const F16F32F64& value);
[[nodiscard]] Value ImageAtomicIAdd(const Value& handle, const Value& coords, [[nodiscard]] Value ImageAtomicIAdd(const Value& handle, const Value& coords,

View file

@ -60,12 +60,15 @@ bool Inst::MayHaveSideEffects() const noexcept {
case Opcode::StoreBufferU32x2: case Opcode::StoreBufferU32x2:
case Opcode::StoreBufferU32x3: case Opcode::StoreBufferU32x3:
case Opcode::StoreBufferU32x4: case Opcode::StoreBufferU32x4:
case Opcode::StoreBufferU64:
case Opcode::StoreBufferF32: case Opcode::StoreBufferF32:
case Opcode::StoreBufferF32x2: case Opcode::StoreBufferF32x2:
case Opcode::StoreBufferF32x3: case Opcode::StoreBufferF32x3:
case Opcode::StoreBufferF32x4: case Opcode::StoreBufferF32x4:
case Opcode::StoreBufferFormatF32: case Opcode::StoreBufferFormatF32:
case Opcode::BufferAtomicIAdd32: case Opcode::BufferAtomicIAdd32:
case Opcode::BufferAtomicIAdd64:
case Opcode::BufferAtomicISub32:
case Opcode::BufferAtomicSMin32: case Opcode::BufferAtomicSMin32:
case Opcode::BufferAtomicUMin32: case Opcode::BufferAtomicUMin32:
case Opcode::BufferAtomicSMax32: case Opcode::BufferAtomicSMax32:
@ -76,15 +79,21 @@ bool Inst::MayHaveSideEffects() const noexcept {
case Opcode::BufferAtomicOr32: case Opcode::BufferAtomicOr32:
case Opcode::BufferAtomicXor32: case Opcode::BufferAtomicXor32:
case Opcode::BufferAtomicSwap32: case Opcode::BufferAtomicSwap32:
case Opcode::BufferAtomicCmpSwap32:
case Opcode::DataAppend: case Opcode::DataAppend:
case Opcode::DataConsume: case Opcode::DataConsume:
case Opcode::WriteSharedU64: case Opcode::WriteSharedU16:
case Opcode::WriteSharedU32: case Opcode::WriteSharedU32:
case Opcode::WriteSharedU64:
case Opcode::SharedAtomicIAdd32: case Opcode::SharedAtomicIAdd32:
case Opcode::SharedAtomicIAdd64:
case Opcode::SharedAtomicISub32:
case Opcode::SharedAtomicSMin32: case Opcode::SharedAtomicSMin32:
case Opcode::SharedAtomicUMin32: case Opcode::SharedAtomicUMin32:
case Opcode::SharedAtomicSMax32: case Opcode::SharedAtomicSMax32:
case Opcode::SharedAtomicUMax32: case Opcode::SharedAtomicUMax32:
case Opcode::SharedAtomicInc32:
case Opcode::SharedAtomicDec32:
case Opcode::SharedAtomicAnd32: case Opcode::SharedAtomicAnd32:
case Opcode::SharedAtomicOr32: case Opcode::SharedAtomicOr32:
case Opcode::SharedAtomicXor32: case Opcode::SharedAtomicXor32:

View file

@ -35,21 +35,21 @@ OPCODE(LoadSharedU32, U32, U32,
OPCODE(LoadSharedU64, U64, U32, ) OPCODE(LoadSharedU64, U64, U32, )
OPCODE(WriteSharedU16, Void, U32, U16, ) OPCODE(WriteSharedU16, Void, U32, U16, )
OPCODE(WriteSharedU32, Void, U32, U32, ) OPCODE(WriteSharedU32, Void, U32, U32, )
OPCODE(WriteSharedU64, Void, U32, U64, ) OPCODE(WriteSharedU64, Void, U32, U64, )
// Shared atomic operations // Shared atomic operations
OPCODE(SharedAtomicIAdd32, U32, U32, U32, ) OPCODE(SharedAtomicIAdd32, U32, U32, U32, )
OPCODE(SharedAtomicIAdd64, U64, U32, U64, ) OPCODE(SharedAtomicIAdd64, U64, U32, U64, )
OPCODE(SharedAtomicISub32, U32, U32, U32, )
OPCODE(SharedAtomicSMin32, U32, U32, U32, ) OPCODE(SharedAtomicSMin32, U32, U32, U32, )
OPCODE(SharedAtomicUMin32, U32, U32, U32, ) OPCODE(SharedAtomicUMin32, U32, U32, U32, )
OPCODE(SharedAtomicSMax32, U32, U32, U32, ) OPCODE(SharedAtomicSMax32, U32, U32, U32, )
OPCODE(SharedAtomicUMax32, U32, U32, U32, ) OPCODE(SharedAtomicUMax32, U32, U32, U32, )
OPCODE(SharedAtomicInc32, U32, U32, )
OPCODE(SharedAtomicDec32, U32, U32, )
OPCODE(SharedAtomicAnd32, U32, U32, U32, ) OPCODE(SharedAtomicAnd32, U32, U32, U32, )
OPCODE(SharedAtomicOr32, U32, U32, U32, ) OPCODE(SharedAtomicOr32, U32, U32, U32, )
OPCODE(SharedAtomicXor32, U32, U32, U32, ) OPCODE(SharedAtomicXor32, U32, U32, U32, )
OPCODE(SharedAtomicISub32, U32, U32, U32, )
OPCODE(SharedAtomicIIncrement32, U32, U32, )
OPCODE(SharedAtomicIDecrement32, U32, U32, )
// Context getters/setters // Context getters/setters
OPCODE(GetUserData, U32, ScalarReg, ) OPCODE(GetUserData, U32, ScalarReg, )
@ -94,23 +94,25 @@ OPCODE(UndefU32, U32,
OPCODE(UndefU64, U64, ) OPCODE(UndefU64, U64, )
// Buffer operations // Buffer operations
OPCODE(LoadBufferU8, U32, Opaque, Opaque, ) OPCODE(LoadBufferU8, U8, Opaque, Opaque, )
OPCODE(LoadBufferU16, U32, Opaque, Opaque, ) OPCODE(LoadBufferU16, U16, Opaque, Opaque, )
OPCODE(LoadBufferU32, U32, Opaque, Opaque, ) OPCODE(LoadBufferU32, U32, Opaque, Opaque, )
OPCODE(LoadBufferU32x2, U32x2, Opaque, Opaque, ) OPCODE(LoadBufferU32x2, U32x2, Opaque, Opaque, )
OPCODE(LoadBufferU32x3, U32x3, Opaque, Opaque, ) OPCODE(LoadBufferU32x3, U32x3, Opaque, Opaque, )
OPCODE(LoadBufferU32x4, U32x4, Opaque, Opaque, ) OPCODE(LoadBufferU32x4, U32x4, Opaque, Opaque, )
OPCODE(LoadBufferU64, U64, Opaque, Opaque, )
OPCODE(LoadBufferF32, F32, Opaque, Opaque, ) OPCODE(LoadBufferF32, F32, Opaque, Opaque, )
OPCODE(LoadBufferF32x2, F32x2, Opaque, Opaque, ) OPCODE(LoadBufferF32x2, F32x2, Opaque, Opaque, )
OPCODE(LoadBufferF32x3, F32x3, Opaque, Opaque, ) OPCODE(LoadBufferF32x3, F32x3, Opaque, Opaque, )
OPCODE(LoadBufferF32x4, F32x4, Opaque, Opaque, ) OPCODE(LoadBufferF32x4, F32x4, Opaque, Opaque, )
OPCODE(LoadBufferFormatF32, F32x4, Opaque, Opaque, ) OPCODE(LoadBufferFormatF32, F32x4, Opaque, Opaque, )
OPCODE(StoreBufferU8, Void, Opaque, Opaque, U32, ) OPCODE(StoreBufferU8, Void, Opaque, Opaque, U8, )
OPCODE(StoreBufferU16, Void, Opaque, Opaque, U32, ) OPCODE(StoreBufferU16, Void, Opaque, Opaque, U16, )
OPCODE(StoreBufferU32, Void, Opaque, Opaque, U32, ) OPCODE(StoreBufferU32, Void, Opaque, Opaque, U32, )
OPCODE(StoreBufferU32x2, Void, Opaque, Opaque, U32x2, ) OPCODE(StoreBufferU32x2, Void, Opaque, Opaque, U32x2, )
OPCODE(StoreBufferU32x3, Void, Opaque, Opaque, U32x3, ) OPCODE(StoreBufferU32x3, Void, Opaque, Opaque, U32x3, )
OPCODE(StoreBufferU32x4, Void, Opaque, Opaque, U32x4, ) OPCODE(StoreBufferU32x4, Void, Opaque, Opaque, U32x4, )
OPCODE(StoreBufferU64, Void, Opaque, Opaque, U64, )
OPCODE(StoreBufferF32, Void, Opaque, Opaque, F32, ) OPCODE(StoreBufferF32, Void, Opaque, Opaque, F32, )
OPCODE(StoreBufferF32x2, Void, Opaque, Opaque, F32x2, ) OPCODE(StoreBufferF32x2, Void, Opaque, Opaque, F32x2, )
OPCODE(StoreBufferF32x3, Void, Opaque, Opaque, F32x3, ) OPCODE(StoreBufferF32x3, Void, Opaque, Opaque, F32x3, )
@ -120,12 +122,13 @@ OPCODE(StoreBufferFormatF32, Void, Opaq
// Buffer atomic operations // Buffer atomic operations
OPCODE(BufferAtomicIAdd32, U32, Opaque, Opaque, U32 ) OPCODE(BufferAtomicIAdd32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicIAdd64, U64, Opaque, Opaque, U64 ) OPCODE(BufferAtomicIAdd64, U64, Opaque, Opaque, U64 )
OPCODE(BufferAtomicISub32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicSMin32, U32, Opaque, Opaque, U32 ) OPCODE(BufferAtomicSMin32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicUMin32, U32, Opaque, Opaque, U32 ) OPCODE(BufferAtomicUMin32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicSMax32, U32, Opaque, Opaque, U32 ) OPCODE(BufferAtomicSMax32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicUMax32, U32, Opaque, Opaque, U32 ) OPCODE(BufferAtomicUMax32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicInc32, U32, Opaque, Opaque, U32, ) OPCODE(BufferAtomicInc32, U32, Opaque, Opaque, )
OPCODE(BufferAtomicDec32, U32, Opaque, Opaque, U32, ) OPCODE(BufferAtomicDec32, U32, Opaque, Opaque, )
OPCODE(BufferAtomicAnd32, U32, Opaque, Opaque, U32, ) OPCODE(BufferAtomicAnd32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicOr32, U32, Opaque, Opaque, U32, ) OPCODE(BufferAtomicOr32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicXor32, U32, Opaque, Opaque, U32, ) OPCODE(BufferAtomicXor32, U32, Opaque, Opaque, U32, )
@ -405,6 +408,8 @@ OPCODE(ConvertF64U32, F64, U32,
OPCODE(ConvertF32U16, F32, U16, ) OPCODE(ConvertF32U16, F32, U16, )
OPCODE(ConvertU16U32, U16, U32, ) OPCODE(ConvertU16U32, U16, U32, )
OPCODE(ConvertU32U16, U32, U16, ) OPCODE(ConvertU32U16, U32, U16, )
OPCODE(ConvertU8U32, U8, U32, )
OPCODE(ConvertU32U8, U32, U8, )
// Image operations // Image operations
OPCODE(ImageSampleRaw, F32x4, Opaque, F32x4, F32x4, F32x4, F32, ) OPCODE(ImageSampleRaw, F32x4, Opaque, F32x4, F32x4, F32x4, F32, )

View file

@ -438,7 +438,9 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
IR::IREmitter ir{*block, IR::Block::InstructionList::s_iterator_to(inst)}; IR::IREmitter ir{*block, IR::Block::InstructionList::s_iterator_to(inst)};
const u32 num_dwords = opcode == IR::Opcode::WriteSharedU32 ? 1 : 2; const u32 num_dwords = opcode == IR::Opcode::WriteSharedU32 ? 1 : 2;
const IR::U32 addr{inst.Arg(0)}; const IR::U32 addr{inst.Arg(0)};
const IR::U32 data{inst.Arg(1).Resolve()}; const IR::Value data = num_dwords == 2
? ir.UnpackUint2x32(IR::U64{inst.Arg(1).Resolve()})
: inst.Arg(1).Resolve();
const auto SetOutput = [&](IR::U32 addr, IR::U32 value, AttributeRegion output_kind, const auto SetOutput = [&](IR::U32 addr, IR::U32 value, AttributeRegion output_kind,
u32 off_dw) { u32 off_dw) {
@ -466,10 +468,10 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
AttributeRegion region = GetAttributeRegionKind(&inst, info, runtime_info); AttributeRegion region = GetAttributeRegionKind(&inst, info, runtime_info);
if (num_dwords == 1) { if (num_dwords == 1) {
SetOutput(addr, data, region, 0); SetOutput(addr, IR::U32{data}, region, 0);
} else { } else {
for (auto i = 0; i < num_dwords; i++) { for (auto i = 0; i < num_dwords; i++) {
SetOutput(addr, IR::U32{data.Inst()->Arg(i)}, region, i); SetOutput(addr, IR::U32{ir.CompositeExtract(data, i)}, region, i);
} }
} }
inst.Invalidate(); inst.Invalidate();
@ -499,7 +501,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
ReadTessControlPointAttribute(addr, stride, ir, i, is_tcs_output_read); ReadTessControlPointAttribute(addr, stride, ir, i, is_tcs_output_read);
read_components.push_back(ir.BitCast<IR::U32>(component)); read_components.push_back(ir.BitCast<IR::U32>(component));
} }
attr_read = ir.CompositeConstruct(read_components); attr_read = ir.PackUint2x32(ir.CompositeConstruct(read_components));
} }
inst.ReplaceUsesWithAndRemove(attr_read); inst.ReplaceUsesWithAndRemove(attr_read);
break; break;
@ -578,7 +580,7 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
const IR::F32 component = GetInput(addr, i); const IR::F32 component = GetInput(addr, i);
read_components.push_back(ir.BitCast<IR::U32>(component)); read_components.push_back(ir.BitCast<IR::U32>(component));
} }
attr_read = ir.CompositeConstruct(read_components); attr_read = ir.PackUint2x32(ir.CompositeConstruct(read_components));
} }
inst.ReplaceUsesWithAndRemove(attr_read); inst.ReplaceUsesWithAndRemove(attr_read);
break; break;

View file

@ -34,13 +34,13 @@ static IR::Value LoadBufferFormat(IR::IREmitter& ir, const IR::Value handle, con
interpreted = ir.Imm32(0.f); interpreted = ir.Imm32(0.f);
break; break;
case AmdGpu::DataFormat::Format8: { case AmdGpu::DataFormat::Format8: {
const auto unpacked = const auto raw = ir.UConvert(32, ir.LoadBufferU8(handle, address, info));
ir.Unpack4x8(format_info.num_format, ir.LoadBufferU8(handle, address, info)); const auto unpacked = ir.Unpack4x8(format_info.num_format, raw);
interpreted = ir.CompositeExtract(unpacked, 0); interpreted = ir.CompositeExtract(unpacked, 0);
break; break;
} }
case AmdGpu::DataFormat::Format8_8: { case AmdGpu::DataFormat::Format8_8: {
const auto raw = ir.LoadBufferU16(handle, address, info); const auto raw = ir.UConvert(32, ir.LoadBufferU16(handle, address, info));
const auto unpacked = ir.Unpack4x8(format_info.num_format, raw); const auto unpacked = ir.Unpack4x8(format_info.num_format, raw);
interpreted = ir.CompositeConstruct(ir.CompositeExtract(unpacked, 0), interpreted = ir.CompositeConstruct(ir.CompositeExtract(unpacked, 0),
ir.CompositeExtract(unpacked, 1)); ir.CompositeExtract(unpacked, 1));
@ -51,8 +51,8 @@ static IR::Value LoadBufferFormat(IR::IREmitter& ir, const IR::Value handle, con
IR::U32{ir.LoadBufferU32(1, handle, address, info)}); IR::U32{ir.LoadBufferU32(1, handle, address, info)});
break; break;
case AmdGpu::DataFormat::Format16: { case AmdGpu::DataFormat::Format16: {
const auto unpacked = const auto raw = ir.UConvert(32, ir.LoadBufferU16(handle, address, info));
ir.Unpack2x16(format_info.num_format, ir.LoadBufferU16(handle, address, info)); const auto unpacked = ir.Unpack2x16(format_info.num_format, raw);
interpreted = ir.CompositeExtract(unpacked, 0); interpreted = ir.CompositeExtract(unpacked, 0);
break; break;
} }
@ -126,7 +126,7 @@ static void StoreBufferFormat(IR::IREmitter& ir, const IR::Value handle, const I
const auto packed = const auto packed =
ir.Pack4x8(format_info.num_format, ir.CompositeConstruct(real_value, ir.Imm32(0.f), ir.Pack4x8(format_info.num_format, ir.CompositeConstruct(real_value, ir.Imm32(0.f),
ir.Imm32(0.f), ir.Imm32(0.f))); ir.Imm32(0.f), ir.Imm32(0.f)));
ir.StoreBufferU8(handle, address, packed, info); ir.StoreBufferU8(handle, address, ir.UConvert(8, packed), info);
break; break;
} }
case AmdGpu::DataFormat::Format8_8: { case AmdGpu::DataFormat::Format8_8: {
@ -134,7 +134,7 @@ static void StoreBufferFormat(IR::IREmitter& ir, const IR::Value handle, const I
ir.CompositeConstruct(ir.CompositeExtract(real_value, 0), ir.CompositeConstruct(ir.CompositeExtract(real_value, 0),
ir.CompositeExtract(real_value, 1), ir.CompositeExtract(real_value, 1),
ir.Imm32(0.f), ir.Imm32(0.f))); ir.Imm32(0.f), ir.Imm32(0.f)));
ir.StoreBufferU16(handle, address, packed, info); ir.StoreBufferU16(handle, address, ir.UConvert(16, packed), info);
break; break;
} }
case AmdGpu::DataFormat::Format8_8_8_8: { case AmdGpu::DataFormat::Format8_8_8_8: {
@ -145,7 +145,7 @@ static void StoreBufferFormat(IR::IREmitter& ir, const IR::Value handle, const I
case AmdGpu::DataFormat::Format16: { case AmdGpu::DataFormat::Format16: {
const auto packed = const auto packed =
ir.Pack2x16(format_info.num_format, ir.CompositeConstruct(real_value, ir.Imm32(0.f))); ir.Pack2x16(format_info.num_format, ir.CompositeConstruct(real_value, ir.Imm32(0.f)));
ir.StoreBufferU16(handle, address, packed, info); ir.StoreBufferU16(handle, address, ir.UConvert(16, packed), info);
break; break;
} }
case AmdGpu::DataFormat::Format16_16: { case AmdGpu::DataFormat::Format16_16: {

View file

@ -17,6 +17,8 @@ using SharpLocation = u32;
bool IsBufferAtomic(const IR::Inst& inst) { bool IsBufferAtomic(const IR::Inst& inst) {
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case IR::Opcode::BufferAtomicIAdd32: case IR::Opcode::BufferAtomicIAdd32:
case IR::Opcode::BufferAtomicIAdd64:
case IR::Opcode::BufferAtomicISub32:
case IR::Opcode::BufferAtomicSMin32: case IR::Opcode::BufferAtomicSMin32:
case IR::Opcode::BufferAtomicUMin32: case IR::Opcode::BufferAtomicUMin32:
case IR::Opcode::BufferAtomicSMax32: case IR::Opcode::BufferAtomicSMax32:
@ -27,6 +29,7 @@ bool IsBufferAtomic(const IR::Inst& inst) {
case IR::Opcode::BufferAtomicOr32: case IR::Opcode::BufferAtomicOr32:
case IR::Opcode::BufferAtomicXor32: case IR::Opcode::BufferAtomicXor32:
case IR::Opcode::BufferAtomicSwap32: case IR::Opcode::BufferAtomicSwap32:
case IR::Opcode::BufferAtomicCmpSwap32:
return true; return true;
default: default:
return false; return false;
@ -41,6 +44,7 @@ bool IsBufferStore(const IR::Inst& inst) {
case IR::Opcode::StoreBufferU32x2: case IR::Opcode::StoreBufferU32x2:
case IR::Opcode::StoreBufferU32x3: case IR::Opcode::StoreBufferU32x3:
case IR::Opcode::StoreBufferU32x4: case IR::Opcode::StoreBufferU32x4:
case IR::Opcode::StoreBufferU64:
case IR::Opcode::StoreBufferF32: case IR::Opcode::StoreBufferF32:
case IR::Opcode::StoreBufferF32x2: case IR::Opcode::StoreBufferF32x2:
case IR::Opcode::StoreBufferF32x3: case IR::Opcode::StoreBufferF32x3:
@ -60,6 +64,7 @@ bool IsBufferInstruction(const IR::Inst& inst) {
case IR::Opcode::LoadBufferU32x2: case IR::Opcode::LoadBufferU32x2:
case IR::Opcode::LoadBufferU32x3: case IR::Opcode::LoadBufferU32x3:
case IR::Opcode::LoadBufferU32x4: case IR::Opcode::LoadBufferU32x4:
case IR::Opcode::LoadBufferU64:
case IR::Opcode::LoadBufferF32: case IR::Opcode::LoadBufferF32:
case IR::Opcode::LoadBufferF32x2: case IR::Opcode::LoadBufferF32x2:
case IR::Opcode::LoadBufferF32x3: case IR::Opcode::LoadBufferF32x3:
@ -85,6 +90,10 @@ IR::Type BufferDataType(const IR::Inst& inst, AmdGpu::NumberFormat num_format) {
case IR::Opcode::LoadBufferU16: case IR::Opcode::LoadBufferU16:
case IR::Opcode::StoreBufferU16: case IR::Opcode::StoreBufferU16:
return IR::Type::U16; return IR::Type::U16;
case IR::Opcode::LoadBufferU64:
case IR::Opcode::StoreBufferU64:
case IR::Opcode::BufferAtomicIAdd64:
return IR::Type::U64;
case IR::Opcode::LoadBufferFormatF32: case IR::Opcode::LoadBufferFormatF32:
case IR::Opcode::StoreBufferFormatF32: case IR::Opcode::StoreBufferFormatF32:
// Formatted buffer loads can use a variety of types. // Formatted buffer loads can use a variety of types.

View file

@ -9,12 +9,14 @@
namespace Shader::Optimization { namespace Shader::Optimization {
static bool IsLoadShared(const IR::Inst& inst) { static bool IsLoadShared(const IR::Inst& inst) {
return inst.GetOpcode() == IR::Opcode::LoadSharedU32 || return inst.GetOpcode() == IR::Opcode::LoadSharedU16 ||
inst.GetOpcode() == IR::Opcode::LoadSharedU32 ||
inst.GetOpcode() == IR::Opcode::LoadSharedU64; inst.GetOpcode() == IR::Opcode::LoadSharedU64;
} }
static bool IsWriteShared(const IR::Inst& inst) { static bool IsWriteShared(const IR::Inst& inst) {
return inst.GetOpcode() == IR::Opcode::WriteSharedU32 || return inst.GetOpcode() == IR::Opcode::WriteSharedU16 ||
inst.GetOpcode() == IR::Opcode::WriteSharedU32 ||
inst.GetOpcode() == IR::Opcode::WriteSharedU64; inst.GetOpcode() == IR::Opcode::WriteSharedU64;
} }

View file

@ -10,18 +10,23 @@ namespace Shader::Optimization {
static bool IsSharedAccess(const IR::Inst& inst) { static bool IsSharedAccess(const IR::Inst& inst) {
const auto opcode = inst.GetOpcode(); const auto opcode = inst.GetOpcode();
switch (opcode) { switch (opcode) {
case IR::Opcode::LoadSharedU16:
case IR::Opcode::LoadSharedU32: case IR::Opcode::LoadSharedU32:
case IR::Opcode::LoadSharedU64: case IR::Opcode::LoadSharedU64:
case IR::Opcode::WriteSharedU16:
case IR::Opcode::WriteSharedU32: case IR::Opcode::WriteSharedU32:
case IR::Opcode::WriteSharedU64: case IR::Opcode::WriteSharedU64:
case IR::Opcode::SharedAtomicAnd32:
case IR::Opcode::SharedAtomicIAdd32: case IR::Opcode::SharedAtomicIAdd32:
case IR::Opcode::SharedAtomicIAdd64: case IR::Opcode::SharedAtomicIAdd64:
case IR::Opcode::SharedAtomicOr32: case IR::Opcode::SharedAtomicISub32:
case IR::Opcode::SharedAtomicSMax32:
case IR::Opcode::SharedAtomicUMax32:
case IR::Opcode::SharedAtomicSMin32: case IR::Opcode::SharedAtomicSMin32:
case IR::Opcode::SharedAtomicUMin32: case IR::Opcode::SharedAtomicUMin32:
case IR::Opcode::SharedAtomicSMax32:
case IR::Opcode::SharedAtomicUMax32:
case IR::Opcode::SharedAtomicInc32:
case IR::Opcode::SharedAtomicDec32:
case IR::Opcode::SharedAtomicAnd32:
case IR::Opcode::SharedAtomicOr32:
case IR::Opcode::SharedAtomicXor32: case IR::Opcode::SharedAtomicXor32:
return true; return true;
default: default:
@ -41,14 +46,8 @@ void SharedMemoryToStoragePass(IR::Program& program, const RuntimeInfo& runtime_
profile.supports_workgroup_explicit_memory_layout)) { profile.supports_workgroup_explicit_memory_layout)) {
return; return;
} }
// Add buffer binding for shared memory storage buffer.
const u32 binding = static_cast<u32>(program.info.buffers.size()); const u32 binding = static_cast<u32>(program.info.buffers.size());
program.info.buffers.push_back({ IR::Type used_types{};
.used_types = IR::Type::U32,
.inline_cbuf = AmdGpu::Buffer::Null(),
.buffer_type = BufferType::SharedMemory,
.is_written = true,
});
for (IR::Block* const block : program.blocks) { for (IR::Block* const block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) { for (IR::Inst& inst : block->Instructions()) {
if (!IsSharedAccess(inst)) { if (!IsSharedAccess(inst)) {
@ -56,73 +55,106 @@ void SharedMemoryToStoragePass(IR::Program& program, const RuntimeInfo& runtime_
} }
IR::IREmitter ir{*block, IR::Block::InstructionList::s_iterator_to(inst)}; IR::IREmitter ir{*block, IR::Block::InstructionList::s_iterator_to(inst)};
const IR::U32 handle = ir.Imm32(binding); const IR::U32 handle = ir.Imm32(binding);
const IR::U32 offset = ir.IMul(ir.GetAttributeU32(IR::Attribute::WorkgroupIndex),
ir.Imm32(shared_memory_size));
const IR::U32 address = ir.IAdd(IR::U32{inst.Arg(0)}, offset);
// Replace shared atomics first // Replace shared atomics first
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case IR::Opcode::SharedAtomicAnd32:
inst.ReplaceUsesWithAndRemove(
ir.BufferAtomicAnd(handle, inst.Arg(0), inst.Arg(1), {}));
continue;
case IR::Opcode::SharedAtomicIAdd32: case IR::Opcode::SharedAtomicIAdd32:
inst.ReplaceUsesWithAndRemove(
ir.BufferAtomicIAdd(handle, address, inst.Arg(1), {}));
used_types |= IR::Type::U32;
continue;
case IR::Opcode::SharedAtomicIAdd64: case IR::Opcode::SharedAtomicIAdd64:
inst.ReplaceUsesWithAndRemove( inst.ReplaceUsesWithAndRemove(
ir.BufferAtomicIAdd(handle, inst.Arg(0), inst.Arg(1), {})); ir.BufferAtomicIAdd(handle, address, inst.Arg(1), {}));
used_types |= IR::Type::U64;
continue; continue;
case IR::Opcode::SharedAtomicOr32: case IR::Opcode::SharedAtomicISub32:
inst.ReplaceUsesWithAndRemove( inst.ReplaceUsesWithAndRemove(
ir.BufferAtomicOr(handle, inst.Arg(0), inst.Arg(1), {})); ir.BufferAtomicISub(handle, address, inst.Arg(1), {}));
used_types |= IR::Type::U32;
continue; continue;
case IR::Opcode::SharedAtomicSMax32:
case IR::Opcode::SharedAtomicUMax32: {
const bool is_signed = inst.GetOpcode() == IR::Opcode::SharedAtomicSMax32;
inst.ReplaceUsesWithAndRemove(
ir.BufferAtomicIMax(handle, inst.Arg(0), inst.Arg(1), is_signed, {}));
continue;
}
case IR::Opcode::SharedAtomicSMin32: case IR::Opcode::SharedAtomicSMin32:
case IR::Opcode::SharedAtomicUMin32: { case IR::Opcode::SharedAtomicUMin32: {
const bool is_signed = inst.GetOpcode() == IR::Opcode::SharedAtomicSMin32; const bool is_signed = inst.GetOpcode() == IR::Opcode::SharedAtomicSMin32;
inst.ReplaceUsesWithAndRemove( inst.ReplaceUsesWithAndRemove(
ir.BufferAtomicIMin(handle, inst.Arg(0), inst.Arg(1), is_signed, {})); ir.BufferAtomicIMin(handle, address, inst.Arg(1), is_signed, {}));
used_types |= IR::Type::U32;
continue; continue;
} }
case IR::Opcode::SharedAtomicXor32: case IR::Opcode::SharedAtomicSMax32:
case IR::Opcode::SharedAtomicUMax32: {
const bool is_signed = inst.GetOpcode() == IR::Opcode::SharedAtomicSMax32;
inst.ReplaceUsesWithAndRemove( inst.ReplaceUsesWithAndRemove(
ir.BufferAtomicXor(handle, inst.Arg(0), inst.Arg(1), {})); ir.BufferAtomicIMax(handle, address, inst.Arg(1), is_signed, {}));
used_types |= IR::Type::U32;
continue;
}
case IR::Opcode::SharedAtomicInc32:
inst.ReplaceUsesWithAndRemove(ir.BufferAtomicInc(handle, address, {}));
used_types |= IR::Type::U32;
continue;
case IR::Opcode::SharedAtomicDec32:
inst.ReplaceUsesWithAndRemove(ir.BufferAtomicDec(handle, address, {}));
used_types |= IR::Type::U32;
continue;
case IR::Opcode::SharedAtomicAnd32:
inst.ReplaceUsesWithAndRemove(ir.BufferAtomicAnd(handle, address, inst.Arg(1), {}));
used_types |= IR::Type::U32;
continue;
case IR::Opcode::SharedAtomicOr32:
inst.ReplaceUsesWithAndRemove(ir.BufferAtomicOr(handle, address, inst.Arg(1), {}));
used_types |= IR::Type::U32;
continue;
case IR::Opcode::SharedAtomicXor32:
inst.ReplaceUsesWithAndRemove(ir.BufferAtomicXor(handle, address, inst.Arg(1), {}));
used_types |= IR::Type::U32;
continue; continue;
default: default:
break; break;
} }
// Replace shared operations. // Replace shared operations.
const IR::U32 offset = ir.IMul(ir.GetAttributeU32(IR::Attribute::WorkgroupIndex),
ir.Imm32(shared_memory_size));
const IR::U32 address = ir.IAdd(IR::U32{inst.Arg(0)}, offset);
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case IR::Opcode::LoadSharedU16: case IR::Opcode::LoadSharedU16:
inst.ReplaceUsesWithAndRemove(ir.LoadBufferU16(handle, address, {})); inst.ReplaceUsesWithAndRemove(ir.LoadBufferU16(handle, address, {}));
used_types |= IR::Type::U16;
break; break;
case IR::Opcode::LoadSharedU32: case IR::Opcode::LoadSharedU32:
inst.ReplaceUsesWithAndRemove(ir.LoadBufferU32(1, handle, address, {})); inst.ReplaceUsesWithAndRemove(ir.LoadBufferU32(1, handle, address, {}));
used_types |= IR::Type::U32;
break; break;
case IR::Opcode::LoadSharedU64: case IR::Opcode::LoadSharedU64:
inst.ReplaceUsesWithAndRemove(ir.LoadBufferU32(2, handle, address, {})); inst.ReplaceUsesWithAndRemove(ir.LoadBufferU64(handle, address, {}));
used_types |= IR::Type::U64;
break; break;
case IR::Opcode::WriteSharedU16: case IR::Opcode::WriteSharedU16:
ir.StoreBufferU16(handle, address, IR::U32{inst.Arg(1)}, {}); ir.StoreBufferU16(handle, address, IR::U16{inst.Arg(1)}, {});
inst.Invalidate(); inst.Invalidate();
used_types |= IR::Type::U16;
break; break;
case IR::Opcode::WriteSharedU32: case IR::Opcode::WriteSharedU32:
ir.StoreBufferU32(1, handle, address, inst.Arg(1), {}); ir.StoreBufferU32(1, handle, address, inst.Arg(1), {});
inst.Invalidate(); inst.Invalidate();
used_types |= IR::Type::U32;
break; break;
case IR::Opcode::WriteSharedU64: case IR::Opcode::WriteSharedU64:
ir.StoreBufferU32(2, handle, address, inst.Arg(1), {}); ir.StoreBufferU64(handle, address, IR::U64{inst.Arg(1)}, {});
inst.Invalidate(); inst.Invalidate();
used_types |= IR::Type::U64;
break; break;
default: default:
break; break;
} }
} }
} }
// Add buffer binding for shared memory storage buffer.
program.info.buffers.push_back({
.used_types = used_types,
.inline_cbuf = AmdGpu::Buffer::Null(),
.buffer_type = BufferType::SharedMemory,
.is_written = true,
});
} }
} // namespace Shader::Optimization } // namespace Shader::Optimization

View file

@ -265,6 +265,7 @@ using U32F32 = TypedValue<Type::U32 | Type::F32>;
using U64F64 = TypedValue<Type::U64 | Type::F64>; using U64F64 = TypedValue<Type::U64 | Type::F64>;
using U32U64 = TypedValue<Type::U32 | Type::U64>; using U32U64 = TypedValue<Type::U32 | Type::U64>;
using U16U32U64 = TypedValue<Type::U16 | Type::U32 | Type::U64>; using U16U32U64 = TypedValue<Type::U16 | Type::U32 | Type::U64>;
using U8U16U32U64 = TypedValue<Type::U8 | Type::U16 | Type::U32 | Type::U64>;
using F32F64 = TypedValue<Type::F32 | Type::F64>; using F32F64 = TypedValue<Type::F32 | Type::F64>;
using F16F32F64 = TypedValue<Type::F16 | Type::F32 | Type::F64>; using F16F32F64 = TypedValue<Type::F16 | Type::F32 | Type::F64>;
using UAny = TypedValue<Type::U8 | Type::U16 | Type::U32 | Type::U64>; using UAny = TypedValue<Type::U8 | Type::U16 | Type::U32 | Type::U64>;