factor out async connection handling into separate class

Also adds a test for reconnecting regularly while indexing.
This commit is contained in:
Sarah Hoffmann 2021-04-20 11:16:12 +02:00
parent 26a81654a8
commit 50b6d7298c
2 changed files with 115 additions and 92 deletions

View File

@ -11,6 +11,68 @@ from nominatim.db.connection import connect
LOG = logging.getLogger()
class WorkerPool:
""" A pool of asynchronous database connections.
The pool may be used as a context manager.
"""
REOPEN_CONNECTIONS_AFTER = 100000
def __init__(self, dsn, pool_size):
self.threads = [DBConnection(dsn) for _ in range(pool_size)]
self.free_workers = self._yield_free_worker()
def finish_all(self):
""" Wait for all connection to finish.
"""
for thread in self.threads:
while not thread.is_done():
thread.wait()
self.free_workers = self._yield_free_worker()
def close(self):
""" Close all connections and clear the pool.
"""
for thread in self.threads:
thread.close()
self.threads = []
self.free_workers = None
def next_free_worker(self):
""" Get the next free connection.
"""
return next(self.free_workers)
def _yield_free_worker(self):
ready = self.threads
command_stat = 0
while True:
for thread in ready:
if thread.is_done():
command_stat += 1
yield thread
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
for thread in self.threads:
while not thread.is_done():
thread.wait()
thread.connect()
ready = self.threads
else:
_, ready, _ = select.select([], self.threads, [])
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
class Indexer:
""" Main indexing routine.
@ -19,17 +81,6 @@ class Indexer:
def __init__(self, dsn, num_threads):
self.dsn = dsn
self.num_threads = num_threads
self.threads = []
def _setup_connections(self):
self.threads = [DBConnection(self.dsn) for _ in range(self.num_threads)]
def _close_connections(self):
for thread in self.threads:
thread.close()
self.threads = []
def index_full(self, analyse=True):
@ -42,27 +93,27 @@ class Indexer:
conn.autocommit = True
if analyse:
def _analyse():
def _analyze():
with conn.cursor() as cur:
cur.execute('ANALYSE')
cur.execute('ANALYZE')
else:
def _analyse():
def _analyze():
pass
self.index_by_rank(0, 4)
_analyse()
_analyze()
self.index_boundaries(0, 30)
_analyse()
_analyze()
self.index_by_rank(5, 25)
_analyse()
_analyze()
self.index_by_rank(26, 30)
_analyse()
_analyze()
self.index_postcodes()
_analyse()
_analyze()
def index_boundaries(self, minrank, maxrank):
@ -71,13 +122,8 @@ class Indexer:
LOG.warning("Starting indexing boundaries using %s threads",
self.num_threads)
self._setup_connections()
try:
for rank in range(max(minrank, 4), min(maxrank, 26)):
self._index(runners.BoundaryRunner(rank))
finally:
self._close_connections()
for rank in range(max(minrank, 4), min(maxrank, 26)):
self._index(runners.BoundaryRunner(rank))
def index_by_rank(self, minrank, maxrank):
""" Index all entries of placex in the given rank range (inclusive)
@ -90,20 +136,15 @@ class Indexer:
LOG.warning("Starting indexing rank (%i to %i) using %i threads",
minrank, maxrank, self.num_threads)
self._setup_connections()
for rank in range(max(1, minrank), maxrank):
self._index(runners.RankRunner(rank))
try:
for rank in range(max(1, minrank), maxrank):
self._index(runners.RankRunner(rank))
if maxrank == 30:
self._index(runners.RankRunner(0))
self._index(runners.InterpolationRunner(), 20)
self._index(runners.RankRunner(30), 20)
else:
self._index(runners.RankRunner(maxrank))
finally:
self._close_connections()
if maxrank == 30:
self._index(runners.RankRunner(0))
self._index(runners.InterpolationRunner(), 20)
self._index(runners.RankRunner(30), 20)
else:
self._index(runners.RankRunner(maxrank))
def index_postcodes(self):
@ -111,12 +152,8 @@ class Indexer:
"""
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
self._setup_connections()
self._index(runners.PostcodeRunner(), 20)
try:
self._index(runners.PostcodeRunner(), 20)
finally:
self._close_connections()
def update_status_table(self):
""" Update the status in the status table to 'indexed'.
@ -147,48 +184,20 @@ class Indexer:
with conn.cursor(name='places') as cur:
cur.execute(runner.sql_get_objects())
next_thread = self.find_free_thread()
while True:
places = [p[0] for p in cur.fetchmany(batch)]
if not places:
break
with WorkerPool(self.dsn, self.num_threads) as pool:
while True:
places = [p[0] for p in cur.fetchmany(batch)]
if not places:
break
LOG.debug("Processing places: %s", str(places))
thread = next(next_thread)
LOG.debug("Processing places: %s", str(places))
worker = pool.next_free_worker()
thread.perform(runner.sql_index_place(places))
progress.add(len(places))
worker.perform(runner.sql_index_place(places))
progress.add(len(places))
conn.commit()
pool.finish_all()
for thread in self.threads:
thread.wait()
conn.commit()
progress.done()
def find_free_thread(self):
""" Generator that returns the next connection that is free for
sending a query.
"""
ready = self.threads
command_stat = 0
while True:
for thread in ready:
if thread.is_done():
command_stat += 1
yield thread
# refresh the connections occasionaly to avoid potential
# memory leaks in Postgresql.
if command_stat > 100000:
for thread in self.threads:
while not thread.is_done():
thread.wait()
thread.connect()
command_stat = 0
ready = self.threads
else:
ready, _, _ = select.select(self.threads, [], [])
assert False, "Unreachable code"

View File

@ -5,7 +5,7 @@ import itertools
import psycopg2
import pytest
from nominatim.indexer.indexer import Indexer
from nominatim.indexer import indexer
class IndexerTestDB:
@ -111,7 +111,7 @@ def test_index_all_by_rank(test_db, threads):
assert 31 == test_db.placex_unindexed()
assert 1 == test_db.osmline_unindexed()
idx = Indexer('dbname=test_nominatim_python_unittest', threads)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', threads)
idx.index_by_rank(0, 30)
assert 0 == test_db.placex_unindexed()
@ -150,7 +150,7 @@ def test_index_partial_without_30(test_db, threads):
assert 31 == test_db.placex_unindexed()
assert 1 == test_db.osmline_unindexed()
idx = Indexer('dbname=test_nominatim_python_unittest', threads)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', threads)
idx.index_by_rank(4, 15)
assert 19 == test_db.placex_unindexed()
@ -170,7 +170,7 @@ def test_index_partial_with_30(test_db, threads):
assert 31 == test_db.placex_unindexed()
assert 1 == test_db.osmline_unindexed()
idx = Indexer('dbname=test_nominatim_python_unittest', threads)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', threads)
idx.index_by_rank(28, 30)
assert 27 == test_db.placex_unindexed()
@ -191,7 +191,7 @@ def test_index_boundaries(test_db, threads):
assert 37 == test_db.placex_unindexed()
assert 1 == test_db.osmline_unindexed()
idx = Indexer('dbname=test_nominatim_python_unittest', threads)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', threads)
idx.index_boundaries(0, 30)
assert 31 == test_db.placex_unindexed()
@ -209,14 +209,15 @@ def test_index_postcodes(test_db, threads):
for postcode in range(32000, 33000):
test_db.add_postcode('us', postcode)
idx = Indexer('dbname=test_nominatim_python_unittest', threads)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', threads)
idx.index_postcodes()
assert 0 == test_db.scalar("""SELECT count(*) FROM location_postcode
WHERE indexed_status != 0""")
def test_index_full(test_db):
@pytest.mark.parametrize("analyse", [True, False])
def test_index_full(test_db, analyse):
for rank in range(4, 10):
test_db.add_admin(rank_address=rank, rank_search=rank)
for rank in range(31):
@ -225,10 +226,23 @@ def test_index_full(test_db):
for postcode in range(1000):
test_db.add_postcode('de', postcode)
idx = Indexer('dbname=test_nominatim_python_unittest', 4)
idx.index_full()
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', 4)
idx.index_full(analyse=analyse)
assert 0 == test_db.placex_unindexed()
assert 0 == test_db.osmline_unindexed()
assert 0 == test_db.scalar("""SELECT count(*) FROM location_postcode
WHERE indexed_status != 0""")
@pytest.mark.parametrize("threads", [1, 15])
def test_index_reopen_connection(test_db, threads, monkeypatch):
monkeypatch.setattr(indexer.WorkerPool, "REOPEN_CONNECTIONS_AFTER", 15)
for _ in range(1000):
test_db.add_place(rank_address=30, rank_search=30)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', threads)
idx.index_by_rank(28, 30)
assert 0 == test_db.placex_unindexed()