Implement switch statements and special function calls (break, syscall, trigger event) in live recompiler

This commit is contained in:
Mr-Wiseguy 2024-10-13 01:47:39 -04:00
parent ce46ef7ed1
commit 514ad596ae
8 changed files with 237 additions and 68 deletions

View file

@ -10,6 +10,8 @@
#include "sljitLir.h"
constexpr uint64_t rdram_offset = 0xFFFFFFFF80000000ULL;
void N64Recomp::live_recompiler_init() {
RabbitizerConfig_Cfg.pseudos.pseudoMove = false;
RabbitizerConfig_Cfg.pseudos.pseudoBeqz = false;
@ -19,7 +21,7 @@ void N64Recomp::live_recompiler_init() {
}
namespace Registers {
constexpr int rdram = SLJIT_S0; // stores (rdram - 0xFFFFFFFF80000000)
constexpr int rdram = SLJIT_S0; // stores (rdram - rdram_offset)
constexpr int ctx = SLJIT_S1; // stores ctx
constexpr int c1cs = SLJIT_S2; // stores ctx
constexpr int hi = SLJIT_S3; // stores ctx
@ -40,11 +42,22 @@ struct ReferenceSymbolCall {
uint16_t reference;
};
struct SwitchErrorJump {
uint32_t instr_vram;
uint32_t jtbl_vram;
sljit_jump* jump;
};
struct N64Recomp::LiveGeneratorContext {
std::string function_name;
std::unordered_map<std::string, sljit_label*> labels;
std::unordered_map<std::string, std::vector<sljit_jump*>> pending_jumps;
std::vector<sljit_label*> func_labels;
std::vector<InnerCall> inner_calls;
std::vector<std::vector<std::string>> switch_jump_labels;
// See LiveGeneratorOutput::jump_tables for info.
std::vector<void**> jump_tables;
std::vector<SwitchErrorJump> switch_error_jumps;
sljit_jump* cur_branch_jump;
};
@ -77,6 +90,47 @@ N64Recomp::LiveGeneratorOutput N64Recomp::LiveGenerator::finish() {
sljit_set_label(call.jump, target_func_label);
}
// Generate the switch error jump targets and assign the jump labels.
if (!context->switch_error_jumps.empty()) {
// Allocate the function name and place it in the literals.
char* func_name = new char[context->function_name.size() + 1];
memcpy(func_name, context->function_name.c_str(), context->function_name.size());
func_name[context->function_name.size()] = '\x00';
ret.string_literals.emplace_back(func_name);
std::vector<sljit_jump*> switch_error_return_jumps{};
switch_error_return_jumps.resize(context->switch_error_jumps.size());
// Generate and assign the labels for the switch error jumps.
for (size_t i = 0; i < context->switch_error_jumps.size(); i++) {
const auto& cur_error_jump = context->switch_error_jumps[i];
// Generate a label and assign it to the jump.
sljit_set_label(cur_error_jump.jump, sljit_emit_label(compiler));
// Load the arguments (function name, vram, jump table address)
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_IMM, sljit_sw(func_name));
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R1, 0, SLJIT_IMM, sljit_sw(cur_error_jump.instr_vram));
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R2, 0, SLJIT_IMM, sljit_sw(cur_error_jump.jtbl_vram));
// Call switch_error.
sljit_emit_icall(compiler, SLJIT_CALL, SLJIT_ARGS3V(P, 32, 32), SLJIT_IMM, sljit_sw(inputs.switch_error));
// Jump to the return statement.
switch_error_return_jumps[i] = sljit_emit_jump(compiler, SLJIT_JUMP);
}
// Generate the return statement.
sljit_label* return_label = sljit_emit_label(compiler);
sljit_emit_return_void(compiler);
// Assign the label for all the return jumps.
for (sljit_jump* cur_jump : switch_error_return_jumps) {
sljit_set_label(cur_jump, return_label);
}
}
context->switch_error_jumps.clear();
// Generate the code.
ret.code = sljit_generate_code(compiler, 0, NULL);
ret.code_size = sljit_get_generated_code_size(compiler);
@ -92,6 +146,34 @@ N64Recomp::LiveGeneratorOutput N64Recomp::LiveGenerator::finish() {
}
}
// Populate all the switch case addresses.
bool invalid_switch = false;
for (size_t switch_index = 0; switch_index < context->switch_jump_labels.size(); switch_index++) {
const std::vector<std::string>& cur_labels = context->switch_jump_labels[switch_index];
void** cur_jump_table = context->jump_tables[switch_index];
for (size_t case_index = 0; case_index < cur_labels.size(); case_index++) {
// Find the label.
auto find_it = context->labels.find(cur_labels[case_index]);
if (find_it == context->labels.end()) {
// Label not found, invalid switch.
// Don't return immediately, as we need to ensure that all the jump tables end up in ret
// so that it cleans them up in its destructor.
invalid_switch = true;
break;
}
// Get the label's address and place it in the jump table.
cur_jump_table[case_index] = reinterpret_cast<void*>(sljit_get_label_addr(find_it->second));
}
ret.jump_tables.emplace_back(cur_jump_table);
}
context->switch_jump_labels.clear();
context->jump_tables.clear();
if (invalid_switch) {
return { };
}
sljit_free_compiler(compiler);
compiler = nullptr;
@ -101,10 +183,18 @@ N64Recomp::LiveGeneratorOutput N64Recomp::LiveGenerator::finish() {
N64Recomp::LiveGeneratorOutput::~LiveGeneratorOutput() {
if (code != nullptr) {
sljit_free_code(code, nullptr);
code = nullptr;
}
for (const char* literal : string_literals) {
delete[] literal;
}
string_literals.clear();
for (void** jump_table : jump_tables) {
delete[] jump_table;
}
jump_tables.clear();
}
constexpr int get_gpr_context_offset(int gpr_index) {
@ -703,9 +793,10 @@ void N64Recomp::LiveGenerator::process_store_op(const StoreOp& op, const Instruc
}
void N64Recomp::LiveGenerator::emit_function_start(const std::string& function_name, size_t func_index) const {
context->function_name = function_name;
context->func_labels[func_index] = sljit_emit_label(compiler);
sljit_emit_enter(compiler, 0, SLJIT_ARGS2V(P, P), 4, 5, 0);
sljit_emit_op2(compiler, SLJIT_SUB, Registers::rdram, 0, Registers::rdram, 0, SLJIT_IMM, 0xFFFFFFFF80000000);
sljit_emit_op2(compiler, SLJIT_SUB, Registers::rdram, 0, Registers::rdram, 0, SLJIT_IMM, rdram_offset);
}
void N64Recomp::LiveGenerator::emit_function_end() const {
@ -723,7 +814,7 @@ void N64Recomp::LiveGenerator::emit_function_call_lookup(uint32_t addr) const {
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_R0, 0);
// Load rdram and ctx into R0 and R1.
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, 0xFFFFFFFF80000000);
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, rdram_offset);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, Registers::ctx, 0);
// Call the function.
@ -741,7 +832,7 @@ void N64Recomp::LiveGenerator::emit_function_call_by_register(int reg) const {
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_R0, 0);
// Load rdram and ctx into R0 and R1.
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, 0xFFFFFFFF80000000);
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, rdram_offset);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, Registers::ctx, 0);
// Call the function.
@ -754,8 +845,10 @@ void N64Recomp::LiveGenerator::emit_function_call_reference_symbol(const Context
}
void N64Recomp::LiveGenerator::emit_function_call(const Context& recompiler_context, size_t function_index) const {
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, 0xFFFFFFFF80000000);
// Load rdram and ctx into R0 and R1.
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, rdram_offset);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, Registers::ctx, 0);
// Call the function and save the jump to set its label later on.
sljit_jump* call_jump = sljit_emit_call(compiler, SLJIT_CALL, SLJIT_ARGS2V(P, P));
context->inner_calls.emplace_back(InnerCall{ .target_func_index = function_index, .jump = call_jump });
}
@ -790,8 +883,8 @@ void N64Recomp::LiveGenerator::emit_label(const std::string& label_name) const {
context->labels.emplace(label_name, label);
}
void N64Recomp::LiveGenerator::emit_variable_declaration(const std::string& var_name, int reg) const {
assert(false);
void N64Recomp::LiveGenerator::emit_jtbl_addend_declaration(const JumpTable& jtbl, int reg) const {
// Nothing to do here, the live recompiler performs a subtraction to get the switch's case.
}
void N64Recomp::LiveGenerator::emit_branch_condition(const ConditionalBranchOp& op, const InstructionContext& ctx) const {
@ -869,20 +962,52 @@ void N64Recomp::LiveGenerator::emit_branch_close() const {
context->cur_branch_jump = nullptr;
}
void N64Recomp::LiveGenerator::emit_switch(const std::string& jump_variable, int shift_amount) const {
assert(false);
void N64Recomp::LiveGenerator::emit_switch(const JumpTable& jtbl, int reg) const {
// Populate the switch's labels.
std::vector<std::string> cur_labels{};
cur_labels.resize(jtbl.entries.size());
for (size_t i = 0; i < cur_labels.size(); i++) {
cur_labels[i] = fmt::format("L_{:08X}", jtbl.entries[i]);
}
context->switch_jump_labels.emplace_back(std::move(cur_labels));
// Allocate the jump table. Must be manually allocated to prevent the address from changing.
void** cur_jump_table = new void*[jtbl.entries.size()];
context->jump_tables.emplace_back(cur_jump_table);
/// Codegen
// Load the jump target register. The lw instruction was patched into an addiu, so this holds
// the address of the jump table entry instead of the actual jump target.
sljit_emit_op1(compiler, SLJIT_MOV, Registers::arithmetic_temp1, 0, SLJIT_MEM1(Registers::ctx), get_gpr_context_offset(reg));
// Subtract the jump table's address from the jump target to get the jump table addend.
// Sign extend the jump table address to 64 bits so that the entire register's contents are used instead of just the lower 32 bits.
sljit_emit_op2(compiler, SLJIT_SUB, Registers::arithmetic_temp1, 0, Registers::arithmetic_temp1, 0, SLJIT_IMM, (sljit_sw)((int32_t)jtbl.vram));
// Bounds check the addend. If it's greater than or equal to the jump table size (entries * sizeof(u32)) then jump to the switch error.
sljit_jump* switch_error_jump = sljit_emit_cmp(compiler, SLJIT_GREATER_EQUAL, Registers::arithmetic_temp1, 0, SLJIT_IMM, jtbl.entries.size() * sizeof(uint32_t));
context->switch_error_jumps.emplace_back(SwitchErrorJump{.instr_vram = jtbl.jr_vram, .jtbl_vram = jtbl.vram, .jump = switch_error_jump});
// Multiply the jump table addend by 2 to get the addend for the real jump table. (4 bytes per entry to 8 bytes per entry).
sljit_emit_op2(compiler, SLJIT_ADD, Registers::arithmetic_temp1, 0, Registers::arithmetic_temp1, 0, Registers::arithmetic_temp1, 0);
// Load the real jump table address.
sljit_emit_op1(compiler, SLJIT_MOV, Registers::arithmetic_temp2, 0, SLJIT_IMM, (sljit_sw)cur_jump_table);
// Load the real jump entry.
sljit_emit_op1(compiler, SLJIT_MOV, Registers::arithmetic_temp1, 0, SLJIT_MEM2(Registers::arithmetic_temp1, Registers::arithmetic_temp2), 0);
// Jump to the loaded entry.
sljit_emit_ijump(compiler, SLJIT_JUMP, Registers::arithmetic_temp1, 0);
}
void N64Recomp::LiveGenerator::emit_case(int case_index, const std::string& target_label) const {
assert(false);
// Nothing to do here, the jump table is built in emit_switch.
}
void N64Recomp::LiveGenerator::emit_switch_error(uint32_t instr_vram, uint32_t jtbl_vram) const {
assert(false);
// Nothing to do here, the jump table is built in emit_switch.
}
void N64Recomp::LiveGenerator::emit_switch_close() const {
assert(false);
// Nothing to do here, the jump table is built in emit_switch.
}
void N64Recomp::LiveGenerator::emit_return() const {
@ -1005,11 +1130,8 @@ void N64Recomp::LiveGenerator::emit_muldiv(InstrId instr_id, int reg1, int reg2)
}
// If the denominator is 0, skip the division and jump the special handling for that case.
// Set the zero flag if the denominator is zero by AND'ing it with itself.
sljit_emit_op2u(compiler, SLJIT_AND | SLJIT_SET_Z, SLJIT_R1, 0, SLJIT_R1, 0);
// Branch past the division if the zero flag is 0.
sljit_jump* jump_skip_division = sljit_emit_jump(compiler, SLJIT_ZERO);
// Branch past the division if the divisor is 0.
sljit_jump* jump_skip_division = sljit_emit_cmp(compiler, SLJIT_EQUAL, SLJIT_R1, 0, SLJIT_IMM, 0);// sljit_emit_jump(compiler, SLJIT_ZERO);
// Perform the division.
sljit_emit_op0(compiler, div_opcode);
@ -1078,19 +1200,37 @@ void N64Recomp::LiveGenerator::emit_muldiv(InstrId instr_id, int reg1, int reg2)
}
void N64Recomp::LiveGenerator::emit_syscall(uint32_t instr_vram) const {
assert(false);
// Load rdram and ctx into R0 and R1.
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, rdram_offset);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, Registers::ctx, 0);
// Load the vram into R2.
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R2, 0, SLJIT_IMM, instr_vram);
// Call syscall_handler.
sljit_emit_icall(compiler, SLJIT_CALL, SLJIT_ARGS3V(P, P, 32), SLJIT_IMM, sljit_sw(inputs.syscall_handler));
}
void N64Recomp::LiveGenerator::emit_do_break(uint32_t instr_vram) const {
assert(false);
// Load the vram into R0.
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R0, 0, SLJIT_IMM, instr_vram);
// Call do_break.
sljit_emit_icall(compiler, SLJIT_CALL, SLJIT_ARGS1V(32), SLJIT_IMM, sljit_sw(inputs.do_break));
}
void N64Recomp::LiveGenerator::emit_pause_self() const {
assert(false);
// Load rdram into R0.
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, rdram_offset);
// Call pause_self.
sljit_emit_icall(compiler, SLJIT_CALL, SLJIT_ARGS1V(P), SLJIT_IMM, sljit_sw(inputs.pause_self));
}
void N64Recomp::LiveGenerator::emit_trigger_event(size_t event_index) const {
assert(false);
void N64Recomp::LiveGenerator::emit_trigger_event(uint32_t event_index) const {
// Load rdram and ctx into R0 and R1.
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, Registers::rdram, 0, SLJIT_IMM, rdram_offset);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, Registers::ctx, 0);
// Load the global event index into R2.
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R2, 0, SLJIT_IMM, event_index + inputs.base_event_index);
// Call trigger_event.
sljit_emit_icall(compiler, SLJIT_CALL, SLJIT_ARGS1V(P), SLJIT_IMM, sljit_sw(inputs.trigger_event));
}
void N64Recomp::LiveGenerator::emit_comment(const std::string& comment) const {