1
1
mirror of https://github.com/yandex/pgmigrate.git synced 2024-10-05 16:17:14 +03:00

Allow reordering setting schema version and afterEach callback

This commit is contained in:
secwall 2023-12-31 20:21:58 +03:00
parent 609549c111
commit e74474056a
6 changed files with 70 additions and 34 deletions

View File

@ -287,7 +287,7 @@ max-attributes=7
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
max-branches=15
# Maximum number of locals for function / method body.
max-locals=15

View File

@ -75,6 +75,29 @@ Feature: Getting info from config
| 6 | After each |
| 7 | After all |
Scenario: Reordering setting schema version and afterEach callback works
Given migration dir
And migrations
| file | code |
| V1__Single_migration.sql | SELECT 1; |
| V2__Another_migration.sql | SELECT 1; |
And config
"""
set_version_info_after_callbacks: true
"""
And config callbacks
| type | file | code |
| beforeAll | before_all.sql | CREATE TABLE mycooltable (seq SERIAL PRIMARY KEY, count int); |
| afterEach | after_each.sql | INSERT INTO mycooltable (count) SELECT count(*) FROM schema_version; |
And database and connection
When we run pgmigrate with "-t 2 migrate"
Then pgmigrate command "succeeded"
And database contains schema_version
And query "SELECT * from mycooltable order by seq;" equals
| seq | count |
| 1 | 0 |
| 2 | 1 |
Scenario: Callbacks from config are executed from dir
Given migration dir
And migrations

View File

@ -9,5 +9,5 @@ from pgmigrate import _get_info
def step_impl(context, baseline, schema='public'):
cur = context.conn.cursor()
info = _get_info(context.migr_dir, 0, 1, schema, cur)
assert list(info.values())[0]['version'] == int(baseline)
assert list(info.values())[0]['description'] == 'Forced baseline'
assert list(info.values())[0].meta['version'] == int(baseline)
assert list(info.values())[0].meta['description'] == 'Forced baseline'

View File

@ -7,5 +7,5 @@ from pgmigrate import _get_info
def step_impl(context, schema='public'):
cur = context.conn.cursor()
info = _get_info(context.migr_dir, 0, 1, schema, cur)
assert list(info.values())[0]['version'] == 1
assert list(info.values())[0]['description'] == 'Single migration'
assert list(info.values())[0].meta['version'] == 1
assert list(info.values())[0].meta['description'] == 'Single migration'

View File

@ -225,11 +225,12 @@ MigrationInfo = namedtuple('MigrationInfo', ('meta', 'file_path'))
Callbacks = namedtuple('Callbacks',
('beforeAll', 'beforeEach', 'afterEach', 'afterAll'))
Config = namedtuple('Config',
('target', 'baseline', 'cursor', 'dryrun', 'callbacks',
'user', 'base_dir', 'conn', 'session', 'conn_instance',
'terminator_instance', 'termination_interval', 'schema',
'disable_schema_check', 'check_serial_versions'))
Config = namedtuple(
'Config',
('target', 'baseline', 'cursor', 'dryrun', 'callbacks', 'user', 'base_dir',
'conn', 'session', 'conn_instance', 'terminator_instance',
'termination_interval', 'schema', 'disable_schema_check',
'check_serial_versions', 'set_version_info_after_callbacks'))
CONFIG_IGNORE = ['cursor', 'conn_instance', 'terminator_instance']
@ -294,7 +295,7 @@ def _get_migrations_info(base_dir, baseline_v, target_v):
for version, ret in _get_migrations_info_from_dir(base_dir).items():
if baseline_v < version <= target:
migrations[version] = ret.meta
migrations[version] = ret
else:
LOG.info(
'Ignore migration %r cause baseline: %r or target: %r',
@ -324,13 +325,13 @@ def _get_info(base_dir, baseline_v, target_v, schema, cursor):
version['version'] = int(version['version'])
transactional = 'NONTRANSACTIONAL' not in version['description']
version['transactional'] = transactional
ret[version['version']] = version
ret[version['version']] = MigrationInfo(meta=version, file_path='')
baseline_v = max(baseline_v, sorted(ret.keys())[-1])
migrations_info = _get_migrations_info(base_dir, baseline_v, target_v)
for version in migrations_info:
num = migrations_info[version]['version']
num = migrations_info[version].meta['version']
if num not in ret:
ret[num] = migrations_info[version]
@ -461,15 +462,16 @@ def _apply_file(file_path, cursor):
raise exc
def _apply_version(version, base_dir, user, schema, cursor):
def _apply_version(version_info, cursor):
"""
Execute all statements in migration version
"""
all_versions = _get_migrations_info_from_dir(base_dir)
version_info = all_versions[version]
LOG.info('Try apply version %r', version_info)
_apply_file(version_info.file_path, cursor)
def _set_schema_version(version, version_info, user, schema, cursor):
cursor.execute(
SQL('INSERT INTO {schema}.schema_version '
'(version, description, installed_by) '
@ -532,7 +534,8 @@ def _get_callbacks(callbacks, base_dir=''):
return _parse_str_callbacks(callbacks, ret, base_dir)
def _migrate_step(state, callbacks, base_dir, user, schema, cursor):
def _migrate_step(state, callbacks, user, schema,
set_version_info_after_callbacks, cursor):
"""
Apply one version with callbacks
"""
@ -543,7 +546,7 @@ def _migrate_step(state, callbacks, base_dir, user, schema, cursor):
_init_schema(schema, cursor)
for version in sorted(state.keys()):
LOG.debug('has version %r', version)
if state[version]['installed_on'] is None:
if state[version].meta['installed_on'] is None:
should_migrate = True
if not before_all_executed and callbacks.beforeAll:
LOG.info('Executing beforeAll callbacks:')
@ -559,7 +562,11 @@ def _migrate_step(state, callbacks, base_dir, user, schema, cursor):
LOG.info(callback)
_apply_file(callback, cursor)
_apply_version(version, base_dir, user, schema, cursor)
_apply_version(state[version], cursor)
if not set_version_info_after_callbacks:
_set_schema_version(version, state[version], user, schema,
cursor)
if callbacks.afterEach:
LOG.info('Executing afterEach callbacks:')
@ -567,6 +574,10 @@ def _migrate_step(state, callbacks, base_dir, user, schema, cursor):
LOG.info(callback)
_apply_file(callback, cursor)
if set_version_info_after_callbacks:
_set_schema_version(version, state[version], user, schema,
cursor)
if should_migrate and callbacks.afterAll:
LOG.info('Executing afterAll callbacks:')
for callback in callbacks.afterAll:
@ -593,7 +604,7 @@ def info(config, stdout=True):
if stdout:
out_state = OrderedDict()
for version in sorted(state, key=int):
out_state[version] = state[version]
out_state[version] = state[version].meta
sys.stdout.write(
json.dumps(out_state, indent=4, separators=(',', ': ')) + '\n')
@ -635,7 +646,7 @@ def _prepare_nontransactional_steps(state, callbacks):
steps = []
i = {'state': {}, 'cbs': _get_callbacks('')}
for version in sorted(state):
if not state[version]['transactional']:
if not state[version].meta['transactional']:
if i['state']:
steps.append(i)
i = {'state': {}, 'cbs': _get_callbacks('')}
@ -658,7 +669,7 @@ def _prepare_nontransactional_steps(state, callbacks):
transactional = []
for (num, step) in enumerate(steps):
if list(step['state'].values())[0]['transactional']:
if list(step['state'].values())[0].meta['transactional']:
transactional.append(num)
if len(transactional) > 1:
@ -679,13 +690,13 @@ def _execute_mixed_steps(config, steps, nt_conn):
if commit_req:
config.cursor.execute('commit')
commit_req = False
if not list(step['state'].values())[0]['transactional']:
if not list(step['state'].values())[0].meta['transactional']:
cur = _init_cursor(nt_conn, config.session)
else:
cur = config.cursor
commit_req = True
_migrate_step(step['state'], step['cbs'], config.base_dir, config.user,
config.schema, cur)
_migrate_step(step['state'], step['cbs'], config.user, config.schema,
config.set_version_info_after_callbacks, cur)
def _schema_check(schema, cursor):
@ -713,7 +724,7 @@ def _check_serial_versions(state, not_applied):
"""
Check that there are no gaps in migration versions
"""
applied = [x for x in state if state[x]['installed_on'] is not None]
applied = [x for x in state if state[x].meta['installed_on'] is not None]
sorted_versions = sorted(not_applied)
if applied:
sorted_versions.insert(0, max(applied))
@ -738,8 +749,8 @@ def migrate(config):
state = _get_state(config.base_dir, config.baseline, config.target,
config.schema, config.cursor)
not_applied = [x for x in state if state[x]['installed_on'] is None]
non_trans = [x for x in not_applied if not state[x]['transactional']]
not_applied = [x for x in state if state[x].meta['installed_on'] is None]
non_trans = [x for x in not_applied if not state[x].meta['transactional']]
if not_applied and config.check_serial_versions:
_check_serial_versions(state, not_applied)
@ -763,8 +774,9 @@ def migrate(config):
with closing(_create_connection(config)) as nt_conn:
nt_conn.autocommit = True
cursor = _init_cursor(nt_conn, config.session)
_migrate_step(state, _get_callbacks(''), config.base_dir,
config.user, config.schema, cursor)
_migrate_step(state, _get_callbacks(''), config.user,
config.schema,
config.set_version_info_after_callbacks, cursor)
if config.terminator_instance:
config.terminator_instance.remove_conn(nt_conn)
else:
@ -778,8 +790,8 @@ def migrate(config):
if config.terminator_instance:
config.terminator_instance.remove_conn(nt_conn)
else:
_migrate_step(state, config.callbacks, config.base_dir, config.user,
config.schema, config.cursor)
_migrate_step(state, config.callbacks, config.user, config.schema,
config.set_version_info_after_callbacks, config.cursor)
if not config.disable_schema_check:
_schema_check(config.schema, config.cursor)
@ -808,7 +820,8 @@ CONFIG_DEFAULTS = Config(target=None,
termination_interval=None,
schema=None,
disable_schema_check=False,
check_serial_versions=False)
check_serial_versions=False,
set_version_info_after_callbacks=False)
def get_config(base_dir, args=None):

View File

@ -46,7 +46,7 @@ deps = flake8==5.0.4
flake8-copyright
flake8-pep3101
pylint
yapf==0.32.0
yapf==0.40.2
[flake8]
copyright-check = True