shader: Add subgroup masks

This commit is contained in:
ReinUsesLisp 2021-04-04 05:17:17 -03:00 committed by ameerj
parent fc93bc2abd
commit da6cf2632c
10 changed files with 169 additions and 45 deletions

View file

@ -6,10 +6,18 @@
namespace Shader::Backend::SPIRV {
namespace {
Id LargeWarpBallot(EmitContext& ctx, Id ballot) {
Id WarpExtract(EmitContext& ctx, Id value) {
const Id shift{ctx.Constant(ctx.U32[1], 5)};
const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)};
return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index);
return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index);
}
Id LoadMask(EmitContext& ctx, Id mask) {
const Id value{ctx.OpLoad(ctx.U32[4], mask)};
if (!ctx.profile.warp_size_potentially_larger_than_guest) {
return ctx.OpCompositeExtract(ctx.U32[1], value, 0U);
}
return WarpExtract(ctx, value);
}
void SetInBoundsFlag(IR::Inst* inst, Id result) {
@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) {
return ctx.OpSubgroupAllKHR(ctx.U1, pred);
}
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id active_mask{WarpExtract(ctx, mask_ballot)};
const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
return ctx.OpIEqual(ctx.U1, lhs, active_mask);
}
@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) {
return ctx.OpSubgroupAnyKHR(ctx.U1, pred);
}
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id active_mask{WarpExtract(ctx, mask_ballot)};
const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value);
}
@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) {
return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred);
}
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id active_mask{WarpExtract(ctx, mask_ballot)};
const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)};
return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value),
ctx.OpIEqual(ctx.U1, lhs, active_mask));
@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) {
if (!ctx.profile.warp_size_potentially_larger_than_guest) {
return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U);
}
return LargeWarpBallot(ctx, ballot);
return WarpExtract(ctx, ballot);
}
Id EmitSubgroupEqMask(EmitContext& ctx) {
return LoadMask(ctx, ctx.subgroup_mask_eq);
}
Id EmitSubgroupLtMask(EmitContext& ctx) {
return LoadMask(ctx, ctx.subgroup_mask_lt);
}
Id EmitSubgroupLeMask(EmitContext& ctx) {
return LoadMask(ctx, ctx.subgroup_mask_le);
}
Id EmitSubgroupGtMask(EmitContext& ctx) {
return LoadMask(ctx, ctx.subgroup_mask_gt);
}
Id EmitSubgroupGeMask(EmitContext& ctx) {
return LoadMask(ctx, ctx.subgroup_mask_ge);
}
Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,