Add function hooks to mod symbol format

This commit is contained in:
Mr-Wiseguy 2025-01-05 02:05:06 -05:00
parent 36b5d9ae33
commit 985c02e979
6 changed files with 106 additions and 6 deletions

View file

@ -573,6 +573,8 @@ N64Recomp::Context build_mod_context(const N64Recomp::Context& input_context, bo
bool event_section = cur_section.name == N64Recomp::EventSectionName;
bool import_section = cur_section.name.starts_with(N64Recomp::ImportSectionPrefix);
bool callback_section = cur_section.name.starts_with(N64Recomp::CallbackSectionPrefix);
bool hook_section = cur_section.name.starts_with(N64Recomp::HookSectionPrefix);
bool hook_return_section = cur_section.name.starts_with(N64Recomp::HookReturnSectionPrefix);
// Add the functions from the current input section to the current output section.
auto& section_out = ret.sections[output_section_index];
@ -638,6 +640,42 @@ N64Recomp::Context build_mod_context(const N64Recomp::Context& input_context, bo
);
}
if (hook_section || hook_return_section) {
// Get the name of the hooked function.
size_t section_prefix_length = hook_section ? N64Recomp::HookSectionPrefix.size() : N64Recomp::HookReturnSectionPrefix.size();
std::string hooked_function_name = cur_section.name.substr(section_prefix_length);
// Find the corresponding symbol in the reference symbols.
N64Recomp::SymbolReference cur_reference;
bool original_func_exists = input_context.find_regular_reference_symbol(hooked_function_name, cur_reference);
// Check that the function being patched exists in the original reference symbols.
if (!original_func_exists) {
fmt::print(stderr, "Function {} hooks a function ({}) that doesn't exist in the original ROM.\n", cur_func.name, hooked_function_name);
return {};
}
// Check that the reference symbol is actually a function.
const auto& reference_symbol = input_context.get_reference_symbol(cur_reference);
if (!reference_symbol.is_function) {
fmt::print(stderr, "Function {0} hooks {1}, but {1} was a variable in the original ROM.\n", cur_func.name, hooked_function_name);
return {};
}
uint32_t reference_section_vram = input_context.get_reference_section_vram(reference_symbol.section_index);
uint32_t reference_section_rom = input_context.get_reference_section_rom(reference_symbol.section_index);
// Add a replacement for this function to the output context.
ret.hooks.emplace_back(
N64Recomp::FunctionHook {
.func_index = (uint32_t)output_func_index,
.original_section_vrom = reference_section_rom,
.original_vram = reference_section_vram + reference_symbol.section_offset,
.flags = hook_return_section ? N64Recomp::HookFlags::AtReturn : N64Recomp::HookFlags{}
}
);
}
std::string name_out;
if (export_section) {

View file

@ -85,6 +85,8 @@ namespace N64Recomp {
constexpr std::string_view EventSectionName = ".recomp_event";
constexpr std::string_view ImportSectionPrefix = ".recomp_import.";
constexpr std::string_view CallbackSectionPrefix = ".recomp_callback.";
constexpr std::string_view HookSectionPrefix = ".recomp_hook.";
constexpr std::string_view HookReturnSectionPrefix = ".recomp_hook_return.";
// Special dependency names.
constexpr std::string_view DependencySelf = ".";
@ -183,6 +185,19 @@ namespace N64Recomp {
ReplacementFlags flags;
};
enum class HookFlags : uint32_t {
AtReturn = 1 << 0,
};
inline HookFlags operator&(HookFlags lhs, HookFlags rhs) { return HookFlags(uint32_t(lhs) & uint32_t(rhs)); }
inline HookFlags operator|(HookFlags lhs, HookFlags rhs) { return HookFlags(uint32_t(lhs) | uint32_t(rhs)); }
struct FunctionHook {
uint32_t func_index;
uint32_t original_section_vrom;
uint32_t original_vram;
HookFlags flags;
};
class Context {
private:
//// Reference symbols (used for populating relocations for patches)
@ -236,6 +251,8 @@ namespace N64Recomp {
std::vector<Callback> callbacks;
// List of symbols from events, which contains the names of events that this context provides.
std::vector<EventSymbol> event_symbols;
// List of hooks, which contains the original function to hook and the function index to call at the hook.
std::vector<FunctionHook> hooks;
// Causes functions to print their name to the console the first time they're called.
bool trace_mode;

View file

@ -201,8 +201,8 @@ std::vector<N64Recomp::InstructionPatch> get_instruction_patches(const toml::tab
return ret;
}
std::vector<N64Recomp::FunctionHook> get_function_hooks(const toml::table* patches_data) {
std::vector<N64Recomp::FunctionHook> ret;
std::vector<N64Recomp::FunctionTextHook> get_function_hooks(const toml::table* patches_data) {
std::vector<N64Recomp::FunctionTextHook> ret;
// Check if the function hook array exists.
const toml::node_view func_hook_data = (*patches_data)["hook"];
@ -230,7 +230,7 @@ std::vector<N64Recomp::FunctionHook> get_function_hooks(const toml::table* patch
throw toml::parse_error("before_vram is not word-aligned", el.source());
}
ret.push_back(N64Recomp::FunctionHook{
ret.push_back(N64Recomp::FunctionTextHook{
.func_name = func_name.value(),
.before_vram = before_vram.has_value() ? (int32_t)before_vram.value() : 0,
.text = text.value(),

View file

@ -12,7 +12,7 @@ namespace N64Recomp {
uint32_t value;
};
struct FunctionHook {
struct FunctionTextHook {
std::string func_name;
int32_t before_vram;
std::string text;
@ -57,7 +57,7 @@ namespace N64Recomp {
std::vector<std::string> ignored_funcs;
std::vector<std::string> renamed_funcs;
std::vector<InstructionPatch> instruction_patches;
std::vector<FunctionHook> function_hooks;
std::vector<FunctionTextHook> function_hooks;
std::vector<FunctionSize> manual_func_sizes;
std::vector<ManualFunction> manual_functions;
std::string bss_section_suffix;

View file

@ -536,7 +536,7 @@ int main(int argc, char** argv) {
}
// Apply any function hooks.
for (const N64Recomp::FunctionHook& patch : config.function_hooks) {
for (const N64Recomp::FunctionTextHook& patch : config.function_hooks) {
// Check if the specified function exists.
auto func_find = context.functions_by_name.find(patch.func_name);
if (func_find == context.functions_by_name.end()) {

View file

@ -16,6 +16,7 @@ struct FileSubHeaderV1 {
uint32_t num_exports;
uint32_t num_callbacks;
uint32_t num_provided_events;
uint32_t num_hooks;
uint32_t string_data_size;
};
@ -89,6 +90,13 @@ struct EventV1 {
uint32_t name_size;
};
struct HookV1 {
uint32_t func_index;
uint32_t original_section_vrom;
uint32_t original_vram;
uint32_t flags; // end
};
template <typename T>
const T* reinterpret_data(std::span<const char> data, size_t& offset, size_t count = 1) {
if (offset + (sizeof(T) * count) > data.size()) {
@ -126,6 +134,7 @@ bool parse_v1(std::span<const char> data, const std::unordered_map<uint32_t, uin
size_t num_exports = subheader->num_exports;
size_t num_callbacks = subheader->num_callbacks;
size_t num_provided_events = subheader->num_provided_events;
size_t num_hooks = subheader->num_hooks;
size_t string_data_size = subheader->string_data_size;
if (string_data_size & 0b11) {
@ -147,6 +156,7 @@ bool parse_v1(std::span<const char> data, const std::unordered_map<uint32_t, uin
mod_context.exported_funcs.resize(num_exports); // Add method
mod_context.callbacks.reserve(num_callbacks);
mod_context.event_symbols.reserve(num_provided_events);
mod_context.hooks.reserve(num_provided_events);
for (size_t section_index = 0; section_index < num_sections; section_index++) {
const SectionHeaderV1* section_header = reinterpret_data<SectionHeaderV1>(data, offset);
@ -434,6 +444,22 @@ bool parse_v1(std::span<const char> data, const std::unordered_map<uint32_t, uin
mod_context.add_event_symbol(std::string{import_name});
}
const HookV1* hooks = reinterpret_data<HookV1>(data, offset, num_hooks);
if (hooks == nullptr) {
printf("Failed to read hooks (count: %zu)\n", num_hooks);
return false;
}
for (size_t hook_index = 0; hook_index < num_hooks; hook_index++) {
const HookV1& hook_in = hooks[hook_index];
N64Recomp::FunctionHook& hook_out = mod_context.hooks.emplace_back();
hook_out.func_index = hook_in.func_index;
hook_out.original_section_vrom = hook_in.original_section_vrom;
hook_out.original_vram = hook_in.original_vram;
hook_out.flags = static_cast<N64Recomp::HookFlags>(hook_in.flags);
}
return offset == data.size();
}
@ -512,6 +538,7 @@ std::vector<uint8_t> N64Recomp::symbols_to_bin_v1(const N64Recomp::Context& cont
size_t num_events = context.event_symbols.size();
size_t num_callbacks = context.callbacks.size();
size_t num_provided_events = context.event_symbols.size();
size_t num_hooks = context.hooks.size();
FileSubHeaderV1 sub_header {
.num_sections = static_cast<uint32_t>(context.sections.size()),
@ -522,6 +549,7 @@ std::vector<uint8_t> N64Recomp::symbols_to_bin_v1(const N64Recomp::Context& cont
.num_exports = static_cast<uint32_t>(num_exported_funcs),
.num_callbacks = static_cast<uint32_t>(num_callbacks),
.num_provided_events = static_cast<uint32_t>(num_provided_events),
.num_hooks = static_cast<uint32_t>(num_hooks),
.string_data_size = 0,
};
@ -757,5 +785,22 @@ std::vector<uint8_t> N64Recomp::symbols_to_bin_v1(const N64Recomp::Context& cont
vec_put(ret, &event_out);
}
// Write the hooks.
for (const FunctionHook& cur_hook : context.hooks) {
uint32_t flags = 0;
if ((cur_hook.flags & HookFlags::AtReturn) == HookFlags::AtReturn) {
flags |= 0x1;
}
HookV1 hook_out {
.func_index = cur_hook.func_index,
.original_section_vrom = cur_hook.original_section_vrom,
.original_vram = cur_hook.original_vram,
.flags = flags
};
vec_put(ret, &hook_out);
}
return ret;
}