Reduced amount of switch cases during instruction decoding

This commit is contained in:
ivan999 2025-06-16 02:59:58 +03:00
parent 213ca72fa1
commit 2a1a1045ec
3 changed files with 173 additions and 282 deletions

View file

@ -16,76 +16,6 @@ T extract(T value, u32 lst, u32 fst) {
}
} // namespace bit
InstEncoding GetInstructionEncoding(u32 token) {
auto encoding = static_cast<InstEncoding>(token & (u32)EncodingMask::MASK_9bit);
switch (encoding) {
case InstEncoding::SOP1:
case InstEncoding::SOPP:
case InstEncoding::SOPC:
return encoding;
default:
break;
}
encoding = static_cast<InstEncoding>(token & (u32)EncodingMask::MASK_7bit);
switch (encoding) {
case InstEncoding::VOP1:
case InstEncoding::VOPC:
return encoding;
default:
break;
}
encoding = static_cast<InstEncoding>(token & (u32)EncodingMask::MASK_6bit);
switch (encoding) {
case InstEncoding::VOP3:
case InstEncoding::EXP:
case InstEncoding::VINTRP:
case InstEncoding::DS:
case InstEncoding::MUBUF:
case InstEncoding::MTBUF:
case InstEncoding::MIMG:
return encoding;
default:
break;
}
encoding = static_cast<InstEncoding>(token & (u32)EncodingMask::MASK_5bit);
switch (encoding) {
case InstEncoding::SMRD:
return encoding;
default:
break;
}
encoding = static_cast<InstEncoding>(token & (u32)EncodingMask::MASK_4bit);
switch (encoding) {
case InstEncoding::SOPK:
return encoding;
default:
break;
}
encoding = static_cast<InstEncoding>(token & (u32)EncodingMask::MASK_2bit);
switch (encoding) {
case InstEncoding::SOP2:
return encoding;
default:
break;
}
encoding = static_cast<InstEncoding>(token & (u32)EncodingMask::MASK_1bit);
switch (encoding) {
case InstEncoding::VOP2:
return encoding;
default:
break;
}
UNREACHABLE();
return InstEncoding::ILLEGAL;
}
bool HasAdditionalLiteral(InstEncoding encoding, Opcode opcode) {
switch (encoding) {
case InstEncoding::SOPK: {
@ -107,128 +37,149 @@ bool IsVop3BEncoding(Opcode opcode) {
opcode == Opcode::V_MAD_U64_U32 || opcode == Opcode::V_MAD_I64_I32;
}
inline void GcnDecodeContext::decodeInstruction32(
void (GcnDecodeContext::*decodeFunc)(u32), OpcodeMap opcodeMap, GcnCodeSlice& code
) {
u32 instruction = code.readu32();
// Decode instruction using the provided decode function.
(this->*decodeFunc)(instruction);
// Update instruction meta info.
updateInstructionMeta(opcodeMap, sizeof(uint32_t));
// Detect literal constant. Only 32 bits instructions may have literal constant.
// Note: Literal constant decode must be performed after meta info updated.
decodeLiteralConstant(opcodeMap, code);
}
inline void GcnDecodeContext::decodeInstruction64(
void (GcnDecodeContext::*decodeFunc)(uint64_t), OpcodeMap opcodeMap, GcnCodeSlice& code
) {
uint64_t instruction = code.readu64();
// Decode instruction using the provided decode function.
(this->*decodeFunc)(instruction);
// Update instruction meta info.
updateInstructionMeta(opcodeMap, sizeof(uint64_t));
}
inline void GcnDecodeContext::decodeInstructionFromMask9bit(GcnCodeSlice& code) {
m_instruction.encoding = static_cast<InstEncoding>(code.at(0) & (u32)EncodingMask::MASK_9bit);
switch (m_instruction.encoding) {
case InstEncoding::SOP1:
decodeInstruction32(&GcnDecodeContext::decodeInstructionSOP1, OpcodeMap::OP_MAP_SOP1, code);
break;
case InstEncoding::SOPP:
decodeInstruction32(&GcnDecodeContext::decodeInstructionSOPP, OpcodeMap::OP_MAP_SOPP, code);
break;
case InstEncoding::SOPC:
decodeInstruction32(&GcnDecodeContext::decodeInstructionSOPC, OpcodeMap::OP_MAP_SOPC, code);
break;
default:
decodeInstructionFromMask7bit(code);
}
}
inline void GcnDecodeContext::decodeInstructionFromMask7bit(GcnCodeSlice& code) {
m_instruction.encoding = static_cast<InstEncoding>(code.at(0) & (u32)EncodingMask::MASK_7bit);
switch (m_instruction.encoding) {
case InstEncoding::VOP1:
decodeInstruction32(&GcnDecodeContext::decodeInstructionVOP1, OpcodeMap::OP_MAP_VOP1, code);
break;
case InstEncoding::VOPC:
decodeInstruction32(&GcnDecodeContext::decodeInstructionVOPC, OpcodeMap::OP_MAP_VOPC, code);
break;
default:
decodeInstructionFromMask6bit(code);
}
}
inline void GcnDecodeContext::decodeInstructionFromMask6bit(GcnCodeSlice& code) {
m_instruction.encoding = static_cast<InstEncoding>(code.at(0) & (u32)EncodingMask::MASK_6bit);
switch (m_instruction.encoding) {
case InstEncoding::VINTRP:
decodeInstruction32(&GcnDecodeContext::decodeInstructionVINTRP, OpcodeMap::OP_MAP_VINTRP, code);
break;
case InstEncoding::VOP3:
decodeInstruction64(&GcnDecodeContext::decodeInstructionVOP3, OpcodeMap::OP_MAP_VOP3, code);
break;
case InstEncoding::EXP:
decodeInstruction64(&GcnDecodeContext::decodeInstructionEXP, OpcodeMap::OP_MAP_EXP, code);
break;
case InstEncoding::DS:
decodeInstruction64(&GcnDecodeContext::decodeInstructionDS, OpcodeMap::OP_MAP_DS, code);
break;
case InstEncoding::MUBUF:
decodeInstruction64(&GcnDecodeContext::decodeInstructionMUBUF, OpcodeMap::OP_MAP_MUBUF, code);
break;
case InstEncoding::MTBUF:
decodeInstruction64(&GcnDecodeContext::decodeInstructionMTBUF, OpcodeMap::OP_MAP_MTBUF, code);
break;
case InstEncoding::MIMG:
decodeInstruction64(&GcnDecodeContext::decodeInstructionMIMG, OpcodeMap::OP_MAP_MIMG, code);
break;
default:
decodeInstructionFromMask5bit(code);
}
}
inline void GcnDecodeContext::decodeInstructionFromMask5bit(GcnCodeSlice& code) {
m_instruction.encoding = static_cast<InstEncoding>(code.at(0) & (u32)EncodingMask::MASK_5bit);
switch (m_instruction.encoding) {
case InstEncoding::SMRD:
decodeInstruction32(&GcnDecodeContext::decodeInstructionSMRD, OpcodeMap::OP_MAP_SMRD, code);
break;
default:
decodeInstructionFromMask4bit(code);
}
}
inline void GcnDecodeContext::decodeInstructionFromMask4bit(GcnCodeSlice& code) {
m_instruction.encoding = static_cast<InstEncoding>(code.at(0) & (u32)EncodingMask::MASK_4bit);
switch (m_instruction.encoding) {
case InstEncoding::SOPK:
decodeInstruction32(&GcnDecodeContext::decodeInstructionSOPK, OpcodeMap::OP_MAP_SOPK, code);
break;
default:
decodeInstructionFromMask2bit(code);
}
}
inline void GcnDecodeContext::decodeInstructionFromMask2bit(GcnCodeSlice& code) {
m_instruction.encoding = static_cast<InstEncoding>(code.at(0) & (u32)EncodingMask::MASK_2bit);
switch (m_instruction.encoding) {
case InstEncoding::SOP2:
decodeInstruction32(&GcnDecodeContext::decodeInstructionSOP2, OpcodeMap::OP_MAP_SOP2, code);
break;
default:
decodeInstructionFromMask1bit(code);
}
}
inline void GcnDecodeContext::decodeInstructionFromMask1bit(GcnCodeSlice& code) {
m_instruction.encoding = static_cast<InstEncoding>(code.at(0) & (u32)EncodingMask::MASK_1bit);
switch (m_instruction.encoding) {
case InstEncoding::VOP2:
decodeInstruction32(&GcnDecodeContext::decodeInstructionVOP2, OpcodeMap::OP_MAP_VOP2, code);
break;
default:
UNREACHABLE();
ASSERT_MSG("illegal encoding");
}
}
GcnInst GcnDecodeContext::decodeInstruction(GcnCodeSlice& code) {
const uint32_t token = code.at(0);
InstEncoding encoding = GetInstructionEncoding(token);
ASSERT_MSG(encoding != InstEncoding::ILLEGAL, "illegal encoding");
uint32_t encodingLen = getEncodingLength(encoding);
// Clear the instruction
m_instruction = GcnInst();
// Decode
if (encodingLen == sizeof(uint32_t)) {
decodeInstruction32(encoding, code);
} else {
decodeInstruction64(encoding, code);
}
// Update instruction meta info.
updateInstructionMeta(encoding);
// Detect literal constant. Only 32 bits instructions may have literal constant.
// Note: Literal constant decode must be performed after meta info updated.
if (encodingLen == sizeof(u32)) {
decodeLiteralConstant(encoding, code);
}
decodeInstructionFromMask9bit(code);
repairOperandType();
return m_instruction;
}
uint32_t GcnDecodeContext::getEncodingLength(InstEncoding encoding) {
uint32_t instLength = 0;
switch (encoding) {
case InstEncoding::SOP1:
case InstEncoding::SOPP:
case InstEncoding::SOPC:
case InstEncoding::SOPK:
case InstEncoding::SOP2:
case InstEncoding::VOP1:
case InstEncoding::VOPC:
case InstEncoding::VOP2:
case InstEncoding::SMRD:
case InstEncoding::VINTRP:
instLength = sizeof(uint32_t);
break;
case InstEncoding::VOP3:
case InstEncoding::MUBUF:
case InstEncoding::MTBUF:
case InstEncoding::MIMG:
case InstEncoding::DS:
case InstEncoding::EXP:
instLength = sizeof(uint64_t);
break;
default:
break;
}
return instLength;
}
uint32_t GcnDecodeContext::getOpMapOffset(InstEncoding encoding) {
uint32_t offset = 0;
switch (encoding) {
case InstEncoding::SOP1:
offset = (uint32_t)OpcodeMap::OP_MAP_SOP1;
break;
case InstEncoding::SOPP:
offset = (uint32_t)OpcodeMap::OP_MAP_SOPP;
break;
case InstEncoding::SOPC:
offset = (uint32_t)OpcodeMap::OP_MAP_SOPC;
break;
case InstEncoding::VOP1:
offset = (uint32_t)OpcodeMap::OP_MAP_VOP1;
break;
case InstEncoding::VOPC:
offset = (uint32_t)OpcodeMap::OP_MAP_VOPC;
break;
case InstEncoding::VOP3:
offset = (uint32_t)OpcodeMap::OP_MAP_VOP3;
break;
case InstEncoding::EXP:
offset = (uint32_t)OpcodeMap::OP_MAP_EXP;
break;
case InstEncoding::VINTRP:
offset = (uint32_t)OpcodeMap::OP_MAP_VINTRP;
break;
case InstEncoding::DS:
offset = (uint32_t)OpcodeMap::OP_MAP_DS;
break;
case InstEncoding::MUBUF:
offset = (uint32_t)OpcodeMap::OP_MAP_MUBUF;
break;
case InstEncoding::MTBUF:
offset = (uint32_t)OpcodeMap::OP_MAP_MTBUF;
break;
case InstEncoding::MIMG:
offset = (uint32_t)OpcodeMap::OP_MAP_MIMG;
break;
case InstEncoding::SMRD:
offset = (uint32_t)OpcodeMap::OP_MAP_SMRD;
break;
case InstEncoding::SOPK:
offset = (uint32_t)OpcodeMap::OP_MAP_SOPK;
break;
case InstEncoding::SOP2:
offset = (uint32_t)OpcodeMap::OP_MAP_SOP2;
break;
case InstEncoding::VOP2:
offset = (uint32_t)OpcodeMap::OP_MAP_VOP2;
break;
default:
break;
}
return offset;
}
uint32_t GcnDecodeContext::mapEncodingOp(InstEncoding encoding, Opcode opcode) {
uint32_t GcnDecodeContext::mapEncodingOp(OpcodeMap opcodeMap, Opcode opcode) {
// Map from uniform opcode to encoding specific opcode.
uint32_t encodingOp = 0;
if (encoding == InstEncoding::VOP3) {
if (m_instruction.encoding == InstEncoding::VOP3) {
if (opcode >= Opcode::V_CMP_F_F32 && opcode <= Opcode::V_CMPX_T_U64) {
uint32_t op =
static_cast<uint32_t>(opcode) - static_cast<uint32_t>(OpcodeMap::OP_MAP_VOPC);
@ -246,28 +197,27 @@ uint32_t GcnDecodeContext::mapEncodingOp(InstEncoding encoding, Opcode opcode) {
static_cast<uint32_t>(opcode) - static_cast<uint32_t>(OpcodeMap::OP_MAP_VOP3);
}
} else {
uint32_t mapOffset = getOpMapOffset(encoding);
uint32_t mapOffset = static_cast<uint32_t>(opcodeMap);
encodingOp = static_cast<uint32_t>(opcode) - mapOffset;
}
return encodingOp;
}
void GcnDecodeContext::updateInstructionMeta(InstEncoding encoding) {
uint32_t encodingOp = mapEncodingOp(encoding, m_instruction.opcode);
InstFormat instFormat = InstructionFormat(encoding, encodingOp);
void GcnDecodeContext::updateInstructionMeta(OpcodeMap opcodeMap, uint32_t encodingLength) {
uint32_t encodingOp = mapEncodingOp(opcodeMap, m_instruction.opcode);
InstFormat instFormat = InstructionFormat(m_instruction.encoding, encodingOp);
ASSERT_MSG(instFormat.src_type != ScalarType::Undefined &&
instFormat.dst_type != ScalarType::Undefined,
"Instruction format table incomplete for opcode {} ({}, encoding = 0x{:x})",
magic_enum::enum_name(m_instruction.opcode), u32(m_instruction.opcode),
u32(encoding));
u32(m_instruction.encoding));
m_instruction.inst_class = instFormat.inst_class;
m_instruction.category = instFormat.inst_category;
m_instruction.encoding = encoding;
m_instruction.src_count = instFormat.src_count;
m_instruction.length = getEncodingLength(encoding);
m_instruction.length = encodingLength;
// Update src operand scalar type.
auto setOperandType = [&instFormat](InstOperand& src) {
@ -337,74 +287,10 @@ OperandField GcnDecodeContext::getOperandField(uint32_t code) {
return field;
}
void GcnDecodeContext::decodeInstruction32(InstEncoding encoding, GcnCodeSlice& code) {
u32 hexInstruction = code.readu32();
switch (encoding) {
case InstEncoding::SOP1:
decodeInstructionSOP1(hexInstruction);
break;
case InstEncoding::SOPP:
decodeInstructionSOPP(hexInstruction);
break;
case InstEncoding::SOPC:
decodeInstructionSOPC(hexInstruction);
break;
case InstEncoding::SOPK:
decodeInstructionSOPK(hexInstruction);
break;
case InstEncoding::SOP2:
decodeInstructionSOP2(hexInstruction);
break;
case InstEncoding::VOP1:
decodeInstructionVOP1(hexInstruction);
break;
case InstEncoding::VOPC:
decodeInstructionVOPC(hexInstruction);
break;
case InstEncoding::VOP2:
decodeInstructionVOP2(hexInstruction);
break;
case InstEncoding::SMRD:
decodeInstructionSMRD(hexInstruction);
break;
case InstEncoding::VINTRP:
decodeInstructionVINTRP(hexInstruction);
break;
default:
break;
}
}
void GcnDecodeContext::decodeInstruction64(InstEncoding encoding, GcnCodeSlice& code) {
uint64_t hexInstruction = code.readu64();
switch (encoding) {
case InstEncoding::VOP3:
decodeInstructionVOP3(hexInstruction);
break;
case InstEncoding::MUBUF:
decodeInstructionMUBUF(hexInstruction);
break;
case InstEncoding::MTBUF:
decodeInstructionMTBUF(hexInstruction);
break;
case InstEncoding::MIMG:
decodeInstructionMIMG(hexInstruction);
break;
case InstEncoding::DS:
decodeInstructionDS(hexInstruction);
break;
case InstEncoding::EXP:
decodeInstructionEXP(hexInstruction);
break;
default:
break;
}
}
void GcnDecodeContext::decodeLiteralConstant(InstEncoding encoding, GcnCodeSlice& code) {
if (HasAdditionalLiteral(encoding, m_instruction.opcode)) {
u32 encoding_op = mapEncodingOp(encoding, m_instruction.opcode);
InstFormat instFormat = InstructionFormat(encoding, encoding_op);
void GcnDecodeContext::decodeLiteralConstant(OpcodeMap opcodeMap, GcnCodeSlice& code) {
if (HasAdditionalLiteral(m_instruction.encoding, m_instruction.opcode)) {
u32 encoding_op = mapEncodingOp(opcodeMap, m_instruction.opcode);
InstFormat instFormat = InstructionFormat(m_instruction.encoding, encoding_op);
m_instruction.src[m_instruction.src_count].field = OperandField::LiteralConst;
m_instruction.src[m_instruction.src_count].type = instFormat.src_type;
m_instruction.src[m_instruction.src_count].code = code.readu32();

View file

@ -16,13 +16,7 @@ struct InstFormat {
ScalarType dst_type = ScalarType::Undefined;
};
InstEncoding GetInstructionEncoding(u32 token);
u32 GetEncodingLength(InstEncoding encoding);
InstFormat InstructionFormat(InstEncoding encoding, u32 opcode);
Opcode DecodeOpcode(u32 token);
InstFormat InstructionFormat(InstEncoding encoding, uint32_t opcode);
class GcnCodeSlice {
public:
@ -58,30 +52,43 @@ public:
GcnInst decodeInstruction(GcnCodeSlice& code);
private:
uint32_t getEncodingLength(InstEncoding encoding);
uint32_t getOpMapOffset(InstEncoding encoding);
uint32_t mapEncodingOp(InstEncoding encoding, Opcode opcode);
void updateInstructionMeta(InstEncoding encoding);
uint32_t mapEncodingOp(OpcodeMap opcodeMap, Opcode opcode);
void updateInstructionMeta(OpcodeMap opcodeMap, uint32_t encodingLength);
uint32_t getMimgModifier(Opcode opcode);
void repairOperandType();
OperandField getOperandField(uint32_t code);
void decodeInstruction32(void (
GcnDecodeContext::*decodeFunc)(u32),
OpcodeMap opcodeMap, GcnCodeSlice& code
);
void decodeInstruction64(void (
GcnDecodeContext::*decodeFunc)(uint64_t),
OpcodeMap opcodeMap, GcnCodeSlice& code
);
void decodeInstruction32(InstEncoding encoding, GcnCodeSlice& code);
void decodeInstruction64(InstEncoding encoding, GcnCodeSlice& code);
void decodeLiteralConstant(InstEncoding encoding, GcnCodeSlice& code);
void decodeInstructionFromMask9bit(GcnCodeSlice& code);
void decodeInstructionFromMask7bit(GcnCodeSlice& code);
void decodeInstructionFromMask6bit(GcnCodeSlice& code);
void decodeInstructionFromMask5bit(GcnCodeSlice& code);
void decodeInstructionFromMask4bit(GcnCodeSlice& code);
void decodeInstructionFromMask2bit(GcnCodeSlice& code);
void decodeInstructionFromMask1bit(GcnCodeSlice& code);
void decodeLiteralConstant(OpcodeMap opcodeMap, GcnCodeSlice& code);
// 32 bits encodings
void decodeInstructionSOP1(uint32_t hexInstruction);
void decodeInstructionSOPP(uint32_t hexInstruction);
void decodeInstructionSOPC(uint32_t hexInstruction);
void decodeInstructionSOPK(uint32_t hexInstruction);
void decodeInstructionSOP2(uint32_t hexInstruction);
void decodeInstructionVOP1(uint32_t hexInstruction);
void decodeInstructionVOPC(uint32_t hexInstruction);
void decodeInstructionVOP2(uint32_t hexInstruction);
void decodeInstructionSMRD(uint32_t hexInstruction);
void decodeInstructionVINTRP(uint32_t hexInstruction);
void decodeInstructionSOP1(u32 hexInstruction);
void decodeInstructionSOPP(u32 hexInstruction);
void decodeInstructionSOPC(u32 hexInstruction);
void decodeInstructionSOPK(u32 hexInstruction);
void decodeInstructionSOP2(u32 hexInstruction);
void decodeInstructionVOP1(u32 hexInstruction);
void decodeInstructionVOPC(u32 hexInstruction);
void decodeInstructionVOP2(u32 hexInstruction);
void decodeInstructionSMRD(u32 hexInstruction);
void decodeInstructionVINTRP(u32 hexInstruction);
// 64 bits encodings
void decodeInstructionVOP3(uint64_t hexInstruction);
void decodeInstructionMUBUF(uint64_t hexInstruction);

View file

@ -2244,8 +2244,6 @@ enum class InstEncoding : u32 {
/// InstructionEncodingMask_1bit
/// bits [31:31] - (0)
VOP2 = 0x00000000u << 31,
ILLEGAL
};
enum class InstClass : u32 {