Add a pass to turn global memory access into storage access, and do all storage related transformations on IR

This commit is contained in:
gdk 2019-11-30 23:53:09 -03:00 committed by Thog
parent 396768f3b4
commit 6a98c643ca
28 changed files with 532 additions and 282 deletions

View file

@ -11,7 +11,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
{
private enum MemoryRegion
{
Global,
Local,
Shared
}
@ -60,13 +59,20 @@ namespace Ryujinx.Graphics.Shader.Instructions
{
OpCodeAtom op = (OpCodeAtom)context.CurrOp;
Operand mem = context.ShiftRightU32(GetSrcA(context), Const(2));
Operand offset = context.ShiftRightU32(GetSrcA(context), Const(2));
mem = context.IAdd(mem, Const(op.Offset));
offset = context.IAdd(offset, Const(op.Offset));
Operand value = GetSrcB(context);
Operand res = EmitAtomicOp(context, Instruction.MrShared, op.AtomicOp, op.Type, mem, value);
Operand res = EmitAtomicOp(
context,
Instruction.MrShared,
op.AtomicOp,
op.Type,
offset,
Const(0),
value);
context.Copy(GetDest(context), res);
}
@ -148,7 +154,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
public static void Ldg(EmitterContext context)
{
EmitLoad(context, MemoryRegion.Global);
EmitLoadGlobal(context);
}
public static void Lds(EmitterContext context)
@ -183,11 +189,16 @@ namespace Ryujinx.Graphics.Shader.Instructions
{
OpCodeRed op = (OpCodeRed)context.CurrOp;
Operand offset = context.IAdd(GetSrcA(context), Const(op.Offset));
(Operand addrLow, Operand addrHigh) = Get40BitsAddress(context, op.Ra, op.Extended, op.Offset);
Operand mem = context.ShiftRightU32(offset, Const(2));
EmitAtomicOp(context, Instruction.MrGlobal, op.AtomicOp, op.Type, mem, GetDest(context));
EmitAtomicOp(
context,
Instruction.MrGlobal,
op.AtomicOp,
op.Type,
addrLow,
addrHigh,
GetDest(context));
}
public static void St(EmitterContext context)
@ -197,7 +208,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
public static void Stg(EmitterContext context)
{
EmitStore(context, MemoryRegion.Global);
EmitStoreGlobal(context);
}
public static void Sts(EmitterContext context)
@ -210,7 +221,8 @@ namespace Ryujinx.Graphics.Shader.Instructions
Instruction mr,
AtomicOp op,
ReductionType type,
Operand mem,
Operand addrLow,
Operand addrHigh,
Operand value)
{
Operand res = Const(0);
@ -220,7 +232,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
case AtomicOp.Add:
if (type == ReductionType.S32 || type == ReductionType.U32)
{
res = context.AtomicAdd(mr, mem, value);
res = context.AtomicAdd(mr, addrLow, addrHigh, value);
}
else
{
@ -230,7 +242,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
case AtomicOp.BitwiseAnd:
if (type == ReductionType.S32 || type == ReductionType.U32)
{
res = context.AtomicAnd(mr, mem, value);
res = context.AtomicAnd(mr, addrLow, addrHigh, value);
}
else
{
@ -240,7 +252,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
case AtomicOp.BitwiseExclusiveOr:
if (type == ReductionType.S32 || type == ReductionType.U32)
{
res = context.AtomicXor(mr, mem, value);
res = context.AtomicXor(mr, addrLow, addrHigh, value);
}
else
{
@ -250,7 +262,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
case AtomicOp.BitwiseOr:
if (type == ReductionType.S32 || type == ReductionType.U32)
{
res = context.AtomicOr(mr, mem, value);
res = context.AtomicOr(mr, addrLow, addrHigh, value);
}
else
{
@ -260,11 +272,11 @@ namespace Ryujinx.Graphics.Shader.Instructions
case AtomicOp.Maximum:
if (type == ReductionType.S32)
{
res = context.AtomicMaxS32(mr, mem, value);
res = context.AtomicMaxS32(mr, addrLow, addrHigh, value);
}
else if (type == ReductionType.U32)
{
res = context.AtomicMaxU32(mr, mem, value);
res = context.AtomicMaxU32(mr, addrLow, addrHigh, value);
}
else
{
@ -274,11 +286,11 @@ namespace Ryujinx.Graphics.Shader.Instructions
case AtomicOp.Minimum:
if (type == ReductionType.S32)
{
res = context.AtomicMinS32(mr, mem, value);
res = context.AtomicMinS32(mr, addrLow, addrHigh, value);
}
else if (type == ReductionType.U32)
{
res = context.AtomicMinU32(mr, mem, value);
res = context.AtomicMinU32(mr, addrLow, addrHigh, value);
}
else
{
@ -331,7 +343,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
switch (region)
{
case MemoryRegion.Global: value = context.LoadGlobal(offset); break;
case MemoryRegion.Local: value = context.LoadLocal (offset); break;
case MemoryRegion.Shared: value = context.LoadShared(offset); break;
}
@ -345,6 +356,38 @@ namespace Ryujinx.Graphics.Shader.Instructions
}
}
private static void EmitLoadGlobal(EmitterContext context)
{
OpCodeMemory op = (OpCodeMemory)context.CurrOp;
bool isSmallInt = op.Size < IntegerSize.B32;
int count = GetVectorCount(op.Size);
(Operand addrLow, Operand addrHigh) = Get40BitsAddress(context, op.Ra, op.Extended, op.Offset);
Operand bitOffset = GetBitOffset(context, addrLow);
for (int index = 0; index < count; index++)
{
Register rd = new Register(op.Rd.Index + index, RegisterType.Gpr);
if (rd.IsRZ)
{
break;
}
Operand value = context.LoadGlobal(context.IAdd(addrLow, Const(index * 4)), addrHigh);
if (isSmallInt)
{
value = ExtractSmallInt(context, op.Size, bitOffset, value);
}
context.Copy(Register(rd), value);
}
}
private static void EmitStore(EmitterContext context, MemoryRegion region)
{
OpCodeMemory op = (OpCodeMemory)context.CurrOp;
@ -384,7 +427,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
switch (region)
{
case MemoryRegion.Global: word = context.LoadGlobal(offset); break;
case MemoryRegion.Local: word = context.LoadLocal (offset); break;
case MemoryRegion.Shared: word = context.LoadShared(offset); break;
}
@ -394,7 +436,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
switch (region)
{
case MemoryRegion.Global: context.StoreGlobal(offset, value); break;
case MemoryRegion.Local: context.StoreLocal (offset, value); break;
case MemoryRegion.Shared: context.StoreShared(offset, value); break;
}
@ -406,9 +447,89 @@ namespace Ryujinx.Graphics.Shader.Instructions
}
}
private static void EmitStoreGlobal(EmitterContext context)
{
OpCodeMemory op = (OpCodeMemory)context.CurrOp;
bool isSmallInt = op.Size < IntegerSize.B32;
int count = GetVectorCount(op.Size);
(Operand addrLow, Operand addrHigh) = Get40BitsAddress(context, op.Ra, op.Extended, op.Offset);
Operand bitOffset = GetBitOffset(context, addrLow);
for (int index = 0; index < count; index++)
{
Register rd = new Register(op.Rd.Index + index, RegisterType.Gpr);
Operand value = Register(rd);
if (isSmallInt)
{
Operand word = context.LoadGlobal(addrLow, addrHigh);
value = InsertSmallInt(context, op.Size, bitOffset, word, value);
}
context.StoreGlobal(context.IAdd(addrLow, Const(index * 4)), addrHigh, value);
if (rd.IsRZ)
{
break;
}
}
}
private static int GetVectorCount(IntegerSize size)
{
switch (size)
{
case IntegerSize.B64:
return 2;
case IntegerSize.B128:
case IntegerSize.UB128:
return 4;
}
return 1;
}
private static (Operand, Operand) Get40BitsAddress(
EmitterContext context,
Register ra,
bool extended,
int offset)
{
Operand addrLow = GetSrcA(context);
Operand addrHigh;
if (extended && !ra.IsRZ)
{
addrHigh = Register(ra.Index + 1, RegisterType.Gpr);
}
else
{
addrHigh = Const(0);
}
Operand offs = Const(offset);
addrLow = context.IAdd(addrLow, offs);
if (extended)
{
Operand carry = context.ICompareLessUnsigned(addrLow, offs);
addrHigh = context.IAdd(addrHigh, context.ConditionalSelect(carry, Const(1), Const(0)));
}
return (addrLow, addrHigh);
}
private static Operand GetBitOffset(EmitterContext context, Operand baseOffset)
{
// Note: byte offset = (baseOffset & 0b11) * 8.
// Note: bit offset = (baseOffset & 0b11) * 8.
// Addresses should be always aligned to the integer type,
// so we don't need to take unaligned addresses into account.
return context.ShiftLeft(context.BitwiseAnd(baseOffset, Const(3)), Const(3));