Fix shared memory definition when only one type is used (#3106)

This commit is contained in:
Marcin Mikołajczyk 2025-06-17 08:37:09 +02:00 committed by GitHub
parent 21d14abaee
commit 9dd35c3a42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 17 deletions

View file

@ -14,8 +14,10 @@ Id EmitLoadSharedU16(EmitContext& ctx, Id offset) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)}; const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)};
return AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] { return AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer = const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, ctx.u32_zero_value, index); ? ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, index);
return ctx.OpLoad(ctx.U16, pointer); return ctx.OpLoad(ctx.U16, pointer);
}); });
} }
@ -26,8 +28,10 @@ Id EmitLoadSharedU32(EmitContext& ctx, Id offset) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)}; const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)};
return AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] { return AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer = const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index); ? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index);
return ctx.OpLoad(ctx.U32[1], pointer); return ctx.OpLoad(ctx.U32[1], pointer);
}); });
} }
@ -38,8 +42,10 @@ Id EmitLoadSharedU64(EmitContext& ctx, Id offset) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)}; const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)};
return AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] { return AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer{ const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)}; ? ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, index);
return ctx.OpLoad(ctx.U64, pointer); return ctx.OpLoad(ctx.U64, pointer);
}); });
} }
@ -50,8 +56,10 @@ void EmitWriteSharedU16(EmitContext& ctx, Id offset, Id value) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)}; const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)};
AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] { AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer = const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, ctx.u32_zero_value, index); ? ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, index);
ctx.OpStore(pointer, value); ctx.OpStore(pointer, value);
return Id{0}; return Id{0};
}); });
@ -63,8 +71,10 @@ void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)}; const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)};
AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] { AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer = const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index); ? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index);
ctx.OpStore(pointer, value); ctx.OpStore(pointer, value);
return Id{0}; return Id{0};
}); });
@ -76,8 +86,10 @@ void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value) {
const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)}; const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)};
AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] { AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] {
const Id pointer{ const Id pointer = std::popcount(static_cast<u32>(ctx.info.shared_types)) > 1
ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)}; ? ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64,
ctx.u32_zero_value, index)
: ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, index);
ctx.OpStore(pointer, value); ctx.OpStore(pointer, value);
return Id{0}; return Id{0};
}); });

View file

@ -995,19 +995,26 @@ void EmitContext::DefineSharedMemory() {
const u32 num_elements{Common::DivCeil(shared_memory_size, element_size)}; const u32 num_elements{Common::DivCeil(shared_memory_size, element_size)};
const Id array_type{TypeArray(element_type, ConstU32(num_elements))}; const Id array_type{TypeArray(element_type, ConstU32(num_elements))};
Decorate(array_type, spv::Decoration::ArrayStride, element_size);
const auto mem_type = [&] {
if (num_types > 1) {
const Id struct_type{TypeStruct(array_type)}; const Id struct_type{TypeStruct(array_type)};
Decorate(struct_type, spv::Decoration::Block);
MemberDecorate(struct_type, 0u, spv::Decoration::Offset, 0u); MemberDecorate(struct_type, 0u, spv::Decoration::Offset, 0u);
return struct_type;
} else {
return array_type;
}
}();
const Id pointer = TypePointer(spv::StorageClass::Workgroup, struct_type); const Id pointer = TypePointer(spv::StorageClass::Workgroup, mem_type);
const Id element_pointer = TypePointer(spv::StorageClass::Workgroup, element_type); const Id element_pointer = TypePointer(spv::StorageClass::Workgroup, element_type);
const Id variable = AddGlobalVariable(pointer, spv::StorageClass::Workgroup); const Id variable = AddGlobalVariable(pointer, spv::StorageClass::Workgroup);
Name(variable, name); Name(variable, name);
interfaces.push_back(variable); interfaces.push_back(variable);
if (num_types > 1) { if (num_types > 1) {
Decorate(struct_type, spv::Decoration::Block); Decorate(array_type, spv::Decoration::ArrayStride, element_size);
Decorate(variable, spv::Decoration::Aliased); Decorate(variable, spv::Decoration::Aliased);
} }