1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-10-06 02:07:53 +03:00

Join conditions: alias tables already included in query

If we have the input `SELECT * FROM Foo JOIN `, we now suggest `Foo Foo2 ON Foo2.ParentID = Foo.ID` (given the appropriate casing file and FK).
There were also some problems with quoted tables and with the casing of table aliases, which are now fixed.
I also made a few cosmetic changes to get_join_matches (pgcompleter.py) just to make it a bit easier to work with.
This commit is contained in:
koljonen 2016-06-11 15:40:51 +02:00
parent 351a58554b
commit 6cb8c38628
No known key found for this signature in database
GPG Key ID: AF0327B5131CD164
6 changed files with 117 additions and 41 deletions

View File

@ -66,7 +66,9 @@ def last_word(text, include='alphanum_underscore'):
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
TableReference.ref = property(lambda self: self.alias or self.name)
TableReference.ref = property(lambda self: self.alias or (
self.name if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'))
# This code is borrowed from sqlparse example script.
@ -140,7 +142,10 @@ def extract_table_identifiers(token_stream, allow_functions=True):
schema_name = schema_name.lower()
quote_count = item.value.count('"')
name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
if name and not name_quoted and name != name.lower():
alias_quoted = alias and item.value[-1] == '"'
if alias_quoted or name_quoted and not alias and name.islower():
alias = '"' + (alias or name) + '"'
if name and not name_quoted and not name.islower():
if not alias:
alias = name
name = name.lower()
@ -198,7 +203,8 @@ def extract_tables(sql):
identifiers = extract_table_identifiers(stream,
allow_functions=not insert_stmt)
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
return tuple(i for i in identifiers if i.ref)
return tuple(i for i in identifiers if i.name)
def find_prev_keyword(sql):

View File

@ -30,6 +30,7 @@ NamedQueries.instance = NamedQueries.from_config(
Match = namedtuple('Match', ['completion', 'priority'])
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
class PGCompleter(Completer):
keywords = get_literals('keywords')
@ -384,38 +385,50 @@ class PGCompleter(Completer):
return self.find_matches(word_before_cursor, flat_cols, meta='column')
def generate_alias(self, tbl, tbls):
if tbl[0] == '"':
aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in itertools.count(2))
else:
aliases = (self.case(tbl) + str(i) for i in itertools.count(2))
return (a for a in aliases if normalize_ref(a) not in tbls).next()
def get_join_matches(self, suggestion, word_before_cursor):
scoped_cols = self.populate_scoped_cols(suggestion.tables)
tbls = suggestion.tables
cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
qualified = dict((t.ref, t.schema) for t in suggestion.tables)
tbls = set((t.schema, t.name) for t in scoped_cols.keys())
tblprio = dict((t.ref, n) for n, t in enumerate(suggestion.tables))
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
refs = set(normalize_ref(t.ref) for t in tbls)
other_tbls = set((t.schema, t.name) for t in cols.keys()[:-1])
joins, prios = [], []
# Iterate over FKs in existing tables to find potential joins
fks = ((fk, rtbl, rcol) for rtbl, rcols in scoped_cols.items()
fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items()
for rcol in rcols for fk in rcol.foreignkeys)
col = namedtuple('col', 'schema tbl col')
for fk, rtbl, rcol in fks:
if (fk.childschema, fk.childtable, fk.childcolumn) == (
rtbl.schema, rtbl.name, rcol.name):
lsch = fk.parentschema
ltbl = fk.parenttable
lcol = fk.parentcolumn
else:
lsch = fk.childschema
ltbl = fk.childtable
lcol = fk.childcolumn
if suggestion.schema and lsch != suggestion.schema:
right = col(rtbl.schema, rtbl.name, rcol.name)
child = col(fk.childschema, fk.childtable, fk.childcolumn)
parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
left = child if parent == right else parent
if suggestion.schema and left.schema != suggestion.schema:
continue
rsch, rtbl, rcol = rtbl.schema, rtbl.ref, rcol.name
join = '{0} ON {0}.{1} = {2}.{3}'.format(
self.case(ltbl), self.case(lcol), rtbl, self.case(rcol))
c = self.case
if normalize_ref(left.tbl) in refs:
lref = self.generate_alias(left.tbl, refs)
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref)
else:
join = '{0} ON {0}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col))
# Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public
if not suggestion.schema and (qualified[rtbl] and lsch == rsch
or lsch not in(rsch, 'public')):
join = lsch + '.' + join
if not suggestion.schema and (qualified[normalize_ref(rtbl.ref)]
and left.schema == right.schema
or left.schema not in(right.schema, 'public')):
join = left.schema + '.' + join
joins.append(join)
prios.append(tblprio[rtbl] * 2 + 0 if (lsch, ltbl) in tbls else 1)
prios.append(ref_prio[normalize_ref(rtbl.ref)] * 2 + (
0 if (left.schema, left.tbl) in other_tbls else 1))
return self.find_matches(word_before_cursor, joins, meta='join',
priority_collection=prios, type_priority=100)

View File

@ -12,11 +12,18 @@ def test_simple_select_single_table():
@pytest.mark.parametrize('sql', [
'select * from abc.def',
'select * from "abc".def',
'select * from "abc"."def"',
'select * from abc."def"',
])
def test_simple_select_single_table_schema_qualified_quoted_table(sql):
tables = extract_tables(sql)
assert tables == (('abc', 'def', '"def"', False),)
@pytest.mark.parametrize('sql', [
'select * from abc.def',
'select * from "abc".def',
])
def test_simple_select_single_table_schema_qualified(sql):
tables = extract_tables(sql)
assert tables == (('abc', 'def', None, False),)

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals
import pytest
import itertools
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
@ -158,19 +159,21 @@ def test_suggested_join_conditions(completer, complete_event, text):
Completion(text='shipments.id = users.id', start_position=0, display_meta='name join'),
Completion(text='shipments.user_id = users.id', start_position=0, display_meta='fk join')])
@pytest.mark.parametrize('text', [
'SELECT * FROM public.users RIGHT OUTER JOIN ',
@pytest.mark.parametrize(('query', 'tbl'), itertools.product((
'SELECT * FROM public.{0} RIGHT OUTER JOIN ',
'''SELECT *
FROM users
FROM {0}
JOIN '''
])
def test_suggested_joins(completer, complete_event, text):
), ('users', '"users"', 'Users')))
def test_suggested_joins(completer, complete_event, query, tbl):
text = query.format(tbl)
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
join = 'custom.shipments ON shipments.user_id = {0}.id'.format(tbl)
assert set(result) == set([
Completion(text='custom.shipments ON shipments.user_id = users.id', start_position=0, display_meta='join'),
Completion(text=join, start_position=0, display_meta='join'),
Completion(text='public', start_position=0, display_meta='schema'),
Completion(text='custom', start_position=0, display_meta='schema'),
Completion(text='"Custom"', start_position=0, display_meta='schema'),

View File

@ -6,7 +6,7 @@ from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
metadata = {
'tables': {
'users': ['id', 'email', 'first_name', 'last_name'],
'users': ['id', 'parentid', 'email', 'first_name', 'last_name'],
'Users': ['userid', 'username'],
'orders': ['id', 'ordered_date', 'status', 'email'],
'select': ['id', 'insert', 'ABC']},
@ -21,6 +21,7 @@ metadata = {
['o', 'o'], '', False, False, True]],
'datatypes': ['custom_type1', 'custom_type2'],
'foreignkeys': [
('public', 'users', 'id', 'public', 'users', 'parentid'),
('public', 'users', 'id', 'public', 'Users', 'userid')
],
}
@ -170,6 +171,7 @@ def test_suggested_column_names_from_visible_table(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column'),
@ -196,6 +198,7 @@ def test_suggested_column_names_in_function(completer, complete_event):
complete_event)
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@ -214,6 +217,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@ -232,6 +236,7 @@ def test_suggested_column_names_with_alias(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@ -251,6 +256,7 @@ def test_suggested_multiple_column_names(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column'),
@ -276,6 +282,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@ -295,6 +302,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
complete_event))
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])
@ -321,7 +329,8 @@ def test_suggest_columns_after_three_way_join(completer, complete_event):
'SELECT * FROM users u FULL OUTER JOIN "Users" u2 ON ',
'''SELECT *
FROM users u
FULL OUTER JOIN "Users" u2 ON '''
FULL OUTER JOIN "Users" u2 ON
'''
])
def test_suggested_join_conditions(completer, complete_event, text):
position = len(text)
@ -346,6 +355,32 @@ def test_suggested_joins(completer, complete_event, text):
complete_event))
assert set(result) == set([
Completion(text='"Users" ON "Users".userid = users.id', start_position=0, display_meta='join'),
Completion(text='users users2 ON users2.id = users.parentid', start_position=0, display_meta='join'),
Completion(text='users users2 ON users2.parentid = users.id', start_position=0, display_meta='join'),
Completion(text='public', start_position=0, display_meta='schema'),
Completion(text='"Users"', start_position=0, display_meta='table'),
Completion(text='"select"', start_position=0, display_meta='table'),
Completion(text='orders', start_position=0, display_meta='table'),
Completion(text='users', start_position=0, display_meta='table'),
Completion(text='user_emails', start_position=0, display_meta='view'),
Completion(text='custom_func2', start_position=0, display_meta='function'),
Completion(text='set_returning_func', start_position=0, display_meta='function'),
Completion(text='custom_func1', start_position=0, display_meta='function')])
@pytest.mark.parametrize('text', [
'SELECT * FROM public."Users" JOIN ',
'SELECT * FROM public."Users" RIGHT OUTER JOIN ',
'''SELECT *
FROM public."Users"
LEFT JOIN '''
])
def test_suggested_joins_quoted_schema_qualified_table(completer, complete_event, text):
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='public.users ON users.id = "Users".userid', start_position=0, display_meta='join'),
Completion(text='public', start_position=0, display_meta='schema'),
Completion(text='"Users"', start_position=0, display_meta='table'),
Completion(text='"select"', start_position=0, display_meta='table'),
@ -642,7 +677,7 @@ def test_wildcard_column_expansion(completer, complete_event):
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, email, first_name, last_name'
col_list = 'id, parentid, email, first_name, last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
@ -656,7 +691,7 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, u.email, u.first_name, u.last_name'
col_list = 'id, u.parentid, u.email, u.first_name, u.last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
@ -664,9 +699,9 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_even
@pytest.mark.parametrize('text,expected', [
('SELECT users.* FROM users',
'id, users.email, users.first_name, users.last_name'),
'id, users.parentid, users.email, users.first_name, users.last_name'),
('SELECT Users.* FROM Users',
'id, Users.email, Users.first_name, Users.last_name'),
'id, Users.parentid, Users.email, Users.first_name, Users.last_name'),
])
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event, text, expected):
pos = len('SELECT users.*')
@ -687,7 +722,7 @@ def test_wildcard_column_expansion_with_two_tables(completer, complete_event):
Document(text=sql, cursor_position=pos), complete_event)
cols = ('"select".id, "select"."insert", "select"."ABC", '
'u.id, u.email, u.first_name, u.last_name')
'u.id, u.parentid, u.email, u.first_name, u.last_name')
expected = [Completion(text=cols, start_position=-1,
display='*', display_meta='columns')]
assert completions == expected
@ -718,6 +753,7 @@ def test_suggest_columns_from_unquoted_table(completer, complete_event, text):
complete_event)
assert set(result) == set([
Completion(text='id', start_position=0, display_meta='column'),
Completion(text='parentid', start_position=0, display_meta='column'),
Completion(text='email', start_position=0, display_meta='column'),
Completion(text='first_name', start_position=0, display_meta='column'),
Completion(text='last_name', start_position=0, display_meta='column')])

View File

@ -23,8 +23,19 @@ def test_select_suggests_cols_with_qualified_table_scope():
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE ',
'SELECT * FROM "tabl" WHERE ',
])
def test_where_suggests_columns_functions_quoted_table(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([
Column(tables=((None, 'tabl', '"tabl"', False),)),
Function(schema=None),
Keyword(),
])
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE (',
'SELECT * FROM tabl WHERE foo = ',
'SELECT * FROM tabl WHERE bar OR ',