From 23bf8bf5e7a82d6c4b1d8dfc5abcef2cde1c5593 Mon Sep 17 00:00:00 2001 From: Paris Oplopoios Date: Tue, 24 Sep 2024 17:03:32 +0300 Subject: [PATCH] Patch `insertq` (#635) * Patch `insertq` * Don't clobber flags, fix asserts a bit * Format code * Fixup some edge cases * A couple nits * Remove extraneous space --- src/core/cpu_patches.cpp | 241 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 231 insertions(+), 10 deletions(-) diff --git a/src/core/cpu_patches.cpp b/src/core/cpu_patches.cpp index 24438b6b5..202cfbb85 100644 --- a/src/core/cpu_patches.cpp +++ b/src/core/cpu_patches.cpp @@ -620,9 +620,6 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera if (immediateForm) { u8 length = operands[1].imm.value.u & 0x3F; u8 index = operands[2].imm.value.u & 0x3F; - if (length == 0) { - length = 64; - } LOG_DEBUG(Core, "Patching immediate form EXTRQ, length: {}, index: {}", length, index); @@ -635,7 +632,15 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera c.push(scratch1); c.push(scratch2); - u64 mask = (1ULL << length) - 1; + u64 mask; + if (length == 0) { + length = 64; // for the check below + mask = 0xFFFF'FFFF'FFFF'FFFF; + } else { + mask = (1ULL << length) - 1; + } + + ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64."); // Get lower qword from xmm register MAYBE_AVX(movq, scratch1, xmm_dst); @@ -676,6 +681,8 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera const Xbyak::Reg64 scratch2 = rcx; const Xbyak::Reg64 mask = rdx; + Xbyak::Label length_zero, end; + c.lea(rsp, ptr[rsp - 128]); c.pushfq(); c.push(scratch1); @@ -686,9 +693,18 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera MAYBE_AVX(movq, scratch1, xmm_src); c.mov(scratch2, scratch1); c.and_(scratch2, 0x3F); + c.jz(length_zero); + + // mask = (1ULL << length) - 1 c.mov(mask, 1); c.shl(mask, cl); c.dec(mask); + c.jmp(end); + + c.L(length_zero); + c.mov(mask, 0xFFFF'FFFF'FFFF'FFFF); + + c.L(end); // Get the shift amount and store it in scratch2 c.shr(scratch1, 8); @@ -708,6 +724,149 @@ static void GenerateEXTRQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenera } } +static void GenerateINSERTQ(const ZydisDecodedOperand* operands, Xbyak::CodeGenerator& c) { + bool immediateForm = operands[2].type == ZYDIS_OPERAND_TYPE_IMMEDIATE && + operands[3].type == ZYDIS_OPERAND_TYPE_IMMEDIATE; + + ASSERT_MSG(operands[0].type == ZYDIS_OPERAND_TYPE_REGISTER && + operands[1].type == ZYDIS_OPERAND_TYPE_REGISTER, + "operands 0 and 1 must be registers."); + + const auto dst = ZydisToXbyakRegisterOperand(operands[0]); + const auto src = ZydisToXbyakRegisterOperand(operands[1]); + + ASSERT_MSG(dst.isXMM() && src.isXMM(), "operands 0 and 1 must be xmm registers."); + + Xbyak::Xmm xmm_dst = *reinterpret_cast(&dst); + Xbyak::Xmm xmm_src = *reinterpret_cast(&src); + + if (immediateForm) { + u8 length = operands[2].imm.value.u & 0x3F; + u8 index = operands[3].imm.value.u & 0x3F; + + const Xbyak::Reg64 scratch1 = rax; + const Xbyak::Reg64 scratch2 = rcx; + const Xbyak::Reg64 mask = rdx; + + // Set rsp to before red zone and save scratch registers + c.lea(rsp, ptr[rsp - 128]); + c.pushfq(); + c.push(scratch1); + c.push(scratch2); + c.push(mask); + + u64 mask_value; + if (length == 0) { + length = 64; // for the check below + mask_value = 0xFFFF'FFFF'FFFF'FFFF; + } else { + mask_value = (1ULL << length) - 1; + } + + ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64."); + + MAYBE_AVX(movq, scratch1, xmm_src); + MAYBE_AVX(movq, scratch2, xmm_dst); + c.mov(mask, mask_value); + + // src &= mask + c.and_(scratch1, mask); + + // src <<= index + c.shl(scratch1, index); + + // dst &= ~(mask << index) + mask_value = ~(mask_value << index); + c.mov(mask, mask_value); + c.and_(scratch2, mask); + + // dst |= src + c.or_(scratch2, scratch1); + + // Insert scratch2 into low 64 bits of dst, upper 64 bits are unaffected + Cpu cpu; + if (cpu.has(Cpu::tAVX)) { + c.vpinsrq(xmm_dst, xmm_dst, scratch2, 0); + } else { + c.pinsrq(xmm_dst, scratch2, 0); + } + + c.pop(mask); + c.pop(scratch2); + c.pop(scratch1); + c.popfq(); + c.lea(rsp, ptr[rsp + 128]); + } else { + ASSERT_MSG(operands[2].type == ZYDIS_OPERAND_TYPE_UNUSED && + operands[3].type == ZYDIS_OPERAND_TYPE_UNUSED, + "operands 2 and 3 must be unused for register form."); + + const Xbyak::Reg64 scratch1 = rax; + const Xbyak::Reg64 scratch2 = rcx; + const Xbyak::Reg64 index = rdx; + const Xbyak::Reg64 mask = rbx; + + Xbyak::Label length_zero, end; + + c.lea(rsp, ptr[rsp - 128]); + c.pushfq(); + c.push(scratch1); + c.push(scratch2); + c.push(index); + c.push(mask); + + // Get upper 64 bits of src and copy it to mask and index + MAYBE_AVX(pextrq, index, xmm_src, 1); + c.mov(mask, index); + + // When length is 0, set it to 64 + c.and_(mask, 0x3F); // mask now holds the length + c.jz(length_zero); // Check if length is 0 and set mask to all 1s if it is + + // Create a mask out of the length + c.mov(cl, mask.cvt8()); + c.mov(mask, 1); + c.shl(mask, cl); + c.dec(mask); + c.jmp(end); + + c.L(length_zero); + c.mov(mask, 0xFFFF'FFFF'FFFF'FFFF); + + c.L(end); + // Get index to insert at + c.shr(index, 8); + c.and_(index, 0x3F); + + // src &= mask + MAYBE_AVX(movq, scratch1, xmm_src); + c.and_(scratch1, mask); + + // mask = ~(mask << index) + c.mov(cl, index.cvt8()); + c.shl(mask, cl); + c.not_(mask); + + // src <<= index + c.shl(scratch1, cl); + + // dst = (dst & mask) | src + MAYBE_AVX(movq, scratch2, xmm_dst); + c.and_(scratch2, mask); + c.or_(scratch2, scratch1); + + // Upper 64 bits are undefined in insertq + MAYBE_AVX(movq, xmm_dst, scratch2); + + c.pop(mask); + c.pop(index); + c.pop(scratch2); + c.pop(scratch1); + c.popfq(); + c.lea(rsp, ptr[rsp + 128]); + } +} + using PatchFilter = bool (*)(const ZydisDecodedOperand*); using InstructionGenerator = void (*)(const ZydisDecodedOperand*, Xbyak::CodeGenerator&); struct PatchInfo { @@ -730,6 +889,7 @@ static const std::unordered_map Patches = { #endif {ZYDIS_MNEMONIC_EXTRQ, {FilterNoSSE4a, GenerateEXTRQ, true}}, + {ZYDIS_MNEMONIC_INSERTQ, {FilterNoSSE4a, GenerateINSERTQ, true}}, #ifdef __APPLE__ // Patches for instruction sets not supported by Rosetta 2. @@ -859,8 +1019,8 @@ static bool TryExecuteIllegalInstruction(void* ctx, void* code_address) { bool immediateForm = operands[1].type == ZYDIS_OPERAND_TYPE_IMMEDIATE && operands[2].type == ZYDIS_OPERAND_TYPE_IMMEDIATE; if (immediateForm) { - LOG_ERROR(Core, "EXTRQ immediate form should have been patched at code address: {}", - fmt::ptr(code_address)); + LOG_CRITICAL(Core, "EXTRQ immediate form should have been patched at code address: {}", + fmt::ptr(code_address)); return false; } else { ASSERT_MSG(operands[0].type == ZYDIS_OPERAND_TYPE_REGISTER && @@ -883,12 +1043,19 @@ static bool TryExecuteIllegalInstruction(void* ctx, void* code_address) { u64 lowQWordDst; memcpy(&lowQWordDst, dst, sizeof(lowQWordDst)); - u64 mask = lowQWordSrc & 0x3F; - mask = (1ULL << mask) - 1; + u64 length = lowQWordSrc & 0x3F; + u64 mask; + if (length == 0) { + length = 64; // for the check below + mask = 0xFFFF'FFFF'FFFF'FFFF; + } else { + mask = (1ULL << length) - 1; + } - u64 shift = (lowQWordSrc >> 8) & 0x3F; + u64 index = (lowQWordSrc >> 8) & 0x3F; + ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64."); - lowQWordDst >>= shift; + lowQWordDst >>= index; lowQWordDst &= mask; memcpy(dst, &lowQWordDst, sizeof(lowQWordDst)); @@ -899,6 +1066,60 @@ static bool TryExecuteIllegalInstruction(void* ctx, void* code_address) { } break; } + case ZYDIS_MNEMONIC_INSERTQ: { + bool immediateForm = operands[2].type == ZYDIS_OPERAND_TYPE_IMMEDIATE && + operands[3].type == ZYDIS_OPERAND_TYPE_IMMEDIATE; + if (immediateForm) { + LOG_CRITICAL(Core, + "INSERTQ immediate form should have been patched at code address: {}", + fmt::ptr(code_address)); + return false; + } else { + ASSERT_MSG(operands[2].type == ZYDIS_OPERAND_TYPE_UNUSED && + operands[3].type == ZYDIS_OPERAND_TYPE_UNUSED, + "operands 2 and 3 must be unused for register form."); + + ASSERT_MSG(operands[0].type == ZYDIS_OPERAND_TYPE_REGISTER && + operands[1].type == ZYDIS_OPERAND_TYPE_REGISTER, + "operands 0 and 1 must be registers."); + + const auto dstIndex = operands[0].reg.value - ZYDIS_REGISTER_XMM0; + const auto srcIndex = operands[1].reg.value - ZYDIS_REGISTER_XMM0; + + const auto dst = Common::GetXmmPointer(ctx, dstIndex); + const auto src = Common::GetXmmPointer(ctx, srcIndex); + + u64 lowQWordSrc, highQWordSrc; + memcpy(&lowQWordSrc, src, sizeof(lowQWordSrc)); + memcpy(&highQWordSrc, (u8*)src + 8, sizeof(highQWordSrc)); + + u64 lowQWordDst; + memcpy(&lowQWordDst, dst, sizeof(lowQWordDst)); + + u64 length = highQWordSrc & 0x3F; + u64 mask; + if (length == 0) { + length = 64; // for the check below + mask = 0xFFFF'FFFF'FFFF'FFFF; + } else { + mask = (1ULL << length) - 1; + } + + u64 index = (highQWordSrc >> 8) & 0x3F; + ASSERT_MSG(length + index <= 64, "length + index must be less than or equal to 64."); + + lowQWordSrc &= mask; + lowQWordDst &= ~(mask << index); + lowQWordDst |= lowQWordSrc << index; + + memcpy(dst, &lowQWordDst, sizeof(lowQWordDst)); + + Common::IncrementRip(ctx, instruction.length); + + return true; + } + break; + } default: { LOG_ERROR(Core, "Unhandled illegal instruction at code address {}: {}", fmt::ptr(code_address), ZydisMnemonicGetString(instruction.mnemonic));