diff --git a/src/core/signals.h b/src/core/signals.h index 6ee525e10..0409b73ae 100644 --- a/src/core/signals.h +++ b/src/core/signals.h @@ -5,6 +5,7 @@ #include #include "common/singleton.h" +#include "common/types.h" namespace Core { diff --git a/src/shader_recompiler/ir/passes/flatten_extended_userdata_pass.cpp b/src/shader_recompiler/ir/passes/flatten_extended_userdata_pass.cpp index bbf3fe8fb..7253e18c1 100644 --- a/src/shader_recompiler/ir/passes/flatten_extended_userdata_pass.cpp +++ b/src/shader_recompiler/ir/passes/flatten_extended_userdata_pass.cpp @@ -10,6 +10,8 @@ #include "common/io_file.h" #include "common/logging/log.h" #include "common/path_util.h" +#include "common/signal_context.h" +#include "core/signals.h" #include "shader_recompiler/info.h" #include "shader_recompiler/ir/breadth_first_search.h" #include "shader_recompiler/ir/opcodes.h" @@ -24,6 +26,7 @@ using namespace Xbyak::util; static Xbyak::CodeGenerator g_srt_codegen(32_MB); +static const u8* g_srt_codegen_start = nullptr; namespace { @@ -54,6 +57,57 @@ static void DumpSrtProgram(const Shader::Info& info, const u8* code, size_t code #endif } +static bool SrtWalkerSignalHandler(void* context, void* fault_address) { + // Only handle if the fault address is within the SRT code range + const u8* code_start = g_srt_codegen_start; + const u8* code_end = code_start + g_srt_codegen.getSize(); + const void* code = Common::GetRip(context); + if (code < code_start || code >= code_end) { + return false; // Not in SRT code range + } + + // Patch instruction to zero register + ZydisDecodedInstruction instruction; + ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT]; + ZyanStatus status = Common::Decoder::Instance()->decodeInstruction(instruction, operands, + const_cast(code), 15); + + ASSERT(ZYAN_SUCCESS(status) && instruction.mnemonic == ZYDIS_MNEMONIC_MOV && + operands[0].type == ZYDIS_OPERAND_TYPE_REGISTER && + operands[1].type == ZYDIS_OPERAND_TYPE_MEMORY); + + size_t len = instruction.length; + const size_t patch_size = 3; + u8* code_patch = const_cast(reinterpret_cast(code)); + + // We can only encounter rdi or r10d as the first operand in a + // fault memory access for SRT walker. + switch (operands[0].reg.value) { + case ZYDIS_REGISTER_RDI: + // mov rdi, [rdi + (off_dw << 2)] -> xor rdi, rdi + code_patch[0] = 0x48; + code_patch[1] = 0x31; + code_patch[2] = 0xFF; + break; + case ZYDIS_REGISTER_R10D: + // mov r10d, [rdi + (off_dw << 2)] -> xor r10d, r10d + code_patch[0] = 0x45; + code_patch[1] = 0x31; + code_patch[2] = 0xD2; + break; + default: + UNREACHABLE_MSG("Unsupported register for SRT walker patch"); + return false; + } + + // Fill nops + memset(code_patch + patch_size, 0x90, len - patch_size); + + LOG_DEBUG(Render_Recompiler, "Patched SRT walker at {}", code); + + return true; +} + using namespace Shader; struct PassInfo { @@ -141,6 +195,15 @@ static void GenerateSrtProgram(Info& info, PassInfo& pass_info) { return; } + // Register the signal handler for SRT walker, if not already registered + if (g_srt_codegen_start == nullptr) { + g_srt_codegen_start = c.getCurr(); + auto* signals = Core::Signals::Instance(); + // Call after the memory invalidation handler + constexpr u32 priority = 1; + signals->RegisterAccessViolationHandler(SrtWalkerSignalHandler, priority); + } + info.srt_info.walker_func = c.getCurr(); pass_info.dst_off_dw = NumUserDataRegs;