Create a search query parser

This commit is contained in:
Kovid Goyal 2022-04-12 19:26:25 +05:30
parent afebea8635
commit 8f92c594f2
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 302 additions and 0 deletions

View File

@ -0,0 +1,275 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import re
from enum import Enum
from functools import lru_cache
from gettext import gettext as _
from typing import (
Callable, Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
TypeVar, Union
)
class ParseException(Exception):
@property
def msg(self) -> str:
if len(self.args) > 0:
return str(self.args[0])
return ""
class ExpressionType(Enum):
OR = 1
AND = 2
NOT = 3
TOKEN = 4
class TokenType(Enum):
OPCODE = 1
WORD = 2
QUOTED_WORD = 3
EOF = 4
T = TypeVar('T')
GetMatches = Callable[[str, str, Set[T]], Set[T]]
class SearchTreeNode:
type = ExpressionType.OR
def __init__(self, type: ExpressionType) -> None:
self.type = type
def search(self, universal_set: Set[T], get_matches: GetMatches[T]) -> Set[T]:
return self(universal_set, get_matches)
def __call__(self, candidates: Set[T], get_matches: GetMatches[T]) -> Set[T]:
...
def iter_token_nodes(self) -> Iterator['TokenNode']:
...
class OrNode(SearchTreeNode):
def __init__(self, lhs: SearchTreeNode, rhs: SearchTreeNode) -> None:
self.lhs = lhs
self.rhs = rhs
def __call__(self, candidates: Set[T], get_matches: GetMatches[T]) -> Set[T]:
lhs = self.lhs(candidates, get_matches)
return lhs.union(self.rhs(candidates.difference(lhs), get_matches))
def iter_token_nodes(self) -> Iterator['TokenNode']:
yield from self.lhs.iter_token_nodes()
yield from self.rhs.iter_token_nodes()
class AndNode(SearchTreeNode):
type = ExpressionType.AND
def __init__(self, lhs: SearchTreeNode, rhs: SearchTreeNode) -> None:
self.lhs = lhs
self.rhs = rhs
def __call__(self, candidates: Set[T], get_matches: GetMatches[T]) -> Set[T]:
lhs = self.lhs(candidates, get_matches)
return self.rhs(lhs, get_matches)
def iter_token_nodes(self) -> Iterator['TokenNode']:
yield from self.lhs.iter_token_nodes()
yield from self.rhs.iter_token_nodes()
class NotNode(SearchTreeNode):
type = ExpressionType.NOT
def __init__(self, rhs: SearchTreeNode) -> None:
self.rhs = rhs
def __call__(self, candidates: Set[T], get_matches: GetMatches[T]) -> Set[T]:
return candidates.difference(self.rhs(candidates, get_matches))
def iter_token_nodes(self) -> Iterator['TokenNode']:
yield from self.rhs.iter_token_nodes()
class TokenNode(SearchTreeNode):
type = ExpressionType.TOKEN
def __init__(self, location: str, query: str) -> None:
self.location = location
self.query = query
def __call__(self, candidates: Set[T], get_matches: GetMatches[T]) -> Set[T]:
return get_matches(self.location, self.query, candidates)
def iter_token_nodes(self) -> Iterator['TokenNode']:
yield self
class Token(NamedTuple):
type: TokenType
val: str
lex_scanner = getattr(re, 'Scanner')([
(r'[()]', lambda x, t: Token(TokenType.OPCODE, t)),
(r'@.+?:[^")\s]+', lambda x, t: Token(TokenType.WORD, str(t))),
(r'[^"()\s]+', lambda x, t: Token(TokenType.WORD, str(t))),
(r'".*?((?<!\\)")', lambda x, t: Token(TokenType.QUOTED_WORD, t[1:-1])),
(r'\s+', None)
], flags=re.DOTALL)
REPLACEMENTS = tuple(('\\' + x, chr(i + 1)) for i, x in enumerate('\\"()'))
class Parser:
def __init__(self, allow_no_location: bool = False) -> None:
self.current_token = 0
self.tokens: List[Token] = []
self.allow_no_location = allow_no_location
def token(self, advance: bool = False) -> Optional[str]:
if self.is_eof():
return None
res = self.tokens[self.current_token].val
if advance:
self.current_token += 1
return res
def lcase_token(self, advance: bool = False) -> Optional[str]:
if self.is_eof():
return None
res = self.tokens[self.current_token].val
if advance:
self.current_token += 1
return res.lower()
def token_type(self) -> TokenType:
if self.is_eof():
return TokenType.EOF
return self.tokens[self.current_token].type
def is_eof(self) -> bool:
return self.current_token >= len(self.tokens)
def advance(self) -> None:
self.current_token += 1
def tokenize(self, expr: str) -> List[Token]:
# Strip out escaped backslashes, quotes and parens so that the
# lex scanner doesn't get confused. We put them back later.
for k, v in REPLACEMENTS:
expr = expr.replace(k, v)
tokens = lex_scanner.scan(expr)[0]
def unescape(x: str) -> str:
for k, v in REPLACEMENTS:
x = x.replace(v, k[1:])
return x
return [
Token(tt, unescape(tv) if tt in (TokenType.WORD, TokenType.QUOTED_WORD) else tv)
for tt, tv in tokens
]
def parse(self, expr: str, locations: Sequence[str]) -> SearchTreeNode:
self.locations = locations
self.tokens = self.tokenize(expr)
self.current_token = 0
prog = self.or_expression()
if not self.is_eof():
raise ParseException(_('Extra characters at end of search'))
return prog
def or_expression(self) -> SearchTreeNode:
lhs = self.and_expression()
if self.lcase_token() == 'or':
self.advance()
return OrNode(lhs, self.or_expression())
return lhs
def and_expression(self) -> SearchTreeNode:
lhs = self.not_expression()
if self.lcase_token() == 'and':
self.advance()
return AndNode(lhs, self.and_expression())
# Account for the optional 'and'
if ((self.token_type() in (TokenType.WORD, TokenType.QUOTED_WORD) or self.token() == '(') and self.lcase_token() != 'or'):
return AndNode(lhs, self.and_expression())
return lhs
def not_expression(self) -> SearchTreeNode:
if self.lcase_token() == 'not':
self.advance()
return NotNode(self.not_expression())
return self.location_expression()
def location_expression(self) -> SearchTreeNode:
if self.token_type() == TokenType.OPCODE and self.token() == '(':
self.advance()
res = self.or_expression()
if self.token_type() != TokenType.OPCODE or self.token(advance=True) != ')':
raise ParseException(_('missing )'))
return res
if self.token_type() not in (TokenType.WORD, TokenType.QUOTED_WORD):
raise ParseException(_('Invalid syntax. Expected a lookup name or a word'))
return self.base_token()
def base_token(self) -> SearchTreeNode:
tt = self.token(advance=True)
assert tt is not None
if self.token_type() is TokenType.QUOTED_WORD:
if self.allow_no_location:
return TokenNode('all', tt)
raise ParseException(f'No location specified before {tt}')
words = tt.split(':')
# The complexity here comes from having colon-separated search
# values. That forces us to check that the first "word" in a colon-
# separated group is a valid location. If not, then the token must
# be reconstructed. We also have the problem that locations can be
# followed by quoted strings that appear as the next token. and that
# tokens can be a sequence of colons.
# We have a location if there is more than one word and the first
# word is in locations. This check could produce a "wrong" answer if
# the search string is something like 'author: "foo"' because it
# will be interpreted as 'author:"foo"'. I am choosing to accept the
# possible error. The expression should be written '"author:" foo'
if len(words) > 1 and words[0].lower() in self.locations:
loc = words[0].lower()
words = words[1:]
if len(words) == 1 and self.token_type() == TokenType.QUOTED_WORD:
tt = self.token(advance=True)
assert tt is not None
return TokenNode(loc, tt)
return TokenNode(loc.lower(), ':'.join(words))
if self.allow_no_location:
return TokenNode('all', ':'.join(words))
raise ParseException(f'No location specified before {tt}')
@lru_cache(maxsize=64)
def build_tree(query: str, locations: Union[str, Tuple[str, ...]], allow_no_location: bool = False) -> SearchTreeNode:
if isinstance(locations, str):
locations = tuple(locations.split())
p = Parser(allow_no_location)
try:
return p.parse(query, locations)
except RuntimeError as e:
raise ParseException(f'Failed to parse {query!r}, too much recursion required') from e
def search(
query: str, locations: Union[str, Tuple[str, ...]], universal_set: Set[T], get_matches: GetMatches[T],
allow_no_location: bool = False,
) -> Set[T]:
return build_tree(query, locations, allow_no_location).search(universal_set, get_matches)

View File

@ -0,0 +1,27 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
from . import BaseTest
class TestSQP(BaseTest):
def test_search_query_parser(self):
from kitty.search_query_parser import search
locations = 'id'
universal_set = {1, 2, 3, 4, 5}
def get_matches(location, query, candidates):
return {x for x in candidates if query == str(x)}
def t(q, expected):
actual = search(q, locations, universal_set, get_matches)
self.ae(actual, expected)
t('id:1', {1})
t('id:1 and id:1', {1})
t('id:1 or id:2', {1, 2})
t('id:1 and id:2', set())
t('not id:1', universal_set - {1})
t('(id:1 or id:2) and id:1', {1})