1
1
mirror of https://github.com/yandex/pgmigrate.git synced 2024-09-19 16:17:24 +03:00
pgmigrate/pgmigrate.py

817 lines
26 KiB
Python
Raw Normal View History

#!/usr/bin/env python
"""
PGmigrate - PostgreSQL migrations made easy
"""
# -*- coding: utf-8 -*-
#
2017-02-06 13:00:04 +03:00
# Copyright (c) 2016-2017 Yandex LLC <https://github.com/yandex>
# Copyright (c) 2016-2017 Other contributors as noted in the AUTHORS file.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose, without fee, and without a written
# agreement is hereby granted, provided that the above copyright notice
# and this paragraph and the following two paragraphs appear in all copies.
#
# IN NO EVENT SHALL YANDEX LLC BE LIABLE TO ANY PARTY FOR DIRECT,
# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST
# PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION,
# EVEN IF YANDEX LLC HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# YANDEX SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS"
# BASIS, AND YANDEX LLC HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE,
# SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
from __future__ import absolute_import, print_function, unicode_literals
import argparse
import codecs
import json
import logging
import os
import re
import sys
2017-12-03 01:39:30 +03:00
import threading
import time
from builtins import str as text
from collections import OrderedDict, namedtuple
2017-12-03 01:39:30 +03:00
from contextlib import closing
import psycopg2
import sqlparse
import yaml
from psycopg2.extras import LoggingConnection
LOG = logging.getLogger(__name__)
class MigrateError(RuntimeError):
"""
Common migration error class
"""
pass
class MalformedStatement(MigrateError):
"""
Incorrect statement exception
"""
pass
class MalformedMigration(MigrateError):
"""
Incorrect migration exception
"""
pass
class MalformedSchema(MigrateError):
"""
Incorrect schema exception
"""
pass
class ConfigurationError(MigrateError):
"""
Incorrect config or cmd args exception
"""
pass
class BaselineError(MigrateError):
"""
Baseline error class
"""
pass
2017-02-14 20:26:09 +03:00
2017-12-03 01:39:30 +03:00
class ConflictTerminator(threading.Thread):
"""
Kills conflicting pids (only on postgresql > 9.6)
"""
def __init__(self, conn_str, interval):
threading.Thread.__init__(self, name='terminator')
self.daemon = True
self.log = logging.getLogger('terminator')
self.conn_str = conn_str
self.pids = set()
self.interval = interval
self.should_run = True
self.conn = None
def stop(self):
"""
Stop iterations and close connection
"""
self.should_run = False
def add_conn(self, conn):
"""
Add conn pid to pgmirate pids list
"""
self.pids.add(conn.get_backend_pid())
def remove_conn(self, conn):
"""
Remove conn from pgmigrate pids list
"""
self.pids.remove(conn.get_backend_pid())
def run(self):
"""
Periodically terminate all backends blocking pgmigrate pids
"""
self.conn = _create_raw_connection(self.conn_str, self.log)
self.conn.autocommit = True
while self.should_run:
with self.conn.cursor() as cursor:
for pid in self.pids:
cursor.execute(
2017-12-03 18:03:08 +03:00
'SELECT pid, pg_terminate_backend(pid) FROM '
2017-12-03 01:39:30 +03:00
'unnest(pg_blocking_pids(%s)) AS pid',
(pid,))
2017-12-03 18:03:08 +03:00
terminated = [x[0] for x in cursor.fetchall()]
for i in terminated:
self.log.info('Terminated conflicting pid: %s', i)
2017-12-03 01:39:30 +03:00
time.sleep(self.interval)
REF_COLUMNS = ['version', 'description', 'type',
'installed_by', 'installed_on']
2017-12-03 01:39:30 +03:00
def _create_raw_connection(conn_string, logger=LOG):
conn = psycopg2.connect(conn_string, connection_factory=LoggingConnection)
2017-12-03 01:39:30 +03:00
conn.initialize(logger)
return conn
def _create_connection(config):
conn = _create_raw_connection(config.conn)
if config.terminator_instance:
config.terminator_instance.add_conn(conn)
return conn
2017-04-01 22:11:47 +03:00
def _init_cursor(conn, session):
"""
2017-04-01 22:11:47 +03:00
Get cursor initialized with session commands
"""
2017-04-01 22:11:47 +03:00
cursor = conn.cursor()
for query in session:
cursor.execute(query)
LOG.info(cursor.statusmessage)
return cursor
def _is_initialized(cursor):
"""
Check that database is initialized
"""
cursor.execute('SELECT EXISTS(SELECT 1 FROM '
'information_schema.tables '
'WHERE table_schema = %s '
'AND table_name = %s)',
('public', 'schema_version'))
table_exists = cursor.fetchone()[0]
if not table_exists:
return False
cursor.execute('SELECT * from public.schema_version limit 1')
colnames = [desc[0] for desc in cursor.description]
if colnames != REF_COLUMNS:
raise MalformedSchema(
'Table schema_version has unexpected '
'structure: {struct}'.format(struct='|'.join(colnames)))
return True
2017-02-14 20:26:09 +03:00
MIGRATION_FILE_RE = re.compile(
r'V(?P<version>\d+)__(?P<description>.+)\.sql$',
)
MigrationInfo = namedtuple('MigrationInfo', ('meta', 'file_path'))
Callbacks = namedtuple('Callbacks', ('beforeAll', 'beforeEach',
'afterEach', 'afterAll'))
Config = namedtuple('Config', ('target', 'baseline', 'cursor', 'dryrun',
'callbacks', 'user', 'base_dir', 'conn',
2017-12-03 01:39:30 +03:00
'session', 'conn_instance',
'terminator_instance', 'termination_interval'))
2017-12-03 01:39:30 +03:00
CONFIG_IGNORE = ['cursor', 'conn_instance', 'terminator_instance']
def _get_files_from_dir(path):
"""
Get all files in all subdirs in path
"""
for root, _, files in os.walk(path):
for fname in files:
2018-04-03 15:06:38 +03:00
yield os.path.basename(fname), os.path.join(root, fname)
def _get_migrations_info_from_dir(base_dir):
"""
Get all migrations from base dir
"""
path = os.path.join(base_dir, 'migrations')
migrations = {}
if not (os.path.exists(path) and os.path.isdir(path)):
raise ConfigurationError(
'Migrations dir not found (expected to be {path})'.format(
path=path))
for fname, file_path in _get_files_from_dir(path):
match = MIGRATION_FILE_RE.match(fname)
if match is None:
LOG.warning(
'File %s does not match by pattern %s. Skipping it.',
file_path, MIGRATION_FILE_RE.pattern)
continue
version = int(match.group('version'))
ret = dict(
version=version,
type='auto',
installed_by=None,
installed_on=None,
description=match.group('description').replace('_', ' '),
)
ret['transactional'] = 'NONTRANSACTIONAL' not in ret['description']
migration = MigrationInfo(
ret,
file_path,
)
if version in migrations:
raise MalformedMigration((
'Found migrations with same version: {version} '
'\nfirst : {first_path}'
'\nsecond: {second_path}').format(
version=version,
first_path=migration.file_path,
second_path=migrations[version].file_path))
migrations[version] = migration
return migrations
def _get_migrations_info(base_dir, baseline_v, target_v):
"""
Get migrations from baseline to target from base dir
"""
migrations = {}
target = target_v if target_v is not None else float('inf')
for version, ret in _get_migrations_info_from_dir(base_dir).items():
if version > baseline_v and version <= target:
migrations[version] = ret.meta
else:
LOG.info(
'Ignore migration %r cause baseline: %r or target: %r',
ret, baseline_v, target,
)
return migrations
def _get_info(base_dir, baseline_v, target_v, cursor):
"""
Get migrations info from database and base dir
"""
ret = {}
cursor.execute(
'SELECT {columns} FROM public.schema_version'.format(
columns=', '.join(REF_COLUMNS)))
for i in cursor.fetchall():
version = {}
for j in enumerate(REF_COLUMNS):
if j[1] == 'installed_on':
version[j[1]] = i[j[0]].strftime('%F %H:%M:%S')
else:
version[j[1]] = i[j[0]]
version['version'] = int(version['version'])
transactional = 'NONTRANSACTIONAL' not in version['description']
version['transactional'] = transactional
ret[version['version']] = version
baseline_v = max(baseline_v, sorted(ret.keys())[-1])
2017-04-03 17:04:46 +03:00
migrations_info = _get_migrations_info(base_dir, baseline_v, target_v)
for version in migrations_info:
num = migrations_info[version]['version']
if num not in ret:
ret[num] = migrations_info[version]
return ret
def _get_database_user(cursor):
cursor.execute('SELECT CURRENT_USER')
return cursor.fetchone()[0]
def _get_state(base_dir, baseline_v, target, cursor):
"""
Get info wrapper (able to handle noninitialized database)
"""
if _is_initialized(cursor):
return _get_info(base_dir, baseline_v, target, cursor)
return _get_migrations_info(base_dir, baseline_v, target)
def _set_baseline(baseline_v, user, cursor):
"""
Cleanup schema_version and set baseline
"""
cursor.execute('SELECT EXISTS(SELECT 1 FROM public'
'.schema_version WHERE version >= %s::bigint)',
(baseline_v,))
check_failed = cursor.fetchone()[0]
if check_failed:
raise BaselineError(
'Unable to baseline, version '
'{version} already applied'.format(version=text(baseline_v)))
LOG.info('cleaning up table schema_version')
cursor.execute('DELETE FROM public.schema_version')
LOG.info(cursor.statusmessage)
LOG.info('setting baseline')
cursor.execute('INSERT INTO public.schema_version '
'(version, type, description, installed_by) '
'VALUES (%s::bigint, %s, %s, %s)',
(text(baseline_v), 'manual',
'Forced baseline', user))
LOG.info(cursor.statusmessage)
def _init_schema(cursor):
"""
Create schema_version table
"""
LOG.info('creating type schema_version_type')
cursor.execute('CREATE TYPE public.schema_version_type '
'AS ENUM (%s, %s)', ('auto', 'manual'))
LOG.info(cursor.statusmessage)
LOG.info('creating table schema_version')
cursor.execute('CREATE TABLE public.schema_version ('
'version BIGINT NOT NULL PRIMARY KEY, '
'description TEXT NOT NULL, '
'type public.schema_version_type NOT NULL '
'DEFAULT %s, '
'installed_by TEXT NOT NULL, '
'installed_on TIMESTAMP WITHOUT time ZONE '
'DEFAULT now() NOT NULL)', ('auto',))
LOG.info(cursor.statusmessage)
def _get_statements(path):
"""
Get statements from file
"""
with codecs.open(path, encoding='utf-8') as i:
data = i.read()
if u'/* pgmigrate-encoding: utf-8 */' not in data:
try:
data.encode('ascii')
except UnicodeError as exc:
raise MalformedStatement(
'Non ascii symbols in file: {0}, {1}'.format(
path, text(exc)))
data = sqlparse.format(data, strip_comments=True)
for statement in sqlparse.parsestream(data, encoding='utf-8'):
st_str = text(statement).strip().encode('utf-8')
if st_str:
yield st_str
def _apply_statement(statement, cursor):
"""
Execute statement using cursor
"""
try:
cursor.execute(statement)
except psycopg2.Error as exc:
LOG.error('Error executing statement:')
for line in statement.splitlines():
LOG.error(line)
LOG.error(exc)
raise MigrateError('Unable to apply statement')
def _apply_file(file_path, cursor):
"""
Execute all statements in file
"""
try:
for statement in _get_statements(file_path):
_apply_statement(statement, cursor)
except MalformedStatement as exc:
LOG.error(exc)
raise exc
def _apply_version(version, base_dir, user, cursor):
"""
Execute all statements in migration version
"""
all_versions = _get_migrations_info_from_dir(base_dir)
version_info = all_versions[version]
LOG.info('Try apply version %r', version_info)
_apply_file(version_info.file_path, cursor)
cursor.execute('INSERT INTO public.schema_version '
'(version, description, installed_by) '
'VALUES (%s::bigint, %s, %s)',
(text(version),
version_info.meta['description'],
user))
def _parse_str_callbacks(callbacks, ret, base_dir):
callbacks = callbacks.split(',')
for callback in callbacks:
if not callback:
continue
tokens = callback.split(':')
if tokens[0] not in ret._fields:
raise ConfigurationError(
'Unexpected callback '
'type: {type}'.format(type=text(tokens[0])))
path = os.path.join(base_dir, tokens[1])
if not os.path.exists(path):
raise ConfigurationError(
'Path unavailable: {path}'.format(path=text(path)))
if os.path.isdir(path):
for fname in sorted(os.listdir(path)):
getattr(ret, tokens[0]).append(os.path.join(path, fname))
else:
getattr(ret, tokens[0]).append(path)
return ret
def _parse_dict_callbacks(callbacks, ret, base_dir):
for i in callbacks:
if i in ret._fields:
for j in callbacks[i]:
path = os.path.join(base_dir, j)
if not os.path.exists(path):
raise ConfigurationError(
'Path unavailable: {path}'.format(path=text(path)))
if os.path.isdir(path):
for fname in sorted(os.listdir(path)):
getattr(ret, i).append(os.path.join(path, fname))
else:
getattr(ret, i).append(path)
else:
raise ConfigurationError(
'Unexpected callback type: {type}'.format(type=text(i)))
return ret
def _get_callbacks(callbacks, base_dir=''):
"""
Parse cmdline/config callbacks
"""
ret = Callbacks(beforeAll=[],
beforeEach=[],
afterEach=[],
afterAll=[])
if isinstance(callbacks, dict):
return _parse_dict_callbacks(callbacks, ret, base_dir)
return _parse_str_callbacks(callbacks, ret, base_dir)
def _migrate_step(state, callbacks, base_dir, user, cursor):
"""
Apply one version with callbacks
"""
before_all_executed = False
should_migrate = False
if not _is_initialized(cursor):
LOG.info('schema not initialized')
_init_schema(cursor)
for version in sorted(state.keys()):
LOG.debug('has version %r', version)
if state[version]['installed_on'] is None:
should_migrate = True
if not before_all_executed and callbacks.beforeAll:
LOG.info('Executing beforeAll callbacks:')
for callback in callbacks.beforeAll:
_apply_file(callback, cursor)
LOG.info(callback)
before_all_executed = True
LOG.info('Migrating to version %d', version)
if callbacks.beforeEach:
LOG.info('Executing beforeEach callbacks:')
for callback in callbacks.beforeEach:
LOG.info(callback)
_apply_file(callback, cursor)
_apply_version(version, base_dir, user, cursor)
if callbacks.afterEach:
LOG.info('Executing afterEach callbacks:')
for callback in callbacks.afterEach:
LOG.info(callback)
_apply_file(callback, cursor)
if should_migrate and callbacks.afterAll:
LOG.info('Executing afterAll callbacks:')
for callback in callbacks.afterAll:
LOG.info(callback)
_apply_file(callback, cursor)
def _finish(config):
if config.dryrun:
config.cursor.execute('rollback')
else:
config.cursor.execute('commit')
2017-12-03 01:39:30 +03:00
if config.terminator_instance:
config.terminator_instance.stop()
config.conn_instance.close()
def info(config, stdout=True):
"""
Info cmdline wrapper
"""
state = _get_state(config.base_dir, config.baseline,
config.target, config.cursor)
if stdout:
out_state = OrderedDict()
for version in sorted(state, key=int):
out_state[version] = state[version]
sys.stdout.write(
json.dumps(out_state, indent=4, separators=(',', ': ')) + '\n')
_finish(config)
return state
def clean(config):
"""
Drop schema_version table
"""
if _is_initialized(config.cursor):
LOG.info('dropping schema_version')
config.cursor.execute('DROP TABLE public.schema_version')
LOG.info(config.cursor.statusmessage)
LOG.info('dropping schema_version_type')
config.cursor.execute('DROP TYPE public.schema_version_type')
LOG.info(config.cursor.statusmessage)
_finish(config)
def baseline(config):
"""
Set baseline cmdline wrapper
"""
if not _is_initialized(config.cursor):
_init_schema(config.cursor)
_set_baseline(config.baseline, config.user, config.cursor)
_finish(config)
def _prepare_nontransactional_steps(state, callbacks):
steps = []
i = {'state': {},
'cbs': _get_callbacks('')}
for version in sorted(state):
if not state[version]['transactional']:
if i['state']:
steps.append(i)
i = {'state': {},
'cbs': _get_callbacks('')}
elif not steps:
LOG.error('First migration MUST be transactional')
raise MalformedMigration('First migration MUST '
'be transactional')
steps.append({'state': {version: state[version]},
'cbs': _get_callbacks('')})
else:
i['state'][version] = state[version]
i['cbs'] = callbacks
if i['state']:
steps.append(i)
transactional = []
for (num, step) in enumerate(steps):
if list(step['state'].values())[0]['transactional']:
transactional.append(num)
if len(transactional) > 1:
for num in transactional[1:]:
steps[num]['cbs'] = steps[num]['cbs']._replace(beforeAll=[])
for num in transactional[:-1]:
steps[num]['cbs'] = steps[num]['cbs']._replace(afterAll=[])
LOG.info('Initialization plan result:\n %s',
json.dumps(steps, indent=4, separators=(',', ': ')))
return steps
2017-12-03 01:39:30 +03:00
def _execute_mixed_steps(config, steps, nt_conn):
commit_req = False
for step in steps:
if commit_req:
config.cursor.execute('commit')
commit_req = False
if not list(step['state'].values())[0]['transactional']:
cur = _init_cursor(nt_conn, config.session)
else:
cur = config.cursor
commit_req = True
_migrate_step(step['state'], step['cbs'],
config.base_dir, config.user, cur)
def migrate(config):
"""
Migrate cmdline wrapper
"""
if config.target is None:
2017-04-01 13:05:14 +03:00
LOG.error('Unknown target (you could use "latest" to '
2017-04-01 22:11:47 +03:00
'use latest available version)')
raise MigrateError('Unknown target')
2017-04-01 22:11:47 +03:00
state = _get_state(config.base_dir, config.baseline,
config.target, config.cursor)
not_applied = [x for x in state if state[x]['installed_on'] is None]
non_trans = [x for x in not_applied if not state[x]['transactional']]
if non_trans:
if config.dryrun:
LOG.error('Dry run for nontransactional migrations '
'is nonsence')
raise MigrateError('Dry run for nontransactional migrations '
'is nonsence')
if len(state) != len(not_applied):
if len(not_applied) != len(non_trans):
LOG.error('Unable to mix transactional and '
'nontransactional migrations')
raise MigrateError('Unable to mix transactional and '
'nontransactional migrations')
config.cursor.execute('rollback')
2017-12-03 01:39:30 +03:00
with closing(_create_connection(config)) as nt_conn:
nt_conn.autocommit = True
cursor = _init_cursor(nt_conn, config.session)
_migrate_step(state, _get_callbacks(''),
config.base_dir, config.user, cursor)
if config.terminator_instance:
config.terminator_instance.remove_conn(nt_conn)
else:
steps = _prepare_nontransactional_steps(state, config.callbacks)
2017-12-03 01:39:30 +03:00
with closing(_create_connection(config)) as nt_conn:
nt_conn.autocommit = True
2017-12-03 01:39:30 +03:00
_execute_mixed_steps(config, steps, nt_conn)
if config.terminator_instance:
config.terminator_instance.remove_conn(nt_conn)
else:
_migrate_step(state, config.callbacks, config.base_dir,
config.user, config.cursor)
_finish(config)
2017-02-14 20:26:09 +03:00
COMMANDS = {
'info': info,
'clean': clean,
'baseline': baseline,
'migrate': migrate,
}
CONFIG_DEFAULTS = Config(target=None, baseline=0, cursor=None, dryrun=False,
callbacks='', base_dir='', user=None,
2017-04-01 22:11:47 +03:00
session=['SET lock_timeout = 0'],
conn='dbname=postgres user=postgres '
'connect_timeout=1',
2017-12-03 01:39:30 +03:00
conn_instance=None,
terminator_instance=None,
termination_interval=None)
def get_config(base_dir, args=None):
"""
Load configuration from yml in base dir with respect of args
"""
path = os.path.join(base_dir, 'migrations.yml')
try:
with codecs.open(path, encoding='utf-8') as i:
base = yaml.load(i.read())
except IOError:
LOG.info('Unable to load %s. Using defaults', path)
base = {}
conf = CONFIG_DEFAULTS
for i in [j for j in CONFIG_DEFAULTS._fields if j not in CONFIG_IGNORE]:
if i in base:
conf = conf._replace(**{i: base[i]})
if args is not None:
if i in args.__dict__ and args.__dict__[i] is not None:
conf = conf._replace(**{i: args.__dict__[i]})
2017-04-01 13:05:14 +03:00
if conf.target is not None:
if conf.target == 'latest':
conf = conf._replace(target=float('inf'))
else:
conf = conf._replace(target=int(conf.target))
2017-12-03 01:39:30 +03:00
if conf.termination_interval and not conf.dryrun:
conf = conf._replace(
terminator_instance=ConflictTerminator(
conf.conn, conf.termination_interval))
conf.terminator_instance.start()
conf = conf._replace(conn_instance=_create_connection(conf))
2017-04-01 22:11:47 +03:00
conf = conf._replace(cursor=_init_cursor(conf.conn_instance, conf.session))
conf = conf._replace(callbacks=_get_callbacks(conf.callbacks,
conf.base_dir))
if conf.user is None:
conf = conf._replace(user=_get_database_user(conf.cursor))
elif not conf.user:
raise ConfigurationError('Empty user name')
return conf
def _main():
"""
Main function
"""
parser = argparse.ArgumentParser()
parser.add_argument('cmd',
choices=COMMANDS.keys(),
type=str,
help='Operation')
parser.add_argument('-t', '--target',
2017-04-01 13:05:14 +03:00
type=str,
help='Target version')
parser.add_argument('-c', '--conn',
type=str,
help='Postgresql connection string')
parser.add_argument('-d', '--base_dir',
type=str,
default='',
help='Migrations base dir')
parser.add_argument('-u', '--user',
type=str,
help='Override database user in migration info')
parser.add_argument('-b', '--baseline',
type=int,
help='Baseline version')
parser.add_argument('-a', '--callbacks',
type=str,
help='Comma-separated list of callbacks '
'(type:dir/file)')
2017-04-01 22:11:47 +03:00
parser.add_argument('-s', '--session',
action='append',
help='Session setup (e.g. isolation level)')
parser.add_argument('-n', '--dryrun',
action='store_true',
help='Say "rollback" in the end instead of "commit"')
2017-12-03 01:39:30 +03:00
parser.add_argument('-l', '--termination_interval',
type=float,
2017-12-03 01:46:47 +03:00
help='Inverval for terminating blocking pids')
parser.add_argument('-v', '--verbose',
default=0,
action='count',
help='Be verbose')
args = parser.parse_args()
logging.basicConfig(
level=(logging.ERROR - 10*(min(3, args.verbose))),
format='%(asctime)s %(levelname)-8s: %(message)s')
config = get_config(args.base_dir, args)
COMMANDS[args.cmd](config)
2017-02-14 20:26:09 +03:00
if __name__ == '__main__':
_main()