Fix shared memory definition when only one type is used (#3106)

This commit is contained in:
Marcin Mikołajczyk 2025-06-17 08:37:09 +02:00 committed by GitHub
parent 21d14abaee
commit 9dd35c3a42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 17 deletions

View file

@ -14,8 +14,10 @@ Id EmitLoadSharedU16(EmitContext& ctx, Id offset) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)};
return AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer =
ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, ctx.u32_zero_value, index);
const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
? ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, index);
return ctx.OpLoad(ctx.U16, pointer);
});
}
@ -26,8 +28,10 @@ Id EmitLoadSharedU32(EmitContext& ctx, Id offset) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)};
return AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer =
ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index);
const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index);
return ctx.OpLoad(ctx.U32[1], pointer);
});
}
@ -38,8 +42,10 @@ Id EmitLoadSharedU64(EmitContext& ctx, Id offset) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)};
return AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer{
ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)};
const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
? ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, index);
return ctx.OpLoad(ctx.U64, pointer);
});
}
@ -50,8 +56,10 @@ void EmitWriteSharedU16(EmitContext& ctx, Id offset, Id value) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)};
AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer =
ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, ctx.u32_zero_value, index);
const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
? ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, index);
ctx.OpStore(pointer, value);
return Id{0};
});
@ -63,8 +71,10 @@ void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)};
AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer =
ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index);
const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index);
ctx.OpStore(pointer, value);
return Id{0};
});
@ -76,8 +86,10 @@ void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)};
AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer{
ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)};
const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
? ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, index);
ctx.OpStore(pointer, value);
return Id{0};
});

View file

@ -995,19 +995,26 @@ void EmitContext::DefineSharedMemory() {
const u32 num_elements{Common::DivCeil(shared_memory_size, element_size)};
const Id array_type{TypeArray(element_type, ConstU32(num_elements))};
Decorate(array_type, spv::Decoration::ArrayStride, element_size);
const Id struct_type{TypeStruct(array_type)};
MemberDecorate(struct_type, 0u, spv::Decoration::Offset, 0u);
const auto mem_type = [&] {
if (num_types > 1) {
const Id struct_type{TypeStruct(array_type)};
Decorate(struct_type, spv::Decoration::Block);
MemberDecorate(struct_type, 0u, spv::Decoration::Offset, 0u);
return struct_type;
} else {
return array_type;
}
}();
const Id pointer = TypePointer(spv::StorageClass::Workgroup, struct_type);
const Id pointer = TypePointer(spv::StorageClass::Workgroup, mem_type);
const Id element_pointer = TypePointer(spv::StorageClass::Workgroup, element_type);
const Id variable = AddGlobalVariable(pointer, spv::StorageClass::Workgroup);
Name(variable, name);
interfaces.push_back(variable);
if (num_types > 1) {
Decorate(struct_type, spv::Decoration::Block);
Decorate(array_type, spv::Decoration::ArrayStride, element_size);
Decorate(variable, spv::Decoration::Aliased);
}