Fix vote and shuffle shader instructions on AMD GPUs (#5540)

* Move shuffle handling out of the backend to a transform pass

* Handle subgroup sizes higher than 32

* Stop using the subgroup size control extension

* Make GenerateShuffleFunction static

* Shader cache version bump
This commit is contained in:
gdkchan 2023-08-16 21:31:07 -03:00 committed by GitHub
parent 64079c034c
commit 6ed613a6e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 445 additions and 265 deletions

View file

@ -76,7 +76,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
switch (op.SReg)
{
case SReg.LaneId:
src = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
src = EmitLoadSubgroupLaneId(context);
break;
case SReg.InvocationId:
@ -146,19 +146,19 @@ namespace Ryujinx.Graphics.Shader.Instructions
break;
case SReg.EqMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupEqMask, null, Const(0));
src = EmitLoadSubgroupMask(context, IoVariable.SubgroupEqMask);
break;
case SReg.LtMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupLtMask, null, Const(0));
src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLtMask);
break;
case SReg.LeMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupLeMask, null, Const(0));
src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLeMask);
break;
case SReg.GtMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupGtMask, null, Const(0));
src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGtMask);
break;
case SReg.GeMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupGeMask, null, Const(0));
src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGeMask);
break;
default:
@ -169,6 +169,52 @@ namespace Ryujinx.Graphics.Shader.Instructions
context.Copy(GetDest(op.Dest), src);
}
private static Operand EmitLoadSubgroupLaneId(EmitterContext context)
{
if (context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize() <= 32)
{
return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
}
return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f));
}
private static Operand EmitLoadSubgroupMask(EmitterContext context, IoVariable ioVariable)
{
int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
if (subgroupSize <= 32)
{
return context.Load(StorageKind.Input, ioVariable, null, Const(0));
}
else if (subgroupSize == 64)
{
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand low = context.Load(StorageKind.Input, ioVariable, null, Const(0));
Operand high = context.Load(StorageKind.Input, ioVariable, null, Const(1));
return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low);
}
else
{
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand element = context.ShiftRightU32(laneId, Const(5));
Operand res = context.Load(StorageKind.Input, ioVariable, null, Const(0));
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(1)),
context.Load(StorageKind.Input, ioVariable, null, Const(1)), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(2)),
context.Load(StorageKind.Input, ioVariable, null, Const(2)), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(3)),
context.Load(StorageKind.Input, ioVariable, null, Const(3)), res);
return res;
}
}
public static void SelR(EmitterContext context)
{
InstSelR op = context.GetOp<InstSelR>();

View file

@ -50,20 +50,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
InstVote op = context.GetOp<InstVote>();
Operand pred = GetPredicate(context, op.SrcPred, op.SrcPredInv);
Operand res = null;
switch (op.VoteMode)
{
case VoteMode.All:
res = context.VoteAll(pred);
break;
case VoteMode.Any:
res = context.VoteAny(pred);
break;
case VoteMode.Eq:
res = context.VoteAllEqual(pred);
break;
}
Operand res = EmitVote(context, op.VoteMode, pred);
if (res != null)
{
@ -76,7 +63,81 @@ namespace Ryujinx.Graphics.Shader.Instructions
if (op.Dest != RegisterConsts.RegisterZeroIndex)
{
context.Copy(GetDest(op.Dest), context.Ballot(pred));
context.Copy(GetDest(op.Dest), EmitBallot(context, pred));
}
}
private static Operand EmitVote(EmitterContext context, VoteMode voteMode, Operand pred)
{
int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
if (subgroupSize <= 32)
{
return voteMode switch
{
VoteMode.All => context.VoteAll(pred),
VoteMode.Any => context.VoteAny(pred),
VoteMode.Eq => context.VoteAllEqual(pred),
_ => null,
};
}
// Emulate vote with ballot masks.
// We do that when the GPU thread count is not 32,
// since the shader code assumes it is 32.
// allInvocations => ballot(pred) == ballot(true),
// anyInvocation => ballot(pred) != 0,
// allInvocationsEqual => ballot(pred) == balot(true) || ballot(pred) == 0
Operand ballotMask = EmitBallot(context, pred);
Operand AllTrue() => context.ICompareEqual(ballotMask, EmitBallot(context, Const(IrConsts.True)));
return voteMode switch
{
VoteMode.All => AllTrue(),
VoteMode.Any => context.ICompareNotEqual(ballotMask, Const(0)),
VoteMode.Eq => context.BitwiseOr(AllTrue(), context.ICompareEqual(ballotMask, Const(0))),
_ => null,
};
}
private static Operand EmitBallot(EmitterContext context, Operand pred)
{
int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
if (subgroupSize <= 32)
{
return context.Ballot(pred, 0);
}
else if (subgroupSize == 64)
{
// TODO: Add support for vector destination and do that with a single operation.
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand low = context.Ballot(pred, 0);
Operand high = context.Ballot(pred, 1);
return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low);
}
else
{
// TODO: Add support for vector destination and do that with a single operation.
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand element = context.ShiftRightU32(laneId, Const(5));
Operand res = context.Ballot(pred, 0);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(1)),
context.Ballot(pred, 1), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(2)),
context.Ballot(pred, 2), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(3)),
context.Ballot(pred, 3), res);
return res;
}
}
}