mirror of
https://github.com/osm-search/Nominatim.git
synced 2024-09-21 07:58:07 +03:00
add function to set up libpq environment
Instead of parsing the DSN for each external libpq program we are going to execute, provide a function that feeds them all necessary parameters through the environment. osm2pgsql is the first user.
This commit is contained in:
parent
e520613362
commit
af7226393a
@ -3,6 +3,7 @@ Specialised connection and cursor functions.
|
||||
"""
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
@ -10,6 +11,8 @@ import psycopg2.extras
|
||||
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class _Cursor(psycopg2.extras.DictCursor):
|
||||
""" A cursor returning dict-like objects and providing specialised
|
||||
execution functions.
|
||||
@ -18,8 +21,7 @@ class _Cursor(psycopg2.extras.DictCursor):
|
||||
def execute(self, query, args=None): # pylint: disable=W0221
|
||||
""" Query execution that logs the SQL query when debugging is enabled.
|
||||
"""
|
||||
logger = logging.getLogger()
|
||||
logger.debug(self.mogrify(query, args).decode('utf-8'))
|
||||
LOG.debug(self.mogrify(query, args).decode('utf-8'))
|
||||
|
||||
super().execute(query, args)
|
||||
|
||||
@ -96,3 +98,52 @@ def connect(dsn):
|
||||
return ctxmgr
|
||||
except psycopg2.OperationalError as err:
|
||||
raise UsageError("Cannot connect to database: {}".format(err)) from err
|
||||
|
||||
|
||||
# Translation from PG connection string parameters to PG environment variables.
|
||||
# Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
|
||||
_PG_CONNECTION_STRINGS = {
|
||||
'host': 'PGHOST',
|
||||
'hostaddr': 'PGHOSTADDR',
|
||||
'port': 'PGPORT',
|
||||
'dbname': 'PGDATABASE',
|
||||
'user': 'PGUSER',
|
||||
'password': 'PGPASSWORD',
|
||||
'passfile': 'PGPASSFILE',
|
||||
'channel_binding': 'PGCHANNELBINDING',
|
||||
'service': 'PGSERVICE',
|
||||
'options': 'PGOPTIONS',
|
||||
'application_name': 'PGAPPNAME',
|
||||
'sslmode': 'PGSSLMODE',
|
||||
'requiressl': 'PGREQUIRESSL',
|
||||
'sslcompression': 'PGSSLCOMPRESSION',
|
||||
'sslcert': 'PGSSLCERT',
|
||||
'sslkey': 'PGSSLKEY',
|
||||
'sslrootcert': 'PGSSLROOTCERT',
|
||||
'sslcrl': 'PGSSLCRL',
|
||||
'requirepeer': 'PGREQUIREPEER',
|
||||
'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
|
||||
'ssl_min_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
|
||||
'gssencmode': 'PGGSSENCMODE',
|
||||
'krbsrvname': 'PGKRBSRVNAME',
|
||||
'gsslib': 'PGGSSLIB',
|
||||
'connect_timeout': 'PGCONNECT_TIMEOUT',
|
||||
'target_session_attrs': 'PGTARGETSESSIONATTRS',
|
||||
}
|
||||
|
||||
|
||||
def get_pg_env(dsn, base_env=None):
|
||||
""" Return a copy of `base_env` with the environment variables for
|
||||
PostgresSQL set up from the given database connection string.
|
||||
If `base_env` is None, then the OS environment is used as a base
|
||||
environment.
|
||||
"""
|
||||
env = base_env if base_env is not None else os.environ
|
||||
|
||||
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
|
||||
if param in _PG_CONNECTION_STRINGS:
|
||||
env[_PG_CONNECTION_STRINGS[param]] = value
|
||||
else:
|
||||
LOG.error("Unknown connection parameter '%s' ignored.", param)
|
||||
|
||||
return env
|
||||
|
@ -10,6 +10,7 @@ from urllib.parse import urlencode
|
||||
from psycopg2.extensions import parse_dsn
|
||||
|
||||
from ..version import NOMINATIM_VERSION
|
||||
from ..db.connection import get_pg_env
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
@ -100,7 +101,7 @@ def run_php_server(server_address, base_dir):
|
||||
def run_osm2pgsql(options):
|
||||
""" Run osm2pgsql with the given options.
|
||||
"""
|
||||
env = os.environ
|
||||
env = get_pg_env(options['dsn'])
|
||||
cmd = [options['osm2pgsql'],
|
||||
'--hstore', '--latlon', '--slim',
|
||||
'--with-forward-dependencies', 'false',
|
||||
@ -116,17 +117,6 @@ def run_osm2pgsql(options):
|
||||
if options['flatnode_file']:
|
||||
cmd.extend(('--flat-nodes', options['flatnode_file']))
|
||||
|
||||
dsn = parse_dsn(options['dsn'])
|
||||
if 'password' in dsn:
|
||||
env['PGPASSWORD'] = dsn['password']
|
||||
if 'dbname' in dsn:
|
||||
cmd.extend(('-d', dsn['dbname']))
|
||||
if 'user' in dsn:
|
||||
cmd.extend(('--username', dsn['user']))
|
||||
for param in ('host', 'port'):
|
||||
if param in dsn:
|
||||
cmd.extend(('--' + param, dsn[param]))
|
||||
|
||||
if options.get('disable_jit', False):
|
||||
env['PGOPTIONS'] = '-c jit=off -c max_parallel_workers_per_gather=0'
|
||||
|
||||
|
@ -3,7 +3,7 @@ Tests for specialised conenction and cursor classes.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from nominatim.db.connection import connect
|
||||
from nominatim.db.connection import connect, get_pg_env
|
||||
|
||||
@pytest.fixture
|
||||
def db(temp_db):
|
||||
@ -48,3 +48,24 @@ def test_cursor_scalar_many_rows(db):
|
||||
with db.cursor() as cur:
|
||||
with pytest.raises(RuntimeError):
|
||||
cur.scalar('SELECT * FROM pg_tables')
|
||||
|
||||
|
||||
def test_get_pg_env_add_variable(monkeypatch):
|
||||
monkeypatch.delenv('PGPASSWORD', raising=False)
|
||||
env = get_pg_env('user=fooF')
|
||||
|
||||
assert env['PGUSER'] == 'fooF'
|
||||
assert 'PGPASSWORD' not in env
|
||||
|
||||
|
||||
def test_get_pg_env_overwrite_variable(monkeypatch):
|
||||
monkeypatch.setenv('PGUSER', 'some default')
|
||||
env = get_pg_env('user=overwriter')
|
||||
|
||||
assert env['PGUSER'] == 'overwriter'
|
||||
|
||||
|
||||
def test_get_pg_env_ignore_unknown():
|
||||
env = get_pg_env('tty=stuff', base_env={})
|
||||
|
||||
assert env == {}
|
||||
|
Loading…
Reference in New Issue
Block a user