diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_shared_memory.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_shared_memory.cpp index c59406499..a9cf89129 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_shared_memory.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_shared_memory.cpp @@ -14,8 +14,10 @@ Id EmitLoadSharedU16(EmitContext& ctx, Id offset) { const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)}; return AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] { - const Id pointer = - ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, ctx.u32_zero_value, index); + const Id pointer = std::popcount(static_cast(ctx.info.shared_types)) > 1 + ? ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, + ctx.u32_zero_value, index) + : ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, index); return ctx.OpLoad(ctx.U16, pointer); }); } @@ -26,8 +28,10 @@ Id EmitLoadSharedU32(EmitContext& ctx, Id offset) { const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)}; return AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] { - const Id pointer = - ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index); + const Id pointer = std::popcount(static_cast(ctx.info.shared_types)) > 1 + ? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, + ctx.u32_zero_value, index) + : ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index); return ctx.OpLoad(ctx.U32[1], pointer); }); } @@ -38,8 +42,10 @@ Id EmitLoadSharedU64(EmitContext& ctx, Id offset) { const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)}; return AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] { - const Id pointer{ - ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)}; + const Id pointer = std::popcount(static_cast(ctx.info.shared_types)) > 1 + ? ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, + ctx.u32_zero_value, index) + : ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, index); return ctx.OpLoad(ctx.U64, pointer); }); } @@ -50,8 +56,10 @@ void EmitWriteSharedU16(EmitContext& ctx, Id offset, Id value) { const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 2u)}; AccessBoundsCheck<16>(ctx, index, ctx.ConstU32(num_elements), [&] { - const Id pointer = - ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, ctx.u32_zero_value, index); + const Id pointer = std::popcount(static_cast(ctx.info.shared_types)) > 1 + ? ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, + ctx.u32_zero_value, index) + : ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, index); ctx.OpStore(pointer, value); return Id{0}; }); @@ -63,8 +71,10 @@ void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value) { const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 4u)}; AccessBoundsCheck<32>(ctx, index, ctx.ConstU32(num_elements), [&] { - const Id pointer = - ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index); + const Id pointer = std::popcount(static_cast(ctx.info.shared_types)) > 1 + ? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, + ctx.u32_zero_value, index) + : ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index); ctx.OpStore(pointer, value); return Id{0}; }); @@ -76,8 +86,10 @@ void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value) { const u32 num_elements{Common::DivCeil(ctx.runtime_info.cs_info.shared_memory_size, 8u)}; AccessBoundsCheck<64>(ctx, index, ctx.ConstU32(num_elements), [&] { - const Id pointer{ - ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)}; + const Id pointer = std::popcount(static_cast(ctx.info.shared_types)) > 1 + ? ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, + ctx.u32_zero_value, index) + : ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, index); ctx.OpStore(pointer, value); return Id{0}; }); diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp index 0a8f78f72..030eb6cb0 100644 --- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp +++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp @@ -995,19 +995,26 @@ void EmitContext::DefineSharedMemory() { const u32 num_elements{Common::DivCeil(shared_memory_size, element_size)}; const Id array_type{TypeArray(element_type, ConstU32(num_elements))}; - Decorate(array_type, spv::Decoration::ArrayStride, element_size); - const Id struct_type{TypeStruct(array_type)}; - MemberDecorate(struct_type, 0u, spv::Decoration::Offset, 0u); + const auto mem_type = [&] { + if (num_types > 1) { + const Id struct_type{TypeStruct(array_type)}; + Decorate(struct_type, spv::Decoration::Block); + MemberDecorate(struct_type, 0u, spv::Decoration::Offset, 0u); + return struct_type; + } else { + return array_type; + } + }(); - const Id pointer = TypePointer(spv::StorageClass::Workgroup, struct_type); + const Id pointer = TypePointer(spv::StorageClass::Workgroup, mem_type); const Id element_pointer = TypePointer(spv::StorageClass::Workgroup, element_type); const Id variable = AddGlobalVariable(pointer, spv::StorageClass::Workgroup); Name(variable, name); interfaces.push_back(variable); if (num_types > 1) { - Decorate(struct_type, spv::Decoration::Block); + Decorate(array_type, spv::Decoration::ArrayStride, element_size); Decorate(variable, spv::Decoration::Aliased); }