diff --git a/src/regex_impl.cc b/src/regex_impl.cc index 2cb945ff7..26b0a7382 100644 --- a/src/regex_impl.cc +++ b/src/regex_impl.cc @@ -29,6 +29,10 @@ struct ParsedRegex SubjectBegin, SubjectEnd, ResetStart, + LookAhead, + LookBehind, + NegativeLookAhead, + NegativeLookBehind, }; struct Quantifier @@ -116,9 +120,9 @@ private: return res; } - AstNodePtr alternative() + AstNodePtr alternative(ParsedRegex::Op op = ParsedRegex::Sequence) { - AstNodePtr res = new_node(ParsedRegex::Sequence); + AstNodePtr res = new_node(op); while (auto node = term()) res->children.push_back(std::move(node)); if (res->children.empty()) @@ -175,8 +179,44 @@ private: case '.': ++m_pos; return new_node(ParsedRegex::AnyChar); case '(': { - ++m_pos; - auto content = disjunction(m_parsed_regex.capture_count++); + auto advance = [&]() { + if (++m_pos == m_regex.end()) + parse_error("unclosed parenthesis"); + return *m_pos; + }; + + AstNodePtr content; + if (advance() == '?') + { + auto c = advance(); + if (c == ':') + content = disjunction(-1); + else if (contains("=!<", c)) + { + bool behind = false; + if (c == '<') + { + advance(); + behind = true; + } + + auto type = *m_pos++; + if (type == '=') + content = alternative(behind ? ParsedRegex::LookBehind + : ParsedRegex::LookAhead); + else if (type == '!') + content = alternative(behind ? ParsedRegex::NegativeLookBehind + : ParsedRegex::NegativeLookAhead); + else + parse_error("invalid disjunction"); + + validate_lookaround(content); + } + else + parse_error("invalid disjunction"); + } + else + content = disjunction(m_parsed_regex.capture_count++); if (at_end() or *m_pos != ')') parse_error("unclosed parenthesis"); @@ -245,7 +285,7 @@ private: if (contains("^$\\.*+?()[]{}|", cp)) // SyntaxCharacter return new_node(ParsedRegex::Literal, cp); - parse_error("unknown atom escape"); + parse_error(format("unknown atom escape '{}'", cp)); } AstNodePtr character_class() @@ -395,6 +435,13 @@ private: StringView{m_pos.base(), m_regex.end()})); } + void validate_lookaround(const AstNodePtr& node) + { + for (auto& child : node->children) + if (child->op != ParsedRegex::Literal) + parse_error("Lookaround can only contain literals"); + } + ParsedRegex m_parsed_regex; StringView m_regex; Iterator m_pos; @@ -406,6 +453,11 @@ private: bool neg; }; + StringView peek(ByteCount count) + { + return StringView{m_pos.base(), m_regex.end()}.substr(0, count); + } + static const CharacterClassEscape character_class_escapes[8]; }; @@ -439,6 +491,10 @@ struct CompiledRegex NotWordBoundary, SubjectBegin, SubjectEnd, + LookAhead, + LookBehind, + NegativeLookAhead, + NegativeLookBehind, }; using Offset = unsigned; @@ -516,6 +572,22 @@ private: break; } + case ParsedRegex::LookAhead: + push_op(CompiledRegex::LookAhead); + push_string(node->children); + break; + case ParsedRegex::LookBehind: + push_op(CompiledRegex::LookBehind); + push_string(node->children, true); + break; + case ParsedRegex::NegativeLookAhead: + push_op(CompiledRegex::NegativeLookAhead); + push_string(node->children); + break; + case ParsedRegex::NegativeLookBehind: + push_op(CompiledRegex::NegativeLookBehind); + push_string(node->children, true); + break; case ParsedRegex::LineStart: push_op(CompiledRegex::LineStart); break; @@ -631,6 +703,20 @@ private: utf8::dump(std::back_inserter(m_program.bytecode), cp); } + void push_string(const Vector& codepoints, bool reversed = false) + { + if (codepoints.size() > 127) + throw runtime_error{"Too long literal string"}; + + push_byte(codepoints.size()); + if (reversed) + for (auto& cp : codepoints | reverse()) + push_codepoint(cp->value); + else + for (auto& cp : codepoints) + push_codepoint(cp->value); + } + CompiledRegex m_program; const ParsedRegex& m_parsed_regex; }; @@ -687,6 +773,27 @@ void dump_regex(const CompiledRegex& program) case CompiledRegex::SubjectEnd: printf("subject end\n"); break; + case CompiledRegex::LookAhead: + case CompiledRegex::NegativeLookAhead: + case CompiledRegex::LookBehind: + case CompiledRegex::NegativeLookBehind: + { + int count = *pos++; + StringView str{pos, pos + count}; + const char* name = nullptr; + if (op == CompiledRegex::LookAhead) + name = "look ahead"; + if (op == CompiledRegex::NegativeLookAhead) + name = "negative look ahead"; + if (op == CompiledRegex::LookBehind) + name = "look behind"; + if (op == CompiledRegex::NegativeLookBehind) + name = "negative look behind"; + + printf("%s (%s)\n", name, (const char*)str.zstr()); + pos += count; + break; + } case CompiledRegex::Match: printf("match\n"); } @@ -783,6 +890,32 @@ struct ThreadedRegexVM if (m_pos != m_end) return StepResult::Failed; break; + case CompiledRegex::LookAhead: + case CompiledRegex::NegativeLookAhead: + { + int count = *thread.inst++; + for (auto it = m_pos; count and it != m_end; ++it, --count) + if (*it != utf8::read(thread.inst)) + break; + if ((op == CompiledRegex::LookAhead and count != 0) or + (op == CompiledRegex::NegativeLookAhead and count == 0)) + return StepResult::Failed; + thread.inst = utf8::advance(thread.inst, prog_end, CharCount{count - 1}); + break; + } + case CompiledRegex::LookBehind: + case CompiledRegex::NegativeLookBehind: + { + int count = *thread.inst++; + for (auto it = m_pos-1; count and it >= m_begin; --it, --count) + if (*it != utf8::read(thread.inst)) + break; + if ((op == CompiledRegex::LookBehind and count != 0) or + (op == CompiledRegex::NegativeLookBehind and count == 0)) + return StepResult::Failed; + thread.inst = utf8::advance(thread.inst, prog_end, CharCount{count - 1}); + break; + } case CompiledRegex::Match: thread.inst = nullptr; return StepResult::Matched; @@ -823,7 +956,7 @@ struct ThreadedRegexVM m_threads.erase(std::remove_if(m_threads.begin(), m_threads.end(), [](const Thread& t) { return t.inst == nullptr; }), m_threads.end()); if (m_threads.empty()) - return false; + return found_match; } // Step remaining threads to see if they match without consuming anything else @@ -1034,6 +1167,30 @@ auto test_regex = UnitTest{[]{ kak_assert(vm.exec("foooo", true, true)); kak_assert(StringView{vm.m_captures[2], vm.m_captures[3]} == "fo"); } + + { + TestVM vm{R"((?=foo).)"}; + kak_assert(vm.exec("barfoo", false, true)); + kak_assert(StringView{vm.m_captures[0], vm.m_captures[1]} == "f"); + } + + { + TestVM vm{R"((?!foo)...)"}; + kak_assert(not vm.exec("foo")); + kak_assert(vm.exec("qux")); + } + + { + TestVM vm{R"(...(?<=foo))"}; + kak_assert(vm.exec("foo")); + kak_assert(not vm.exec("qux")); + } + + { + TestVM vm{R"(...(?