From 4e69eb89e802fee0c25fde13c452b8b8a2cfd39e Mon Sep 17 00:00:00 2001 From: Ali Mohammad Pur Date: Fri, 28 Jul 2023 21:02:34 +0330 Subject: [PATCH] LibRegex: Generate a search tree when patterns would benefit from it This takes the previous alternation optimisation and applies it to all the alternation blocks instead of just the few instructions at the start. By generating a trie of instructions, all logically equivalent instructions will be consolidated into a single node, allowing the engine to avoid checking the same thing multiple times. For instance, given the pattern /abc|ac|ab/, this optimisation would generate the following tree: - a | - b | | - c | | | - | | - | - c | | - which will attempt to match 'a' or 'b' only once, and would also limit the number of backtrackings performed in case alternatives fails to match. This optimisation is currently gated behind a simple cost model that estimates the number of instructions generated, which is pessimistic for small patterns, though the change in performance in such patterns is not particularly large. --- AK/DisjointChunks.h | 24 +- Tests/LibRegex/Regex.cpp | 2 + Userland/Libraries/LibRegex/RegexByteCode.h | 36 +- .../Libraries/LibRegex/RegexOptimizer.cpp | 437 ++++++++++++------ 4 files changed, 347 insertions(+), 152 deletions(-) diff --git a/AK/DisjointChunks.h b/AK/DisjointChunks.h index 8b6033f9455..f95bdb52256 100644 --- a/AK/DisjointChunks.h +++ b/AK/DisjointChunks.h @@ -94,6 +94,14 @@ public: DisjointSpans& operator=(DisjointSpans&&) = default; DisjointSpans& operator=(DisjointSpans const&) = default; + Span singular_span() const + { + VERIFY(m_spans.size() == 1); + return m_spans[0]; + } + + SpanContainer const& individual_spans() const { return m_spans; } + bool operator==(DisjointSpans const& other) const { if (other.size() != size()) @@ -440,8 +448,22 @@ private: Vector m_chunks; }; -} +template +struct Traits> : public GenericTraits> { + static unsigned hash(DisjointSpans const& span) + { + unsigned hash = 0; + for (auto const& value : span) { + auto value_hash = Traits::hash(value); + hash = pair_int_hash(hash, value_hash); + } + return hash; + } + constexpr static bool is_trivial() { return false; } +}; + +} #if USING_AK_GLOBALLY using AK::DisjointChunks; using AK::DisjointSpans; diff --git a/Tests/LibRegex/Regex.cpp b/Tests/LibRegex/Regex.cpp index d88f9cd6a3d..40639842324 100644 --- a/Tests/LibRegex/Regex.cpp +++ b/Tests/LibRegex/Regex.cpp @@ -1047,6 +1047,8 @@ TEST_CASE(optimizer_alternation) Array tests { // Pattern, Subject, Expected length Tuple { "a|"sv, "a"sv, 1u }, + Tuple { "a|a|a|a|a|a|a|a|a|b"sv, "a"sv, 1u }, + Tuple { "ab|ac|ad|bc"sv, "bc"sv, 2u }, }; for (auto& test : tests) { diff --git a/Userland/Libraries/LibRegex/RegexByteCode.h b/Userland/Libraries/LibRegex/RegexByteCode.h index 38aba05d93a..6dc0b8845ac 100644 --- a/Userland/Libraries/LibRegex/RegexByteCode.h +++ b/Userland/Libraries/LibRegex/RegexByteCode.h @@ -185,6 +185,23 @@ public: Base::first_chunk().prepend(forward(value)); } + void append(Span value) + { + if (is_empty()) + Base::append({}); + auto& last = Base::last_chunk(); + last.ensure_capacity(value.size()); + for (auto v : value) + last.unchecked_append(v); + } + + void ensure_capacity(size_t capacity) + { + if (is_empty()) + Base::append({}); + Base::last_chunk().ensure_capacity(capacity); + } + void last_chunk() const = delete; void first_chunk() const = delete; @@ -210,20 +227,11 @@ public: void insert_bytecode_compare_string(StringView view) { - ByteCode bytecode; - - bytecode.empend(static_cast(OpCodeId::Compare)); - bytecode.empend(static_cast(1)); // number of arguments - - ByteCode arguments; - - arguments.empend(static_cast(CharacterCompareType::String)); - arguments.insert_string(view); - - bytecode.empend(arguments.size()); // size of arguments - bytecode.extend(move(arguments)); - - extend(move(bytecode)); + empend(static_cast(OpCodeId::Compare)); + empend(static_cast(1)); // number of arguments + empend(2 + view.length()); // size of arguments + empend(static_cast(CharacterCompareType::String)); + insert_string(view); } void insert_bytecode_group_capture_left(size_t capture_groups_count) diff --git a/Userland/Libraries/LibRegex/RegexOptimizer.cpp b/Userland/Libraries/LibRegex/RegexOptimizer.cpp index e765902464e..b12f3034fdd 100644 --- a/Userland/Libraries/LibRegex/RegexOptimizer.cpp +++ b/Userland/Libraries/LibRegex/RegexOptimizer.cpp @@ -5,9 +5,12 @@ */ #include +#include +#include #include #include #include +#include #include #include #include @@ -815,6 +818,9 @@ void Optimizer::append_alternation(ByteCode& target, ByteCode&& left, ByteCode&& append_alternation(target, alternatives); } +template +using OrderedHashMapForTrie = OrderedHashMap; + void Optimizer::append_alternation(ByteCode& target, Span alternatives) { if (alternatives.size() == 0) @@ -846,154 +852,311 @@ void Optimizer::append_alternation(ByteCode& target, Span alternatives }; #endif - Vector> basic_blocks; - basic_blocks.ensure_capacity(alternatives.size()); + // First, find incoming jump edges. + // We need them for two reasons: + // - We need to distinguish between insn-A-jumped-to-by-insn-B and insn-A-jumped-to-by-insn-C (as otherwise we'd break trie invariants) + // - We need to know which jumps to patch when we're done - for (auto& entry : alternatives) - basic_blocks.append(Regex::split_basic_blocks(entry)); - - Optional left_skip; - size_t shared_block_count = basic_blocks.first().size(); - for (auto& entry : basic_blocks) - shared_block_count = min(shared_block_count, entry.size()); + struct JumpEdge { + Span jump_insn; + }; + Vector>> incoming_jump_edges_for_each_alternative; + incoming_jump_edges_for_each_alternative.resize(alternatives.size()); MatchState state; - for (size_t block_index = 0; block_index < shared_block_count; block_index++) { - auto& left_block = basic_blocks.first()[block_index]; - auto left_end = block_index + 1 == basic_blocks.first().size() ? left_block.end : basic_blocks.first()[block_index + 1].start; - auto can_continue = true; + + for (size_t i = 0; i < alternatives.size(); ++i) { + auto& alternative = alternatives[i]; + // Add a jump to the "end" of the block; this is implicit in the bytecode, but we need it to be explicit in the trie. + // Jump{offset=0} + alternative.append(static_cast(OpCodeId::Jump)); + alternative.append(0); + + auto& incoming_jump_edges = incoming_jump_edges_for_each_alternative[i]; + + auto alternative_bytes = alternative.spans<1>().singular_span(); + for (state.instruction_position = 0; state.instruction_position < alternative.size();) { + auto& opcode = alternative.get_opcode(state); + auto opcode_bytes = alternative_bytes.slice(state.instruction_position, opcode.size()); + + switch (opcode.opcode_id()) { + case OpCodeId::Jump: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::JumpNonEmpty: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkJump: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkStay: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkReplaceJump: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkReplaceStay: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::Repeat: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + default: + break; + } + state.instruction_position += opcode.size(); + } + } + + struct QualifiedIP { + size_t alternative_index; + size_t instruction_position; + }; + using Tree = Trie, Vector, Traits>, void, OrderedHashMapForTrie>; + Tree trie { {} }; // Root node is empty, key{ instruction_bytes, dependent_instruction_bytes... } -> IP + + size_t common_hits = 0; + size_t total_nodes = 0; + size_t total_bytecode_entries_in_tree = 0; + for (size_t i = 0; i < alternatives.size(); ++i) { + auto& alternative = alternatives[i]; + auto& incoming_jump_edges = incoming_jump_edges_for_each_alternative[i]; + + auto* active_node = ≜ + auto alternative_span = alternative.spans<1>().singular_span(); + for (state.instruction_position = 0; state.instruction_position < alternative_span.size();) { + total_nodes += 1; + auto& opcode = alternative.get_opcode(state); + auto opcode_bytes = alternative_span.slice(state.instruction_position, opcode.size()); + Vector> node_key_bytes; + node_key_bytes.append(opcode_bytes); + + if (auto edges = incoming_jump_edges.get(state.instruction_position); edges.has_value()) { + for (auto& edge : *edges) + node_key_bytes.append(edge.jump_insn); + } + + active_node = static_cast(MUST(active_node->ensure_child(DisjointSpans { move(node_key_bytes) }))); + + if (active_node->has_metadata()) { + active_node->metadata_value().append({ i, state.instruction_position }); + common_hits += 1; + } else { + active_node->set_metadata(Vector { QualifiedIP { i, state.instruction_position } }); + total_bytecode_entries_in_tree += opcode.size(); + } + state.instruction_position += opcode.size(); + } + } + + if constexpr (REGEX_DEBUG) { + Function print_tree = [&](decltype(trie)& node, size_t indent = 0) mutable { + DeprecatedString name = "(no ip)"; + DeprecatedString insn; + if (node.has_metadata()) { + name = DeprecatedString::formatted( + "{}@{} ({} node{})", + node.metadata_value().first().instruction_position, + node.metadata_value().first().alternative_index, + node.metadata_value().size(), + node.metadata_value().size() == 1 ? "" : "s"); + + MatchState state; + state.instruction_position = node.metadata_value().first().instruction_position; + auto& opcode = alternatives[node.metadata_value().first().alternative_index].get_opcode(state); + insn = DeprecatedString::formatted("{} {}", opcode.to_deprecated_string(), opcode.arguments_string()); + } + dbgln("{:->{}}| {} -- {}", "", indent * 2, name, insn); + for (auto& child : node.children()) + print_tree(static_cast(*child.value), indent + 1); + }; + + print_tree(trie, 0); + } + + // This is really only worth it if we don't blow up the size by the 2-extra-instruction-per-node scheme, similarly, if no nodes are shared, we're better off not using a tree. + auto tree_cost = (total_nodes - common_hits) * 2; + auto chain_cost = total_nodes + alternatives.size() * 2; + dbgln_if(REGEX_DEBUG, "Total nodes: {}, common hits: {} (tree cost = {}, chain cost = {})", total_nodes, common_hits, tree_cost, chain_cost); + + if (common_hits == 0 || tree_cost > chain_cost) { + // It's better to lay these out as a normal sequence of instructions. + auto patch_start = target.size(); for (size_t i = 1; i < alternatives.size(); ++i) { - auto& right_blocks = basic_blocks[i]; - auto& right_block = right_blocks[block_index]; - auto right_end = block_index + 1 == right_blocks.size() ? right_block.end : right_blocks[block_index + 1].start; - - if (left_end - left_block.start != right_end - right_block.start) { - can_continue = false; - break; - } - - if (alternatives[0].spans().slice(left_block.start, left_end - left_block.start) != alternatives[i].spans().slice(right_block.start, right_end - right_block.start)) { - can_continue = false; - break; - } - } - if (!can_continue) - break; - - size_t i = 0; - for (auto& entry : alternatives) { - auto& blocks = basic_blocks[i++]; - auto& block = blocks[block_index]; - auto end = block_index + 1 == blocks.size() ? block.end : blocks[block_index + 1].start; - state.instruction_position = block.start; - size_t skip = 0; - while (state.instruction_position < end) { - auto& opcode = entry.get_opcode(state); - state.instruction_position += opcode.size(); - skip = state.instruction_position; - } - - if (left_skip.has_value()) - left_skip = min(skip, *left_skip); - else - left_skip = skip; - } - } - - // Remove forward jumps as they no longer make sense. - state.instruction_position = 0; - for (size_t i = 0; i < left_skip.value_or(0);) { - auto& opcode = alternatives[0].get_opcode(state); - switch (opcode.opcode_id()) { - case OpCodeId::Jump: - case OpCodeId::ForkJump: - case OpCodeId::JumpNonEmpty: - case OpCodeId::ForkStay: - case OpCodeId::ForkReplaceJump: - case OpCodeId::ForkReplaceStay: - if (opcode.argument(0) + opcode.size() > left_skip.value_or(0)) { - left_skip = i; - goto break_out; - } - break; - default: - break; - } - i += opcode.size(); - } -break_out:; - - dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}", left_skip, 0, alternatives[0].size()); - - if (left_skip.has_value() && *left_skip > 0) { - target.extend(alternatives[0].release_slice(basic_blocks.first().first().start, *left_skip)); - auto first = true; - for (auto& entry : alternatives) { - if (first) { - first = false; - continue; - } - entry = entry.release_slice(*left_skip); - } - } - - if (all_of(alternatives, [](auto& entry) { return entry.is_empty(); })) - return; - - size_t patch_start = target.size(); - for (size_t i = 1; i < alternatives.size(); ++i) { - target.empend(static_cast(OpCodeId::ForkJump)); - target.empend(0u); // To be filled later. - } - - size_t size_to_jump = 0; - bool seen_one_empty = false; - for (size_t i = alternatives.size(); i > 0; --i) { - auto& entry = alternatives[i - 1]; - if (entry.is_empty()) { - if (seen_one_empty) - continue; - seen_one_empty = true; + target.empend(static_cast(OpCodeId::ForkJump)); + target.empend(0u); // To be filled later. } - auto is_first = i == 1; - auto instruction_size = entry.size() + (is_first ? 0 : 2); // Jump; -> +2 - size_to_jump += instruction_size; - - if (!is_first) - target[patch_start + (i - 2) * 2 + 1] = size_to_jump + (alternatives.size() - i) * 2; - - dbgln_if(REGEX_DEBUG, "{} size = {}, cum={}", i - 1, instruction_size, size_to_jump); - } - - seen_one_empty = false; - for (size_t i = alternatives.size(); i > 0; --i) { - auto& chunk = alternatives[i - 1]; - if (chunk.is_empty()) { - if (seen_one_empty) - continue; - seen_one_empty = true; - } - - ByteCode* previous_chunk = nullptr; - size_t j = i - 1; - auto seen_one_empty_before = chunk.is_empty(); - while (j >= 1) { - --j; - auto& candidate_chunk = alternatives[j]; - if (candidate_chunk.is_empty()) { - if (seen_one_empty_before) + size_t size_to_jump = 0; + bool seen_one_empty = false; + for (size_t i = alternatives.size(); i > 0; --i) { + auto& entry = alternatives[i - 1]; + if (entry.is_empty()) { + if (seen_one_empty) continue; + seen_one_empty = true; } - previous_chunk = &candidate_chunk; - break; + + auto is_first = i == 1; + auto instruction_size = entry.size() + (is_first ? 0 : 2); // Jump; -> +2 + size_to_jump += instruction_size; + + if (!is_first) + target[patch_start + (i - 2) * 2 + 1] = size_to_jump + (alternatives.size() - i) * 2; + + dbgln_if(REGEX_DEBUG, "{} size = {}, cum={}", i - 1, instruction_size, size_to_jump); } - size_to_jump -= chunk.size() + (previous_chunk ? 2 : 0); + seen_one_empty = false; + for (size_t i = alternatives.size(); i > 0; --i) { + auto& chunk = alternatives[i - 1]; + if (chunk.is_empty()) { + if (seen_one_empty) + continue; + seen_one_empty = true; + } - target.extend(move(chunk)); - target.empend(static_cast(OpCodeId::Jump)); - target.empend(size_to_jump); // Jump to the _END label + ByteCode* previous_chunk = nullptr; + size_t j = i - 1; + auto seen_one_empty_before = chunk.is_empty(); + while (j >= 1) { + --j; + auto& candidate_chunk = alternatives[j]; + if (candidate_chunk.is_empty()) { + if (seen_one_empty_before) + continue; + } + previous_chunk = &candidate_chunk; + break; + } + + size_to_jump -= chunk.size() + (previous_chunk ? 2 : 0); + + target.extend(move(chunk)); + target.empend(static_cast(OpCodeId::Jump)); + target.empend(size_to_jump); // Jump to the _END label + } + } else { + target.ensure_capacity(total_bytecode_entries_in_tree + common_hits * 6); + + auto node_is = [](Tree const* node, QualifiedIP ip) { + if (!node->has_metadata()) + return false; + for (auto& node_ip : node->metadata_value()) { + if (node_ip.alternative_index == ip.alternative_index && node_ip.instruction_position == ip.instruction_position) + return true; + } + return false; + }; + + struct Patch { + QualifiedIP source_ip; + size_t target_ip; + bool done { false }; + }; + Vector patch_locations; + patch_locations.ensure_capacity(total_nodes); + + auto add_patch_point = [&](Tree const* node, size_t target_ip) { + if (!node->has_metadata()) + return; + auto& node_ip = node->metadata_value().first(); + patch_locations.append({ node_ip, target_ip }); + }; + + Queue nodes_to_visit; + nodes_to_visit.enqueue(&trie); + + // each node: + // node.re + // forkjump child1 + // forkjump child2 + // ... + while (!nodes_to_visit.is_empty()) { + auto const* node = nodes_to_visit.dequeue(); + for (auto& patch : patch_locations) { + if (!patch.done && node_is(node, patch.source_ip)) { + auto value = static_cast(target.size() - patch.target_ip - 1); + target[patch.target_ip] = value; + patch.done = true; + } + } + + if (!node->value().individual_spans().is_empty()) { + auto insn_bytes = node->value().individual_spans().first(); + + target.ensure_capacity(target.size() + insn_bytes.size()); + state.instruction_position = target.size(); + target.append(insn_bytes); + + auto& opcode = target.get_opcode(state); + + ssize_t jump_offset; + auto is_jump = true; + auto patch_location = state.instruction_position + 1; + + switch (opcode.opcode_id()) { + case OpCodeId::Jump: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::JumpNonEmpty: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkJump: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkStay: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkReplaceJump: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkReplaceStay: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::Repeat: + jump_offset = static_cast(0) - static_cast(static_cast(opcode).offset()); + break; + default: + is_jump = false; + break; + } + + if (is_jump) { + VERIFY(node->has_metadata()); + auto& ip = node->metadata_value().first(); + patch_locations.append({ QualifiedIP { ip.alternative_index, ip.instruction_position + jump_offset + opcode.size() }, patch_location }); + } + } + + for (auto const& child : node->children()) { + auto* child_node = static_cast(child.value.ptr()); + target.append(static_cast(OpCodeId::ForkJump)); + add_patch_point(child_node, target.size()); + target.append(static_cast(0)); + nodes_to_visit.enqueue(child_node); + } + } + + for (auto& patch : patch_locations) { + if (patch.done) + continue; + + auto& alternative = alternatives[patch.source_ip.alternative_index]; + if (patch.source_ip.instruction_position >= alternative.size()) { + // This just wants to jump to the end of the alternative, which is fine. + // Patch it to jump to the end of the target instead. + target[patch.target_ip] = static_cast(target.size() - patch.target_ip - 1); + continue; + } + + dbgln("Regex Tree / Unpatched jump: {}@{} -> {}@{}", + patch.source_ip.instruction_position, + patch.source_ip.alternative_index, + patch.target_ip, + target[patch.target_ip]); + VERIFY_NOT_REACHED(); + } } }