1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-10-06 10:17:15 +03:00

Merge pull request #531 from dbcli/koljonen/generate_aliases

Suggest table aliases + add tests for casing
This commit is contained in:
darikg 2016-07-06 07:39:56 -04:00 committed by GitHub
commit 1605bf1cdb
8 changed files with 450 additions and 242 deletions

View File

@ -139,11 +139,12 @@ class PGCli(object):
# Initialize completer
smart_completion = c['main'].as_bool('smart_completion')
settings = {'casing_file': get_casing_file(c),
'generate_casing_file': c['main'].as_bool('generate_casing_file'),
'asterisk_column_order': c['main']['asterisk_column_order']}
self.settings = {'casing_file': get_casing_file(c),
'generate_casing_file': c['main'].as_bool('generate_casing_file'),
'generate_aliases': c['main'].as_bool('generate_aliases'),
'asterisk_column_order': c['main']['asterisk_column_order']}
completer = PGCompleter(smart_completion, pgspecial=self.pgspecial,
settings=settings)
settings=self.settings)
self.completer = completer
self._completer_lock = threading.Lock()
self.register_special_commands()
@ -602,12 +603,8 @@ class PGCli(object):
callback = functools.partial(self._on_completions_refreshed,
persist_priorities=persist_priorities)
c = self.config
settings = {'casing_file': get_casing_file(c),
'generate_casing_file': c['main'].as_bool('generate_casing_file'),
'asterisk_column_order': c['main']['asterisk_column_order']}
self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
callback, history=history, settings=settings)
callback, history=history, settings=self.settings)
return [(None, None, None,
'Auto-completion refresh started in the background.')]

View File

@ -20,20 +20,27 @@ else:
Special = namedtuple('Special', [])
Database = namedtuple('Database', [])
Schema = namedtuple('Schema', [])
Table = namedtuple('Table', ['schema'])
# FromClauseItem is a table/view/function used in the FROM clause
# `tables` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique
FromClauseItem = namedtuple('FromClauseItem', 'schema tables')
Table = namedtuple('Table', ['schema', 'tables'])
View = namedtuple('View', ['schema', 'tables'])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
JoinCondition = namedtuple('JoinCondition', ['tables', 'parent'])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
Join = namedtuple('Join', ['tables', 'schema'])
Function = namedtuple('Function', ['schema', 'filter'])
Function = namedtuple('Function', ['schema', 'tables', 'filter'])
# For convenience, don't require the `filter` argument in Function constructor
Function.__new__.__defaults__ = (None, None)
Function.__new__.__defaults__ = (None, tuple(), None)
Table.__new__.__defaults__ = (None, tuple())
View.__new__.__defaults__ = (None, tuple())
FromClauseItem.__new__.__defaults__ = (None, tuple())
Column = namedtuple('Column', ['tables', 'require_last_table'])
Column.__new__.__defaults__ = (None, None)
View = namedtuple('View', ['schema'])
Keyword = namedtuple('Keyword', [])
NamedQuery = namedtuple('NamedQuery', [])
Datatype = namedtuple('Datatype', ['schema'])
@ -320,26 +327,26 @@ def suggest_based_on_last_token(token, stmt):
('copy', 'from', 'update', 'into', 'describe', 'truncate')):
schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith('join') and token.is_keyword
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = [Table(schema=schema)]
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
# Only tables can be TRUNCATED, otherwise suggest views
if token_v != 'truncate':
suggest.append(View(schema=schema))
# Suggest set-returning functions in the FROM clause
if token_v == 'from' or (token_v.endswith('join') and token.is_keyword):
suggest.append(Function(schema=schema, filter='for_from_clause'))
if token_v == 'from' or is_join:
suggest.append(FromClauseItem(schema=schema, tables=tables))
elif token_v == 'truncate':
suggest.append(Table(schema))
else:
suggest.extend((Table(schema), View(schema)))
if (token_v.endswith('join') and token.is_keyword
and _allow_join(stmt.parsed)):
tables = extract_tables(stmt.text_before_cursor)
if is_join and _allow_join(stmt.parsed):
suggest.append(Join(tables=tables, schema=schema))
return tuple(suggest)

View File

@ -15,6 +15,9 @@ wider_completion_menu = False
# lines. End of line (return) is considered as the end of the statement.
multi_line = False
# If set to True, table suggestions will include a table alias
generate_aliases = False
# log_file location.
# In Unix/Linux: ~/.config/pgcli/log
# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log

View File

@ -1,21 +1,21 @@
from __future__ import print_function, unicode_literals
import logging
import re
import itertools
from itertools import count, repeat, chain
import operator
from collections import namedtuple, defaultdict
from pgspecial.namedqueries import NamedQueries
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.contrib.completers import PathCompleter
from prompt_toolkit.document import Document
from .packages.sqlcompletion import (
from .packages.sqlcompletion import (FromClauseItem,
suggest_type, Special, Database, Schema, Table, Function, Column, View,
Keyword, NamedQuery, Datatype, Alias, Path, JoinCondition, Join)
from .packages.function_metadata import ColumnMetadata, ForeignKey
from .packages.parseutils import last_word, TableReference
from .packages.pgliterals.main import get_literals
from .packages.prioritization import PrevalenceCounter
from .config import load_config, config_location, get_config
from .config import load_config, config_location
try:
from collections import OrderedDict
@ -32,6 +32,16 @@ Match = namedtuple('Match', ['completion', 'priority'])
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
def generate_alias(tbl, tbs):
""" Generate a table alias, consisting of all upper-case letters in
the table name, or, if there are no upper-case letters, the first letter +
all letters preceded by _
param tbl - unescaped name of the table to alias
param tbls - set TableReference objects for tables already in query
"""
return ''.join([l for l in tbl if l.isupper()] or
[l for l, prev in zip(tbl, '_' + tbl) if prev == '_' and l != '_'])
class PGCompleter(Completer):
keywords = get_literals('keywords')
functions = get_literals('functions')
@ -43,6 +53,7 @@ class PGCompleter(Completer):
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
settings = settings or {}
self.generate_aliases = settings.get('generate_aliases')
self.casing_file = settings.get('casing_file')
self.generate_casing_file = settings.get('generate_casing_file')
self.asterisk_column_order = settings.get(
@ -279,9 +290,9 @@ class PGCompleter(Completer):
return -float('Infinity'), -match_point
# Fallback to meta param if meta_collection param is None
meta_collection = meta_collection or itertools.repeat(meta)
meta_collection = meta_collection or repeat(meta)
# Fallback to 0 if priority_collection param is None
priority_collection = priority_collection or itertools.repeat(0)
priority_collection = priority_collection or repeat(0)
collection = zip(collection, meta_collection, priority_collection)
@ -299,9 +310,13 @@ class PGCompleter(Completer):
# position. Since we use *higher* priority to mean "more
# important," we use -ord(c) to prioritize "aa" > "ab" and end
# with 1 to prioritize shorter strings (ie "user" > "users").
# We first do a case-insensitive sort and then a
# case-sensitive one as a tie breaker.
# We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names.
lexical_priority = tuple(-ord(c) for c in self.unescape_name(item)) + (1,)
lexical_priority = (tuple(0 if c in(' _') else -ord(c)
for c in self.unescape_name(item.lower())) + (1,)
+ tuple(c for c in item))
item = self.case(item)
priority = type_priority, prio, sort_key, priority_func(item), lexical_priority
@ -346,12 +361,13 @@ class PGCompleter(Completer):
return [m.completion for m in matches]
def get_column_matches(self, suggestion, word_before_cursor):
tables = suggestion.tables
_logger.debug("Completion column scope: %r", tables)
scoped_cols = self.populate_scoped_cols(tables)
colit = scoped_cols.items
flat_cols = list(itertools.chain(*((c.name for c in cols)
flat_cols = list(chain(*((c.name for c in cols)
for t, cols in colit())))
if suggestion.require_last_table:
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
@ -385,15 +401,21 @@ class PGCompleter(Completer):
return self.find_matches(word_before_cursor, flat_cols, meta='column')
def generate_alias(self, tbl, tbls):
def alias(self, tbl, tbls):
""" Generate a unique table alias
tbl - name of the table to alias, quoted if it needs to be
tbls - set of table refs already in use, normalized with normalize_ref
tbls - TableReference iterable of tables already in query
"""
if tbl[0] == '"':
aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in itertools.count(2))
tbl = self.case(tbl)
tbls = set(normalize_ref(t.ref) for t in tbls)
if self.generate_aliases:
tbl = generate_alias(self.unescape_name(tbl), tbls)
if normalize_ref(tbl) not in tbls:
return tbl
elif tbl[0] == '"':
aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2))
else:
aliases = (self.case(tbl) + str(i) for i in itertools.count(2))
aliases = (tbl + str(i) for i in count(2))
return next(a for a in aliases if normalize_ref(a) not in tbls)
def get_join_matches(self, suggestion, word_before_cursor):
@ -418,8 +440,8 @@ class PGCompleter(Completer):
if suggestion.schema and left.schema != suggestion.schema:
continue
c = self.case
if normalize_ref(left.tbl) in refs:
lref = self.generate_alias(left.tbl, refs)
if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.tables)
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref)
else:
@ -493,14 +515,19 @@ class PGCompleter(Completer):
return self.find_matches(word_before_cursor, conds,
meta_collection=metas, type_priority=100, priority_collection=prios)
def get_function_matches(self, suggestion, word_before_cursor):
def get_function_matches(self, suggestion, word_before_cursor, alias=False):
if suggestion.filter == 'for_from_clause':
# Only suggest functions allowed in FROM clause
filt = lambda f: not f.is_aggregate and not f.is_window
funcs = self.populate_functions(suggestion.schema, filt)
if alias:
funcs = [self.case(f) + '() ' + self.alias(f,
suggestion.tables) for f in funcs]
else:
funcs = [self.case(f) + '()' for f in funcs]
else:
funcs = self.populate_schema_objects(
suggestion.schema, 'functions')
funcs = [f + '()' for f in self.populate_schema_objects(
suggestion.schema, 'functions')]
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
@ -527,7 +554,16 @@ class PGCompleter(Completer):
return self.find_matches(word_before_cursor, schema_names, meta='schema')
def get_table_matches(self, suggestion, word_before_cursor):
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
alias = self.generate_aliases
t_sug = Table(*suggestion)
v_sug = View(*suggestion)
f_sug = Function(*suggestion, filter='for_from_clause')
return (self.get_table_matches(t_sug, word_before_cursor, alias)
+ self.get_view_matches(v_sug, word_before_cursor, alias)
+ self.get_function_matches(f_sug, word_before_cursor, alias))
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
# Unless we're sure the user really wants them, don't suggest the
@ -535,16 +571,21 @@ class PGCompleter(Completer):
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.startswith('pg_')]
if alias:
tables = [self.case(t) + ' ' + self.alias(t, suggestion.tables)
for t in tables]
return self.find_matches(word_before_cursor, tables, meta='table')
def get_view_matches(self, suggestion, word_before_cursor):
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
views = self.populate_schema_objects(suggestion.schema, 'views')
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.startswith('pg_')]
if alias:
views = [self.case(v) + ' ' + self.alias(v, suggestion.tables)
for v in views]
return self.find_matches(word_before_cursor, views, meta='view')
def get_alias_matches(self, suggestion, word_before_cursor):
@ -594,6 +635,7 @@ class PGCompleter(Completer):
word_before_cursor, NamedQueries.instance.list(), meta='named query')
suggestion_matchers = {
FromClauseItem: get_from_clause_item_matches,
JoinCondition: get_join_condition_matches,
Join: get_join_matches,
Column: get_column_matches,
@ -663,7 +705,7 @@ class PGCompleter(Completer):
objects = [obj for schema in schemas
for obj in metadata[schema].keys()]
return objects
return [self.case(o) for o in objects]
def populate_functions(self, schema, filter_func):
"""Returns a list of function names

View File

@ -56,7 +56,7 @@ class MetaData(object):
for x in self.metadata.get('views', {}).get(schema, [])]
def functions(self, schema='public', pos=0):
return [function(escape(x[0]), pos)
return [function(escape(x[0] + '()'), pos)
for x in self.metadata.get('functions', {}).get(schema, [])]
def schemas(self, pos=0):
@ -65,9 +65,12 @@ class MetaData(object):
@property
def completer(self):
return self.get_completer()
def get_completer(self, settings=None, casing=None):
metadata = self.metadata
import pgcli.pgcompleter as pgcompleter
comp = pgcompleter.PGCompleter(smart_completion=True)
from pgcli.pgcompleter import PGCompleter
comp = PGCompleter(smart_completion=True, settings=settings)
schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []
@ -105,5 +108,6 @@ class MetaData(object):
comp.extend_datatypes(datatypes)
comp.extend_foreignkeys(foreignkeys)
comp.set_search_path(['public'])
comp.extend_casing(casing or [])
return comp

View File

@ -2,7 +2,7 @@ from __future__ import unicode_literals
import pytest
import itertools
from metadata import (MetaData, alias, name_join, fk_join, join,
function, wildcard_expansion)
schema, table, function, wildcard_expansion)
from prompt_toolkit.document import Document
from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
@ -49,19 +49,25 @@ testdata = MetaData(metadata)
def completer():
return testdata.completer
casing = ('SELECT', 'Orders', 'User_Emails', 'CUSTOM', 'Func1')
@pytest.fixture
def completer_with_casing():
return testdata.get_completer(casing=casing)
@pytest.fixture
def completer_with_aliases():
return testdata.get_completer({'generate_aliases': True})
@pytest.fixture
def completer_aliases_casing(request):
return testdata.get_completer({'generate_aliases': True}, casing)
@pytest.fixture
def complete_event():
from mock import Mock
return Mock()
def test_schema_or_visible_table_completion(completer, complete_event):
text = 'SELECT * FROM '
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set(testdata.schemas() + testdata.functions() + testdata.tables())
@pytest.mark.parametrize('table', [
'users',
'"users"',
@ -289,8 +295,8 @@ def test_schema_qualified_function_name(completer, complete_event):
result = set(completer.get_completions(
Document(text=text, cursor_position=postion), complete_event))
assert result == set([
function('func3', -len('func')),
function('set_returning_func', -len('func'))])
function('func3()', -len('func')),
function('set_returning_func()', -len('func'))])
@pytest.mark.parametrize('text', [
@ -430,3 +436,51 @@ def test_suggest_columns_from_quoted_table(completer, complete_event, text):
result = completer.get_completions(Document(text=text, cursor_position=pos),
complete_event)
assert set(result) == set(testdata.columns('Users', 'custom'))
texts = ['SELECT * FROM ', 'SELECT * FROM public.Orders O CROSS JOIN ']
@pytest.mark.parametrize('text', texts)
def test_schema_or_visible_table_completion(completer, complete_event, text):
result = completer.get_completions(Document(text=text), complete_event)
assert set(result) == set(testdata.schemas()
+ testdata.views() + testdata.tables() + testdata.functions())
result = completer.get_completions(Document(text=text), complete_event)
@pytest.mark.parametrize('text', texts)
def test_table_aliases(completer_with_aliases, complete_event, text):
result = completer_with_aliases.get_completions(
Document(text=text), complete_event)
assert set(result) == set(testdata.schemas() + [
table('users u'),
table('orders o' if text == 'SELECT * FROM ' else 'orders o2'),
table('"select" s'),
function('func1() f'),
function('func2() f')])
@pytest.mark.parametrize('text', texts)
def test_aliases_with_casing(completer_aliases_casing, complete_event, text):
result = completer_aliases_casing.get_completions(
Document(text=text), complete_event)
assert set(result) == set([
schema('public'),
schema('CUSTOM'),
schema('"Custom"'),
table('users u'),
table('Orders O' if text == 'SELECT * FROM ' else 'Orders O2'),
table('"select" s'),
function('Func1() F'),
function('func2() f')])
@pytest.mark.parametrize('text', texts)
def test_table_casing(completer_with_casing, complete_event, text):
result = completer_with_casing.get_completions(
Document(text=text), complete_event)
assert set(result) == set([
schema('public'),
schema('CUSTOM'),
schema('"Custom"'),
table('users'),
table('Orders'),
table('"select"'),
function('Func1()'),
function('func2()')])

View File

@ -1,7 +1,7 @@
from __future__ import unicode_literals
import pytest
from metadata import (MetaData, alias, name_join, fk_join, join, keyword,
table, function, column, wildcard_expansion)
schema, table, view, function, column, wildcard_expansion)
from prompt_toolkit.document import Document
from pgcli.packages.function_metadata import FunctionMetadata, ForeignKey
@ -14,10 +14,10 @@ metadata = {
'views': {
'user_emails': ['id', 'email']},
'functions': [
['custom_func1', [''], [''], [''], '', False, False,
False],
['custom_func2', [''], [''], [''], '', False, False,
False],
['custom_fun', [''], [''], [''], '', False, False, False],
['_custom_fun', [''], [''], [''], '', False, False, False],
['custom_func1', [''], [''], [''], '', False, False, False],
['custom_func2', [''], [''], [''], '', False, False, False],
['set_returning_func', ['x', 'y'], ['integer', 'integer'],
['o', 'o'], '', False, False, True]],
'datatypes': ['custom_type1', 'custom_type2'],
@ -31,10 +31,45 @@ metadata = dict((k, {'public': v}) for k, v in metadata.items())
testdata = MetaData(metadata)
cased_users_cols = ['ID', 'PARENTID', 'Email', 'First_Name', 'last_name']
cased_users2_cols = ['UserID', 'UserName']
cased_funcs = ['Custom_Fun', '_custom_fun', 'Custom_Func1',
'custom_func2', 'set_returning_func']
cased_tbls = ['Users', 'Orders']
cased_views = ['User_Emails']
casing = (['SELECT', 'PUBLIC'] + cased_funcs + cased_tbls + cased_views
+ cased_users_cols + cased_users2_cols)
# Lists for use in assertions
cased_funcs = [function(f + '()') for f in cased_funcs]
cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])]
cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls
cased_users_cols = [column(c) for c in cased_users_cols]
aliased_rels = [table(t) for t in ('users u', '"Users" U', 'orders o',
'"select" s')] + [view('user_emails ue')] + [function(f) for f in (
'_custom_fun() cf', 'custom_fun() cf', 'custom_func1() cf',
'custom_func2() cf', 'set_returning_func() srf')]
cased_aliased_rels = [table(t) for t in ('Users U', '"Users" U', 'Orders O',
'"select" s')] + [view('User_Emails UE')] + [function(f) for f in (
'_custom_fun() cf', 'Custom_Fun() CF', 'Custom_Func1() CF',
'custom_func2() cf', 'set_returning_func() srf')]
@pytest.fixture
def completer():
return testdata.completer
@pytest.fixture
def cased_completer():
return testdata.get_completer(casing=casing)
@pytest.fixture
def aliased_completer():
return testdata.get_completer({'generate_aliases': True})
@pytest.fixture
def cased_aliased_completer(request):
return testdata.get_completer({'generate_aliases': True}, casing)
@pytest.fixture
def complete_event():
from mock import Mock
@ -58,15 +93,6 @@ def test_select_keyword_completion(completer, complete_event):
assert set(result) == set([keyword('SELECT', -3)])
def test_schema_or_visible_table_completion(completer, complete_event):
text = 'SELECT * FROM '
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set(testdata.schemas()
+ testdata.views() + testdata.tables() + testdata.functions())
def test_builtin_function_name_completion(completer, complete_event):
text = 'SELECT MA'
position = len('SELECT MA')
@ -94,8 +120,10 @@ def test_user_function_name_completion(completer, complete_event):
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set([
function('custom_func1', -2),
function('custom_func2', -2),
function('custom_fun()', -2),
function('_custom_fun()', -2),
function('custom_func1()', -2),
function('custom_func2()', -2),
keyword('CURRENT', -2),
])
@ -107,8 +135,10 @@ def test_user_function_name_completion_matches_anywhere(completer,
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set([
function('custom_func1', -2),
function('custom_func2', -2)])
function('custom_fun()', -2),
function('_custom_fun()', -2),
function('custom_func1()', -2),
function('custom_func2()', -2)])
def test_suggested_column_names_from_visible_table(completer, complete_event):
@ -129,6 +159,22 @@ def test_suggested_column_names_from_visible_table(completer, complete_event):
)
def test_suggested_cased_column_names(cased_completer, complete_event):
"""
Suggest column and function names when selecting from table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT from users'
position = len('SELECT ')
result = set(cased_completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set(cased_funcs + cased_users_cols
+ testdata.builtin_functions() + testdata.keywords())
def test_suggested_column_names_in_function(completer, complete_event):
"""
Suggest column and function names when selecting multiple
@ -205,6 +251,22 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event):
complete_event))
assert set(result) == set(testdata.columns('users'))
def test_suggested_cased_column_names_with_alias(cased_completer, complete_event):
"""
Suggest column names on table alias and dot
when selecting multiple columns from table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT u.id, u. from users u'
position = len('SELECT u.id, u.')
result = set(cased_completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set(cased_users_cols)
def test_suggested_multiple_column_names_with_dot(completer, complete_event):
"""
Suggest column names on table names and dot
@ -231,29 +293,38 @@ def test_suggest_columns_after_three_way_join(completer, complete_event):
assert (column('id') in
set(result))
@pytest.mark.parametrize('text', [
'SELECT * FROM users u JOIN "Users" u2 ON ',
'SELECT * FROM users u INNER join "Users" u2 ON ',
'SELECT * FROM USERS u right JOIN "Users" u2 ON ',
'SELECT * FROM users u LEFT JOIN "Users" u2 ON ',
'SELECT * FROM Users u FULL JOIN "Users" u2 ON ',
'SELECT * FROM users u right outer join "Users" u2 ON ',
'SELECT * FROM Users u LEFT OUTER JOIN "Users" u2 ON ',
'SELECT * FROM users u FULL OUTER JOIN "Users" u2 ON ',
join_condition_texts = [
'SELECT * FROM users U JOIN "Users" U2 ON ',
'SELECT * FROM users U INNER join "Users" U2 ON ',
'SELECT * FROM USERS U right JOIN "Users" U2 ON ',
'SELECT * FROM users U LEFT JOIN "Users" U2 ON ',
'SELECT * FROM Users U FULL JOIN "Users" U2 ON ',
'SELECT * FROM users U right outer join "Users" U2 ON ',
'SELECT * FROM Users U LEFT OUTER JOIN "Users" U2 ON ',
'SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON ',
'''SELECT *
FROM users u
FULL OUTER JOIN "Users" u2 ON
FROM users U
FULL OUTER JOIN "Users" U2 ON
'''
])
]
@pytest.mark.parametrize('text', join_condition_texts)
def test_suggested_join_conditions(completer, complete_event, text):
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
Document(text=text,), complete_event))
assert set(result) == set([
alias('u'),
alias('u2'),
fk_join('u2.userid = u.id')])
alias('U'),
alias('U2'),
fk_join('U2.userid = U.id')])
@pytest.mark.parametrize('text', join_condition_texts)
def test_cased_join_conditions(cased_completer, complete_event, text):
result = set(cased_completer.get_completions(
Document(text=text), complete_event))
assert set(result) == set([
alias('U'),
alias('U2'),
fk_join('U2.UserID = U.ID')])
@pytest.mark.parametrize('text', [
'''SELECT *
@ -273,10 +344,11 @@ def test_suggested_join_conditions_with_same_table_twice(completer, complete_eve
fk_join('u2.userid = users.id'),
name_join('u2.userid = "Users".userid'),
name_join('u2.username = "Users".username'),
alias('"Users"'),
alias('u'),
alias('u2'),
alias('users')]
alias('users'),
alias('"Users"')
]
@pytest.mark.parametrize('text', [
'SELECT * FROM users JOIN users u2 on foo.'
@ -314,24 +386,44 @@ def test_suggested_joins_fuzzy(completer, complete_event, text):
expected = join('users ON users.id = u.userid', -len(last_word))
assert expected in result
@pytest.mark.parametrize('text', [
'SELECT * FROM users JOIN ',
join_texts = [
'SELECT * FROM Users JOIN ',
'''SELECT *
FROM users
FROM Users
INNER JOIN '''
])
]
@pytest.mark.parametrize('text', join_texts)
def test_suggested_joins(completer, complete_event, text):
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
Document(text=text), complete_event))
assert set(result) == set(testdata.schemas() + testdata.tables()
+ testdata.views() + [
join('"Users" ON "Users".userid = users.id'),
join('users users2 ON users2.id = users.parentid'),
join('users users2 ON users2.parentid = users.id'),
join('"Users" ON "Users".userid = Users.id'),
join('users users2 ON users2.id = Users.parentid'),
join('users users2 ON users2.parentid = Users.id'),
] + testdata.functions())
@pytest.mark.parametrize('text', join_texts)
def test_cased_joins(cased_completer, complete_event, text):
result = set(cased_completer.get_completions(
Document(text=text), complete_event))
assert set(result) == set([schema('PUBLIC')] + cased_rels + [
join('"Users" ON "Users".UserID = Users.ID'),
join('Users Users2 ON Users2.ID = Users.PARENTID'),
join('Users Users2 ON Users2.PARENTID = Users.ID'),
])
@pytest.mark.parametrize('text', join_texts)
def test_aliased_joins(aliased_completer, complete_event, text):
result = set(aliased_completer.get_completions(
Document(text=text), complete_event))
assert set(result) == set(testdata.schemas() + aliased_rels + [
join('"Users" U ON U.userid = Users.id'),
join('users u ON u.id = Users.parentid'),
join('users u ON u.parentid = Users.id'),
])
@pytest.mark.parametrize('text', [
'SELECT * FROM public."Users" JOIN ',
'SELECT * FROM public."Users" RIGHT OUTER JOIN ',
@ -459,15 +551,17 @@ def test_table_names_after_from(completer, complete_event, text):
assert set(result) == set(testdata.schemas() + testdata.tables()
+ testdata.views() + testdata.functions())
assert [c.text for c in result] == [
'"Users"',
'custom_func1',
'custom_func2',
'_custom_fun()',
'custom_fun()',
'custom_func1()',
'custom_func2()',
'orders',
'public',
'"select"',
'set_returning_func',
'set_returning_func()',
'user_emails',
'users'
'users',
'"Users"',
]
def test_auto_escaped_col_names(completer, complete_event):
@ -681,3 +775,62 @@ def test_suggest_columns_from_quoted_table(completer, complete_event):
result = completer.get_completions(Document(text=text, cursor_position=pos),
complete_event)
assert set(result) == set(testdata.columns('Users'))
@pytest.mark.parametrize('text', ['SELECT * FROM ',
'SELECT * FROM Orders o CROSS JOIN '])
def test_schema_or_visible_table_completion(completer, complete_event, text):
result = completer.get_completions(Document(text=text), complete_event)
assert set(result) == set(testdata.schemas()
+ testdata.views() + testdata.tables() + testdata.functions())
@pytest.mark.parametrize('text', ['SELECT * FROM '])
def test_table_aliases(aliased_completer, complete_event, text):
result = aliased_completer.get_completions(
Document(text=text), complete_event)
assert set(result) == set(testdata.schemas() + aliased_rels)
@pytest.mark.parametrize('text', ['SELECT * FROM Orders o CROSS JOIN '])
def test_duplicate_table_aliases(aliased_completer, complete_event, text):
result = aliased_completer.get_completions(
Document(text=text), complete_event)
assert set(result) == set(testdata.schemas() + [
table('orders o2'),
table('users u'),
table('"Users" U'),
table('"select" s'),
view('user_emails ue'),
function('_custom_fun() cf'),
function('custom_fun() cf'),
function('custom_func1() cf'),
function('custom_func2() cf'),
function('set_returning_func() srf')])
@pytest.mark.parametrize('text', ['SELECT * FROM Orders o CROSS JOIN '])
def test_duplicate_aliases_with_casing(cased_aliased_completer,
complete_event, text):
result = cased_aliased_completer.get_completions(
Document(text=text), complete_event)
assert set(result) == set([
schema('PUBLIC'),
table('Orders O2'),
table('Users U'),
table('"Users" U'),
table('"select" s'),
view('User_Emails UE'),
function('_custom_fun() cf'),
function('Custom_Fun() CF'),
function('Custom_Func1() CF'),
function('custom_func2() cf'),
function('set_returning_func() srf')])
@pytest.mark.parametrize('text', ['SELECT * FROM '])
def test_aliases_with_casing(cased_aliased_completer, complete_event, text):
result = cased_aliased_completer.get_completions(
Document(text=text), complete_event)
assert set(result) == set([schema('PUBLIC')] + cased_aliased_rels)
@pytest.mark.parametrize('text', ['SELECT * FROM '])
def test_table_casing(cased_completer, complete_event, text):
result = cased_completer.get_completions(
Document(text=text), complete_event)
assert set(result) == set([schema('PUBLIC')] + cased_rels)

View File

@ -1,25 +1,23 @@
from pgcli.packages.sqlcompletion import (
suggest_type, Special, Database, Schema, Table, Column, View, Keyword,
Function, Datatype, Alias, JoinCondition, Join)
FromClauseItem, Function, Datatype, Alias, JoinCondition, Join)
import pytest
# Returns the expected select-clause suggestions for a single-table select
def cols_etc(table, schema=None, alias=None, is_function=False, parent=None):
return set([
Column(tables=((schema, table, alias, is_function),)),
Function(schema=parent),
Keyword()])
def test_select_suggests_cols_with_visible_table_scope():
suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
assert set(suggestions) == set([
Column(tables=((None, 'tabl', None, False),)),
Function(schema=None),
Keyword()
])
assert set(suggestions) == cols_etc('tabl')
def test_select_suggests_cols_with_qualified_table_scope():
suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
assert set(suggestions) == set([
Column(tables=(('sch', 'tabl', None, False),)),
Function(schema=None),
Keyword()
])
assert set(suggestions) == cols_etc('tabl', 'sch')
@pytest.mark.parametrize('expression', [
@ -27,11 +25,7 @@ def test_select_suggests_cols_with_qualified_table_scope():
])
def test_where_suggests_columns_functions_quoted_table(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([
Column(tables=((None, 'tabl', '"tabl"', False),)),
Function(schema=None),
Keyword(),
])
assert set(suggestions) == cols_etc('tabl', alias='"tabl"')
@pytest.mark.parametrize('expression', [
@ -48,11 +42,7 @@ def test_where_suggests_columns_functions_quoted_table(expression):
])
def test_where_suggests_columns_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([
Column(tables=((None, 'tabl', None, False),)),
Function(schema=None),
Keyword(),
])
assert set(suggestions) == cols_etc('tabl')
@pytest.mark.parametrize('expression', [
@ -61,21 +51,13 @@ def test_where_suggests_columns_functions(expression):
])
def test_where_in_suggests_columns(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([
Column(tables=((None, 'tabl', None, False),)),
Function(schema=None),
Keyword(),
])
assert set(suggestions) == cols_etc('tabl')
def test_where_equals_any_suggests_columns_or_keywords():
text = 'SELECT * FROM tabl WHERE foo = ANY('
suggestions = suggest_type(text, text)
assert set(suggestions) == set([
Column(tables=((None, 'tabl', None, False),)),
Function(schema=None),
Keyword(),
])
assert set(suggestions) == cols_etc('tabl')
def test_lparen_suggests_cols():
@ -114,9 +96,7 @@ def test_suggests_tables_views_and_schemas(expression):
def test_suggest_tables_views_schemas_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None),
Schema()
])
@ -127,11 +107,10 @@ def test_suggest_tables_views_schemas_and_functions(expression):
])
def test_suggest_after_join_with_two_tables(expression):
suggestions = suggest_type(expression, expression)
tables = tuple([(None, 'foo', None, False), (None, 'bar', None, False)])
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
Join(((None, 'foo', None, False), (None, 'bar', None, False)), None),
FromClauseItem(schema=None, tables=tables),
Join(tables, None),
Schema(),
])
@ -142,10 +121,9 @@ def test_suggest_after_join_with_two_tables(expression):
])
def test_suggest_after_join_with_one_table(expression):
suggestions = suggest_type(expression, expression)
tables = ((None, 'foo', None, False),)
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None, tables=tables),
Join(((None, 'foo', None, False),), None),
Schema(),
])
@ -154,7 +132,6 @@ def test_suggest_after_join_with_one_table(expression):
@pytest.mark.parametrize('expression', [
'INSERT INTO sch.',
'COPY sch.',
'UPDATE sch.',
'DESCRIBE sch.',
])
def test_suggest_qualified_tables_and_views(expression):
@ -165,6 +142,17 @@ def test_suggest_qualified_tables_and_views(expression):
])
@pytest.mark.parametrize('expression', [
'UPDATE sch.',
])
def test_suggest_qualified_aliasable_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([
Table(schema='sch'),
View(schema='sch'),
])
@pytest.mark.parametrize('expression', [
'SELECT * FROM sch.',
'SELECT * FROM sch."',
@ -174,11 +162,7 @@ def test_suggest_qualified_tables_and_views(expression):
])
def test_suggest_qualified_tables_views_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([
Table(schema='sch'),
View(schema='sch'),
Function(schema='sch', filter='for_from_clause'),
])
assert set(suggestions) == set([FromClauseItem(schema='sch')])
@pytest.mark.parametrize('expression', [
@ -186,11 +170,10 @@ def test_suggest_qualified_tables_views_and_functions(expression):
])
def test_suggest_qualified_tables_views_functions_and_joins(expression):
suggestions = suggest_type(expression, expression)
tbls = tuple([(None, 'foo', None, False)])
assert set(suggestions) == set([
Table(schema='sch'),
View(schema='sch'),
Function(schema='sch', filter='for_from_clause'),
Join(((None, 'foo', None, False),), 'sch'),
FromClauseItem(schema='sch', tables=tbls),
Join(tbls, 'sch'),
])
@ -225,9 +208,7 @@ def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type('SELECT a, b FROM tbl1, ',
'SELECT a, b FROM tbl1, ')
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None),
Schema(),
])
@ -259,11 +240,7 @@ def test_insert_into_lparen_comma_suggests_cols():
def test_partially_typed_col_name_suggests_col_names():
suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
'SELECT * FROM tabl WHERE col_n')
assert set(suggestions) == set([
Column(tables=((None, 'tabl', None, False),)),
Function(schema=None),
Keyword()
])
assert set(suggestions) == cols_etc('tabl')
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
@ -352,15 +329,24 @@ def test_outer_table_reference_in_exists_subquery_suggests_columns():
@pytest.mark.parametrize('expression', [
'SELECT * FROM (SELECT * FROM ',
'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
])
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
assert set(suggestion) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None),
Schema(),
])
@pytest.mark.parametrize('expression', [
'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
])
def test_sub_select_table_name_completion_with_outer_table(expression):
suggestion = suggest_type(expression, expression)
tbls = tuple([(None, 'foo', None, False)])
assert set(suggestion) == set([
FromClauseItem(schema=None, tables=tbls),
Schema(),
])
@ -379,11 +365,7 @@ def test_sub_select_col_name_completion():
def test_sub_select_multiple_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
'SELECT * FROM (SELECT a, ')
assert set(suggestions) == set([
Column(tables=((None, 'abc', None, False),)),
Function(schema=None),
Keyword(),
])
assert set(suggestions) == cols_etc('abc')
def test_sub_select_dot_col_name_completion():
@ -402,22 +384,22 @@ def test_sub_select_dot_col_name_completion():
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
suggestion = suggest_type(text, text)
tbls = tuple([(None, 'abc', tbl_alias or None, False)])
assert set(suggestion) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None, tables=tbls),
Schema(),
Join(((None, 'abc', tbl_alias if tbl_alias else None, False),), None),
Join(tbls, None),
])
def test_left_join_with_comma():
text = 'select * from foo f left join bar b,'
suggestions = suggest_type(text, text)
# tbls should also include (None, 'bar', 'b', False)
# but there's a bug with commas
tbls = tuple([(None, 'foo', 'f', False)])
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None, tables=tbls),
Schema(),
])
@ -428,14 +410,13 @@ def test_left_join_with_comma():
])
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, 'abc', 'a', False), (None, 'def', 'd', False))
assert set(suggestions) == set([
Column(tables=((None, 'abc', 'a', False),)),
Table(schema='a'),
View(schema='a'),
Function(schema='a'),
JoinCondition(tables=((None, 'abc', 'a', False),
(None, 'def', 'd', False)),
parent=(None, 'abc', 'a', False))
JoinCondition(tables=tables, parent=(None, 'abc', 'a', False))
])
@ -467,10 +448,8 @@ on ''',
])
def test_on_suggests_aliases_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set((JoinCondition(
tables=((None, 'abc', 'a', False),
(None, 'bcd', 'b', False)),
parent=None),
tables = ((None, 'abc', 'a', False), (None, 'bcd', 'b', False))
assert set(suggestions) == set((JoinCondition(tables=tables, parent=None),
Alias(aliases=('a', 'b',)),))
@pytest.mark.parametrize('sql', [
@ -479,10 +458,8 @@ def test_on_suggests_aliases_and_join_conditions(sql):
])
def test_on_suggests_tables_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set((JoinCondition(tables=(
(None, 'abc', None, False),
(None, 'bcd', None, False)),
parent=None),
tables = ((None, 'abc', None, False), (None, 'bcd', None, False))
assert set(suggestions) == set((JoinCondition(tables=tables, parent=None),
Alias(aliases=('abc', 'bcd',)),))
@ -501,23 +478,17 @@ def test_on_suggests_aliases_right_side(sql):
])
def test_on_suggests_tables_and_join_conditions_right_side(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set((JoinCondition(
tables=(
(None, 'abc', None, False),
(None, 'bcd', None, False)),
parent=None
),
tables = ((None, 'abc', None, False), (None, 'bcd', None, False))
assert set(suggestions) == set((JoinCondition(tables=tables, parent=None),
Alias(aliases=('abc', 'bcd',)),))
@pytest.mark.parametrize('col_list', ('', 'col1, ',))
def test_join_using_suggests_common_columns(col_list):
text = 'select * from abc inner join def using (' + col_list
tables = ((None, 'abc', None, False), (None, 'def', None, False))
assert set(suggest_type(text, text)) == set([
Column(tables=((None, 'abc', None, False),
(None, 'def', None, False)),
require_last_table=True),
])
Column(tables=tables, require_last_table=True),])
def test_suggest_columns_after_multiple_joins():
@ -534,9 +505,7 @@ def test_2_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ',
'select * from a; select * from ')
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None),
Schema(),
])
@ -552,9 +521,7 @@ def test_2_statements_2nd_current():
suggestions = suggest_type('select * from; select * from ',
'select * from; select * from ')
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None),
Schema(),
])
@ -563,38 +530,26 @@ def test_2_statements_1st_current():
suggestions = suggest_type('select * from ; select * from b',
'select * from ')
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None),
Schema(),
])
suggestions = suggest_type('select from a; select * from b',
'select ')
assert set(suggestions) == set([
Column(tables=((None, 'a', None, False),)),
Function(schema=None),
Keyword()
])
assert set(suggestions) == cols_etc('a')
def test_3_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ; select * from c',
'select * from a; select * from ')
assert set(suggestions) == set([
Table(schema=None),
View(schema=None),
Function(schema=None, filter='for_from_clause'),
FromClauseItem(schema=None),
Schema(),
])
suggestions = suggest_type('select * from a; select from b; select * from c',
'select * from a; select ')
assert set(suggestions) == set([
Column(tables=((None, 'b', None, False),)),
Function(schema=None),
Keyword()
])
assert set(suggestions) == cols_etc('b')
def test_create_db_with_template():
@ -709,11 +664,7 @@ def test_invalid_sql():
def test_suggest_where_keyword(text):
# https://github.com/dbcli/mycli/issues/135
suggestions = suggest_type(text, text)
assert set(suggestions) == set([
Column(tables=((None, 'foo', None, False),)),
Function(schema=None),
Keyword(),
])
assert set(suggestions) == cols_etc('foo')
@pytest.mark.parametrize('text, before, expected', [
@ -737,11 +688,8 @@ def test_named_query_completion(text, before, expected):
def test_select_suggests_fields_from_function():
suggestions = suggest_type('SELECT FROM func()', 'SELECT ')
assert set(suggestions) == set([
Column(tables=((None, 'func', None, True),)),
Function(schema=None),
Keyword()
])
assert set(suggestions) == cols_etc(
'func', is_function=True)
@pytest.mark.parametrize('sql', [
@ -750,7 +698,7 @@ def test_select_suggests_fields_from_function():
])
def test_ignore_leading_double_quotes(sql):
suggestions = suggest_type(sql, sql)
assert Table(schema=None) in set(suggestions)
assert FromClauseItem(schema=None) in set(suggestions)
@pytest.mark.parametrize('sql', [