use context management when processing Tiger data

This commit is contained in:
Sarah Hoffmann 2022-05-11 09:45:15 +02:00
parent ae6b029543
commit 5d5f40a82f

View File

@ -21,33 +21,57 @@ from nominatim.indexer.place_info import PlaceInfo
LOG = logging.getLogger()
def handle_tarfile_or_directory(data_dir):
""" Handles tarfile or directory for importing tiger data
class TigerInput:
""" Context manager that goes through Tiger input files which may
either be in a directory or gzipped together in a tar file.
"""
tar = None
def __init__(self, data_dir):
self.tar_handle = None
self.files = []
if data_dir.endswith('.tar.gz'):
try:
tar = tarfile.open(data_dir)
self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
except tarfile.ReadError as err:
LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
raise UsageError("Cannot open Tiger data file.") from err
csv_files = [i for i in tar.getmembers() if i.name.endswith('.csv')]
LOG.warning("Found %d CSV files in tarfile with path %s", len(csv_files), data_dir)
if not csv_files:
LOG.warning("Tiger data import selected but no files in tarfile's path %s", data_dir)
return None, None
self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
else:
files = os.listdir(data_dir)
csv_files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
LOG.warning("Found %d CSV files in path %s", len(csv_files), data_dir)
if not csv_files:
LOG.warning("Tiger data import selected but no files found in path %s", data_dir)
return None, None
self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
return csv_files, tar
if not self.files:
LOG.warning("Tiger data import selected but no files found at %s", data_dir)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.tar_handle:
self.tar_handle.close()
self.tar_handle = None
def next_file(self):
""" Return a file handle to the next file to be processed.
Raises an IndexError if there is no file left.
"""
fname = self.files.pop(0)
if self.tar_handle is not None:
return io.TextIOWrapper(self.tar_handle.extractfile(fname))
return open(fname, encoding='utf-8')
def __len__(self):
return len(self.files)
def handle_threaded_sql_statements(pool, fd, analyzer):
@ -79,9 +103,9 @@ def add_tiger_data(data_dir, config, threads, tokenizer):
""" Import tiger data from directory or tar file `data dir`.
"""
dsn = config.get_libpq_dsn()
files, tar = handle_tarfile_or_directory(data_dir)
if not files:
with TigerInput(data_dir) as tar:
if not tar:
return
with connect(dsn) as conn:
@ -94,19 +118,12 @@ def add_tiger_data(data_dir, config, threads, tokenizer):
with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
with tokenizer.name_analyzer() as analyzer:
for fname in files:
if not tar:
fd = open(fname)
else:
fd = io.TextIOWrapper(tar.extractfile(fname))
while tar:
with tar.next_file() as fd:
handle_threaded_sql_statements(pool, fd, analyzer)
fd.close()
if tar:
tar.close()
print('\n')
LOG.warning("Creating indexes on Tiger data")
with connect(dsn) as conn:
sql = SQLPreprocessor(conn, config)