diff --git a/src/core/cpu_patches.cpp b/src/core/cpu_patches.cpp index 76a65348b..1b159d32b 100644 --- a/src/core/cpu_patches.cpp +++ b/src/core/cpu_patches.cpp @@ -658,26 +658,15 @@ static PatchModule* GetModule(const void* ptr) { return &(std::prev(upper_bound)->second); } -static bool TryPatch(void* code_address) { - auto* code = static_cast(code_address); - auto* module = GetModule(code); - if (module == nullptr) { - return false; - } - - std::unique_lock lock{module->mutex}; - - // Return early if already patched, in case multiple threads signaled at the same time. - if (std::ranges::find(module->patched, code) != module->patched.end()) { - return true; - } - +/// Returns a boolean indicating whether the instruction was patched, and the offset to advance past +/// whatever is at the current code pointer. +static std::pair TryPatch(u8* code, PatchModule* module) { ZydisDecodedInstruction instruction; ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT]; const auto status = ZydisDecoderDecodeFull(&instr_decoder, code, module->end - code, &instruction, operands); if (!ZYAN_SUCCESS(status)) { - return false; + return std::make_pair(false, 1); } if (Patches.contains(instruction.mnemonic)) { @@ -717,20 +706,52 @@ static bool TryPatch(void* code_address) { module->patched.insert(code); LOG_DEBUG(Core, "Patched instruction '{}' at: {}", ZydisMnemonicGetString(instruction.mnemonic), fmt::ptr(code)); - return true; + return std::make_pair(true, instruction.length); } } } - return false; + return std::make_pair(false, instruction.length); +} + +static bool TryPatchJit(void* code_address) { + auto* code = static_cast(code_address); + auto* module = GetModule(code); + if (module == nullptr) { + return false; + } + + std::unique_lock lock{module->mutex}; + + // Return early if already patched, in case multiple threads signaled at the same time. + if (std::ranges::find(module->patched, code) != module->patched.end()) { + return true; + } + + return TryPatch(code, module).first; +} + +static void TryPatchAot(void* code_address, u64 code_size) { + auto* code = static_cast(code_address); + auto* module = GetModule(code); + if (module == nullptr) { + return; + } + + std::unique_lock lock{module->mutex}; + + const auto* end = code + code_size; + while (code < end) { + code += TryPatch(code, module).second; + } } static bool PatchesAccessViolationHandler(void* code_address, void* fault_address, bool is_write) { - return TryPatch(code_address); + return TryPatchJit(code_address); } static bool PatchesIllegalInstructionHandler(void* code_address) { - return TryPatch(code_address); + return TryPatchJit(code_address); } static void PatchesInit() { @@ -757,17 +778,23 @@ void RegisterPatchModule(void* module_ptr, u64 module_size, void* trampoline_are } void PrePatchInstructions(u64 segment_addr, u64 segment_size) { -#ifdef __APPLE__ +#if defined(__APPLE__) // HACK: For some reason patching in the signal handler at the start of a page does not work // under Rosetta 2. Patch any instructions at the start of a page ahead of time. if (!Patches.empty()) { auto* code_page = reinterpret_cast(Common::AlignUp(segment_addr, 0x1000)); const auto* end_page = code_page + Common::AlignUp(segment_size, 0x1000); while (code_page < end_page) { - TryPatch(code_page); + TryPatchJit(code_page); code_page += 0x1000; } } +#elif !defined(_WIN32) + // Linux and others have an FS segment pointing to valid memory, so continue to do full + // ahead-of-time patching for now until a better solution is worked out. + if (!Patches.empty()) { + TryPatchAot(reinterpret_cast(segment_addr), segment_size); + } #endif }