1
1
mirror of https://github.com/yandex/pgmigrate.git synced 2024-10-05 16:17:14 +03:00

Allow reordering setting schema version and afterEach callback

This commit is contained in:
secwall 2023-12-31 20:21:58 +03:00
parent 609549c111
commit e74474056a
6 changed files with 70 additions and 34 deletions

View File

@ -287,7 +287,7 @@ max-attributes=7
max-bool-expr=5 max-bool-expr=5
# Maximum number of branch for function / method body. # Maximum number of branch for function / method body.
max-branches=12 max-branches=15
# Maximum number of locals for function / method body. # Maximum number of locals for function / method body.
max-locals=15 max-locals=15

View File

@ -75,6 +75,29 @@ Feature: Getting info from config
| 6 | After each | | 6 | After each |
| 7 | After all | | 7 | After all |
Scenario: Reordering setting schema version and afterEach callback works
Given migration dir
And migrations
| file | code |
| V1__Single_migration.sql | SELECT 1; |
| V2__Another_migration.sql | SELECT 1; |
And config
"""
set_version_info_after_callbacks: true
"""
And config callbacks
| type | file | code |
| beforeAll | before_all.sql | CREATE TABLE mycooltable (seq SERIAL PRIMARY KEY, count int); |
| afterEach | after_each.sql | INSERT INTO mycooltable (count) SELECT count(*) FROM schema_version; |
And database and connection
When we run pgmigrate with "-t 2 migrate"
Then pgmigrate command "succeeded"
And database contains schema_version
And query "SELECT * from mycooltable order by seq;" equals
| seq | count |
| 1 | 0 |
| 2 | 1 |
Scenario: Callbacks from config are executed from dir Scenario: Callbacks from config are executed from dir
Given migration dir Given migration dir
And migrations And migrations

View File

@ -9,5 +9,5 @@ from pgmigrate import _get_info
def step_impl(context, baseline, schema='public'): def step_impl(context, baseline, schema='public'):
cur = context.conn.cursor() cur = context.conn.cursor()
info = _get_info(context.migr_dir, 0, 1, schema, cur) info = _get_info(context.migr_dir, 0, 1, schema, cur)
assert list(info.values())[0]['version'] == int(baseline) assert list(info.values())[0].meta['version'] == int(baseline)
assert list(info.values())[0]['description'] == 'Forced baseline' assert list(info.values())[0].meta['description'] == 'Forced baseline'

View File

@ -7,5 +7,5 @@ from pgmigrate import _get_info
def step_impl(context, schema='public'): def step_impl(context, schema='public'):
cur = context.conn.cursor() cur = context.conn.cursor()
info = _get_info(context.migr_dir, 0, 1, schema, cur) info = _get_info(context.migr_dir, 0, 1, schema, cur)
assert list(info.values())[0]['version'] == 1 assert list(info.values())[0].meta['version'] == 1
assert list(info.values())[0]['description'] == 'Single migration' assert list(info.values())[0].meta['description'] == 'Single migration'

View File

@ -225,11 +225,12 @@ MigrationInfo = namedtuple('MigrationInfo', ('meta', 'file_path'))
Callbacks = namedtuple('Callbacks', Callbacks = namedtuple('Callbacks',
('beforeAll', 'beforeEach', 'afterEach', 'afterAll')) ('beforeAll', 'beforeEach', 'afterEach', 'afterAll'))
Config = namedtuple('Config', Config = namedtuple(
('target', 'baseline', 'cursor', 'dryrun', 'callbacks', 'Config',
'user', 'base_dir', 'conn', 'session', 'conn_instance', ('target', 'baseline', 'cursor', 'dryrun', 'callbacks', 'user', 'base_dir',
'terminator_instance', 'termination_interval', 'schema', 'conn', 'session', 'conn_instance', 'terminator_instance',
'disable_schema_check', 'check_serial_versions')) 'termination_interval', 'schema', 'disable_schema_check',
'check_serial_versions', 'set_version_info_after_callbacks'))
CONFIG_IGNORE = ['cursor', 'conn_instance', 'terminator_instance'] CONFIG_IGNORE = ['cursor', 'conn_instance', 'terminator_instance']
@ -294,7 +295,7 @@ def _get_migrations_info(base_dir, baseline_v, target_v):
for version, ret in _get_migrations_info_from_dir(base_dir).items(): for version, ret in _get_migrations_info_from_dir(base_dir).items():
if baseline_v < version <= target: if baseline_v < version <= target:
migrations[version] = ret.meta migrations[version] = ret
else: else:
LOG.info( LOG.info(
'Ignore migration %r cause baseline: %r or target: %r', 'Ignore migration %r cause baseline: %r or target: %r',
@ -324,13 +325,13 @@ def _get_info(base_dir, baseline_v, target_v, schema, cursor):
version['version'] = int(version['version']) version['version'] = int(version['version'])
transactional = 'NONTRANSACTIONAL' not in version['description'] transactional = 'NONTRANSACTIONAL' not in version['description']
version['transactional'] = transactional version['transactional'] = transactional
ret[version['version']] = version ret[version['version']] = MigrationInfo(meta=version, file_path='')
baseline_v = max(baseline_v, sorted(ret.keys())[-1]) baseline_v = max(baseline_v, sorted(ret.keys())[-1])
migrations_info = _get_migrations_info(base_dir, baseline_v, target_v) migrations_info = _get_migrations_info(base_dir, baseline_v, target_v)
for version in migrations_info: for version in migrations_info:
num = migrations_info[version]['version'] num = migrations_info[version].meta['version']
if num not in ret: if num not in ret:
ret[num] = migrations_info[version] ret[num] = migrations_info[version]
@ -461,15 +462,16 @@ def _apply_file(file_path, cursor):
raise exc raise exc
def _apply_version(version, base_dir, user, schema, cursor): def _apply_version(version_info, cursor):
""" """
Execute all statements in migration version Execute all statements in migration version
""" """
all_versions = _get_migrations_info_from_dir(base_dir)
version_info = all_versions[version]
LOG.info('Try apply version %r', version_info) LOG.info('Try apply version %r', version_info)
_apply_file(version_info.file_path, cursor) _apply_file(version_info.file_path, cursor)
def _set_schema_version(version, version_info, user, schema, cursor):
cursor.execute( cursor.execute(
SQL('INSERT INTO {schema}.schema_version ' SQL('INSERT INTO {schema}.schema_version '
'(version, description, installed_by) ' '(version, description, installed_by) '
@ -532,7 +534,8 @@ def _get_callbacks(callbacks, base_dir=''):
return _parse_str_callbacks(callbacks, ret, base_dir) return _parse_str_callbacks(callbacks, ret, base_dir)
def _migrate_step(state, callbacks, base_dir, user, schema, cursor): def _migrate_step(state, callbacks, user, schema,
set_version_info_after_callbacks, cursor):
""" """
Apply one version with callbacks Apply one version with callbacks
""" """
@ -543,7 +546,7 @@ def _migrate_step(state, callbacks, base_dir, user, schema, cursor):
_init_schema(schema, cursor) _init_schema(schema, cursor)
for version in sorted(state.keys()): for version in sorted(state.keys()):
LOG.debug('has version %r', version) LOG.debug('has version %r', version)
if state[version]['installed_on'] is None: if state[version].meta['installed_on'] is None:
should_migrate = True should_migrate = True
if not before_all_executed and callbacks.beforeAll: if not before_all_executed and callbacks.beforeAll:
LOG.info('Executing beforeAll callbacks:') LOG.info('Executing beforeAll callbacks:')
@ -559,7 +562,11 @@ def _migrate_step(state, callbacks, base_dir, user, schema, cursor):
LOG.info(callback) LOG.info(callback)
_apply_file(callback, cursor) _apply_file(callback, cursor)
_apply_version(version, base_dir, user, schema, cursor) _apply_version(state[version], cursor)
if not set_version_info_after_callbacks:
_set_schema_version(version, state[version], user, schema,
cursor)
if callbacks.afterEach: if callbacks.afterEach:
LOG.info('Executing afterEach callbacks:') LOG.info('Executing afterEach callbacks:')
@ -567,6 +574,10 @@ def _migrate_step(state, callbacks, base_dir, user, schema, cursor):
LOG.info(callback) LOG.info(callback)
_apply_file(callback, cursor) _apply_file(callback, cursor)
if set_version_info_after_callbacks:
_set_schema_version(version, state[version], user, schema,
cursor)
if should_migrate and callbacks.afterAll: if should_migrate and callbacks.afterAll:
LOG.info('Executing afterAll callbacks:') LOG.info('Executing afterAll callbacks:')
for callback in callbacks.afterAll: for callback in callbacks.afterAll:
@ -593,7 +604,7 @@ def info(config, stdout=True):
if stdout: if stdout:
out_state = OrderedDict() out_state = OrderedDict()
for version in sorted(state, key=int): for version in sorted(state, key=int):
out_state[version] = state[version] out_state[version] = state[version].meta
sys.stdout.write( sys.stdout.write(
json.dumps(out_state, indent=4, separators=(',', ': ')) + '\n') json.dumps(out_state, indent=4, separators=(',', ': ')) + '\n')
@ -635,7 +646,7 @@ def _prepare_nontransactional_steps(state, callbacks):
steps = [] steps = []
i = {'state': {}, 'cbs': _get_callbacks('')} i = {'state': {}, 'cbs': _get_callbacks('')}
for version in sorted(state): for version in sorted(state):
if not state[version]['transactional']: if not state[version].meta['transactional']:
if i['state']: if i['state']:
steps.append(i) steps.append(i)
i = {'state': {}, 'cbs': _get_callbacks('')} i = {'state': {}, 'cbs': _get_callbacks('')}
@ -658,7 +669,7 @@ def _prepare_nontransactional_steps(state, callbacks):
transactional = [] transactional = []
for (num, step) in enumerate(steps): for (num, step) in enumerate(steps):
if list(step['state'].values())[0]['transactional']: if list(step['state'].values())[0].meta['transactional']:
transactional.append(num) transactional.append(num)
if len(transactional) > 1: if len(transactional) > 1:
@ -679,13 +690,13 @@ def _execute_mixed_steps(config, steps, nt_conn):
if commit_req: if commit_req:
config.cursor.execute('commit') config.cursor.execute('commit')
commit_req = False commit_req = False
if not list(step['state'].values())[0]['transactional']: if not list(step['state'].values())[0].meta['transactional']:
cur = _init_cursor(nt_conn, config.session) cur = _init_cursor(nt_conn, config.session)
else: else:
cur = config.cursor cur = config.cursor
commit_req = True commit_req = True
_migrate_step(step['state'], step['cbs'], config.base_dir, config.user, _migrate_step(step['state'], step['cbs'], config.user, config.schema,
config.schema, cur) config.set_version_info_after_callbacks, cur)
def _schema_check(schema, cursor): def _schema_check(schema, cursor):
@ -713,7 +724,7 @@ def _check_serial_versions(state, not_applied):
""" """
Check that there are no gaps in migration versions Check that there are no gaps in migration versions
""" """
applied = [x for x in state if state[x]['installed_on'] is not None] applied = [x for x in state if state[x].meta['installed_on'] is not None]
sorted_versions = sorted(not_applied) sorted_versions = sorted(not_applied)
if applied: if applied:
sorted_versions.insert(0, max(applied)) sorted_versions.insert(0, max(applied))
@ -738,8 +749,8 @@ def migrate(config):
state = _get_state(config.base_dir, config.baseline, config.target, state = _get_state(config.base_dir, config.baseline, config.target,
config.schema, config.cursor) config.schema, config.cursor)
not_applied = [x for x in state if state[x]['installed_on'] is None] not_applied = [x for x in state if state[x].meta['installed_on'] is None]
non_trans = [x for x in not_applied if not state[x]['transactional']] non_trans = [x for x in not_applied if not state[x].meta['transactional']]
if not_applied and config.check_serial_versions: if not_applied and config.check_serial_versions:
_check_serial_versions(state, not_applied) _check_serial_versions(state, not_applied)
@ -763,8 +774,9 @@ def migrate(config):
with closing(_create_connection(config)) as nt_conn: with closing(_create_connection(config)) as nt_conn:
nt_conn.autocommit = True nt_conn.autocommit = True
cursor = _init_cursor(nt_conn, config.session) cursor = _init_cursor(nt_conn, config.session)
_migrate_step(state, _get_callbacks(''), config.base_dir, _migrate_step(state, _get_callbacks(''), config.user,
config.user, config.schema, cursor) config.schema,
config.set_version_info_after_callbacks, cursor)
if config.terminator_instance: if config.terminator_instance:
config.terminator_instance.remove_conn(nt_conn) config.terminator_instance.remove_conn(nt_conn)
else: else:
@ -778,8 +790,8 @@ def migrate(config):
if config.terminator_instance: if config.terminator_instance:
config.terminator_instance.remove_conn(nt_conn) config.terminator_instance.remove_conn(nt_conn)
else: else:
_migrate_step(state, config.callbacks, config.base_dir, config.user, _migrate_step(state, config.callbacks, config.user, config.schema,
config.schema, config.cursor) config.set_version_info_after_callbacks, config.cursor)
if not config.disable_schema_check: if not config.disable_schema_check:
_schema_check(config.schema, config.cursor) _schema_check(config.schema, config.cursor)
@ -808,7 +820,8 @@ CONFIG_DEFAULTS = Config(target=None,
termination_interval=None, termination_interval=None,
schema=None, schema=None,
disable_schema_check=False, disable_schema_check=False,
check_serial_versions=False) check_serial_versions=False,
set_version_info_after_callbacks=False)
def get_config(base_dir, args=None): def get_config(base_dir, args=None):

View File

@ -46,7 +46,7 @@ deps = flake8==5.0.4
flake8-copyright flake8-copyright
flake8-pep3101 flake8-pep3101
pylint pylint
yapf==0.32.0 yapf==0.40.2
[flake8] [flake8]
copyright-check = True copyright-check = True