LibRegex: Generate a search tree when patterns would benefit from it

This takes the previous alternation optimisation and applies it to all
the alternation blocks instead of just the few instructions at the
start.
By generating a trie of instructions, all logically equivalent
instructions will be consolidated into a single node, allowing the
engine to avoid checking the same thing multiple times.
For instance, given the pattern /abc|ac|ab/, this optimisation would
generate the following tree:
    - a
    | - b
    | | - c
    | | | - <accept>
    | | - <accept>
    | - c
    | | - <accept>
which will attempt to match 'a' or 'b' only once, and would also limit
the number of backtrackings performed in case alternatives fails to
match.

This optimisation is currently gated behind a simple cost model that
estimates the number of instructions generated, which is pessimistic for
small patterns, though the change in performance in such patterns is not
particularly large.
This commit is contained in:
Ali Mohammad Pur 2023-07-28 21:02:34 +03:30 committed by Andreas Kling
parent 18f4b6c670
commit 4e69eb89e8
Notes: sideshowbarker 2024-07-17 02:05:41 +09:00
4 changed files with 347 additions and 152 deletions

View File

@ -94,6 +94,14 @@ public:
DisjointSpans& operator=(DisjointSpans&&) = default;
DisjointSpans& operator=(DisjointSpans const&) = default;
Span<T> singular_span() const
{
VERIFY(m_spans.size() == 1);
return m_spans[0];
}
SpanContainer const& individual_spans() const { return m_spans; }
bool operator==(DisjointSpans const& other) const
{
if (other.size() != size())
@ -440,8 +448,22 @@ private:
Vector<ChunkType, InlineCapacity> m_chunks;
};
}
template<typename T>
struct Traits<DisjointSpans<T>> : public GenericTraits<DisjointSpans<T>> {
static unsigned hash(DisjointSpans<T> const& span)
{
unsigned hash = 0;
for (auto const& value : span) {
auto value_hash = Traits<T>::hash(value);
hash = pair_int_hash(hash, value_hash);
}
return hash;
}
constexpr static bool is_trivial() { return false; }
};
}
#if USING_AK_GLOBALLY
using AK::DisjointChunks;
using AK::DisjointSpans;

View File

@ -1047,6 +1047,8 @@ TEST_CASE(optimizer_alternation)
Array tests {
// Pattern, Subject, Expected length
Tuple { "a|"sv, "a"sv, 1u },
Tuple { "a|a|a|a|a|a|a|a|a|b"sv, "a"sv, 1u },
Tuple { "ab|ac|ad|bc"sv, "bc"sv, 2u },
};
for (auto& test : tests) {

View File

@ -185,6 +185,23 @@ public:
Base::first_chunk().prepend(forward<T>(value));
}
void append(Span<ByteCodeValueType const> value)
{
if (is_empty())
Base::append({});
auto& last = Base::last_chunk();
last.ensure_capacity(value.size());
for (auto v : value)
last.unchecked_append(v);
}
void ensure_capacity(size_t capacity)
{
if (is_empty())
Base::append({});
Base::last_chunk().ensure_capacity(capacity);
}
void last_chunk() const = delete;
void first_chunk() const = delete;
@ -210,20 +227,11 @@ public:
void insert_bytecode_compare_string(StringView view)
{
ByteCode bytecode;
bytecode.empend(static_cast<ByteCodeValueType>(OpCodeId::Compare));
bytecode.empend(static_cast<u64>(1)); // number of arguments
ByteCode arguments;
arguments.empend(static_cast<ByteCodeValueType>(CharacterCompareType::String));
arguments.insert_string(view);
bytecode.empend(arguments.size()); // size of arguments
bytecode.extend(move(arguments));
extend(move(bytecode));
empend(static_cast<ByteCodeValueType>(OpCodeId::Compare));
empend(static_cast<u64>(1)); // number of arguments
empend(2 + view.length()); // size of arguments
empend(static_cast<ByteCodeValueType>(CharacterCompareType::String));
insert_string(view);
}
void insert_bytecode_group_capture_left(size_t capture_groups_count)

View File

@ -5,9 +5,12 @@
*/
#include <AK/Debug.h>
#include <AK/Function.h>
#include <AK/Queue.h>
#include <AK/QuickSort.h>
#include <AK/RedBlackTree.h>
#include <AK/Stack.h>
#include <AK/Trie.h>
#include <LibRegex/Regex.h>
#include <LibRegex/RegexBytecodeStreamOptimizer.h>
#include <LibUnicode/CharacterTypes.h>
@ -815,6 +818,9 @@ void Optimizer::append_alternation(ByteCode& target, ByteCode&& left, ByteCode&&
append_alternation(target, alternatives);
}
template<typename K, typename V, typename KTraits>
using OrderedHashMapForTrie = OrderedHashMap<K, V, KTraits>;
void Optimizer::append_alternation(ByteCode& target, Span<ByteCode> alternatives)
{
if (alternatives.size() == 0)
@ -846,154 +852,311 @@ void Optimizer::append_alternation(ByteCode& target, Span<ByteCode> alternatives
};
#endif
Vector<Vector<Detail::Block>> basic_blocks;
basic_blocks.ensure_capacity(alternatives.size());
// First, find incoming jump edges.
// We need them for two reasons:
// - We need to distinguish between insn-A-jumped-to-by-insn-B and insn-A-jumped-to-by-insn-C (as otherwise we'd break trie invariants)
// - We need to know which jumps to patch when we're done
for (auto& entry : alternatives)
basic_blocks.append(Regex<PosixBasicParser>::split_basic_blocks(entry));
Optional<size_t> left_skip;
size_t shared_block_count = basic_blocks.first().size();
for (auto& entry : basic_blocks)
shared_block_count = min(shared_block_count, entry.size());
struct JumpEdge {
Span<ByteCodeValueType const> jump_insn;
};
Vector<HashMap<size_t, Vector<JumpEdge>>> incoming_jump_edges_for_each_alternative;
incoming_jump_edges_for_each_alternative.resize(alternatives.size());
MatchState state;
for (size_t block_index = 0; block_index < shared_block_count; block_index++) {
auto& left_block = basic_blocks.first()[block_index];
auto left_end = block_index + 1 == basic_blocks.first().size() ? left_block.end : basic_blocks.first()[block_index + 1].start;
auto can_continue = true;
for (size_t i = 0; i < alternatives.size(); ++i) {
auto& alternative = alternatives[i];
// Add a jump to the "end" of the block; this is implicit in the bytecode, but we need it to be explicit in the trie.
// Jump{offset=0}
alternative.append(static_cast<ByteCodeValueType>(OpCodeId::Jump));
alternative.append(0);
auto& incoming_jump_edges = incoming_jump_edges_for_each_alternative[i];
auto alternative_bytes = alternative.spans<1>().singular_span();
for (state.instruction_position = 0; state.instruction_position < alternative.size();) {
auto& opcode = alternative.get_opcode(state);
auto opcode_bytes = alternative_bytes.slice(state.instruction_position, opcode.size());
switch (opcode.opcode_id()) {
case OpCodeId::Jump:
incoming_jump_edges.ensure(static_cast<OpCode_Jump const&>(opcode).offset() + state.instruction_position).append({ opcode_bytes });
break;
case OpCodeId::JumpNonEmpty:
incoming_jump_edges.ensure(static_cast<OpCode_JumpNonEmpty const&>(opcode).offset() + state.instruction_position).append({ opcode_bytes });
break;
case OpCodeId::ForkJump:
incoming_jump_edges.ensure(static_cast<OpCode_ForkJump const&>(opcode).offset() + state.instruction_position).append({ opcode_bytes });
break;
case OpCodeId::ForkStay:
incoming_jump_edges.ensure(static_cast<OpCode_ForkStay const&>(opcode).offset() + state.instruction_position).append({ opcode_bytes });
break;
case OpCodeId::ForkReplaceJump:
incoming_jump_edges.ensure(static_cast<OpCode_ForkReplaceJump const&>(opcode).offset() + state.instruction_position).append({ opcode_bytes });
break;
case OpCodeId::ForkReplaceStay:
incoming_jump_edges.ensure(static_cast<OpCode_ForkReplaceStay const&>(opcode).offset() + state.instruction_position).append({ opcode_bytes });
break;
case OpCodeId::Repeat:
incoming_jump_edges.ensure(static_cast<OpCode_Repeat const&>(opcode).offset() + state.instruction_position).append({ opcode_bytes });
break;
default:
break;
}
state.instruction_position += opcode.size();
}
}
struct QualifiedIP {
size_t alternative_index;
size_t instruction_position;
};
using Tree = Trie<DisjointSpans<ByteCodeValueType const>, Vector<QualifiedIP>, Traits<DisjointSpans<ByteCodeValueType const>>, void, OrderedHashMapForTrie>;
Tree trie { {} }; // Root node is empty, key{ instruction_bytes, dependent_instruction_bytes... } -> IP
size_t common_hits = 0;
size_t total_nodes = 0;
size_t total_bytecode_entries_in_tree = 0;
for (size_t i = 0; i < alternatives.size(); ++i) {
auto& alternative = alternatives[i];
auto& incoming_jump_edges = incoming_jump_edges_for_each_alternative[i];
auto* active_node = &trie;
auto alternative_span = alternative.spans<1>().singular_span();
for (state.instruction_position = 0; state.instruction_position < alternative_span.size();) {
total_nodes += 1;
auto& opcode = alternative.get_opcode(state);
auto opcode_bytes = alternative_span.slice(state.instruction_position, opcode.size());
Vector<Span<ByteCodeValueType const>> node_key_bytes;
node_key_bytes.append(opcode_bytes);
if (auto edges = incoming_jump_edges.get(state.instruction_position); edges.has_value()) {
for (auto& edge : *edges)
node_key_bytes.append(edge.jump_insn);
}
active_node = static_cast<decltype(active_node)>(MUST(active_node->ensure_child(DisjointSpans<ByteCodeValueType const> { move(node_key_bytes) })));
if (active_node->has_metadata()) {
active_node->metadata_value().append({ i, state.instruction_position });
common_hits += 1;
} else {
active_node->set_metadata(Vector<QualifiedIP> { QualifiedIP { i, state.instruction_position } });
total_bytecode_entries_in_tree += opcode.size();
}
state.instruction_position += opcode.size();
}
}
if constexpr (REGEX_DEBUG) {
Function<void(decltype(trie)&, size_t)> print_tree = [&](decltype(trie)& node, size_t indent = 0) mutable {
DeprecatedString name = "(no ip)";
DeprecatedString insn;
if (node.has_metadata()) {
name = DeprecatedString::formatted(
"{}@{} ({} node{})",
node.metadata_value().first().instruction_position,
node.metadata_value().first().alternative_index,
node.metadata_value().size(),
node.metadata_value().size() == 1 ? "" : "s");
MatchState state;
state.instruction_position = node.metadata_value().first().instruction_position;
auto& opcode = alternatives[node.metadata_value().first().alternative_index].get_opcode(state);
insn = DeprecatedString::formatted("{} {}", opcode.to_deprecated_string(), opcode.arguments_string());
}
dbgln("{:->{}}| {} -- {}", "", indent * 2, name, insn);
for (auto& child : node.children())
print_tree(static_cast<decltype(trie)&>(*child.value), indent + 1);
};
print_tree(trie, 0);
}
// This is really only worth it if we don't blow up the size by the 2-extra-instruction-per-node scheme, similarly, if no nodes are shared, we're better off not using a tree.
auto tree_cost = (total_nodes - common_hits) * 2;
auto chain_cost = total_nodes + alternatives.size() * 2;
dbgln_if(REGEX_DEBUG, "Total nodes: {}, common hits: {} (tree cost = {}, chain cost = {})", total_nodes, common_hits, tree_cost, chain_cost);
if (common_hits == 0 || tree_cost > chain_cost) {
// It's better to lay these out as a normal sequence of instructions.
auto patch_start = target.size();
for (size_t i = 1; i < alternatives.size(); ++i) {
auto& right_blocks = basic_blocks[i];
auto& right_block = right_blocks[block_index];
auto right_end = block_index + 1 == right_blocks.size() ? right_block.end : right_blocks[block_index + 1].start;
if (left_end - left_block.start != right_end - right_block.start) {
can_continue = false;
break;
}
if (alternatives[0].spans().slice(left_block.start, left_end - left_block.start) != alternatives[i].spans().slice(right_block.start, right_end - right_block.start)) {
can_continue = false;
break;
}
}
if (!can_continue)
break;
size_t i = 0;
for (auto& entry : alternatives) {
auto& blocks = basic_blocks[i++];
auto& block = blocks[block_index];
auto end = block_index + 1 == blocks.size() ? block.end : blocks[block_index + 1].start;
state.instruction_position = block.start;
size_t skip = 0;
while (state.instruction_position < end) {
auto& opcode = entry.get_opcode(state);
state.instruction_position += opcode.size();
skip = state.instruction_position;
}
if (left_skip.has_value())
left_skip = min(skip, *left_skip);
else
left_skip = skip;
}
}
// Remove forward jumps as they no longer make sense.
state.instruction_position = 0;
for (size_t i = 0; i < left_skip.value_or(0);) {
auto& opcode = alternatives[0].get_opcode(state);
switch (opcode.opcode_id()) {
case OpCodeId::Jump:
case OpCodeId::ForkJump:
case OpCodeId::JumpNonEmpty:
case OpCodeId::ForkStay:
case OpCodeId::ForkReplaceJump:
case OpCodeId::ForkReplaceStay:
if (opcode.argument(0) + opcode.size() > left_skip.value_or(0)) {
left_skip = i;
goto break_out;
}
break;
default:
break;
}
i += opcode.size();
}
break_out:;
dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}", left_skip, 0, alternatives[0].size());
if (left_skip.has_value() && *left_skip > 0) {
target.extend(alternatives[0].release_slice(basic_blocks.first().first().start, *left_skip));
auto first = true;
for (auto& entry : alternatives) {
if (first) {
first = false;
continue;
}
entry = entry.release_slice(*left_skip);
}
}
if (all_of(alternatives, [](auto& entry) { return entry.is_empty(); }))
return;
size_t patch_start = target.size();
for (size_t i = 1; i < alternatives.size(); ++i) {
target.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
target.empend(0u); // To be filled later.
}
size_t size_to_jump = 0;
bool seen_one_empty = false;
for (size_t i = alternatives.size(); i > 0; --i) {
auto& entry = alternatives[i - 1];
if (entry.is_empty()) {
if (seen_one_empty)
continue;
seen_one_empty = true;
target.empend(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
target.empend(0u); // To be filled later.
}
auto is_first = i == 1;
auto instruction_size = entry.size() + (is_first ? 0 : 2); // Jump; -> +2
size_to_jump += instruction_size;
if (!is_first)
target[patch_start + (i - 2) * 2 + 1] = size_to_jump + (alternatives.size() - i) * 2;
dbgln_if(REGEX_DEBUG, "{} size = {}, cum={}", i - 1, instruction_size, size_to_jump);
}
seen_one_empty = false;
for (size_t i = alternatives.size(); i > 0; --i) {
auto& chunk = alternatives[i - 1];
if (chunk.is_empty()) {
if (seen_one_empty)
continue;
seen_one_empty = true;
}
ByteCode* previous_chunk = nullptr;
size_t j = i - 1;
auto seen_one_empty_before = chunk.is_empty();
while (j >= 1) {
--j;
auto& candidate_chunk = alternatives[j];
if (candidate_chunk.is_empty()) {
if (seen_one_empty_before)
size_t size_to_jump = 0;
bool seen_one_empty = false;
for (size_t i = alternatives.size(); i > 0; --i) {
auto& entry = alternatives[i - 1];
if (entry.is_empty()) {
if (seen_one_empty)
continue;
seen_one_empty = true;
}
previous_chunk = &candidate_chunk;
break;
auto is_first = i == 1;
auto instruction_size = entry.size() + (is_first ? 0 : 2); // Jump; -> +2
size_to_jump += instruction_size;
if (!is_first)
target[patch_start + (i - 2) * 2 + 1] = size_to_jump + (alternatives.size() - i) * 2;
dbgln_if(REGEX_DEBUG, "{} size = {}, cum={}", i - 1, instruction_size, size_to_jump);
}
size_to_jump -= chunk.size() + (previous_chunk ? 2 : 0);
seen_one_empty = false;
for (size_t i = alternatives.size(); i > 0; --i) {
auto& chunk = alternatives[i - 1];
if (chunk.is_empty()) {
if (seen_one_empty)
continue;
seen_one_empty = true;
}
target.extend(move(chunk));
target.empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
target.empend(size_to_jump); // Jump to the _END label
ByteCode* previous_chunk = nullptr;
size_t j = i - 1;
auto seen_one_empty_before = chunk.is_empty();
while (j >= 1) {
--j;
auto& candidate_chunk = alternatives[j];
if (candidate_chunk.is_empty()) {
if (seen_one_empty_before)
continue;
}
previous_chunk = &candidate_chunk;
break;
}
size_to_jump -= chunk.size() + (previous_chunk ? 2 : 0);
target.extend(move(chunk));
target.empend(static_cast<ByteCodeValueType>(OpCodeId::Jump));
target.empend(size_to_jump); // Jump to the _END label
}
} else {
target.ensure_capacity(total_bytecode_entries_in_tree + common_hits * 6);
auto node_is = [](Tree const* node, QualifiedIP ip) {
if (!node->has_metadata())
return false;
for (auto& node_ip : node->metadata_value()) {
if (node_ip.alternative_index == ip.alternative_index && node_ip.instruction_position == ip.instruction_position)
return true;
}
return false;
};
struct Patch {
QualifiedIP source_ip;
size_t target_ip;
bool done { false };
};
Vector<Patch> patch_locations;
patch_locations.ensure_capacity(total_nodes);
auto add_patch_point = [&](Tree const* node, size_t target_ip) {
if (!node->has_metadata())
return;
auto& node_ip = node->metadata_value().first();
patch_locations.append({ node_ip, target_ip });
};
Queue<Tree*> nodes_to_visit;
nodes_to_visit.enqueue(&trie);
// each node:
// node.re
// forkjump child1
// forkjump child2
// ...
while (!nodes_to_visit.is_empty()) {
auto const* node = nodes_to_visit.dequeue();
for (auto& patch : patch_locations) {
if (!patch.done && node_is(node, patch.source_ip)) {
auto value = static_cast<ByteCodeValueType>(target.size() - patch.target_ip - 1);
target[patch.target_ip] = value;
patch.done = true;
}
}
if (!node->value().individual_spans().is_empty()) {
auto insn_bytes = node->value().individual_spans().first();
target.ensure_capacity(target.size() + insn_bytes.size());
state.instruction_position = target.size();
target.append(insn_bytes);
auto& opcode = target.get_opcode(state);
ssize_t jump_offset;
auto is_jump = true;
auto patch_location = state.instruction_position + 1;
switch (opcode.opcode_id()) {
case OpCodeId::Jump:
jump_offset = static_cast<OpCode_Jump const&>(opcode).offset();
break;
case OpCodeId::JumpNonEmpty:
jump_offset = static_cast<OpCode_JumpNonEmpty const&>(opcode).offset();
break;
case OpCodeId::ForkJump:
jump_offset = static_cast<OpCode_ForkJump const&>(opcode).offset();
break;
case OpCodeId::ForkStay:
jump_offset = static_cast<OpCode_ForkStay const&>(opcode).offset();
break;
case OpCodeId::ForkReplaceJump:
jump_offset = static_cast<OpCode_ForkReplaceJump const&>(opcode).offset();
break;
case OpCodeId::ForkReplaceStay:
jump_offset = static_cast<OpCode_ForkReplaceStay const&>(opcode).offset();
break;
case OpCodeId::Repeat:
jump_offset = static_cast<ssize_t>(0) - static_cast<ssize_t>(static_cast<OpCode_Repeat const&>(opcode).offset());
break;
default:
is_jump = false;
break;
}
if (is_jump) {
VERIFY(node->has_metadata());
auto& ip = node->metadata_value().first();
patch_locations.append({ QualifiedIP { ip.alternative_index, ip.instruction_position + jump_offset + opcode.size() }, patch_location });
}
}
for (auto const& child : node->children()) {
auto* child_node = static_cast<Tree*>(child.value.ptr());
target.append(static_cast<ByteCodeValueType>(OpCodeId::ForkJump));
add_patch_point(child_node, target.size());
target.append(static_cast<ByteCodeValueType>(0));
nodes_to_visit.enqueue(child_node);
}
}
for (auto& patch : patch_locations) {
if (patch.done)
continue;
auto& alternative = alternatives[patch.source_ip.alternative_index];
if (patch.source_ip.instruction_position >= alternative.size()) {
// This just wants to jump to the end of the alternative, which is fine.
// Patch it to jump to the end of the target instead.
target[patch.target_ip] = static_cast<ByteCodeValueType>(target.size() - patch.target_ip - 1);
continue;
}
dbgln("Regex Tree / Unpatched jump: {}@{} -> {}@{}",
patch.source_ip.instruction_position,
patch.source_ip.alternative_index,
patch.target_ip,
target[patch.target_ip]);
VERIFY_NOT_REACHED();
}
}
}