add address counts to tokens

This commit is contained in:
Sarah Hoffmann 2024-03-15 10:54:13 +01:00
parent bb5de9b955
commit 07b7fd1dbb
10 changed files with 32 additions and 21 deletions

View File

@ -97,6 +97,7 @@ class ICUToken(qmod.Token):
""" Create a ICUToken from the row of the word table.
"""
count = 1 if row.info is None else row.info.get('count', 1)
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
penalty = 0.0
if row.type == 'w':
@ -123,7 +124,8 @@ class ICUToken(qmod.Token):
return ICUToken(penalty=penalty, token=row.word_id, count=count,
lookup_word=lookup_word, is_indexed=True,
word_token=row.word_token, info=row.info)
word_token=row.word_token, info=row.info,
addr_count=addr_count)
@ -257,7 +259,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if len(part.token) <= 4 and part[0].isdigit()\
and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
ICUToken(0.5, 0, 1, part.token, True, part.token, None))
ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None))
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:

View File

@ -210,6 +210,7 @@ class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
return LegacyToken(penalty=penalty, token=row.word_id,
count=row.search_name_count or 1,
addr_count=1, # not supported
lookup_word=lookup_word,
word_token=row.word_token.strip(),
category=(rowclass, row.type) if rowclass is not None else None,
@ -226,7 +227,7 @@ class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
if len(part) <= 4 and part.isdigit()\
and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
LegacyToken(penalty=0.5, token=0, count=1,
LegacyToken(penalty=0.5, token=0, count=1, addr_count=1,
lookup_word=part, word_token=part,
category=None, country=None,
operator=None, is_indexed=True))

View File

@ -99,10 +99,10 @@ class Token(ABC):
penalty: float
token: int
count: int
addr_count: int
lookup_word: str
is_indexed: bool
addr_count: int = 1
@abstractmethod
def get_category(self) -> Tuple[str, str]:

View File

@ -201,7 +201,7 @@ class AbstractTokenizer(ABC):
@abstractmethod
def update_statistics(self, config: Configuration) -> None:
def update_statistics(self, config: Configuration, threads: int = 1) -> None:
""" Recompute any tokenizer statistics necessary for efficient lookup.
This function is meant to be called from time to time by the user
to improve performance. However, the tokenizer must not depend on

View File

@ -210,7 +210,7 @@ class LegacyTokenizer(AbstractTokenizer):
self._save_config(conn, config)
def update_statistics(self, _: Configuration) -> None:
def update_statistics(self, config: Configuration, threads: int = 1) -> None:
""" Recompute the frequency of full words.
"""
with connect(self.dsn) as conn:

View File

@ -18,7 +18,8 @@ class MyToken(query.Token):
def mktoken(tid: int):
return MyToken(3.0, tid, 1, 'foo', True)
return MyToken(penalty=3.0, token=tid, count=1, addr_count=1,
lookup_word='foo', is_indexed=True)
@pytest.mark.parametrize('ptype,ttype', [('NONE', 'WORD'),

View File

@ -31,7 +31,9 @@ def make_query(*args):
for end, ttype, tinfo in tlist:
for tid, word in tinfo:
q.add_token(TokenRange(start, end), ttype,
MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True))
MyToken(penalty=0.5 if ttype == TokenType.PARTIAL else 0.0,
token=tid, count=1, addr_count=1,
lookup_word=word, is_indexed=True))
return q
@ -395,14 +397,14 @@ def make_counted_searches(name_part, name_full, address_part, address_full,
q.add_node(BreakType.END, PhraseType.NONE)
q.add_token(TokenRange(0, 1), TokenType.PARTIAL,
MyToken(0.5, 1, name_part, 'name_part', True))
MyToken(0.5, 1, name_part, 1, 'name_part', True))
q.add_token(TokenRange(0, 1), TokenType.WORD,
MyToken(0, 101, name_full, 'name_full', True))
MyToken(0, 101, name_full, 1, 'name_full', True))
for i in range(num_address_parts):
q.add_token(TokenRange(i + 1, i + 2), TokenType.PARTIAL,
MyToken(0.5, 2, address_part, 'address_part', True))
MyToken(0.5, 2, address_part, 1, 'address_part', True))
q.add_token(TokenRange(i + 1, i + 2), TokenType.WORD,
MyToken(0, 102, address_full, 'address_full', True))
MyToken(0, 102, address_full, 1, 'address_full', True))
builder = SearchBuilder(q, SearchDetails())

View File

@ -19,7 +19,8 @@ class MyToken(Token):
def make_query(*args):
q = QueryStruct([Phrase(args[0][1], '')])
dummy = MyToken(3.0, 45, 1, 'foo', True)
dummy = MyToken(penalty=3.0, token=45, count=1, addr_count=1,
lookup_word='foo', is_indexed=True)
for btype, ptype, _ in args[1:]:
q.add_node(btype, ptype)

View File

@ -32,16 +32,16 @@ class DummyTokenizer:
self.update_statistics_called = False
self.update_word_tokens_called = False
def update_sql_functions(self, *args):
def update_sql_functions(self, *args, **kwargs):
self.update_sql_functions_called = True
def finalize_import(self, *args):
def finalize_import(self, *args, **kwargs):
self.finalize_import_called = True
def update_statistics(self, *args):
def update_statistics(self, *args, **kwargs):
self.update_statistics_called = True
def update_word_tokens(self, *args):
def update_word_tokens(self, *args, **kwargs):
self.update_word_tokens_called = True

View File

@ -227,16 +227,20 @@ def test_update_statistics_reverse_only(word_table, tokenizer_factory, test_conf
def test_update_statistics(word_table, table_factory, temp_db_cursor,
tokenizer_factory, test_config):
word_table.add_full_word(1000, 'hello')
word_table.add_full_word(1001, 'bye')
table_factory('search_name',
'place_id BIGINT, name_vector INT[]',
[(12, [1000])])
'place_id BIGINT, name_vector INT[], nameaddress_vector INT[]',
[(12, [1000], [1001])])
tok = tokenizer_factory()
tok.update_statistics(test_config)
assert temp_db_cursor.scalar("""SELECT count(*) FROM word
WHERE type = 'W' and
(info->>'count')::int > 0""") > 0
WHERE type = 'W' and word_id = 1000 and
(info->>'count')::int > 0""") == 1
assert temp_db_cursor.scalar("""SELECT count(*) FROM word
WHERE type = 'W' and word_id = 1001 and
(info->>'addr_count')::int > 0""") == 1
def test_normalize_postcode(analyzer):