diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp index 6442ae9f8..eff562955 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp @@ -175,19 +175,24 @@ Id EmitReadConst(EmitContext& ctx, IR::Inst* inst) { return ctx.OpLoad(ctx.U32[1], ptr); } -Id EmitReadConstBuffer(EmitContext& ctx, u32 handle, Id index) { +template +Id ReadConstBuffer(EmitContext& ctx, u32 handle, Id index) { const auto& buffer = ctx.buffers[handle]; index = ctx.OpIAdd(ctx.U32[1], index, buffer.offset_dwords); - const auto [id, pointer_type] = buffer[BufferAlias::U32]; + const auto [id, pointer_type] = buffer[alias]; + const auto value_type = alias == BufferAlias::U32 ? ctx.U32[1] : ctx.F32[1]; const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)}; - const Id result{ctx.OpLoad(ctx.U32[1], ptr)}; + const Id result{ctx.OpLoad(value_type, ptr)}; if (Sirit::ValidId(buffer.size_dwords)) { const Id in_bounds = ctx.OpULessThan(ctx.U1[1], index, buffer.size_dwords); - return ctx.OpSelect(ctx.U32[1], in_bounds, result, ctx.u32_zero_value); - } else { - return result; + return ctx.OpSelect(value_type, in_bounds, result, ctx.u32_zero_value); } + return result; +} + +Id EmitReadConstBuffer(EmitContext& ctx, u32 handle, Id index) { + return ReadConstBuffer(ctx, handle, index); } Id EmitReadStepRate(EmitContext& ctx, int rate_idx) { @@ -246,7 +251,7 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, Id index) { ctx.OpUDiv(ctx.U32[1], ctx.OpLoad(ctx.U32[1], ctx.instance_id), step_rate), ctx.ConstU32(param.num_components)), ctx.ConstU32(comp)); - return EmitReadConstBuffer(ctx, param.buffer_handle, offset); + return ReadConstBuffer(ctx, param.buffer_handle, offset); } Id result;