Use vector outputs for texture operations (#3939)

* Change AggregateType to include vector type counts

* Replace VariableType uses with AggregateType and delete VariableType

* Support new local vector types on SPIR-V and GLSL

* Start using vector outputs for texture operations

* Use vectors on more texture operations

* Use vector output for ImageLoad operations

* Replace all uses of single destination texture constructors with multi destination ones

* Update textureGatherOffsets replacement to split vector operations

* Shader cache version bump

Co-authored-by: Ac_K <Acoustik666@gmail.com>
This commit is contained in:
gdkchan 2022-12-29 12:09:34 -03:00 committed by GitHub
parent 52c115a1f8
commit 9dfe81770a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 1100 additions and 747 deletions

View file

@ -241,6 +241,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
throw new NotImplementedException(node.GetType().Name);
}
public Instruction GetWithType(IAstNode node, out AggregateType type)
{
if (node is AstOperation operation)
{
var opResult = Instructions.Generate(this, operation);
type = opResult.Type;
return opResult.Value;
}
else if (node is AstOperand operand)
{
switch (operand.Type)
{
case IrOperandType.LocalVariable:
type = operand.VarType;
return GetLocal(type, operand);
default:
throw new ArgumentException($"Invalid operand type \"{operand.Type}\".");
}
}
throw new NotImplementedException(node.GetType().Name);
}
private Instruction GetUndefined(AggregateType type)
{
return type switch
@ -325,7 +348,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
if (components > 1)
{
attrOffset &= ~0xf;
type = AggregateType.Vector | AggregateType.FP32;
type = components switch
{
2 => AggregateType.Vector2 | AggregateType.FP32,
3 => AggregateType.Vector3 | AggregateType.FP32,
4 => AggregateType.Vector4 | AggregateType.FP32,
_ => AggregateType.FP32
};
attrInfo = new AttributeInfo(attrOffset, (attr - attrOffset) / 4, components, type, false);
}
}
@ -335,7 +365,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
bool isIndexed = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr) && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr));
if ((type & (AggregateType.Array | AggregateType.Vector)) == 0)
if ((type & (AggregateType.Array | AggregateType.ElementCountMask)) == 0)
{
if (invocationId != null)
{
@ -452,7 +482,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
elemType = attrInfo.Type & AggregateType.ElementTypeMask;
if ((attrInfo.Type & (AggregateType.Array | AggregateType.Vector)) == 0)
if ((attrInfo.Type & (AggregateType.Array | AggregateType.ElementCountMask)) == 0)
{
return ioVariable;
}
@ -533,13 +563,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
public Instruction GetLocal(AggregateType dstType, AstOperand local)
{
var srcType = local.VarType.Convert();
var srcType = local.VarType;
return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetLocalPointer(local)));
}
public Instruction GetArgument(AggregateType dstType, AstOperand funcArg)
{
var srcType = funcArg.VarType.Convert();
var srcType = funcArg.VarType;
return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetArgumentPointer(funcArg)));
}
@ -550,13 +580,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
public Instruction GetType(AggregateType type, int length = 1)
{
if (type.HasFlag(AggregateType.Array))
if ((type & AggregateType.Array) != 0)
{
return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
}
else if (type.HasFlag(AggregateType.Vector))
else if ((type & AggregateType.ElementCountMask) != 0)
{
return TypeVector(GetType(type & ~AggregateType.Vector), length);
int vectorLength = (type & AggregateType.ElementCountMask) switch
{
AggregateType.Vector2 => 2,
AggregateType.Vector3 => 3,
AggregateType.Vector4 => 4,
_ => 1
};
return TypeVector(GetType(type & ~AggregateType.ElementCountMask), vectorLength);
}
return type switch

View file

@ -23,11 +23,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
DeclareParameters(context, function.OutArguments, function.InArguments.Length);
}
private static void DeclareParameters(CodeGenContext context, IEnumerable<VariableType> argTypes, int argIndex)
private static void DeclareParameters(CodeGenContext context, IEnumerable<AggregateType> argTypes, int argIndex)
{
foreach (var argType in argTypes)
{
var argPointerType = context.TypePointer(StorageClass.Function, context.GetType(argType.Convert()));
var argPointerType = context.TypePointer(StorageClass.Function, context.GetType(argType));
var spvArg = context.FunctionParameter(argPointerType);
context.DeclareArgument(argIndex++, spvArg);
@ -38,7 +38,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
foreach (AstOperand local in function.Locals)
{
var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(local.VarType.Convert()));
var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(local.VarType));
var spvLocal = context.Variable(localPointerType, StorageClass.Function);
context.AddLocalVariable(spvLocal);
@ -62,7 +62,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
for (int i = 0; i < function.InArguments.Length; i++)
{
var type = function.GetArgumentType(i).Convert();
var type = function.GetArgumentType(i);
var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(type));
var spvLocal = context.Variable(localPointerType, StorageClass.Function);
@ -303,7 +303,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var dim = GetDim(descriptor.Type);
var imageType = context.TypeImage(
context.GetType(meta.Format.GetComponentType().Convert()),
context.GetType(meta.Format.GetComponentType()),
dim,
descriptor.Type.HasFlag(SamplerType.Shadow),
descriptor.Type.HasFlag(SamplerType.Array),
@ -652,7 +652,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
if (components > 1)
{
attr &= ~0xf;
type = AggregateType.Vector | AggregateType.FP32;
type = components switch
{
2 => AggregateType.Vector2 | AggregateType.FP32,
3 => AggregateType.Vector3 | AggregateType.FP32,
4 => AggregateType.Vector4 | AggregateType.FP32,
_ => AggregateType.FP32
};
hasComponent = false;
}
}

View file

@ -1,5 +1,4 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using Ryujinx.Graphics.Shader.Translation;
using System;
using static Spv.Specification;
@ -7,20 +6,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
static class EnumConversion
{
public static AggregateType Convert(this VariableType type)
{
return type switch
{
VariableType.None => AggregateType.Void,
VariableType.Bool => AggregateType.Bool,
VariableType.F32 => AggregateType.FP32,
VariableType.F64 => AggregateType.FP64,
VariableType.S32 => AggregateType.S32,
VariableType.U32 => AggregateType.U32,
_ => throw new ArgumentException($"Invalid variable type \"{type}\".")
};
}
public static ExecutionModel Convert(this ShaderStage stage)
{
return stage switch

View file

@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Numerics;
using static Spv.Specification;
namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
@ -146,6 +147,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Add(Instruction.Truncate, GenerateTruncate);
Add(Instruction.UnpackDouble2x32, GenerateUnpackDouble2x32);
Add(Instruction.UnpackHalf2x16, GenerateUnpackHalf2x16);
Add(Instruction.VectorExtract, GenerateVectorExtract);
Add(Instruction.VoteAll, GenerateVoteAll);
Add(Instruction.VoteAllEqual, GenerateVoteAllEqual);
Add(Instruction.VoteAny, GenerateVoteAny);
@ -317,7 +319,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
else
{
var type = function.GetArgumentType(i).Convert();
var type = function.GetArgumentType(i);
var value = context.Get(type, operand);
var spvLocal = spvLocals[i];
@ -327,7 +329,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
}
var retType = function.ReturnType.Convert();
var retType = function.ReturnType;
var result = context.FunctionCall(context.GetType(retType), spvFunc, args);
return new OperationResult(retType, result);
}
@ -604,10 +606,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
// TODO: Bindless texture support. For now we just return 0/do nothing.
if (isBindless)
{
return new OperationResult(componentType.Convert(), componentType switch
return new OperationResult(componentType, componentType switch
{
VariableType.S32 => context.Constant(context.TypeS32(), 0),
VariableType.U32 => context.Constant(context.TypeU32(), 0u),
AggregateType.S32 => context.Constant(context.TypeS32(), 0),
AggregateType.U32 => context.Constant(context.TypeU32(), 0u),
_ => context.Constant(context.TypeFP32(), 0f),
});
}
@ -652,13 +654,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
pCoords = Src(AggregateType.S32);
}
SpvInstruction value = Src(componentType.Convert());
SpvInstruction value = Src(componentType);
(var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)];
var image = context.Load(imageType, imageVariable);
SpvInstruction resultType = context.GetType(componentType.Convert());
SpvInstruction resultType = context.GetType(componentType);
SpvInstruction imagePointerType = context.TypePointer(StorageClass.Image, resultType);
var pointer = context.ImageTexelPointer(imagePointerType, imageVariable, pCoords, context.Constant(context.TypeU32(), 0));
@ -668,10 +670,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var result = (texOp.Flags & TextureFlags.AtomicMask) switch
{
TextureFlags.Add => context.AtomicIAdd(resultType, pointer, one, zero, value),
TextureFlags.Minimum => componentType == VariableType.S32
TextureFlags.Minimum => componentType == AggregateType.S32
? context.AtomicSMin(resultType, pointer, one, zero, value)
: context.AtomicUMin(resultType, pointer, one, zero, value),
TextureFlags.Maximum => componentType == VariableType.S32
TextureFlags.Maximum => componentType == AggregateType.S32
? context.AtomicSMax(resultType, pointer, one, zero, value)
: context.AtomicUMax(resultType, pointer, one, zero, value),
TextureFlags.Increment => context.AtomicIIncrement(resultType, pointer, one, zero),
@ -680,11 +682,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
TextureFlags.BitwiseOr => context.AtomicOr(resultType, pointer, one, zero, value),
TextureFlags.BitwiseXor => context.AtomicXor(resultType, pointer, one, zero, value),
TextureFlags.Swap => context.AtomicExchange(resultType, pointer, one, zero, value),
TextureFlags.CAS => context.AtomicCompareExchange(resultType, pointer, one, zero, zero, Src(componentType.Convert()), value),
TextureFlags.CAS => context.AtomicCompareExchange(resultType, pointer, one, zero, zero, Src(componentType), value),
_ => context.AtomicIAdd(resultType, pointer, one, zero, value),
};
return new OperationResult(componentType.Convert(), result);
return new OperationResult(componentType, result);
}
private static OperationResult GenerateImageLoad(CodeGenContext context, AstOperation operation)
@ -698,14 +700,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
// TODO: Bindless texture support. For now we just return 0/do nothing.
if (isBindless)
{
var zero = componentType switch
{
VariableType.S32 => context.Constant(context.TypeS32(), 0),
VariableType.U32 => context.Constant(context.TypeU32(), 0u),
_ => context.Constant(context.TypeFP32(), 0f),
};
return new OperationResult(componentType.Convert(), zero);
return GetZeroOperationResult(context, texOp, componentType, isVector: true);
}
bool isArray = (texOp.Type & SamplerType.Array) != 0;
@ -753,12 +748,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
(var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)];
var image = context.Load(imageType, imageVariable);
var imageComponentType = context.GetType(componentType.Convert());
var imageComponentType = context.GetType(componentType);
var swizzledResultType = texOp.GetVectorType(componentType);
var texel = context.ImageRead(context.TypeVector(imageComponentType, 4), image, pCoords, ImageOperandsMask.MaskNone);
var result = context.CompositeExtract(imageComponentType, texel, (SpvLiteralInteger)texOp.Index);
var result = GetSwizzledResult(context, texel, swizzledResultType, texOp.Index);
return new OperationResult(componentType.Convert(), result);
return new OperationResult(componentType, result);
}
private static OperationResult GenerateImageStore(CodeGenContext context, AstOperation operation)
@ -823,20 +819,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
if (srcIndex < texOp.SourcesCount)
{
cElems[i] = Src(componentType.Convert());
cElems[i] = Src(componentType);
}
else
{
cElems[i] = componentType switch
{
VariableType.S32 => context.Constant(context.TypeS32(), 0),
VariableType.U32 => context.Constant(context.TypeU32(), 0u),
AggregateType.S32 => context.Constant(context.TypeS32(), 0),
AggregateType.U32 => context.Constant(context.TypeU32(), 0u),
_ => context.Constant(context.TypeFP32(), 0f),
};
}
}
var texel = context.CompositeConstruct(context.TypeVector(context.GetType(componentType.Convert()), ComponentsCount), cElems);
var texel = context.CompositeConstruct(context.TypeVector(context.GetType(componentType), ComponentsCount), cElems);
(var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)];
@ -1238,7 +1234,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result);
}
@ -1268,7 +1264,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result);
}
@ -1294,7 +1290,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result);
}
@ -1324,7 +1320,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid));
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result);
}
@ -1485,10 +1481,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
bool isMultisample = (texOp.Type & SamplerType.Multisample) != 0;
bool isShadow = (texOp.Type & SamplerType.Shadow) != 0;
bool colorIsVector = isGather || !isShadow;
// TODO: Bindless texture support. For now we just return 0.
if (isBindless)
{
return new OperationResult(AggregateType.FP32, context.Constant(context.TypeFP32(), 0f));
return GetZeroOperationResult(context, texOp, AggregateType.FP32, colorIsVector);
}
// This combination is valid, but not available on GLSL.
@ -1705,7 +1703,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
operandsList.Add(sample);
}
bool colorIsVector = isGather || !isShadow;
var resultType = colorIsVector ? context.TypeVector(context.TypeFP32(), 4) : context.TypeFP32();
var meta = new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format);
@ -1758,12 +1755,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
result = context.ImageSampleImplicitLod(resultType, image, pCoords, operandsMask, operands);
}
var swizzledResultType = AggregateType.FP32;
if (colorIsVector)
{
result = context.CompositeExtract(context.TypeFP32(), result, (SpvLiteralInteger)texOp.Index);
swizzledResultType = texOp.GetVectorType(swizzledResultType);
result = GetSwizzledResult(context, result, swizzledResultType, texOp.Index);
}
return new OperationResult(AggregateType.FP32, result);
return new OperationResult(swizzledResultType, result);
}
private static OperationResult GenerateTextureSize(CodeGenContext context, AstOperation operation)
@ -1862,6 +1863,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return new OperationResult(AggregateType.FP32, result);
}
private static OperationResult GenerateVectorExtract(CodeGenContext context, AstOperation operation)
{
var vector = context.GetWithType(operation.GetSource(0), out AggregateType vectorType);
var scalarType = vectorType & ~AggregateType.ElementCountMask;
var resultType = context.GetType(scalarType);
SpvInstruction result;
if (operation.GetSource(1) is AstOperand indexOperand && indexOperand.Type == OperandType.Constant)
{
result = context.CompositeExtract(resultType, vector, (SpvLiteralInteger)indexOperand.Value);
}
else
{
var index = context.Get(AggregateType.S32, operation.GetSource(1));
result = context.VectorExtractDynamic(resultType, vector, index);
}
return new OperationResult(scalarType, result);
}
private static OperationResult GenerateVoteAll(CodeGenContext context, AstOperation operation)
{
var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
@ -2044,6 +2065,64 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
context.AddLabel(loopEnd);
}
private static OperationResult GetZeroOperationResult(
CodeGenContext context,
AstTextureOperation texOp,
AggregateType scalarType,
bool isVector)
{
var zero = scalarType switch
{
AggregateType.S32 => context.Constant(context.TypeS32(), 0),
AggregateType.U32 => context.Constant(context.TypeU32(), 0u),
_ => context.Constant(context.TypeFP32(), 0f),
};
if (isVector)
{
AggregateType outputType = texOp.GetVectorType(scalarType);
if ((outputType & AggregateType.ElementCountMask) != 0)
{
int componentsCount = BitOperations.PopCount((uint)texOp.Index);
SpvInstruction[] values = new SpvInstruction[componentsCount];
values.AsSpan().Fill(zero);
return new OperationResult(outputType, context.ConstantComposite(context.GetType(outputType), values));
}
}
return new OperationResult(scalarType, zero);
}
private static SpvInstruction GetSwizzledResult(CodeGenContext context, SpvInstruction vector, AggregateType swizzledResultType, int mask)
{
if ((swizzledResultType & AggregateType.ElementCountMask) != 0)
{
SpvLiteralInteger[] components = new SpvLiteralInteger[BitOperations.PopCount((uint)mask)];
int componentIndex = 0;
for (int i = 0; i < 4; i++)
{
if ((mask & (1 << i)) != 0)
{
components[componentIndex++] = i;
}
}
return context.VectorShuffle(context.GetType(swizzledResultType), vector, vector, components);
}
else
{
int componentIndex = (int)BitOperations.TrailingZeroCount(mask);
return context.CompositeExtract(context.GetType(swizzledResultType), vector, (SpvLiteralInteger)componentIndex);
}
}
private static SpvInstruction GetStorageElemPointer(CodeGenContext context, AstOperation operation)
{
var sbVariable = context.StorageBuffersArray;

View file

@ -104,13 +104,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++)
{
var function = info.Functions[funcIndex];
var retType = context.GetType(function.ReturnType.Convert());
var retType = context.GetType(function.ReturnType);
var funcArgs = new SpvInstruction[function.InArguments.Length + function.OutArguments.Length];
for (int argIndex = 0; argIndex < funcArgs.Length; argIndex++)
{
var argType = context.GetType(function.GetArgumentType(argIndex).Convert());
var argType = context.GetType(function.GetArgumentType(argIndex));
var argPointerType = context.TypePointer(StorageClass.Function, argType);
funcArgs[argIndex] = argPointerType;
}
@ -387,7 +387,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
if (dest.Type == OperandType.LocalVariable)
{
var source = context.Get(dest.VarType.Convert(), assignment.Source);
var source = context.Get(dest.VarType, assignment.Source);
context.Store(context.GetLocalPointer(dest), source);
}
else if (dest.Type == OperandType.Attribute || dest.Type == OperandType.AttributePerPatch)
@ -407,7 +407,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
else if (dest.Type == OperandType.Argument)
{
var source = context.Get(dest.VarType.Convert(), assignment.Source);
var source = context.Get(dest.VarType, assignment.Source);
context.Store(context.GetArgumentPointer(dest), source);
}
else