Metal: Compute Shaders (#19)

* check for too bix texture bindings

* implement lod query

* print shader stage name

* always have fragment input

* resolve merge conflicts

* fix: lod query

* fix: casting texture coords

* support non-array memories

* use structure types for buffers

* implement compute pipeline cache

* compute dispatch

* improve error message

* rebind compute state

* bind compute textures

* pass local size as an argument to dispatch

* implement texture buffers

* hack: change vertex index to vertex id

* pass support buffer as an argument to every function

* return at the end of function

* fix: certain missing compute bindings

* implement texture base

* improve texture binding system

* remove useless exception

* move texture handle to texture base

* fix: segfault when using disposed textures

---------

Co-authored-by: Samuliak <samuliak77@gmail.com>
Co-authored-by: SamoZ256 <96914946+SamoZ256@users.noreply.github.com>
This commit is contained in:
Isaac Marovitz 2024-05-29 16:21:59 +01:00
parent 131ab75d55
commit b064d76a4f
26 changed files with 718 additions and 224 deletions

View file

@ -134,7 +134,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
case Instruction.Load:
return Load(context, operation);
case Instruction.Lod:
return "|| LOD ||";
return Lod(context, operation);
case Instruction.MemoryBarrier:
return "|| MEMORY BARRIER ||";
case Instruction.Store:

View file

@ -12,11 +12,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value);
string[] args = new string[operation.SourcesCount - 1];
int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.additionalArgCount];
for (int i = 0; i < args.Length; i++)
// Additional arguments
args[0] = "support_buffer";
int argIndex = CodeGenContext.additionalArgCount;
for (int i = 0; i < argCount; i++)
{
args[i] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
}
return $"{functon.Name}({string.Join(", ", args)})";

View file

@ -24,6 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
inputsCount--;
}
string fieldName = "";
switch (storageKind)
{
case StorageKind.ConstantBuffer:
@ -45,6 +46,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name;
}
varType = field.Type;
break;
@ -126,6 +136,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]";
}
}
varName += fieldName;
if (isStore)
{
@ -141,6 +152,37 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
return GenerateLoadOrStore(context, operation, isStore: false);
}
// TODO: check this
public static string Lod(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
int coordsCount = texOp.Type.GetDimensions();
int coordsIndex = 0;
string samplerName = GetSamplerName(context.Properties, texOp);
string coordsExpr;
if (coordsCount > 1)
{
string[] elems = new string[coordsCount];
for (int index = 0; index < coordsCount; index++)
{
elems[index] = GetSourceExpr(context, texOp.GetSource(coordsIndex + index), AggregateType.FP32);
}
coordsExpr = "float" + coordsCount + "(" + string.Join(", ", elems) + ")";
}
else
{
coordsExpr = GetSourceExpr(context, texOp.GetSource(coordsIndex), AggregateType.FP32);
}
return $"tex_{samplerName}.calculate_unclamped_lod(samp_{samplerName}, {coordsExpr}){GetMaskMultiDest(texOp.Index)}";
}
public static string Store(CodeGenContext context, AstOperation operation)
{
return GenerateLoadOrStore(context, operation, isStore: true);
@ -176,11 +218,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
}
else
{
texCall += "sample";
if (isGather)
{
texCall += "_gather";
texCall += "gather";
}
else
{
texCall += "sample";
}
if (isShadow)
@ -188,22 +232,31 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
texCall += "_compare";
}
texCall += $"(samp_{samplerName}";
texCall += $"(samp_{samplerName}, ";
}
int coordsCount = texOp.Type.GetDimensions();
int pCount = coordsCount;
bool appended = false;
void Append(string str)
{
texCall += ", " + str;
if (appended)
{
texCall += ", ";
}
else {
appended = true;
}
texCall += str;
}
AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32;
string AssemblePVector(int count)
{
string coords;
if (count > 1)
{
string[] elems = new string[count];
@ -213,14 +266,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
elems[index] = Src(coordType);
}
string prefix = intCoords ? "int" : "float";
return prefix + count + "(" + string.Join(", ", elems) + ")";
coords = string.Join(", ", elems);
}
else
{
return Src(coordType);
coords = Src(coordType);
}
string prefix = intCoords ? "uint" : "float";
return prefix + (count > 1 ? count : "") + "(" + coords + ")";
}
Append(AssemblePVector(pCount));
@ -254,6 +309,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
private static string GetMaskMultiDest(int mask)
{
if (mask == 0x0)
{
return "";
}
string swizzle = ".";
for (int i = 0; i < 4; i++)

View file

@ -35,7 +35,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_index", AggregateType.U32),
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32),
IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32),
_ => (null, AggregateType.Invalid),