Replace constant buffer access on shader with new Load instruction (#4646)

This commit is contained in:
gdkchan 2023-05-20 16:19:26 -03:00 committed by GitHub
parent fb27042e01
commit 402f05b8ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 788 additions and 625 deletions

View file

@ -98,7 +98,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Add(Instruction.ImageStore, GenerateImageStore);
Add(Instruction.IsNan, GenerateIsNan);
Add(Instruction.Load, GenerateLoad);
Add(Instruction.LoadConstant, GenerateLoadConstant);
Add(Instruction.LoadLocal, GenerateLoadLocal);
Add(Instruction.LoadShared, GenerateLoadShared);
Add(Instruction.LoadStorage, GenerateLoadStorage);
@ -313,10 +312,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
for (int i = 0; i < args.Length; i++)
{
var operand = (AstOperand)operation.GetSource(i + 1);
var operand = operation.GetSource(i + 1);
if (i >= function.InArguments.Length)
{
args[i] = context.GetLocalPointer(operand);
args[i] = context.GetLocalPointer((AstOperand)operand);
}
else
{
@ -867,68 +867,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return GenerateLoadOrStore(context, operation, isStore: false);
}
private static OperationResult GenerateLoadConstant(CodeGenContext context, AstOperation operation)
{
var src1 = operation.GetSource(0);
var src2 = context.Get(AggregateType.S32, operation.GetSource(1));
var i1 = context.Constant(context.TypeS32(), 0);
var i2 = context.ShiftRightArithmetic(context.TypeS32(), src2, context.Constant(context.TypeS32(), 2));
var i3 = context.BitwiseAnd(context.TypeS32(), src2, context.Constant(context.TypeS32(), 3));
SpvInstruction value = null;
if (context.Config.GpuAccessor.QueryHostHasVectorIndexingBug())
{
// Test for each component individually.
for (int i = 0; i < 4; i++)
{
var component = context.Constant(context.TypeS32(), i);
SpvInstruction elemPointer;
if (context.UniformBuffersArray != null)
{
var ubVariable = context.UniformBuffersArray;
var i0 = context.Get(AggregateType.S32, src1);
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, component);
}
else
{
var ubVariable = context.UniformBuffers[((AstOperand)src1).Value];
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, component);
}
SpvInstruction newValue = context.Load(context.TypeFP32(), elemPointer);
value = value != null ? context.Select(context.TypeFP32(), context.IEqual(context.TypeBool(), i3, component), newValue, value) : newValue;
}
}
else
{
SpvInstruction elemPointer;
if (context.UniformBuffersArray != null)
{
var ubVariable = context.UniformBuffersArray;
var i0 = context.Get(AggregateType.S32, src1);
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, i3);
}
else
{
var ubVariable = context.UniformBuffers[((AstOperand)src1).Value];
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, i3);
}
value = context.Load(context.TypeFP32(), elemPointer);
}
return new OperationResult(AggregateType.FP32, value);
}
private static OperationResult GenerateLoadLocal(CodeGenContext context, AstOperation operation)
{
return GenerateLoadLocalOrShared(context, operation, StorageClass.Private, context.LocalMemory);
@ -1990,12 +1928,32 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
StorageKind storageKind = operation.StorageKind;
SpvInstruction pointer;
StorageClass storageClass;
SpvInstruction baseObj;
AggregateType varType;
int srcIndex = 0;
switch (storageKind)
{
case StorageKind.ConstantBuffer:
if (!(operation.GetSource(srcIndex++) is AstOperand bindingIndex) || bindingIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
if (!(operation.GetSource(srcIndex) is AstOperand fieldIndex) || fieldIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
BufferDefinition buffer = context.Config.Properties.ConstantBuffers[bindingIndex.Value];
StructureField field = buffer.Type.Fields[fieldIndex.Value];
storageClass = StorageClass.Uniform;
varType = field.Type & AggregateType.ElementTypeMask;
baseObj = context.ConstantBuffers[bindingIndex.Value];
break;
case StorageKind.Input:
case StorageKind.InputPerPatch:
case StorageKind.Output:
@ -2038,33 +1996,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
varType = context.Config.GetFragmentOutputColorType(location);
}
else if (ioVariable == IoVariable.FragmentOutputIsBgra)
{
var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeU32());
var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 1), elemIndex);
varType = AggregateType.U32;
break;
}
else if (ioVariable == IoVariable.SupportBlockRenderScale)
{
var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 4), elemIndex);
varType = AggregateType.FP32;
break;
}
else if (ioVariable == IoVariable.SupportBlockViewInverse)
{
var pointerType = context.TypePointer(StorageClass.Uniform, context.TypeFP32());
var elemIndex = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(pointerType, context.SupportBuffer, context.Constant(context.TypeU32(), 2), elemIndex);
varType = AggregateType.FP32;
break;
}
else
{
(_, varType) = IoMap.GetSpirvBuiltIn(ioVariable);
@ -2072,55 +2003,57 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
varType &= AggregateType.ElementTypeMask;
int inputsCount = (isStore ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
var storageClass = isOutput ? StorageClass.Output : StorageClass.Input;
storageClass = isOutput ? StorageClass.Output : StorageClass.Input;
var ioDefinition = new IoDefinition(storageKind, ioVariable, location, component);
var dict = isPerPatch
? (isOutput ? context.OutputsPerPatch : context.InputsPerPatch)
: (isOutput ? context.Outputs : context.Inputs);
SpvInstruction baseObj = dict[ioDefinition];
SpvInstruction e0, e1, e2;
switch (inputsCount)
{
case 0:
pointer = baseObj;
break;
case 1:
e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0);
break;
case 2:
e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1);
break;
case 3:
e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
e2 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1, e2);
break;
default:
var indexes = new SpvInstruction[inputsCount];
int index = 0;
for (; index < inputsCount; srcIndex++, index++)
{
indexes[index] = context.Get(AggregateType.S32, operation.GetSource(srcIndex));
}
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, indexes);
break;
}
baseObj = dict[ioDefinition];
break;
default:
throw new InvalidOperationException($"Invalid storage kind {storageKind}.");
}
int inputsCount = (isStore ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
SpvInstruction e0, e1, e2;
SpvInstruction pointer;
switch (inputsCount)
{
case 0:
pointer = baseObj;
break;
case 1:
e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0);
break;
case 2:
e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1);
break;
case 3:
e0 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
e1 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
e2 = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, e0, e1, e2);
break;
default:
var indexes = new SpvInstruction[inputsCount];
int index = 0;
for (; index < inputsCount; srcIndex++, index++)
{
indexes[index] = context.Get(AggregateType.S32, operation.GetSource(srcIndex));
}
pointer = context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, indexes);
break;
}
if (isStore)
{
context.Store(pointer, context.Get(varType, operation.GetSource(srcIndex)));