mirror of
https://github.com/dbcli/pgcli.git
synced 2024-10-06 02:07:53 +03:00
Fix some join-condition issues
When self-joining a table with an FK to or from some other table, we got a false FK-join suggestion for that column. There was also a problem with quoted tables not being quoted in the join condition. And there were a couple of problems when trying to join a non-existent table or using a non-existent qualifier (`SELECT * FROM Foo JOIN Bar ON Meow.`). I also rewrote get_join_condition_matches a bit in the process, hopefully making it a bit simpler.
This commit is contained in:
parent
9e98896bb3
commit
5b20e107b8
@ -315,7 +315,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
|
||||
suggest.append(Function(schema=schema, filter='for_from_clause'))
|
||||
|
||||
if (token_v.endswith('join') and token.is_keyword
|
||||
and _allow_join_suggestion(parsed_statement)):
|
||||
and _allow_join(parsed_statement)):
|
||||
tables = extract_tables(text_before_cursor)
|
||||
suggest.append(Join(tables=tables, schema=schema))
|
||||
|
||||
@ -346,15 +346,15 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text,
|
||||
View(schema=parent),
|
||||
Function(schema=parent)]
|
||||
last_token = parsed_statement
|
||||
if _allow_join_condition_suggestion(parsed_statement):
|
||||
if filteredtables and _allow_join_condition(parsed_statement):
|
||||
sugs.append(JoinCondition(tables=tables,
|
||||
parent=filteredtables[-1]))
|
||||
return tuple(sugs)
|
||||
else:
|
||||
# ON <suggestion>
|
||||
# Use table alias if there is one, otherwise the table name
|
||||
aliases = tuple(t.alias or t.name for t in tables)
|
||||
if _allow_join_condition_suggestion(parsed_statement):
|
||||
aliases = tuple(t.ref for t in tables)
|
||||
if _allow_join_condition(parsed_statement):
|
||||
return (Alias(aliases=aliases), JoinCondition(
|
||||
tables=tables, parent=None))
|
||||
else:
|
||||
@ -397,7 +397,7 @@ def identifies(id, ref):
|
||||
ref.schema and (id == ref.schema + '.' + ref.name))
|
||||
|
||||
|
||||
def _allow_join_condition_suggestion(statement):
|
||||
def _allow_join_condition(statement):
|
||||
"""
|
||||
Tests if a join condition should be suggested
|
||||
|
||||
@ -417,7 +417,7 @@ def _allow_join_condition_suggestion(statement):
|
||||
return last_tok.value.lower() in ('on', 'and', 'or')
|
||||
|
||||
|
||||
def _allow_join_suggestion(statement):
|
||||
def _allow_join(statement):
|
||||
"""
|
||||
Tests if a join should be suggested
|
||||
|
||||
|
@ -439,56 +439,56 @@ class PGCompleter(Completer):
|
||||
priority_collection=prios, type_priority=100)
|
||||
|
||||
def get_join_condition_matches(self, suggestion, word_before_cursor):
|
||||
lefttable = suggestion.parent or suggestion.tables[-1]
|
||||
scoped_cols = self.populate_scoped_cols(suggestion.tables)
|
||||
col = namedtuple('col', 'schema tbl col')
|
||||
tbls = self.populate_scoped_cols(suggestion.tables).items
|
||||
cols = [(t, c) for t, cs in tbls() for c in cs]
|
||||
try:
|
||||
lref = (suggestion.parent or suggestion.tables[-1]).ref
|
||||
ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
|
||||
except IndexError: # The user typed an incorrect table qualifier
|
||||
return []
|
||||
conds, found_conds = [], set()
|
||||
|
||||
def make_cond(tbl1, tbl2, col1, col2):
|
||||
prefix = '' if suggestion.parent else tbl1 + '.'
|
||||
def add_cond(lcol, rcol, rref, meta, prio):
|
||||
prefix = '' if suggestion.parent else ltbl.ref + '.'
|
||||
case = self.case
|
||||
return prefix + case(col1) + ' = ' + tbl2 + '.' + case(col2)
|
||||
cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol)
|
||||
if cond not in found_conds:
|
||||
found_conds.add(cond)
|
||||
conds.append((cond, meta, prio + ref_prio[rref]))
|
||||
|
||||
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
|
||||
d = defaultdict(list)
|
||||
for pair in pairs:
|
||||
d[pair[0]].append(pair[1])
|
||||
return d
|
||||
|
||||
# Tables that are closer to the cursor get higher prio
|
||||
refprio = dict((tbl.ref, num) for num, tbl
|
||||
ref_prio = dict((tbl.ref, num) for num, tbl
|
||||
in enumerate(suggestion.tables))
|
||||
# Map (schema, tablename) to tables and ref to columns
|
||||
tbldict = defaultdict(list)
|
||||
for t in scoped_cols.keys():
|
||||
tbldict[(t.schema, t.name)].append(t)
|
||||
refcols = dict((t.ref, cs) for t, cs in scoped_cols.items())
|
||||
# Map (schema, table, col) to tables
|
||||
coldict = list_dict(((t.schema, t.name, c.name), t)
|
||||
for t, c in cols if t.ref != lref)
|
||||
# For each fk from the left table, generate a join condition if
|
||||
# the other table is also in the scope
|
||||
conds = []
|
||||
for lcol in refcols.get(lefttable.ref, []):
|
||||
for fk in lcol.foreignkeys:
|
||||
for rcol in ((fk.parentschema, fk.parenttable,
|
||||
fk.parentcolumn), (fk.childschema, fk.childtable,
|
||||
fk.childcolumn)):
|
||||
for rtbl in tbldict[(rcol[0], rcol[1])]:
|
||||
if rtbl and rtbl.ref != lefttable.ref:
|
||||
cond = make_cond(lefttable.ref, rtbl.ref,
|
||||
lcol.name, rcol[2])
|
||||
prio = 2000 + refprio[rtbl.ref]
|
||||
conds.append((cond, 'fk join', prio))
|
||||
fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
|
||||
for fk, lcol in fks:
|
||||
left = col(ltbl.schema, ltbl.name, lcol)
|
||||
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
||||
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
||||
left, right = (child, par) if left == child else (par, child)
|
||||
for rtbl in coldict[right]:
|
||||
add_cond(left.col, right.col, rtbl.ref, 'fk join', 2000)
|
||||
# For name matching, use a {(colname, coltype): TableReference} dict
|
||||
col_table = defaultdict(lambda: [])
|
||||
for tbl, col in ((t, c) for t, cs in scoped_cols.items() for c in cs):
|
||||
col_table[(col.name, col.datatype)].append(tbl)
|
||||
coltyp = namedtuple('coltyp', 'name datatype')
|
||||
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
|
||||
# Find all name-match join conditions
|
||||
found = set(cnd[0] for cnd in conds)
|
||||
for c in refcols.get(lefttable.ref, []):
|
||||
for rtbl in col_table[(c.name, c.datatype)]:
|
||||
if rtbl.ref != lefttable.ref:
|
||||
cond = make_cond(lefttable.ref, rtbl.ref, c.name, c.name)
|
||||
if cond not in found:
|
||||
prio = (1000 if c.datatype and c.datatype in(
|
||||
'integer', 'bigint', 'smallint')
|
||||
else 0 + refprio[rtbl.ref])
|
||||
conds.append((cond, 'name join', prio))
|
||||
for c in (coltyp(c.name, c.datatype) for c in lcols):
|
||||
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
|
||||
add_cond(c.name, c.name, rtbl.ref, 'name join', 1000
|
||||
if c.datatype in ('integer', 'bigint', 'smallint') else 0)
|
||||
|
||||
if not conds:
|
||||
return []
|
||||
|
||||
conds, metas, prios = zip(*conds)
|
||||
conds, metas, prios = zip(*conds) if conds else ([], [], [])
|
||||
|
||||
return self.find_matches(word_before_cursor, conds,
|
||||
meta_collection=metas, type_priority=100, priority_collection=prios)
|
||||
|
@ -100,17 +100,18 @@ def test_schema_or_visible_table_completion(completer, complete_event):
|
||||
Completion(text='orders', start_position=0, display_meta='table')])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('text', [
|
||||
'SELECT FROM users',
|
||||
'SELECT FROM "users"',
|
||||
@pytest.mark.parametrize('table', [
|
||||
'users',
|
||||
'"users"',
|
||||
])
|
||||
def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, text):
|
||||
def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event, table):
|
||||
"""
|
||||
Suggest column and function names when selecting from table
|
||||
:param completer:
|
||||
:param complete_event:
|
||||
:return:
|
||||
"""
|
||||
text = 'SELECT FROM ' + table
|
||||
position = len('SELECT ')
|
||||
result = set(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
|
@ -342,6 +342,52 @@ def test_suggested_join_conditions(completer, complete_event, text):
|
||||
Completion(text='u2', start_position=0, display_meta='table alias'),
|
||||
Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join')])
|
||||
|
||||
@pytest.mark.parametrize('text', [
|
||||
'''SELECT *
|
||||
FROM users
|
||||
CROSS JOIN "Users"
|
||||
NATURAL JOIN users u
|
||||
JOIN "Users" u2 ON
|
||||
'''
|
||||
])
|
||||
def test_suggested_join_conditions_with_same_table_twice(completer, complete_event, text):
|
||||
position = len(text)
|
||||
result = completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event)
|
||||
assert result == [
|
||||
Completion(text='u2.userid = u.id', start_position=0, display_meta='fk join'),
|
||||
Completion(text='u2.userid = users.id', start_position=0, display_meta='fk join'),
|
||||
Completion(text='u2.userid = "Users".userid', start_position=0, display_meta='name join'),
|
||||
Completion(text='u2.username = "Users".username', start_position=0, display_meta='name join'),
|
||||
Completion(text='"Users"', start_position=0, display_meta='table alias'),
|
||||
Completion(text='u', start_position=0, display_meta='table alias'),
|
||||
Completion(text='u2', start_position=0, display_meta='table alias'),
|
||||
Completion(text='users', start_position=0, display_meta='table alias')]
|
||||
|
||||
@pytest.mark.parametrize('text', [
|
||||
'SELECT * FROM users JOIN users u2 on foo.'
|
||||
])
|
||||
def test_suggested_join_conditions_with_invalid_qualifier(completer, complete_event, text):
|
||||
position = len(text)
|
||||
result = set(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event))
|
||||
assert set(result) == set()
|
||||
|
||||
@pytest.mark.parametrize(('text', 'ref'), [
|
||||
('SELECT * FROM users JOIN NonTable on ', 'NonTable'),
|
||||
('SELECT * FROM users JOIN nontable nt on ', 'nt')
|
||||
])
|
||||
def test_suggested_join_conditions_with_invalid_table(completer, complete_event, text, ref):
|
||||
position = len(text)
|
||||
result = set(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event))
|
||||
assert set(result) == set([
|
||||
Completion(text='users', start_position=0, display_meta='table alias'),
|
||||
Completion(text=ref, start_position=0, display_meta='table alias')])
|
||||
|
||||
@pytest.mark.parametrize('text', [
|
||||
'SELECT * FROM "Users" u JOIN u',
|
||||
'SELECT * FROM "Users" u JOIN uid',
|
||||
|
Loading…
Reference in New Issue
Block a user