diff --git a/RecompModTool/main.cpp b/RecompModTool/main.cpp index 9fbb7d1..cb29b0c 100644 --- a/RecompModTool/main.cpp +++ b/RecompModTool/main.cpp @@ -573,6 +573,8 @@ N64Recomp::Context build_mod_context(const N64Recomp::Context& input_context, bo bool event_section = cur_section.name == N64Recomp::EventSectionName; bool import_section = cur_section.name.starts_with(N64Recomp::ImportSectionPrefix); bool callback_section = cur_section.name.starts_with(N64Recomp::CallbackSectionPrefix); + bool hook_section = cur_section.name.starts_with(N64Recomp::HookSectionPrefix); + bool hook_return_section = cur_section.name.starts_with(N64Recomp::HookReturnSectionPrefix); // Add the functions from the current input section to the current output section. auto& section_out = ret.sections[output_section_index]; @@ -638,6 +640,42 @@ N64Recomp::Context build_mod_context(const N64Recomp::Context& input_context, bo ); } + if (hook_section || hook_return_section) { + // Get the name of the hooked function. + size_t section_prefix_length = hook_section ? N64Recomp::HookSectionPrefix.size() : N64Recomp::HookReturnSectionPrefix.size(); + std::string hooked_function_name = cur_section.name.substr(section_prefix_length); + + // Find the corresponding symbol in the reference symbols. + N64Recomp::SymbolReference cur_reference; + bool original_func_exists = input_context.find_regular_reference_symbol(hooked_function_name, cur_reference); + + // Check that the function being patched exists in the original reference symbols. + if (!original_func_exists) { + fmt::print(stderr, "Function {} hooks a function ({}) that doesn't exist in the original ROM.\n", cur_func.name, hooked_function_name); + return {}; + } + + // Check that the reference symbol is actually a function. + const auto& reference_symbol = input_context.get_reference_symbol(cur_reference); + if (!reference_symbol.is_function) { + fmt::print(stderr, "Function {0} hooks {1}, but {1} was a variable in the original ROM.\n", cur_func.name, hooked_function_name); + return {}; + } + + uint32_t reference_section_vram = input_context.get_reference_section_vram(reference_symbol.section_index); + uint32_t reference_section_rom = input_context.get_reference_section_rom(reference_symbol.section_index); + + // Add a replacement for this function to the output context. + ret.hooks.emplace_back( + N64Recomp::FunctionHook { + .func_index = (uint32_t)output_func_index, + .original_section_vrom = reference_section_rom, + .original_vram = reference_section_vram + reference_symbol.section_offset, + .flags = hook_return_section ? N64Recomp::HookFlags::AtReturn : N64Recomp::HookFlags{} + } + ); + } + std::string name_out; if (export_section) { diff --git a/include/recompiler/context.h b/include/recompiler/context.h index 7dd70c2..8231d75 100644 --- a/include/recompiler/context.h +++ b/include/recompiler/context.h @@ -85,6 +85,8 @@ namespace N64Recomp { constexpr std::string_view EventSectionName = ".recomp_event"; constexpr std::string_view ImportSectionPrefix = ".recomp_import."; constexpr std::string_view CallbackSectionPrefix = ".recomp_callback."; + constexpr std::string_view HookSectionPrefix = ".recomp_hook."; + constexpr std::string_view HookReturnSectionPrefix = ".recomp_hook_return."; // Special dependency names. constexpr std::string_view DependencySelf = "."; @@ -183,6 +185,19 @@ namespace N64Recomp { ReplacementFlags flags; }; + enum class HookFlags : uint32_t { + AtReturn = 1 << 0, + }; + inline HookFlags operator&(HookFlags lhs, HookFlags rhs) { return HookFlags(uint32_t(lhs) & uint32_t(rhs)); } + inline HookFlags operator|(HookFlags lhs, HookFlags rhs) { return HookFlags(uint32_t(lhs) | uint32_t(rhs)); } + + struct FunctionHook { + uint32_t func_index; + uint32_t original_section_vrom; + uint32_t original_vram; + HookFlags flags; + }; + class Context { private: //// Reference symbols (used for populating relocations for patches) @@ -236,6 +251,8 @@ namespace N64Recomp { std::vector callbacks; // List of symbols from events, which contains the names of events that this context provides. std::vector event_symbols; + // List of hooks, which contains the original function to hook and the function index to call at the hook. + std::vector hooks; // Causes functions to print their name to the console the first time they're called. bool trace_mode; diff --git a/src/config.cpp b/src/config.cpp index a54421a..e90b043 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -201,8 +201,8 @@ std::vector get_instruction_patches(const toml::tab return ret; } -std::vector get_function_hooks(const toml::table* patches_data) { - std::vector ret; +std::vector get_function_hooks(const toml::table* patches_data) { + std::vector ret; // Check if the function hook array exists. const toml::node_view func_hook_data = (*patches_data)["hook"]; @@ -230,7 +230,7 @@ std::vector get_function_hooks(const toml::table* patch throw toml::parse_error("before_vram is not word-aligned", el.source()); } - ret.push_back(N64Recomp::FunctionHook{ + ret.push_back(N64Recomp::FunctionTextHook{ .func_name = func_name.value(), .before_vram = before_vram.has_value() ? (int32_t)before_vram.value() : 0, .text = text.value(), diff --git a/src/config.h b/src/config.h index 0f01a33..536c4cc 100644 --- a/src/config.h +++ b/src/config.h @@ -12,7 +12,7 @@ namespace N64Recomp { uint32_t value; }; - struct FunctionHook { + struct FunctionTextHook { std::string func_name; int32_t before_vram; std::string text; @@ -57,7 +57,7 @@ namespace N64Recomp { std::vector ignored_funcs; std::vector renamed_funcs; std::vector instruction_patches; - std::vector function_hooks; + std::vector function_hooks; std::vector manual_func_sizes; std::vector manual_functions; std::string bss_section_suffix; diff --git a/src/main.cpp b/src/main.cpp index 8a8fe91..e7db2ed 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -536,7 +536,7 @@ int main(int argc, char** argv) { } // Apply any function hooks. - for (const N64Recomp::FunctionHook& patch : config.function_hooks) { + for (const N64Recomp::FunctionTextHook& patch : config.function_hooks) { // Check if the specified function exists. auto func_find = context.functions_by_name.find(patch.func_name); if (func_find == context.functions_by_name.end()) { diff --git a/src/mod_symbols.cpp b/src/mod_symbols.cpp index fcfdead..ab80b04 100644 --- a/src/mod_symbols.cpp +++ b/src/mod_symbols.cpp @@ -16,6 +16,7 @@ struct FileSubHeaderV1 { uint32_t num_exports; uint32_t num_callbacks; uint32_t num_provided_events; + uint32_t num_hooks; uint32_t string_data_size; }; @@ -89,6 +90,13 @@ struct EventV1 { uint32_t name_size; }; +struct HookV1 { + uint32_t func_index; + uint32_t original_section_vrom; + uint32_t original_vram; + uint32_t flags; // end +}; + template const T* reinterpret_data(std::span data, size_t& offset, size_t count = 1) { if (offset + (sizeof(T) * count) > data.size()) { @@ -126,6 +134,7 @@ bool parse_v1(std::span data, const std::unordered_mapnum_exports; size_t num_callbacks = subheader->num_callbacks; size_t num_provided_events = subheader->num_provided_events; + size_t num_hooks = subheader->num_hooks; size_t string_data_size = subheader->string_data_size; if (string_data_size & 0b11) { @@ -147,6 +156,7 @@ bool parse_v1(std::span data, const std::unordered_map(data, offset); @@ -434,6 +444,22 @@ bool parse_v1(std::span data, const std::unordered_map(data, offset, num_hooks); + if (hooks == nullptr) { + printf("Failed to read hooks (count: %zu)\n", num_hooks); + return false; + } + + for (size_t hook_index = 0; hook_index < num_hooks; hook_index++) { + const HookV1& hook_in = hooks[hook_index]; + N64Recomp::FunctionHook& hook_out = mod_context.hooks.emplace_back(); + + hook_out.func_index = hook_in.func_index; + hook_out.original_section_vrom = hook_in.original_section_vrom; + hook_out.original_vram = hook_in.original_vram; + hook_out.flags = static_cast(hook_in.flags); + } + return offset == data.size(); } @@ -512,6 +538,7 @@ std::vector N64Recomp::symbols_to_bin_v1(const N64Recomp::Context& cont size_t num_events = context.event_symbols.size(); size_t num_callbacks = context.callbacks.size(); size_t num_provided_events = context.event_symbols.size(); + size_t num_hooks = context.hooks.size(); FileSubHeaderV1 sub_header { .num_sections = static_cast(context.sections.size()), @@ -522,6 +549,7 @@ std::vector N64Recomp::symbols_to_bin_v1(const N64Recomp::Context& cont .num_exports = static_cast(num_exported_funcs), .num_callbacks = static_cast(num_callbacks), .num_provided_events = static_cast(num_provided_events), + .num_hooks = static_cast(num_hooks), .string_data_size = 0, }; @@ -757,5 +785,22 @@ std::vector N64Recomp::symbols_to_bin_v1(const N64Recomp::Context& cont vec_put(ret, &event_out); } + // Write the hooks. + for (const FunctionHook& cur_hook : context.hooks) { + uint32_t flags = 0; + if ((cur_hook.flags & HookFlags::AtReturn) == HookFlags::AtReturn) { + flags |= 0x1; + } + + HookV1 hook_out { + .func_index = cur_hook.func_index, + .original_section_vrom = cur_hook.original_section_vrom, + .original_vram = cur_hook.original_vram, + .flags = flags + }; + + vec_put(ret, &hook_out); + } + return ret; }