1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-10-06 02:07:53 +03:00

Fix some join-condition issues

When self-joining a table with an FK to or from some other table, we got a false FK-join suggestion for that column.
There was also a problem with quoted tables not being quoted in the join condition.
And there were a couple of problems when trying to join a non-existent table or using a non-existent qualifier (`SELECT * FROM Foo JOIN Bar ON Meow.`).

I also rewrote get_join_condition_matches a bit in the process, hopefully making it a bit simpler.
This commit is contained in:
koljonen 2016-06-09 23:38:33 +02:00
parent 9e98896bb3
commit 5b20e107b8
No known key found for this signature in database
GPG Key ID: AF0327B5131CD164
4 changed files with 97 additions and 50 deletions

View File

@ -315,7 +315,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
suggest.append(Function(schema=schema, filter='for_from_clause'))
if (token_v.endswith('join') and token.is_keyword
and _allow_join_suggestion(parsed_statement)):
and _allow_join(parsed_statement)):
tables = extract_tables(text_before_cursor)
suggest.append(Join(tables=tables, schema=schema))
@ -346,15 +346,15 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
View(schema=parent),
Function(schema=parent)]
last_token = parsed_statement
if _allow_join_condition_suggestion(parsed_statement):
if filteredtables and _allow_join_condition(parsed_statement):
sugs.append(JoinCondition(tables=tables,
parent=filteredtables[-1]))
return tuple(sugs)
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.alias or t.name for t in tables)
if _allow_join_condition_suggestion(parsed_statement):
aliases = tuple(t.ref for t in tables)
if _allow_join_condition(parsed_statement):
return (Alias(aliases=aliases), JoinCondition(
tables=tables, parent=None))
else:
@ -397,7 +397,7 @@ def identifies(id, ref):
ref.schema and (id == ref.schema + '.' + ref.name))
def _allow_join_condition_suggestion(statement):
def _allow_join_condition(statement):
"""
Tests if a join condition should be suggested
@ -417,7 +417,7 @@ def _allow_join_condition_suggestion(statement):
return last_tok.value.lower() in ('on', 'and', 'or')
def _allow_join_suggestion(statement):
def _allow_join(statement):
"""
Tests if a join should be suggested

View File

@ -439,56 +439,56 @@ class PGCompleter(Completer):
priority_collection=prios, type_priority=100)
def get_join_condition_matches(self, suggestion, word_before_cursor):
lefttable = suggestion.parent or suggestion.tables[-1]
scoped_cols = self.populate_scoped_cols(suggestion.tables)
col = namedtuple('col', 'schema tbl col')
tbls = self.populate_scoped_cols(suggestion.tables).items
cols = [(t, c) for t, cs in tbls() for c in cs]
try:
lref = (suggestion.parent or suggestion.tables[-1]).ref
ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
except IndexError: # The user typed an incorrect table qualifier
return []
conds, found_conds = [], set()
def make_cond(tbl1, tbl2, col1, col2):
prefix = '' if suggestion.parent else tbl1 + '.'
def add_cond(lcol, rcol, rref, meta, prio):
prefix = '' if suggestion.parent else ltbl.ref + '.'
case = self.case
return prefix + case(col1) + ' = ' + tbl2 + '.' + case(col2)
cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol)
if cond not in found_conds:
found_conds.add(cond)
conds.append((cond, meta, prio + ref_prio[rref]))
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
d = defaultdict(list)
for pair in pairs:
d[pair[0]].append(pair[1])
return d
# Tables that are closer to the cursor get higher prio
refprio = dict((tbl.ref, num) for num, tbl
ref_prio = dict((tbl.ref, num) for num, tbl
in enumerate(suggestion.tables))
# Map (schema, tablename) to tables and ref to columns
tbldict = defaultdict(list)
for t in scoped_cols.keys():
tbldict[(t.schema, t.name)].append(t)
refcols = dict((t.ref, cs) for t, cs in scoped_cols.items())
# Map (schema, table, col) to tables
coldict = list_dict(((t.schema, t.name, c.name), t)
for t, c in cols if t.ref != lref)
# For each fk from the left table, generate a join condition if
# the other table is also in the scope
conds = []
for lcol in refcols.get(lefttable.ref, []):
for fk in lcol.foreignkeys:
for rcol in ((fk.parentschema, fk.parenttable,
fk.parentcolumn), (fk.childschema, fk.childtable,
fk.childcolumn)):
for rtbl in tbldict[(rcol[0], rcol[1])]:
if rtbl and rtbl.ref != lefttable.ref:
cond = make_cond(lefttable.ref, rtbl.ref,
lcol.name, rcol[2])
prio = 2000 + refprio[rtbl.ref]
conds.append((cond, 'fk join', prio))
fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
for fk, lcol in fks:
left = col(ltbl.schema, ltbl.name, lcol)
child = col(fk.childschema, fk.childtable, fk.childcolumn)
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
left, right = (child, par) if left == child else (par, child)
for rtbl in coldict[right]:
add_cond(left.col, right.col, rtbl.ref, 'fk join', 2000)
# For name matching, use a {(colname, coltype): TableReference} dict
col_table = defaultdict(lambda: [])
for tbl, col in ((t, c) for t, cs in scoped_cols.items() for c in cs):
col_table[(col.name, col.datatype)].append(tbl)
coltyp = namedtuple('coltyp', 'name datatype')
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
# Find all name-match join conditions
found = set(cnd[0] for cnd in conds)
for c in refcols.get(lefttable.ref, []):
for rtbl in col_table[(c.name, c.datatype)]:
if rtbl.ref != lefttable.ref:
cond = make_cond(lefttable.ref, rtbl.ref, c.name, c.name)
if cond not in found:
prio = (1000 if c.datatype and c.datatype in(
'integer', 'bigint', 'smallint')
else 0 + refprio[rtbl.ref])
conds.append((cond, 'name join', prio))
for c in (coltyp(c.name, c.datatype) for c in lcols):
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
add_cond(c.name, c.name, rtbl.ref, 'name join', 1000
if c.datatype in ('integer', 'bigint', 'smallint') else 0)
if not conds:
return []
conds, metas, prios = zip(*conds)
conds, metas, prios = zip(*conds) if conds else ([], [], [])
return self.find_matches(word_before_cursor, conds,
meta_collection=metas, type_priority=100, priority_collection=prios)

View File

@ -100,17 +100,18 @@ def test_schema_or_visible_table_completion(completer, complete_event):
Completion(text='orders', start_position=0, display_meta='table')])
@pytest.mark.parametrize('text', [
'SELECT FROM users',
'SELECT FROM "users"',
@pytest.mark.parametrize('table', [
'users',
'"users"',
])
def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, text):
def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, table):
"""
Suggest column and function names when selecting from table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT FROM ' + table
position = len('SELECT ')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),

View File

@ -342,6 +342,52 @@ def test_suggested_join_conditions(completer, complete_event, text):
Completion(text='u2', start_position=0, display_meta='table alias'),
Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join')])
@pytest.mark.parametrize('text', [
'''SELECT *
FROM users
CROSS JOIN "Users"
NATURAL JOIN users u
JOIN "Users" u2 ON
'''
])
def test_suggested_join_conditions_with_same_table_twice(completer, complete_event, text):
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position),
complete_event)
assert result == [
Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join'),
Completion(text='u2.userid = users.id', start_position=0, display_meta='fk join'),
Completion(text='u2.userid = "Users".userid', start_position=0, display_meta='name join'),
Completion(text='u2.username = "Users".username', start_position=0, display_meta='name join'),
Completion(text='"Users"', start_position=0, display_meta='table alias'),
Completion(text='u', start_position=0, display_meta='table alias'),
Completion(text='u2', start_position=0, display_meta='table alias'),
Completion(text='users', start_position=0, display_meta='table alias')]
@pytest.mark.parametrize('text', [
'SELECT * FROM users JOIN users u2 on foo.'
])
def test_suggested_join_conditions_with_invalid_qualifier(completer, complete_event, text):
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set()
@pytest.mark.parametrize(('text', 'ref'), [
('SELECT * FROM users JOIN NonTable on ', 'NonTable'),
('SELECT * FROM users JOIN nontable nt on ', 'nt')
])
def test_suggested_join_conditions_with_invalid_table(completer, complete_event, text, ref):
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='users', start_position=0, display_meta='table alias'),
Completion(text=ref, start_position=0, display_meta='table alias')])
@pytest.mark.parametrize('text', [
'SELECT * FROM "Users" u JOIN u',
'SELECT * FROM "Users" u JOIN uid',