Metal: Compute Shaders (#19)
* check for too bix texture bindings * implement lod query * print shader stage name * always have fragment input * resolve merge conflicts * fix: lod query * fix: casting texture coords * support non-array memories * use structure types for buffers * implement compute pipeline cache * compute dispatch * improve error message * rebind compute state * bind compute textures * pass local size as an argument to dispatch * implement texture buffers * hack: change vertex index to vertex id * pass support buffer as an argument to every function * return at the end of function * fix: certain missing compute bindings * implement texture base * improve texture binding system * remove useless exception * move texture handle to texture base * fix: segfault when using disposed textures --------- Co-authored-by: Samuliak <samuliak77@gmail.com> Co-authored-by: SamoZ256 <96914946+SamoZ256@users.noreply.github.com>
This commit is contained in:
parent
131ab75d55
commit
b064d76a4f
26 changed files with 718 additions and 224 deletions
|
@ -8,6 +8,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
{
|
||||
public const string Tab = " ";
|
||||
|
||||
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
|
||||
public const int additionalArgCount = 1;
|
||||
|
||||
public StructuredFunction CurrentFunction { get; set; }
|
||||
|
||||
public StructuredProgramInfo Info { get; }
|
||||
|
|
|
@ -54,6 +54,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input)));
|
||||
context.AppendLine();
|
||||
DeclareOutputAttributes(context, info.IoDefinitions.Where(x => x.StorageKind == StorageKind.Output));
|
||||
context.AppendLine();
|
||||
DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values);
|
||||
DeclareBufferStructures(context, context.Properties.StorageBuffers.Values);
|
||||
}
|
||||
|
||||
static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind)
|
||||
|
@ -111,8 +114,41 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
{
|
||||
foreach (var memory in memories)
|
||||
{
|
||||
string arraySize = "";
|
||||
if ((memory.Type & AggregateType.Array) != 0)
|
||||
{
|
||||
arraySize = $"[{memory.ArrayLength}]";
|
||||
}
|
||||
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
|
||||
context.AppendLine($"{typeName} {memory.Name}[{memory.ArrayLength}];");
|
||||
context.AppendLine($"{typeName} {memory.Name}{arraySize};");
|
||||
}
|
||||
}
|
||||
|
||||
private static void DeclareBufferStructures(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
|
||||
{
|
||||
foreach (BufferDefinition buffer in buffers)
|
||||
{
|
||||
context.AppendLine($"struct Struct_{buffer.Name}");
|
||||
context.EnterScope();
|
||||
|
||||
foreach (StructureField field in buffer.Type.Fields)
|
||||
{
|
||||
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
|
||||
{
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
|
||||
}
|
||||
else
|
||||
{
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name};");
|
||||
}
|
||||
}
|
||||
|
||||
context.LeaveScope(";");
|
||||
context.AppendLine();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,7 +160,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
}
|
||||
else
|
||||
{
|
||||
if (inputs.Any())
|
||||
if (inputs.Any() || context.Definitions.Stage == ShaderStage.Fragment)
|
||||
{
|
||||
string prefix = "";
|
||||
|
||||
|
@ -136,9 +172,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
case ShaderStage.Fragment:
|
||||
context.AppendLine($"struct FragmentIn");
|
||||
break;
|
||||
case ShaderStage.Compute:
|
||||
context.AppendLine($"struct KernelIn");
|
||||
break;
|
||||
}
|
||||
|
||||
context.EnterScope();
|
||||
|
|
|
@ -134,7 +134,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
case Instruction.Load:
|
||||
return Load(context, operation);
|
||||
case Instruction.Lod:
|
||||
return "|| LOD ||";
|
||||
return Lod(context, operation);
|
||||
case Instruction.MemoryBarrier:
|
||||
return "|| MEMORY BARRIER ||";
|
||||
case Instruction.Store:
|
||||
|
|
|
@ -12,11 +12,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
|
||||
var functon = context.GetFunction(funcId.Value);
|
||||
|
||||
string[] args = new string[operation.SourcesCount - 1];
|
||||
int argCount = operation.SourcesCount - 1;
|
||||
string[] args = new string[argCount + CodeGenContext.additionalArgCount];
|
||||
|
||||
for (int i = 0; i < args.Length; i++)
|
||||
// Additional arguments
|
||||
args[0] = "support_buffer";
|
||||
|
||||
int argIndex = CodeGenContext.additionalArgCount;
|
||||
for (int i = 0; i < argCount; i++)
|
||||
{
|
||||
args[i] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
|
||||
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
|
||||
}
|
||||
|
||||
return $"{functon.Name}({string.Join(", ", args)})";
|
||||
|
|
|
@ -24,6 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
inputsCount--;
|
||||
}
|
||||
|
||||
string fieldName = "";
|
||||
switch (storageKind)
|
||||
{
|
||||
case StorageKind.ConstantBuffer:
|
||||
|
@ -45,6 +46,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
|
||||
StructureField field = buffer.Type.Fields[fieldIndex.Value];
|
||||
varName = buffer.Name;
|
||||
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
|
||||
{
|
||||
// Unsized array, the buffer is indexed instead of the field
|
||||
fieldName = "." + field.Name;
|
||||
}
|
||||
else
|
||||
{
|
||||
varName += "->" + field.Name;
|
||||
}
|
||||
varType = field.Type;
|
||||
break;
|
||||
|
||||
|
@ -126,6 +136,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]";
|
||||
}
|
||||
}
|
||||
varName += fieldName;
|
||||
|
||||
if (isStore)
|
||||
{
|
||||
|
@ -141,6 +152,37 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
return GenerateLoadOrStore(context, operation, isStore: false);
|
||||
}
|
||||
|
||||
// TODO: check this
|
||||
public static string Lod(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
AstTextureOperation texOp = (AstTextureOperation)operation;
|
||||
|
||||
int coordsCount = texOp.Type.GetDimensions();
|
||||
int coordsIndex = 0;
|
||||
|
||||
string samplerName = GetSamplerName(context.Properties, texOp);
|
||||
|
||||
string coordsExpr;
|
||||
|
||||
if (coordsCount > 1)
|
||||
{
|
||||
string[] elems = new string[coordsCount];
|
||||
|
||||
for (int index = 0; index < coordsCount; index++)
|
||||
{
|
||||
elems[index] = GetSourceExpr(context, texOp.GetSource(coordsIndex + index), AggregateType.FP32);
|
||||
}
|
||||
|
||||
coordsExpr = "float" + coordsCount + "(" + string.Join(", ", elems) + ")";
|
||||
}
|
||||
else
|
||||
{
|
||||
coordsExpr = GetSourceExpr(context, texOp.GetSource(coordsIndex), AggregateType.FP32);
|
||||
}
|
||||
|
||||
return $"tex_{samplerName}.calculate_unclamped_lod(samp_{samplerName}, {coordsExpr}){GetMaskMultiDest(texOp.Index)}";
|
||||
}
|
||||
|
||||
public static string Store(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
return GenerateLoadOrStore(context, operation, isStore: true);
|
||||
|
@ -176,11 +218,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
}
|
||||
else
|
||||
{
|
||||
texCall += "sample";
|
||||
|
||||
if (isGather)
|
||||
{
|
||||
texCall += "_gather";
|
||||
texCall += "gather";
|
||||
}
|
||||
else
|
||||
{
|
||||
texCall += "sample";
|
||||
}
|
||||
|
||||
if (isShadow)
|
||||
|
@ -188,22 +232,31 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
texCall += "_compare";
|
||||
}
|
||||
|
||||
texCall += $"(samp_{samplerName}";
|
||||
texCall += $"(samp_{samplerName}, ";
|
||||
}
|
||||
|
||||
int coordsCount = texOp.Type.GetDimensions();
|
||||
|
||||
int pCount = coordsCount;
|
||||
|
||||
bool appended = false;
|
||||
void Append(string str)
|
||||
{
|
||||
texCall += ", " + str;
|
||||
if (appended)
|
||||
{
|
||||
texCall += ", ";
|
||||
}
|
||||
else {
|
||||
appended = true;
|
||||
}
|
||||
texCall += str;
|
||||
}
|
||||
|
||||
AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32;
|
||||
|
||||
string AssemblePVector(int count)
|
||||
{
|
||||
string coords;
|
||||
if (count > 1)
|
||||
{
|
||||
string[] elems = new string[count];
|
||||
|
@ -213,14 +266,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
elems[index] = Src(coordType);
|
||||
}
|
||||
|
||||
string prefix = intCoords ? "int" : "float";
|
||||
|
||||
return prefix + count + "(" + string.Join(", ", elems) + ")";
|
||||
coords = string.Join(", ", elems);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Src(coordType);
|
||||
coords = Src(coordType);
|
||||
}
|
||||
|
||||
string prefix = intCoords ? "uint" : "float";
|
||||
|
||||
return prefix + (count > 1 ? count : "") + "(" + coords + ")";
|
||||
}
|
||||
|
||||
Append(AssemblePVector(pCount));
|
||||
|
@ -254,6 +309,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
|
||||
private static string GetMaskMultiDest(int mask)
|
||||
{
|
||||
if (mask == 0x0)
|
||||
{
|
||||
return "";
|
||||
}
|
||||
|
||||
string swizzle = ".";
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
|
||||
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
|
||||
// gl_VertexIndex does not have a direct equivalent in MSL
|
||||
IoVariable.VertexIndex => ("vertex_index", AggregateType.U32),
|
||||
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
|
||||
IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32),
|
||||
IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32),
|
||||
_ => (null, AggregateType.Invalid),
|
||||
|
|
|
@ -48,6 +48,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
|
||||
PrintBlock(context, function.MainBlock, isMainFunc);
|
||||
|
||||
// In case the shader hasn't returned, return
|
||||
if (isMainFunc && stage != ShaderStage.Compute)
|
||||
{
|
||||
context.AppendLine("return out;");
|
||||
}
|
||||
|
||||
context.LeaveScope();
|
||||
}
|
||||
|
||||
|
@ -57,11 +63,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
ShaderStage stage,
|
||||
bool isMainFunc = false)
|
||||
{
|
||||
string[] args = new string[function.InArguments.Length + function.OutArguments.Length];
|
||||
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.additionalArgCount;
|
||||
|
||||
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
|
||||
|
||||
// All non-main functions need to be able to access the support_buffer as well
|
||||
if (!isMainFunc)
|
||||
{
|
||||
args[0] = "constant Struct_support_buffer* support_buffer";
|
||||
}
|
||||
|
||||
int argIndex = additionalArgCount;
|
||||
for (int i = 0; i < function.InArguments.Length; i++)
|
||||
{
|
||||
args[i] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
|
||||
args[argIndex++] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
|
||||
}
|
||||
|
||||
for (int i = 0; i < function.OutArguments.Length; i++)
|
||||
|
@ -69,7 +84,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
int j = i + function.InArguments.Length;
|
||||
|
||||
// Likely need to be made into pointers
|
||||
args[j] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
|
||||
args[argIndex++] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
|
||||
}
|
||||
|
||||
string funcKeyword = "inline";
|
||||
|
@ -97,20 +112,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
returnType = "void";
|
||||
}
|
||||
|
||||
if (context.AttributeUsage.UsedInputAttributes != 0)
|
||||
if (stage == ShaderStage.Vertex)
|
||||
{
|
||||
if (stage == ShaderStage.Vertex)
|
||||
if (context.AttributeUsage.UsedInputAttributes != 0)
|
||||
{
|
||||
args = args.Prepend("VertexIn in [[stage_in]]").ToArray();
|
||||
}
|
||||
else if (stage == ShaderStage.Fragment)
|
||||
{
|
||||
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
|
||||
}
|
||||
else if (stage == ShaderStage.Compute)
|
||||
{
|
||||
args = args.Prepend("KernelIn in [[stage_in]]").ToArray();
|
||||
}
|
||||
}
|
||||
else if (stage == ShaderStage.Fragment)
|
||||
{
|
||||
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
|
||||
}
|
||||
|
||||
// TODO: add these only if they are used
|
||||
|
@ -119,18 +130,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
|
||||
args = args.Append("uint instance_id [[instance_id]]").ToArray();
|
||||
}
|
||||
else if (stage == ShaderStage.Compute)
|
||||
{
|
||||
args = args.Append("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]]").ToArray();
|
||||
args = args.Append("uint3 thread_position_in_grid [[thread_position_in_grid]]").ToArray();
|
||||
args = args.Append("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]").ToArray();
|
||||
}
|
||||
|
||||
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
|
||||
{
|
||||
var varType = constantBuffer.Type.Fields[0].Type & ~AggregateType.Array;
|
||||
args = args.Append($"constant {Declarations.GetVarTypeName(context, varType)} *{constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
|
||||
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
|
||||
}
|
||||
|
||||
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
|
||||
{
|
||||
var varType = storageBuffers.Type.Fields[0].Type & ~AggregateType.Array;
|
||||
// Offset the binding by 15 to avoid clashing with the constant buffers
|
||||
args = args.Append($"device {Declarations.GetVarTypeName(context, varType)} *{storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
|
||||
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
|
||||
}
|
||||
|
||||
foreach (var texture in context.Properties.Textures.Values)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue