extensive refactor, now plays nicely with extensions

This commit is contained in:
Robert Lechte 2016-07-28 23:21:35 +12:00
parent 64f222ee1b
commit 885bcf4a20
11 changed files with 292 additions and 152 deletions

View File

@ -70,11 +70,11 @@ Documentation
Features and Limitations
-----------
------------------------
Migra will detect changes to tables, views, materialized views, indexes, constraints, sequences, and extensions.
Migra will detect changes to tables, views, materialized views, indexes, constraints, sequences, and which extensions are installed.
In terms of specific PostgreSQL feature limitations, `migra` is only confirmed to work with SQL/PLPGSQL functions so far. Doesn't track changes to function modifiers (IMMUTABLE/STABLE/VOLATILE, STRICT, RETURNS NULL ON NULL INPUT, etc).
In terms of specific PostgreSQL feature limitations, `migra` is only confirmed to work with SQL/PLPGSQL functions so far.
Installation

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
from .util import differences
from .statements import Statements
from functools import partial
from sqlbag import quoted_identifier
THINGS = [
'sequences',
@ -15,39 +15,25 @@ THINGS = [
]
def fully_quoted(name, schema):
return '{}.{}'.format(quoted_identifier(schema), quoted_identifier(name))
def get_changes(
i_from,
i_target,
thing,
def statements_for_changes(
things_from,
things_target,
creations_only=False,
drops_only=False,
specific_table=None):
drops_only=False):
added, removed, modified, unmodified = \
differences(getattr(i_from, thing), getattr(i_target, thing))
if specific_table:
added = {k: v for k, v in added.items()
if fully_quoted(v.table_name, v.schema) == specific_table}
removed = {k: v for k, v in removed.items()
if fully_quoted(v.table_name, v.schema) == specific_table}
modified = {k: v for k, v in modified.items()
if fully_quoted(v.table_name, v.schema) == specific_table}
added, removed, modified, _ = \
differences(things_from, things_target)
statements = Statements()
if not drops_only:
for k, v in added.items():
statements.append(v.create_statement)
if not creations_only:
for k, v in removed.items():
statements.append(v.drop_statement)
if not drops_only:
for k, v in added.items():
statements.append(v.create_statement)
for k, v in modified.items():
if not creations_only:
statements.append(v.drop_statement)
@ -57,73 +43,42 @@ def get_changes(
return statements
def get_schema_changes(i_from, i_target):
added, removed, modified, _ = differences(i_from.tables, i_target.tables)
def get_schema_changes(tables_from, tables_target):
added, removed, modified, _ = differences(tables_from, tables_target)
statements = Statements()
for t, v in removed.items():
statements.append(v.drop_statement)
indexes = Statements()
constraints = Statements()
for t, v in added.items():
statements.append(v.create_statement)
indexes += get_changes(i_from, i_target, 'indexes', specific_table=t)
for t, v in added.items():
statements += [i.create_statement for i in v.indexes.values()]
constraints += get_changes(i_from, i_target, 'constraints', specific_table=t)
statements += indexes
statements += constraints
for t, v in added.items():
statements += [c.create_statement for c in v.constraints.values()]
for t, v in modified.items():
before = i_from.tables[t]
before = tables_from[t]
c_added, c_removed, c_modified, _ = \
differences(before.columns, v.columns)
for k, c in c_added.items():
statements.append(
'alter table {t} add column {k} {dtype};'.format(
k=c.quoted_name, t=v.quoted_full_name, dtype=c.dbtypestr)
)
for k, c in c_removed.items():
statements.append(
'alter table {t} drop column {k};'.format(
k=c.quoted_name, t=t)
)
alter = v.alter_table_statement(c.drop_column_clause)
statements.append(alter)
for k, c in c_added.items():
alter = v.alter_table_statement(c.add_column_clause)
statements.append(alter)
for k, c in c_modified.items():
bc = before.columns[k]
if bc.not_null != c.not_null:
keyword = 'set' if c.not_null else 'drop'
statements += c.alter_table_statements(before.columns[k], t)
stmt = 'alter table {} alter column {} {} not null;'.format(
v.quoted_full_name,
c.quoted_name,
keyword
)
statements.append(stmt)
if bc.default != c.default:
if c.default:
stmt = 'alter table {} alter column {} set default {};'.format(
v.quoted_full_name,
c.quoted_name,
c.default
)
else:
stmt = 'alter table {} alter column {} drop default;'.format(
v.quoted_full_name,
c.quoted_name
)
statements.append(stmt)
statements += statements_for_changes(before.constraints, v.constraints)
statements += statements_for_changes(before.indexes, v.indexes)
indexes = get_changes(i_from, i_target, 'indexes', specific_table=t)
statements += indexes
constraints = get_changes(i_from, i_target, 'constraints', specific_table=t)
statements += constraints
return statements
@ -134,8 +89,14 @@ class Changes(object):
def __getattr__(self, name):
if name == 'schema':
return partial(get_schema_changes, self.i_from, self.i_target)
return partial(
get_schema_changes,
self.i_from.tables,
self.i_target.tables)
elif name in THINGS:
return partial(get_changes, self.i_from, self.i_target, name)
return partial(
statements_for_changes,
getattr(self.i_from, name),
getattr(self.i_target, name))
else:
raise AttributeError(name)

View File

@ -1,25 +1,40 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, print_function
from sqlbag import S
import argparse
import sys
from .migra import Migration
from .statements import UnsafeMigrationException
def do_command():
def parse_args(args):
parser = argparse.ArgumentParser(
description='Generate a database migration.')
parser.add_argument('--unsafe', dest='unsafe', action='store_true')
parser.add_argument(
'--unsafe',
dest='unsafe',
action='store_true',
help='Prevent migra from erroring upon generation of drop statements.')
parser.add_argument('dburl_from',
help='The database you want to migrate.')
parser.add_argument(
'dburl_from',
help='The database you want to migrate.')
parser.add_argument('dburl_target',
help='The database you want to use as the target.')
parser.add_argument(
'dburl_target',
help='The database you want to use as the target.')
args = parser.parse_args()
return parser.parse_args(args)
def run(args, out=None, err=None):
if not out:
out = sys.stdout # pragma: no cover
if not err:
err = sys.stderr # pragma: no cover
with S(args.dburl_from) as s0, S(args.dburl_target) as s1:
m = Migration(s0, s1)
@ -27,8 +42,20 @@ def do_command():
if args.unsafe:
m.set_safety(False)
m.add_all_changes()
print(m.sql)
try:
print(m.sql, file=out)
except UnsafeMigrationException:
print('-- ERROR: destructive statements generated. Use the --unsafe flag to suppress this error.', file=err)
return 3
if not m.statements:
sys.exit(0)
return 0
else:
sys.exit(1)
return 2
def do_command(): # pragma: no cover
args = parse_args(sys.argv[1:])
status = run(args)
sys.exit(status)

View File

@ -42,12 +42,14 @@ class Migration(object):
self.statements.safe = safety_on
def add_all_changes(self):
self.add(self.changes.sequences())
self.add(self.changes.sequences(drops_only=True))
self.add(self.changes.extensions(drops_only=True))
self.add(self.changes.views(drops_only=True))
self.add(self.changes.functions(drops_only=True))
self.add(self.changes.schema())
self.add(self.changes.sequences(creations_only=True))
self.add(self.changes.extensions(creations_only=True))
self.add(self.changes.views(creations_only=True))
self.add(self.changes.functions(creations_only=True))

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from collections import OrderedDict as od
def differences(a, b):
a_keys = set(a.keys())
@ -9,9 +11,9 @@ def differences(a, b):
keys_removed = set(a_keys) - set(b_keys)
keys_common = set(a_keys) & set(b_keys)
added = {k: b[k] for k in keys_added}
removed = {k: a[k] for k in keys_removed}
modified = {k: b[k] for k in keys_common if not a[k] == b[k]}
unmodified = {k: b[k] for k in keys_common if a[k] == b[k]}
added = od((k, b[k]) for k in sorted(keys_added))
removed = od((k, a[k]) for k in sorted(keys_removed))
modified = od((k, b[k]) for k in sorted(keys_common) if a[k] != b[k])
unmodified = od((k, b[k]) for k in sorted(keys_common) if a[k] == b[k])
return added, removed, modified, unmodified

View File

@ -8,7 +8,7 @@ readme = io.open('README.rst').read()
setup(
name='migra',
version='0.1.1469411508',
version='0.1.1469704724',
url='https://github.com/djrobstep/migra',
description='Like diff but for PostgreSQL schemas',
long_description=readme,
@ -28,5 +28,6 @@ setup(
'console_scripts': [
'migra = migra:do_command',
],
}
},
extras_require={'pg': ['psycopg2']}
)

View File

@ -2,19 +2,34 @@ create extension pg_trgm;
CREATE TABLE products (
product_no integer PRIMARY KEY,
name text not null,
name varchar(10) not null unique,
price numeric,
x integer not null default 7,
unwantedcolumn date,
oldcolumn text,
constraint x check (price > 0)
);
create index on products(price);
create view vvv as select 1;
create view vvv as select * from products;
CREATE TABLE orders (
order_id integer PRIMARY KEY,
shipping_address text
);
CREATE TABLE unwanted (
id integer PRIMARY KEY,
name text not null
);
create or replace function public.changed(i integer, t text[])
returns TABLE(a text, c integer) as
$$
declare
BEGIN
select 'no', 1;
END;
$$
LANGUAGE PLPGSQL STABLE returns null on null input security definer;

View File

@ -1,4 +1,5 @@
create extension hstore;
create extension postgis;
CREATE TABLE products (
product_no integer PRIMARY KEY,
@ -18,8 +19,6 @@ CREATE TABLE orders (
shipping_address text
);
CREATE TABLE order_items (
product_no integer REFERENCES products ON DELETE RESTRICT,
order_id integer REFERENCES orders ON DELETE CASCADE,
@ -36,6 +35,17 @@ $$
END;
$$
LANGUAGE PLPGSQL;
LANGUAGE PLPGSQL volatile returns null on null input security definer;
create or replace function public.newfunc(i integer, t text[])
returns TABLE(a text, c integer) as
$$
declare
BEGIN
select 'no', 1;
END;
$$
LANGUAGE PLPGSQL STABLE returns null on null input security invoker;
create view vvv as select 2;

View File

@ -0,0 +1,83 @@
drop extension if exists "pg_trgm";
drop view if exists "public"."vvv" cascade;
drop function if exists "public"."changed"(i integer, t text[]) cascade;
drop table "public"."unwanted";
create table "public"."order_items" (
"product_no" integer not null,
"order_id" integer not null,
"quantity" integer
);
CREATE UNIQUE INDEX order_items_pkey ON order_items USING btree (product_no, order_id);
alter table "public"."order_items" add constraint "order_items_pkey" PRIMARY KEY using index "order_items_pkey";
alter table "public"."order_items" add constraint "order_items_product_no_fkey" FOREIGN KEY (product_no) REFERENCES products(product_no) ON DELETE RESTRICT;
alter table "public"."order_items" add constraint "order_items_order_id_fkey" FOREIGN KEY (order_id) REFERENCES orders(order_id) ON DELETE CASCADE;
alter table "public"."products" drop column "oldcolumn";
alter table "public"."products" add column "newcolumn" text;
alter table "public"."products" add column "newcolumn2" interval;
alter table "public"."products" alter column "name" drop not null;
alter table "public"."products" alter column "name" set data type text;
alter table "public"."products" alter column "price" set not null;
alter table "public"."products" alter column "price" set default 100;
alter table "public"."products" alter column "x" drop not null;
alter table "public"."products" alter column "x" drop default;
alter table "public"."products" drop constraint "products_name_key";
alter table "public"."products" add constraint "y" CHECK ((price > (0)::numeric));
alter table "public"."products" drop constraint "x";
alter table "public"."products" add constraint "x" CHECK ((price > (10)::numeric));
drop index if exists "public"."products_name_key";
drop index if exists "public"."products_price_idx";
CREATE INDEX products_name_idx ON products USING btree (name);
create extension "hstore" with schema "public" version '1.3';
create extension "postgis" with schema "public" version '2.2.1';
create view "public"."vvv" as SELECT 2;
create or replace function "public"."newfunc"(i integer, t text[])
returns TABLE(a text, c integer) as
$$
declare
BEGIN
select 'no', 1;
END;
$$
language PLPGSQL STABLE RETURNS NULL ON NULL INPUT SECURITY INVOKER;
create or replace function "public"."changed"(i integer, t text[])
returns TABLE(a text, c integer) as
$$
declare
BEGIN
select 'no', 1;
END;
$$
language PLPGSQL VOLATILE RETURNS NULL ON NULL INPUT SECURITY DEFINER;

89
tests/test_migra.py Normal file
View File

@ -0,0 +1,89 @@
from __future__ import unicode_literals
import io
from pytest import raises
from migra import Statements, UnsafeMigrationException, Migration
from migra.command import run
from sqlbag import temporary_database, S, load_sql_from_file
from migra.command import parse_args
SQL = """select 1;
select 2;
"""
DROP = 'drop table x;'
B = 'alter table "public"."products" add column "newcolumn" text;\n\n'
A = 'alter table "public"."products" drop column "oldcolumn";\n\n'
EXPECTED = io.open('tests/FIXTURES/expected.sql').read().strip()
EXPECTED2 = EXPECTED.replace(A + B, '')
def test_statements():
s1 = Statements(['select 1;'])
s2 = Statements(['select 2;'])
s3 = s1 + s2
assert type(s1) == type(s2) == type(s3)
s3 = s3 + Statements([DROP])
with raises(UnsafeMigrationException):
assert s3.sql == SQL
s3.safe = False
SQL_WITH_DROP = SQL + DROP + '\n\n'
assert s3.sql == SQL_WITH_DROP
def outs():
return io.StringIO(), io.StringIO()
def test_all():
with temporary_database() as d0, temporary_database() as d1:
with S(d0) as s0, S(d1) as s1:
load_sql_from_file(s0, 'tests/FIXTURES/a.sql')
load_sql_from_file(s1, 'tests/FIXTURES/b.sql')
args = parse_args([d0, d1])
assert not args.unsafe
out, err = outs()
assert run(args, out=out, err=err) == 3
assert out.getvalue() == ''
assert err.getvalue() == '-- ERROR: destructive statements generated. Use the --unsafe flag to suppress this error.\n'
args = parse_args(['--unsafe', d0, d1])
assert args.unsafe
out, err = outs()
assert run(args, out=out, err=err) == 2
assert err.getvalue() == ''
assert out.getvalue().strip() == EXPECTED
with S(d0) as s0, S(d1) as s1:
m = Migration(s0, s1)
with raises(AttributeError):
m.changes.nonexist
m.set_safety(False)
m.add_sql('alter table products rename column oldcolumn to newcolumn;')
m.apply()
m.add_all_changes()
assert m.sql.strip() == EXPECTED2 # sql generated OK
m.apply()
# check for changes again and make sure none are pending
m.add_all_changes()
assert not m.statements # no further statements to apply
out, err = outs()
assert run(args, out=out, err=err) == 0

View File

@ -1,50 +0,0 @@
from __future__ import unicode_literals
from pytest import raises
from migra import Statements, UnsafeMigrationException, Migration
from sqlbag import temporary_database, S, load_sql_from_file
SQL = """select 1;
select 2;
"""
DROP = 'drop table x;'
def test_statements():
s1 = Statements(['select 1;'])
s2 = Statements(['select 2;'])
s3 = s1 + s2
assert type(s1) == type(s2) == type(s3)
s3 = s3 + Statements([DROP])
with raises(UnsafeMigrationException):
assert s3.sql == SQL
s3.safe = False
SQL_WITH_DROP = SQL + DROP + '\n\n'
assert s3.sql == SQL_WITH_DROP
def test_changes():
with temporary_database() as d0, temporary_database() as d1:
with S(d0) as s0, S(d1) as s1:
load_sql_from_file(s0, 'tests/FIXTURES/a.sql')
load_sql_from_file(s1, 'tests/FIXTURES/b.sql')
with S(d0) as s0, S(d1) as s1:
m = Migration(s0, s1)
m.set_safety(False)
m.add_sql('alter table products rename column oldcolumn to newcolumn;')
m.apply()
m.add_all_changes()
m.sql # sql generated OK
m.apply()
m.add_all_changes()
assert not m.statements