diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils.py index ff332d86..fd78eec3 100644 --- a/pgcli/packages/parseutils.py +++ b/pgcli/packages/parseutils.py @@ -66,6 +66,7 @@ def last_word(text, include='alphanum_underscore'): TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', 'is_function']) +TableReference.ref = property(lambda self: self.alias or self.name) # This code is borrowed from sqlparse example script. diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index e6d9c298..1b7873a7 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -3,7 +3,7 @@ import logging import re import itertools import operator -from collections import namedtuple +from collections import namedtuple, defaultdict from pgspecial.namedqueries import NamedQueries from prompt_toolkit.completion import Completer, Completion from prompt_toolkit.contrib.completers import PathCompleter @@ -11,7 +11,7 @@ from prompt_toolkit.document import Document from .packages.sqlcompletion import ( suggest_type, Special, Database, Schema, Table, Function, Column, View, Keyword, NamedQuery, Datatype, Alias, Path) -from .packages.parseutils import last_word +from .packages.parseutils import last_word, TableReference from .packages.pgliterals.main import get_literals from .packages.prioritization import PrevalenceCounter from .config import load_config, config_location @@ -182,7 +182,7 @@ class PGCompleter(Completer): self.all_completions = set(self.keywords + self.functions) def find_matches(self, text, collection, mode='fuzzy', - meta=None, meta_collection=None, parent=None): + meta=None, meta_collection=None): """Find completion matches for the given text. Given the user's input text and a collection of available @@ -220,9 +220,8 @@ class PGCompleter(Completer): # or None if the item doesn't match # Note: higher priority values mean more important, so use negative # signs to flip the direction of the tuple - searchtext = '' if text == '*' and meta == 'column' else text if fuzzy: - regex = '.*?'.join(map(re.escape, searchtext)) + regex = '.*?'.join(map(re.escape, text)) pat = re.compile('(%s)' % regex) def _match(item): @@ -230,10 +229,10 @@ class PGCompleter(Completer): if r: return -len(r.group()), -r.start() else: - match_end_limit = len(searchtext) + match_end_limit = len(text) def _match(item): - match_point = item.lower().find(searchtext, 0, match_end_limit) + match_point = item.lower().find(text, 0, match_end_limit) if match_point >= 0: # Use negative infinity to force keywords to sort after all # fuzzy matches @@ -270,12 +269,6 @@ class PGCompleter(Completer): matches.append(Match( completion=Completion(item, -text_len, display_meta=meta), priority=priority)) - if text == '*' and meta == 'column' and matches: - sep = ', ' + parent + '.' if parent else ', ' - cols = (m.completion.text for m in matches) - collist = (sep).join([c for c in cols if c != '*']) - matches = [Match(completion=Completion(collist, -text_len, - display_meta='columns', display='*'), priority=(1,1,1))] return matches def get_completions(self, document, complete_event, smart_completion=None): @@ -314,15 +307,32 @@ class PGCompleter(Completer): tables = suggestion.tables _logger.debug("Completion column scope: %r", tables) scoped_cols = self.populate_scoped_cols(tables) - + colit = scoped_cols.items() + flat_cols = [] + for cols in scoped_cols.values(): + flat_cols.extend(cols) if suggestion.drop_unique: # drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should # suggest only columns that appear in more than one table - scoped_cols = [col for (col, count) - in Counter(scoped_cols).items() + flat_cols = [col for (col, count) + in Counter(flat_cols).items() if count > 1 and col != '*'] - - return self.find_matches(word_before_cursor, scoped_cols, meta='column', parent=suggestion.parent) + lastword = last_word(word_before_cursor, include='most_punctuations') + if lastword == '*': + if suggestion.parent: + # User typed x.*; replicate "x." for all columns + sep = ', ' + self.escape_name(suggestion.parent) + '.' + collist = (sep).join([c for c in flat_cols if c != '*']) + elif len(scoped_cols) > 1: + # Multiple tables; qualify all columns + collist = (', ').join([t.ref + '.' + c for t, cs in colit + for c in cs if c != '*']) + else: + # Plain columns + collist = (', ').join([c for c in flat_cols if c != '*']) + return [Match(completion=Completion(collist, -1, + display_meta='columns', display='*'), priority=(1,1,1))] + return self.find_matches(word_before_cursor, flat_cols, meta='column') def get_function_matches(self, suggestion, word_before_cursor): if suggestion.filter == 'is_set_returning': @@ -442,76 +452,33 @@ class PGCompleter(Completer): def populate_scoped_cols(self, scoped_tbls): """ Find all columns in a set of scoped_tables :param scoped_tbls: list of TableReference namedtuples - :return: list of column names + :return: dict {TableReference:[list of column names]} """ - columns = [] + columns = defaultdict(lambda: []) meta = self.dbmetadata - + def addcols(schema, rel, alias, reltype, cols): + tbl = TableReference(schema, rel, alias, reltype == 'functions') + columns[tbl].extend(cols) for tbl in scoped_tbls: - if tbl.schema: - # A fully qualified schema.relname reference - schema = self.escape_name(tbl.schema) + schemas = [tbl.schema] if tbl.schema else self.search_path + for schema in schemas: relname = self.escape_name(tbl.name) - + schema = self.escape_name(schema) if tbl.is_function: - # Return column names from a set-returning function - try: - # Get an array of FunctionMetadata objects - functions = meta['functions'][schema][relname] - except KeyError: - # No such function name - continue - - for func in functions: + # Return column names from a set-returning function + # Get an array of FunctionMetadata objects + functions = meta['functions'].get(schema, {}).get(relname) + for func in (functions or []): # func is a FunctionMetadata object - columns.extend(func.fieldnames()) + cols = func.fieldnames() + addcols(schema, relname, tbl.alias, 'functions', cols) else: - # We don't know if schema.relname is a table or view. Since - # tables and views cannot share the same name, we can check - # one at a time - try: - columns.extend(meta['tables'][schema][relname]) - - # Table exists, so don't bother checking for a view - continue - except KeyError: - pass - - try: - columns.extend(meta['views'][schema][relname]) - except KeyError: - pass - - else: - # Schema not specified, so traverse the search path looking for - # a table or view that matches. Note that in order to get proper - # shadowing behavior, we need to check both views and tables for - # each schema before checking the next schema - for schema in self.search_path: - relname = self.escape_name(tbl.name) - - if tbl.is_function: - try: - functions = meta['functions'][schema][relname] - except KeyError: - continue - - for func in functions: - # func is a FunctionMetadata object - columns.extend(func.fieldnames()) - else: - try: - columns.extend(meta['tables'][schema][relname]) + for reltype in ('tables', 'views'): + cols = meta[reltype].get(schema, {}).get(relname) + if cols: + addcols(schema, relname, tbl.alias, reltype, cols) break - except KeyError: - pass - - try: - columns.extend(meta['views'][schema][relname]) - break - except KeyError: - pass return columns diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 1881110a..868a69ca 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -322,4 +322,73 @@ def test_suggest_columns_from_aliased_set_returning_function(completer, complete result = completer.get_completions(Document(text=sql, cursor_position=pos), complete_event) assert set(result) == set([ - Completion(text='x', start_position=0, display_meta='column')]) \ No newline at end of file + Completion(text='x', start_position=0, display_meta='column')]) + + +def test_wildcard_column_expansion(completer, complete_event): + sql = 'SELECT * FROM custom.set_returning_func()' + pos = len('SELECT *') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'x' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + + +def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event): + sql = 'SELECT p.* FROM custom.products p' + pos = len('SELECT p.*') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'id, p.product_name, p.price' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + + +def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event): + sql = 'SELECT "select".* FROM public."select"' + pos = len('SELECT "select".*') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'id, "select"."insert", "select"."ABC"' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + +def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event): + sql = 'SELECT * FROM public."select" JOIN custom.users u ON true' + pos = len('SELECT *') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + + +def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event): + sql = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true' + pos = len('SELECT "select".*') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'id, "select"."insert", "select"."ABC"' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 36ca6203..bab462c2 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -553,3 +553,70 @@ def test_columns_before_keywords(completer, complete_event): assert completions.index(column) < completions.index(keyword) +def test_wildcard_column_expansion(completer, complete_event): + sql = 'SELECT * FROM users' + pos = len('SELECT *') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'id, email, first_name, last_name' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + + +def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event): + sql = 'SELECT u.* FROM users u' + pos = len('SELECT u.*') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'id, u.email, u.first_name, u.last_name' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + + +def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event): + sql = 'SELECT users.* FROM users' + pos = len('SELECT users.*') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'id, users.email, users.first_name, users.last_name' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + +def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event): + sql = 'SELECT * FROM "select" JOIN users u ON true' + pos = len('SELECT *') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions + + +def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event): + sql = 'SELECT "select".* FROM "select" JOIN users u ON true' + pos = len('SELECT "select".*') + + completions = completer.get_completions( + Document(text=sql, cursor_position=pos), complete_event) + + col_list = 'id, "select"."insert", "select"."ABC"' + expected = [Completion(text=col_list, start_position=-1, + display='*', display_meta='columns')] + + assert expected == completions