Metal: Compute Shaders (#19)

* check for too bix texture bindings

* implement lod query

* print shader stage name

* always have fragment input

* resolve merge conflicts

* fix: lod query

* fix: casting texture coords

* support non-array memories

* use structure types for buffers

* implement compute pipeline cache

* compute dispatch

* improve error message

* rebind compute state

* bind compute textures

* pass local size as an argument to dispatch

* implement texture buffers

* hack: change vertex index to vertex id

* pass support buffer as an argument to every function

* return at the end of function

* fix: certain missing compute bindings

* implement texture base

* improve texture binding system

* remove useless exception

* move texture handle to texture base

* fix: segfault when using disposed textures

---------

Co-authored-by: Samuliak <samuliak77@gmail.com>
Co-authored-by: SamoZ256 <96914946+SamoZ256@users.noreply.github.com>
This commit is contained in:
Isaac Marovitz 2024-05-29 16:21:59 +01:00
parent 131ab75d55
commit b064d76a4f
26 changed files with 718 additions and 224 deletions

View file

@ -48,6 +48,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
PrintBlock(context, function.MainBlock, isMainFunc);
// In case the shader hasn't returned, return
if (isMainFunc && stage != ShaderStage.Compute)
{
context.AppendLine("return out;");
}
context.LeaveScope();
}
@ -57,11 +63,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage,
bool isMainFunc = false)
{
string[] args = new string[function.InArguments.Length + function.OutArguments.Length];
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.additionalArgCount;
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc)
{
args[0] = "constant Struct_support_buffer* support_buffer";
}
int argIndex = additionalArgCount;
for (int i = 0; i < function.InArguments.Length; i++)
{
args[i] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
args[argIndex++] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
}
for (int i = 0; i < function.OutArguments.Length; i++)
@ -69,7 +84,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
int j = i + function.InArguments.Length;
// Likely need to be made into pointers
args[j] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
args[argIndex++] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
}
string funcKeyword = "inline";
@ -97,20 +112,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
returnType = "void";
}
if (context.AttributeUsage.UsedInputAttributes != 0)
if (stage == ShaderStage.Vertex)
{
if (stage == ShaderStage.Vertex)
if (context.AttributeUsage.UsedInputAttributes != 0)
{
args = args.Prepend("VertexIn in [[stage_in]]").ToArray();
}
else if (stage == ShaderStage.Fragment)
{
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
}
else if (stage == ShaderStage.Compute)
{
args = args.Prepend("KernelIn in [[stage_in]]").ToArray();
}
}
else if (stage == ShaderStage.Fragment)
{
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
}
// TODO: add these only if they are used
@ -119,18 +130,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
args = args.Append("uint instance_id [[instance_id]]").ToArray();
}
else if (stage == ShaderStage.Compute)
{
args = args.Append("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_grid [[thread_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]").ToArray();
}
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{
var varType = constantBuffer.Type.Fields[0].Type & ~AggregateType.Array;
args = args.Append($"constant {Declarations.GetVarTypeName(context, varType)} *{constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
}
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{
var varType = storageBuffers.Type.Fields[0].Type & ~AggregateType.Array;
// Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device {Declarations.GetVarTypeName(context, varType)} *{storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
}
foreach (var texture in context.Properties.Textures.Values)