Metal: Compute Shaders (#19)

* check for too bix texture bindings

* implement lod query

* print shader stage name

* always have fragment input

* resolve merge conflicts

* fix: lod query

* fix: casting texture coords

* support non-array memories

* use structure types for buffers

* implement compute pipeline cache

* compute dispatch

* improve error message

* rebind compute state

* bind compute textures

* pass local size as an argument to dispatch

* implement texture buffers

* hack: change vertex index to vertex id

* pass support buffer as an argument to every function

* return at the end of function

* fix: certain missing compute bindings

* implement texture base

* improve texture binding system

* remove useless exception

* move texture handle to texture base

* fix: segfault when using disposed textures

---------

Co-authored-by: Samuliak <samuliak77@gmail.com>
Co-authored-by: SamoZ256 <96914946+SamoZ256@users.noreply.github.com>
This commit is contained in:
Isaac Marovitz 2024-05-29 16:21:59 +01:00
parent 131ab75d55
commit b064d76a4f
26 changed files with 718 additions and 224 deletions

View file

@ -8,6 +8,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int additionalArgCount = 1;
public StructuredFunction CurrentFunction { get; set; }
public StructuredProgramInfo Info { get; }

View file

@ -54,6 +54,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input)));
context.AppendLine();
DeclareOutputAttributes(context, info.IoDefinitions.Where(x => x.StorageKind == StorageKind.Output));
context.AppendLine();
DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values);
DeclareBufferStructures(context, context.Properties.StorageBuffers.Values);
}
static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind)
@ -111,8 +114,41 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
foreach (var memory in memories)
{
string arraySize = "";
if ((memory.Type & AggregateType.Array) != 0)
{
arraySize = $"[{memory.ArrayLength}]";
}
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {memory.Name}[{memory.ArrayLength}];");
context.AppendLine($"{typeName} {memory.Name}{arraySize};");
}
}
private static void DeclareBufferStructures(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
{
foreach (BufferDefinition buffer in buffers)
{
context.AppendLine($"struct Struct_{buffer.Name}");
context.EnterScope();
foreach (StructureField field in buffer.Type.Fields)
{
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
}
else
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name};");
}
}
context.LeaveScope(";");
context.AppendLine();
}
}
@ -124,7 +160,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
}
else
{
if (inputs.Any())
if (inputs.Any() || context.Definitions.Stage == ShaderStage.Fragment)
{
string prefix = "";
@ -136,9 +172,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
case ShaderStage.Fragment:
context.AppendLine($"struct FragmentIn");
break;
case ShaderStage.Compute:
context.AppendLine($"struct KernelIn");
break;
}
context.EnterScope();

View file

@ -134,7 +134,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
case Instruction.Load:
return Load(context, operation);
case Instruction.Lod:
return "|| LOD ||";
return Lod(context, operation);
case Instruction.MemoryBarrier:
return "|| MEMORY BARRIER ||";
case Instruction.Store:

View file

@ -12,11 +12,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value);
string[] args = new string[operation.SourcesCount - 1];
int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.additionalArgCount];
for (int i = 0; i < args.Length; i++)
// Additional arguments
args[0] = "support_buffer";
int argIndex = CodeGenContext.additionalArgCount;
for (int i = 0; i < argCount; i++)
{
args[i] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
}
return $"{functon.Name}({string.Join(", ", args)})";

View file

@ -24,6 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
inputsCount--;
}
string fieldName = "";
switch (storageKind)
{
case StorageKind.ConstantBuffer:
@ -45,6 +46,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name;
}
varType = field.Type;
break;
@ -126,6 +136,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]";
}
}
varName += fieldName;
if (isStore)
{
@ -141,6 +152,37 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
return GenerateLoadOrStore(context, operation, isStore: false);
}
// TODO: check this
public static string Lod(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
int coordsCount = texOp.Type.GetDimensions();
int coordsIndex = 0;
string samplerName = GetSamplerName(context.Properties, texOp);
string coordsExpr;
if (coordsCount > 1)
{
string[] elems = new string[coordsCount];
for (int index = 0; index < coordsCount; index++)
{
elems[index] = GetSourceExpr(context, texOp.GetSource(coordsIndex + index), AggregateType.FP32);
}
coordsExpr = "float" + coordsCount + "(" + string.Join(", ", elems) + ")";
}
else
{
coordsExpr = GetSourceExpr(context, texOp.GetSource(coordsIndex), AggregateType.FP32);
}
return $"tex_{samplerName}.calculate_unclamped_lod(samp_{samplerName}, {coordsExpr}){GetMaskMultiDest(texOp.Index)}";
}
public static string Store(CodeGenContext context, AstOperation operation)
{
return GenerateLoadOrStore(context, operation, isStore: true);
@ -176,11 +218,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
}
else
{
texCall += "sample";
if (isGather)
{
texCall += "_gather";
texCall += "gather";
}
else
{
texCall += "sample";
}
if (isShadow)
@ -188,22 +232,31 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
texCall += "_compare";
}
texCall += $"(samp_{samplerName}";
texCall += $"(samp_{samplerName}, ";
}
int coordsCount = texOp.Type.GetDimensions();
int pCount = coordsCount;
bool appended = false;
void Append(string str)
{
texCall += ", " + str;
if (appended)
{
texCall += ", ";
}
else {
appended = true;
}
texCall += str;
}
AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32;
string AssemblePVector(int count)
{
string coords;
if (count > 1)
{
string[] elems = new string[count];
@ -213,14 +266,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
elems[index] = Src(coordType);
}
string prefix = intCoords ? "int" : "float";
return prefix + count + "(" + string.Join(", ", elems) + ")";
coords = string.Join(", ", elems);
}
else
{
return Src(coordType);
coords = Src(coordType);
}
string prefix = intCoords ? "uint" : "float";
return prefix + (count > 1 ? count : "") + "(" + coords + ")";
}
Append(AssemblePVector(pCount));
@ -254,6 +309,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
private static string GetMaskMultiDest(int mask)
{
if (mask == 0x0)
{
return "";
}
string swizzle = ".";
for (int i = 0; i < 4; i++)

View file

@ -35,7 +35,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_index", AggregateType.U32),
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32),
IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32),
_ => (null, AggregateType.Invalid),

View file

@ -48,6 +48,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
PrintBlock(context, function.MainBlock, isMainFunc);
// In case the shader hasn't returned, return
if (isMainFunc && stage != ShaderStage.Compute)
{
context.AppendLine("return out;");
}
context.LeaveScope();
}
@ -57,11 +63,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage,
bool isMainFunc = false)
{
string[] args = new string[function.InArguments.Length + function.OutArguments.Length];
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.additionalArgCount;
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc)
{
args[0] = "constant Struct_support_buffer* support_buffer";
}
int argIndex = additionalArgCount;
for (int i = 0; i < function.InArguments.Length; i++)
{
args[i] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
args[argIndex++] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
}
for (int i = 0; i < function.OutArguments.Length; i++)
@ -69,7 +84,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
int j = i + function.InArguments.Length;
// Likely need to be made into pointers
args[j] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
args[argIndex++] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
}
string funcKeyword = "inline";
@ -97,20 +112,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
returnType = "void";
}
if (context.AttributeUsage.UsedInputAttributes != 0)
if (stage == ShaderStage.Vertex)
{
if (stage == ShaderStage.Vertex)
if (context.AttributeUsage.UsedInputAttributes != 0)
{
args = args.Prepend("VertexIn in [[stage_in]]").ToArray();
}
else if (stage == ShaderStage.Fragment)
{
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
}
else if (stage == ShaderStage.Compute)
{
args = args.Prepend("KernelIn in [[stage_in]]").ToArray();
}
}
else if (stage == ShaderStage.Fragment)
{
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
}
// TODO: add these only if they are used
@ -119,18 +130,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
args = args.Append("uint instance_id [[instance_id]]").ToArray();
}
else if (stage == ShaderStage.Compute)
{
args = args.Append("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_grid [[thread_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]").ToArray();
}
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{
var varType = constantBuffer.Type.Fields[0].Type & ~AggregateType.Array;
args = args.Append($"constant {Declarations.GetVarTypeName(context, varType)} *{constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
}
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{
var varType = storageBuffers.Type.Fields[0].Type & ~AggregateType.Array;
// Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device {Declarations.GetVarTypeName(context, varType)} *{storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
}
foreach (var texture in context.Properties.Textures.Values)