shader_recompiler: Implement manual barycentric interpolation path (#1644)

* shader_recompiler: Implement manual barycentric interpolation path

* clang format

* emit_spirv: Fix typo

* emit_spirv: Simplify variable definition

* spirv_emit: clang format
This commit is contained in:
TheTurtle 2024-12-02 23:20:54 +02:00 committed by GitHub
parent fda4f06518
commit eb844b9b63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 129 additions and 58 deletions

View file

@ -206,7 +206,7 @@ Id DefineMain(EmitContext& ctx, const IR::Program& program) {
return main; return main;
} }
void SetupCapabilities(const Info& info, EmitContext& ctx) { void SetupCapabilities(const Info& info, const Profile& profile, EmitContext& ctx) {
ctx.AddCapability(spv::Capability::Image1D); ctx.AddCapability(spv::Capability::Image1D);
ctx.AddCapability(spv::Capability::Sampled1D); ctx.AddCapability(spv::Capability::Sampled1D);
ctx.AddCapability(spv::Capability::ImageQuery); ctx.AddCapability(spv::Capability::ImageQuery);
@ -251,6 +251,10 @@ void SetupCapabilities(const Info& info, EmitContext& ctx) {
if (info.stage == Stage::Geometry) { if (info.stage == Stage::Geometry) {
ctx.AddCapability(spv::Capability::Geometry); ctx.AddCapability(spv::Capability::Geometry);
} }
if (info.stage == Stage::Fragment && profile.needs_manual_interpolation) {
ctx.AddExtension("SPV_KHR_fragment_shader_barycentric");
ctx.AddCapability(spv::Capability::FragmentBarycentricKHR);
}
} }
void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) { void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
@ -342,7 +346,7 @@ std::vector<u32> EmitSPIRV(const Profile& profile, const RuntimeInfo& runtime_in
EmitContext ctx{profile, runtime_info, program.info, binding}; EmitContext ctx{profile, runtime_info, program.info, binding};
const Id main{DefineMain(ctx, program)}; const Id main{DefineMain(ctx, program)};
DefineEntryPoint(program, ctx, main); DefineEntryPoint(program, ctx, main);
SetupCapabilities(program.info, ctx); SetupCapabilities(program.info, profile, ctx);
SetupFloatMode(ctx, profile, runtime_info, main); SetupFloatMode(ctx, profile, runtime_info, main);
PatchPhiNodes(program, ctx); PatchPhiNodes(program, ctx);
binding.user_data += program.info.ud_mask.NumRegs(); binding.user_data += program.info.ud_mask.NumRegs();

View file

@ -171,54 +171,38 @@ Id EmitReadStepRate(EmitContext& ctx, int rate_idx) {
rate_idx == 0 ? ctx.u32_zero_value : ctx.u32_one_value)); rate_idx == 0 ? ctx.u32_zero_value : ctx.u32_one_value));
} }
Id EmitGetAttributeForGeometry(EmitContext& ctx, IR::Attribute attr, u32 comp, u32 index) {
if (IR::IsPosition(attr)) {
ASSERT(attr == IR::Attribute::Position0);
const auto position_arr_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[4]);
const auto pointer{
ctx.OpAccessChain(position_arr_ptr, ctx.gl_in, ctx.ConstU32(index), ctx.ConstU32(0u))};
const auto position_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]);
return ctx.OpLoad(ctx.F32[1],
ctx.OpAccessChain(position_comp_ptr, pointer, ctx.ConstU32(comp)));
}
if (IR::IsParam(attr)) {
const u32 param_id{u32(attr) - u32(IR::Attribute::Param0)};
const auto param = ctx.input_params.at(param_id).id;
const auto param_arr_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[4]);
const auto pointer{ctx.OpAccessChain(param_arr_ptr, param, ctx.ConstU32(index))};
const auto position_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]);
return ctx.OpLoad(ctx.F32[1],
ctx.OpAccessChain(position_comp_ptr, pointer, ctx.ConstU32(comp)));
}
UNREACHABLE();
}
Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, u32 index) { Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, u32 index) {
if (ctx.info.stage == Stage::Geometry) { if (ctx.info.stage == Stage::Geometry) {
if (IR::IsPosition(attr)) { return EmitGetAttributeForGeometry(ctx, attr, comp, index);
ASSERT(attr == IR::Attribute::Position0);
const auto position_arr_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[4]);
const auto pointer{ctx.OpAccessChain(position_arr_ptr, ctx.gl_in, ctx.ConstU32(index),
ctx.ConstU32(0u))};
const auto position_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]);
return ctx.OpLoad(ctx.F32[1],
ctx.OpAccessChain(position_comp_ptr, pointer, ctx.ConstU32(comp)));
}
if (IR::IsParam(attr)) {
const u32 param_id{u32(attr) - u32(IR::Attribute::Param0)};
const auto param = ctx.input_params.at(param_id).id;
const auto param_arr_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[4]);
const auto pointer{ctx.OpAccessChain(param_arr_ptr, param, ctx.ConstU32(index))};
const auto position_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]);
return ctx.OpLoad(ctx.F32[1],
ctx.OpAccessChain(position_comp_ptr, pointer, ctx.ConstU32(comp)));
}
UNREACHABLE();
} }
if (IR::IsParam(attr)) { if (IR::IsParam(attr)) {
const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; const u32 index{u32(attr) - u32(IR::Attribute::Param0)};
const auto& param{ctx.input_params.at(index)}; const auto& param{ctx.input_params.at(index)};
if (param.buffer_handle < 0) { if (param.buffer_handle >= 0) {
if (!ValidId(param.id)) {
// Attribute is disabled or varying component is not written
return ctx.ConstF32(comp == 3 ? 1.0f : 0.0f);
}
Id result;
if (param.is_default) {
result = ctx.OpCompositeExtract(param.component_type, param.id, comp);
} else if (param.num_components > 1) {
const Id pointer{
ctx.OpAccessChain(param.pointer_type, param.id, ctx.ConstU32(comp))};
result = ctx.OpLoad(param.component_type, pointer);
} else {
result = ctx.OpLoad(param.component_type, param.id);
}
if (param.is_integer) {
result = ctx.OpBitcast(ctx.F32[1], result);
}
return result;
} else {
const auto step_rate = EmitReadStepRate(ctx, param.id.value); const auto step_rate = EmitReadStepRate(ctx, param.id.value);
const auto offset = ctx.OpIAdd( const auto offset = ctx.OpIAdd(
ctx.U32[1], ctx.U32[1],
@ -229,7 +213,26 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, u32 index) {
ctx.ConstU32(comp)); ctx.ConstU32(comp));
return EmitReadConstBuffer(ctx, param.buffer_handle, offset); return EmitReadConstBuffer(ctx, param.buffer_handle, offset);
} }
Id result;
if (param.is_loaded) {
// Attribute is either default or manually interpolated. The id points to an already
// loaded vector.
result = ctx.OpCompositeExtract(param.component_type, param.id, comp);
} else if (param.num_components > 1) {
// Attribute is a vector and we need to access a specific component.
const Id pointer{ctx.OpAccessChain(param.pointer_type, param.id, ctx.ConstU32(comp))};
result = ctx.OpLoad(param.component_type, pointer);
} else {
// Attribute is a single float or interger, simply load it.
result = ctx.OpLoad(param.component_type, param.id);
}
if (param.is_integer) {
result = ctx.OpBitcast(ctx.F32[1], result);
}
return result;
} }
switch (attr) { switch (attr) {
case IR::Attribute::FragCoord: { case IR::Attribute::FragCoord: {
const Id coord = ctx.OpLoad( const Id coord = ctx.OpLoad(

View file

@ -8,6 +8,9 @@
namespace Shader::Backend::SPIRV { namespace Shader::Backend::SPIRV {
void EmitPrologue(EmitContext& ctx) { void EmitPrologue(EmitContext& ctx) {
if (ctx.stage == Stage::Fragment) {
ctx.DefineInterpolatedAttribs();
}
ctx.DefineBufferOffsets(); ctx.DefineBufferOffsets();
} }

View file

@ -222,6 +222,36 @@ void EmitContext::DefineBufferOffsets() {
} }
} }
void EmitContext::DefineInterpolatedAttribs() {
if (!profile.needs_manual_interpolation) {
return;
}
// Iterate all input attributes, load them and manually interpolate with barycentric
// coordinates.
for (s32 i = 0; i < runtime_info.fs_info.num_inputs; i++) {
const auto& input = runtime_info.fs_info.inputs[i];
const u32 semantic = input.param_index;
auto& params = input_params[semantic];
if (input.is_flat || params.is_loaded) {
continue;
}
const Id p_array{OpLoad(TypeArray(F32[4], ConstU32(3U)), params.id)};
const Id p0{OpCompositeExtract(F32[4], p_array, 0U)};
const Id p1{OpCompositeExtract(F32[4], p_array, 1U)};
const Id p2{OpCompositeExtract(F32[4], p_array, 2U)};
const Id p10{OpFSub(F32[4], p1, p0)};
const Id p20{OpFSub(F32[4], p2, p0)};
const Id bary_coord{OpLoad(F32[3], gl_bary_coord_id)};
const Id bary_coord_y{OpCompositeExtract(F32[1], bary_coord, 1)};
const Id bary_coord_z{OpCompositeExtract(F32[1], bary_coord, 2)};
const Id p10_y{OpVectorTimesScalar(F32[4], p10, bary_coord_y)};
const Id p20_z{OpVectorTimesScalar(F32[4], p20, bary_coord_z)};
params.id = OpFAdd(F32[4], p0, OpFAdd(F32[4], p10_y, p20_z));
Name(params.id, fmt::format("fs_in_attr{}", semantic));
params.is_loaded = true;
}
}
Id MakeDefaultValue(EmitContext& ctx, u32 default_value) { Id MakeDefaultValue(EmitContext& ctx, u32 default_value) {
switch (default_value) { switch (default_value) {
case 0: case 0:
@ -260,14 +290,14 @@ void EmitContext::DefineInputs() {
input.instance_step_rate == Info::VsInput::InstanceIdType::OverStepRate0 ? 0 input.instance_step_rate == Info::VsInput::InstanceIdType::OverStepRate0 ? 0
: 1; : 1;
// Note that we pass index rather than Id // Note that we pass index rather than Id
input_params[input.binding] = { input_params[input.binding] = SpirvAttribute{
rate_idx, .id = rate_idx,
input_u32, .pointer_type = input_u32,
U32[1], .component_type = U32[1],
input.num_components, .num_components = input.num_components,
true, .is_integer = true,
false, .is_loaded = false,
input.instance_data_buf, .buffer_handle = input.instance_data_buf,
}; };
} else { } else {
Id id{DefineInput(type, input.binding)}; Id id{DefineInput(type, input.binding)};
@ -286,6 +316,10 @@ void EmitContext::DefineInputs() {
frag_coord = DefineVariable(F32[4], spv::BuiltIn::FragCoord, spv::StorageClass::Input); frag_coord = DefineVariable(F32[4], spv::BuiltIn::FragCoord, spv::StorageClass::Input);
frag_depth = DefineVariable(F32[1], spv::BuiltIn::FragDepth, spv::StorageClass::Output); frag_depth = DefineVariable(F32[1], spv::BuiltIn::FragDepth, spv::StorageClass::Output);
front_facing = DefineVariable(U1[1], spv::BuiltIn::FrontFacing, spv::StorageClass::Input); front_facing = DefineVariable(U1[1], spv::BuiltIn::FrontFacing, spv::StorageClass::Input);
if (profile.needs_manual_interpolation) {
gl_bary_coord_id =
DefineVariable(F32[3], spv::BuiltIn::BaryCoordKHR, spv::StorageClass::Input);
}
for (s32 i = 0; i < runtime_info.fs_info.num_inputs; i++) { for (s32 i = 0; i < runtime_info.fs_info.num_inputs; i++) {
const auto& input = runtime_info.fs_info.inputs[i]; const auto& input = runtime_info.fs_info.inputs[i];
const u32 semantic = input.param_index; const u32 semantic = input.param_index;
@ -299,14 +333,21 @@ void EmitContext::DefineInputs() {
const IR::Attribute param{IR::Attribute::Param0 + input.param_index}; const IR::Attribute param{IR::Attribute::Param0 + input.param_index};
const u32 num_components = info.loads.NumComponents(param); const u32 num_components = info.loads.NumComponents(param);
const Id type{F32[num_components]}; const Id type{F32[num_components]};
const Id id{DefineInput(type, semantic)}; Id attr_id{};
if (input.is_flat) { if (profile.needs_manual_interpolation && !input.is_flat) {
Decorate(id, spv::Decoration::Flat); attr_id = DefineInput(TypeArray(type, ConstU32(3U)), semantic);
Decorate(attr_id, spv::Decoration::PerVertexKHR);
Name(attr_id, fmt::format("fs_in_attr{}_p", semantic));
} else {
attr_id = DefineInput(type, semantic);
Name(attr_id, fmt::format("fs_in_attr{}", semantic));
}
if (input.is_flat) {
Decorate(attr_id, spv::Decoration::Flat);
} }
Name(id, fmt::format("fs_in_attr{}", semantic));
input_params[semantic] = input_params[semantic] =
GetAttributeInfo(AmdGpu::NumberFormat::Float, id, num_components, false); GetAttributeInfo(AmdGpu::NumberFormat::Float, attr_id, num_components, false);
interfaces.push_back(id); interfaces.push_back(attr_id);
} }
break; break;
case Stage::Compute: case Stage::Compute:

View file

@ -42,7 +42,9 @@ public:
~EmitContext(); ~EmitContext();
Id Def(const IR::Value& value); Id Def(const IR::Value& value);
void DefineBufferOffsets(); void DefineBufferOffsets();
void DefineInterpolatedAttribs();
[[nodiscard]] Id DefineInput(Id type, u32 location) { [[nodiscard]] Id DefineInput(Id type, u32 location) {
const Id input_id{DefineVar(type, spv::StorageClass::Input)}; const Id input_id{DefineVar(type, spv::StorageClass::Input)};
@ -197,6 +199,9 @@ public:
Id shared_memory_u32_type{}; Id shared_memory_u32_type{};
Id interpolate_func{};
Id gl_bary_coord_id{};
struct TextureDefinition { struct TextureDefinition {
const VectorIds* data_types; const VectorIds* data_types;
Id id; Id id;
@ -241,7 +246,7 @@ public:
Id component_type; Id component_type;
u32 num_components; u32 num_components;
bool is_integer{}; bool is_integer{};
bool is_default{}; bool is_loaded{};
s32 buffer_handle{-1}; s32 buffer_handle{-1};
}; };
std::array<SpirvAttribute, IR::NumParams> input_params{}; std::array<SpirvAttribute, IR::NumParams> input_params{};

View file

@ -24,6 +24,7 @@ struct Profile {
bool support_explicit_workgroup_layout{}; bool support_explicit_workgroup_layout{};
bool has_broken_spirv_clamp{}; bool has_broken_spirv_clamp{};
bool lower_left_origin_mode{}; bool lower_left_origin_mode{};
bool needs_manual_interpolation{};
u64 min_ssbo_alignment{}; u64 min_ssbo_alignment{};
}; };

View file

@ -256,6 +256,7 @@ bool Instance::CreateDevice() {
workgroup_memory_explicit_layout = workgroup_memory_explicit_layout =
add_extension(VK_KHR_WORKGROUP_MEMORY_EXPLICIT_LAYOUT_EXTENSION_NAME); add_extension(VK_KHR_WORKGROUP_MEMORY_EXPLICIT_LAYOUT_EXTENSION_NAME);
vertex_input_dynamic_state = add_extension(VK_EXT_VERTEX_INPUT_DYNAMIC_STATE_EXTENSION_NAME); vertex_input_dynamic_state = add_extension(VK_EXT_VERTEX_INPUT_DYNAMIC_STATE_EXTENSION_NAME);
fragment_shader_barycentric = add_extension(VK_KHR_FRAGMENT_SHADER_BARYCENTRIC_EXTENSION_NAME);
// The next two extensions are required to be available together in order to support write masks // The next two extensions are required to be available together in order to support write masks
color_write_en = add_extension(VK_EXT_COLOR_WRITE_ENABLE_EXTENSION_NAME); color_write_en = add_extension(VK_EXT_COLOR_WRITE_ENABLE_EXTENSION_NAME);
@ -399,6 +400,9 @@ bool Instance::CreateDevice() {
vk::PhysicalDevicePrimitiveTopologyListRestartFeaturesEXT{ vk::PhysicalDevicePrimitiveTopologyListRestartFeaturesEXT{
.primitiveTopologyListRestart = true, .primitiveTopologyListRestart = true,
}, },
vk::PhysicalDeviceFragmentShaderBarycentricFeaturesKHR{
.fragmentShaderBarycentric = true,
},
#ifdef __APPLE__ #ifdef __APPLE__
feature_chain.get<vk::PhysicalDevicePortabilitySubsetFeaturesKHR>(), feature_chain.get<vk::PhysicalDevicePortabilitySubsetFeaturesKHR>(),
#endif #endif
@ -438,6 +442,9 @@ bool Instance::CreateDevice() {
if (!vertex_input_dynamic_state) { if (!vertex_input_dynamic_state) {
device_chain.unlink<vk::PhysicalDeviceVertexInputDynamicStateFeaturesEXT>(); device_chain.unlink<vk::PhysicalDeviceVertexInputDynamicStateFeaturesEXT>();
} }
if (!fragment_shader_barycentric) {
device_chain.unlink<vk::PhysicalDeviceFragmentShaderBarycentricFeaturesKHR>();
}
auto [device_result, dev] = physical_device.createDeviceUnique(device_chain.get()); auto [device_result, dev] = physical_device.createDeviceUnique(device_chain.get());
if (device_result != vk::Result::eSuccess) { if (device_result != vk::Result::eSuccess) {

View file

@ -143,6 +143,11 @@ public:
return maintenance5; return maintenance5;
} }
/// Returns true when VK_KHR_fragment_shader_barycentric is supported.
bool IsFragmentShaderBarycentricSupported() const {
return fragment_shader_barycentric;
}
bool IsListRestartSupported() const { bool IsListRestartSupported() const {
return list_restart; return list_restart;
} }

View file

@ -169,6 +169,8 @@ PipelineCache::PipelineCache(const Instance& instance_, Scheduler& scheduler_,
.support_fp32_denorm_preserve = bool(vk12_props.shaderDenormPreserveFloat32), .support_fp32_denorm_preserve = bool(vk12_props.shaderDenormPreserveFloat32),
.support_fp32_denorm_flush = bool(vk12_props.shaderDenormFlushToZeroFloat32), .support_fp32_denorm_flush = bool(vk12_props.shaderDenormFlushToZeroFloat32),
.support_explicit_workgroup_layout = true, .support_explicit_workgroup_layout = true,
.needs_manual_interpolation = instance.IsFragmentShaderBarycentricSupported() &&
instance.GetDriverID() == vk::DriverId::eNvidiaProprietary,
}; };
auto [cache_result, cache] = instance.GetDevice().createPipelineCacheUnique({}); auto [cache_result, cache] = instance.GetDevice().createPipelineCacheUnique({});
ASSERT_MSG(cache_result == vk::Result::eSuccess, "Failed to create pipeline cache: {}", ASSERT_MSG(cache_result == vk::Result::eSuccess, "Failed to create pipeline cache: {}",