spirv: Add lower fp16 to fp32 pass

This commit is contained in:
ReinUsesLisp 2021-02-19 18:10:18 -03:00 committed by ameerj
parent 85cce78583
commit 6db69990da
32 changed files with 479 additions and 285 deletions

View file

@ -4,8 +4,8 @@
#pragma once
#include <string>
#include <compare>
#include <string>
#include <fmt/format.h>

View file

@ -547,11 +547,11 @@ F32 IREmitter::FPSqrt(const F32& value) {
F16F32F64 IREmitter::FPSaturate(const F16F32F64& value) {
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<F16>(Opcode::FPSaturate16, value);
case Type::U32:
case Type::F32:
return Inst<F32>(Opcode::FPSaturate32, value);
case Type::U64:
case Type::F64:
return Inst<F64>(Opcode::FPSaturate64, value);
default:
ThrowInvalidType(value.Type());
@ -560,11 +560,11 @@ F16F32F64 IREmitter::FPSaturate(const F16F32F64& value) {
F16F32F64 IREmitter::FPRoundEven(const F16F32F64& value) {
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<F16>(Opcode::FPRoundEven16, value);
case Type::U32:
case Type::F32:
return Inst<F32>(Opcode::FPRoundEven32, value);
case Type::U64:
case Type::F64:
return Inst<F64>(Opcode::FPRoundEven64, value);
default:
ThrowInvalidType(value.Type());
@ -573,11 +573,11 @@ F16F32F64 IREmitter::FPRoundEven(const F16F32F64& value) {
F16F32F64 IREmitter::FPFloor(const F16F32F64& value) {
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<F16>(Opcode::FPFloor16, value);
case Type::U32:
case Type::F32:
return Inst<F32>(Opcode::FPFloor32, value);
case Type::U64:
case Type::F64:
return Inst<F64>(Opcode::FPFloor64, value);
default:
ThrowInvalidType(value.Type());
@ -586,11 +586,11 @@ F16F32F64 IREmitter::FPFloor(const F16F32F64& value) {
F16F32F64 IREmitter::FPCeil(const F16F32F64& value) {
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<F16>(Opcode::FPCeil16, value);
case Type::U32:
case Type::F32:
return Inst<F32>(Opcode::FPCeil32, value);
case Type::U64:
case Type::F64:
return Inst<F64>(Opcode::FPCeil64, value);
default:
ThrowInvalidType(value.Type());
@ -599,11 +599,11 @@ F16F32F64 IREmitter::FPCeil(const F16F32F64& value) {
F16F32F64 IREmitter::FPTrunc(const F16F32F64& value) {
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<F16>(Opcode::FPTrunc16, value);
case Type::U32:
case Type::F32:
return Inst<F32>(Opcode::FPTrunc32, value);
case Type::U64:
case Type::F64:
return Inst<F64>(Opcode::FPTrunc64, value);
default:
ThrowInvalidType(value.Type());
@ -729,33 +729,33 @@ U32U64 IREmitter::ConvertFToS(size_t bitsize, const F16F32F64& value) {
switch (bitsize) {
case 16:
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<U32>(Opcode::ConvertS16F16, value);
case Type::U32:
case Type::F32:
return Inst<U32>(Opcode::ConvertS16F32, value);
case Type::U64:
case Type::F64:
return Inst<U32>(Opcode::ConvertS16F64, value);
default:
ThrowInvalidType(value.Type());
}
case 32:
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<U32>(Opcode::ConvertS32F16, value);
case Type::U32:
case Type::F32:
return Inst<U32>(Opcode::ConvertS32F32, value);
case Type::U64:
case Type::F64:
return Inst<U32>(Opcode::ConvertS32F64, value);
default:
ThrowInvalidType(value.Type());
}
case 64:
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<U64>(Opcode::ConvertS64F16, value);
case Type::U32:
case Type::F32:
return Inst<U64>(Opcode::ConvertS64F32, value);
case Type::U64:
case Type::F64:
return Inst<U64>(Opcode::ConvertS64F64, value);
default:
ThrowInvalidType(value.Type());
@ -769,33 +769,33 @@ U32U64 IREmitter::ConvertFToU(size_t bitsize, const F16F32F64& value) {
switch (bitsize) {
case 16:
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<U32>(Opcode::ConvertU16F16, value);
case Type::U32:
case Type::F32:
return Inst<U32>(Opcode::ConvertU16F32, value);
case Type::U64:
case Type::F64:
return Inst<U32>(Opcode::ConvertU16F64, value);
default:
ThrowInvalidType(value.Type());
}
case 32:
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<U32>(Opcode::ConvertU32F16, value);
case Type::U32:
case Type::F32:
return Inst<U32>(Opcode::ConvertU32F32, value);
case Type::U64:
case Type::F64:
return Inst<U32>(Opcode::ConvertU32F64, value);
default:
ThrowInvalidType(value.Type());
}
case 64:
switch (value.Type()) {
case Type::U16:
case Type::F16:
return Inst<U64>(Opcode::ConvertU64F16, value);
case Type::U32:
case Type::F32:
return Inst<U64>(Opcode::ConvertU64F32, value);
case Type::U64:
case Type::F64:
return Inst<U64>(Opcode::ConvertU64F64, value);
default:
ThrowInvalidType(value.Type());
@ -829,10 +829,10 @@ U32U64 IREmitter::ConvertU(size_t result_bitsize, const U32U64& value) {
case 64:
switch (value.Type()) {
case Type::U32:
return Inst<U64>(Opcode::ConvertU64U32, value);
case Type::U64:
// Nothing to do
return value;
case Type::U64:
return Inst<U64>(Opcode::ConvertU64U32, value);
default:
break;
}

View file

@ -216,6 +216,10 @@ void Inst::ReplaceUsesWith(Value replacement) {
}
}
void Inst::ReplaceOpcode(IR::Opcode opcode) {
op = opcode;
}
void Inst::Use(const Value& value) {
Inst* const inst{value.Inst()};
++inst->use_count;

View file

@ -86,6 +86,8 @@ public:
void ReplaceUsesWith(Value replacement);
void ReplaceOpcode(IR::Opcode opcode);
template <typename FlagsType>
requires(sizeof(FlagsType) <= sizeof(u32) && std::is_trivially_copyable_v<FlagsType>)
[[nodiscard]] FlagsType Flags() const noexcept {

View file

@ -119,8 +119,10 @@ OPCODE(PackUint2x32, U64, U32x
OPCODE(UnpackUint2x32, U32x2, U64, )
OPCODE(PackFloat2x16, U32, F16x2, )
OPCODE(UnpackFloat2x16, F16x2, U32, )
OPCODE(PackDouble2x32, U64, U32x2, )
OPCODE(UnpackDouble2x32, U32x2, U64, )
OPCODE(PackHalf2x16, U32, F32x2, )
OPCODE(UnpackHalf2x16, F32x2, U32, )
OPCODE(PackDouble2x32, F64, U32x2, )
OPCODE(UnpackDouble2x32, U32x2, F64, )
// Pseudo-operation, handled specially at final emit
OPCODE(GetZeroFromOp, U1, Opaque, )

View file

@ -35,4 +35,4 @@ std::string DumpProgram(const Program& program) {
return ret;
}
} // namespace Shader::IR
} // namespace Shader::IR