1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-10-06 02:07:53 +03:00

Some changes after code review

Move *-expansion logic from find_matches to get_column_matches
Make the expansion always qualify column names if there's more than one table. For this I changed populate_scoped_columns to return a dict{TableReference:[column names]}; I took this code from my other branch joinconditions.
Add a bunch of tests
This commit is contained in:
koljonen 2016-05-14 18:19:55 +02:00
parent 16efe77433
commit b793ed5c00
No known key found for this signature in database
GPG Key ID: AF0327B5131CD164
4 changed files with 184 additions and 80 deletions

View File

@ -66,6 +66,7 @@ def last_word(text, include='alphanum_underscore'):
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function']) 'is_function'])
TableReference.ref = property(lambda self: self.alias or self.name)
# This code is borrowed from sqlparse example script. # This code is borrowed from sqlparse example script.

View File

@ -3,7 +3,7 @@ import logging
import re import re
import itertools import itertools
import operator import operator
from collections import namedtuple from collections import namedtuple, defaultdict
from pgspecial.namedqueries import NamedQueries from pgspecial.namedqueries import NamedQueries
from prompt_toolkit.completion import Completer, Completion from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.contrib.completers import PathCompleter from prompt_toolkit.contrib.completers import PathCompleter
@ -11,7 +11,7 @@ from prompt_toolkit.document import Document
from .packages.sqlcompletion import ( from .packages.sqlcompletion import (
suggest_type, Special, Database, Schema, Table, Function, Column, View, suggest_type, Special, Database, Schema, Table, Function, Column, View,
Keyword, NamedQuery, Datatype, Alias, Path) Keyword, NamedQuery, Datatype, Alias, Path)
from .packages.parseutils import last_word from .packages.parseutils import last_word, TableReference
from .packages.pgliterals.main import get_literals from .packages.pgliterals.main import get_literals
from .packages.prioritization import PrevalenceCounter from .packages.prioritization import PrevalenceCounter
from .config import load_config, config_location from .config import load_config, config_location
@ -182,7 +182,7 @@ class PGCompleter(Completer):
self.all_completions = set(self.keywords + self.functions) self.all_completions = set(self.keywords + self.functions)
def find_matches(self, text, collection, mode='fuzzy', def find_matches(self, text, collection, mode='fuzzy',
meta=None, meta_collection=None, parent=None): meta=None, meta_collection=None):
"""Find completion matches for the given text. """Find completion matches for the given text.
Given the user's input text and a collection of available Given the user's input text and a collection of available
@ -220,9 +220,8 @@ class PGCompleter(Completer):
# or None if the item doesn't match # or None if the item doesn't match
# Note: higher priority values mean more important, so use negative # Note: higher priority values mean more important, so use negative
# signs to flip the direction of the tuple # signs to flip the direction of the tuple
searchtext = '' if text == '*' and meta == 'column' else text
if fuzzy: if fuzzy:
regex = '.*?'.join(map(re.escape, searchtext)) regex = '.*?'.join(map(re.escape, text))
pat = re.compile('(%s)' % regex) pat = re.compile('(%s)' % regex)
def _match(item): def _match(item):
@ -230,10 +229,10 @@ class PGCompleter(Completer):
if r: if r:
return -len(r.group()), -r.start() return -len(r.group()), -r.start()
else: else:
match_end_limit = len(searchtext) match_end_limit = len(text)
def _match(item): def _match(item):
match_point = item.lower().find(searchtext, 0, match_end_limit) match_point = item.lower().find(text, 0, match_end_limit)
if match_point >= 0: if match_point >= 0:
# Use negative infinity to force keywords to sort after all # Use negative infinity to force keywords to sort after all
# fuzzy matches # fuzzy matches
@ -270,12 +269,6 @@ class PGCompleter(Completer):
matches.append(Match( matches.append(Match(
completion=Completion(item, -text_len, display_meta=meta), completion=Completion(item, -text_len, display_meta=meta),
priority=priority)) priority=priority))
if text == '*' and meta == 'column' and matches:
sep = ', ' + parent + '.' if parent else ', '
cols = (m.completion.text for m in matches)
collist = (sep).join([c for c in cols if c != '*'])
matches = [Match(completion=Completion(collist, -text_len,
display_meta='columns', display='*'), priority=(1,1,1))]
return matches return matches
def get_completions(self, document, complete_event, smart_completion=None): def get_completions(self, document, complete_event, smart_completion=None):
@ -314,15 +307,32 @@ class PGCompleter(Completer):
tables = suggestion.tables tables = suggestion.tables
_logger.debug("Completion column scope: %r", tables) _logger.debug("Completion column scope: %r", tables)
scoped_cols = self.populate_scoped_cols(tables) scoped_cols = self.populate_scoped_cols(tables)
colit = scoped_cols.items()
flat_cols = []
for cols in scoped_cols.values():
flat_cols.extend(cols)
if suggestion.drop_unique: if suggestion.drop_unique:
# drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should # drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in more than one table # suggest only columns that appear in more than one table
scoped_cols = [col for (col, count) flat_cols = [col for (col, count)
in Counter(scoped_cols).items() in Counter(flat_cols).items()
if count > 1 and col != '*'] if count > 1 and col != '*']
lastword = last_word(word_before_cursor, include='most_punctuations')
return self.find_matches(word_before_cursor, scoped_cols, meta='column', parent=suggestion.parent) if lastword == '*':
if suggestion.parent:
# User typed x.*; replicate "x." for all columns
sep = ', ' + self.escape_name(suggestion.parent) + '.'
collist = (sep).join([c for c in flat_cols if c != '*'])
elif len(scoped_cols) > 1:
# Multiple tables; qualify all columns
collist = (', ').join([t.ref + '.' + c for t, cs in colit
for c in cs if c != '*'])
else:
# Plain columns
collist = (', ').join([c for c in flat_cols if c != '*'])
return [Match(completion=Completion(collist, -1,
display_meta='columns', display='*'), priority=(1,1,1))]
return self.find_matches(word_before_cursor, flat_cols, meta='column')
def get_function_matches(self, suggestion, word_before_cursor): def get_function_matches(self, suggestion, word_before_cursor):
if suggestion.filter == 'is_set_returning': if suggestion.filter == 'is_set_returning':
@ -442,76 +452,33 @@ class PGCompleter(Completer):
def populate_scoped_cols(self, scoped_tbls): def populate_scoped_cols(self, scoped_tbls):
""" Find all columns in a set of scoped_tables """ Find all columns in a set of scoped_tables
:param scoped_tbls: list of TableReference namedtuples :param scoped_tbls: list of TableReference namedtuples
:return: list of column names :return: dict {TableReference:[list of column names]}
""" """
columns = [] columns = defaultdict(lambda: [])
meta = self.dbmetadata meta = self.dbmetadata
def addcols(schema, rel, alias, reltype, cols):
tbl = TableReference(schema, rel, alias, reltype == 'functions')
columns[tbl].extend(cols)
for tbl in scoped_tbls: for tbl in scoped_tbls:
if tbl.schema: schemas = [tbl.schema] if tbl.schema else self.search_path
# A fully qualified schema.relname reference for schema in schemas:
schema = self.escape_name(tbl.schema)
relname = self.escape_name(tbl.name) relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function: if tbl.is_function:
# Return column names from a set-returning function # Return column names from a set-returning function
try: # Get an array of FunctionMetadata objects
# Get an array of FunctionMetadata objects functions = meta['functions'].get(schema, {}).get(relname)
functions = meta['functions'][schema][relname] for func in (functions or []):
except KeyError:
# No such function name
continue
for func in functions:
# func is a FunctionMetadata object # func is a FunctionMetadata object
columns.extend(func.fieldnames()) cols = func.fieldnames()
addcols(schema, relname, tbl.alias, 'functions', cols)
else: else:
# We don't know if schema.relname is a table or view. Since for reltype in ('tables', 'views'):
# tables and views cannot share the same name, we can check cols = meta[reltype].get(schema, {}).get(relname)
# one at a time if cols:
try: addcols(schema, relname, tbl.alias, reltype, cols)
columns.extend(meta['tables'][schema][relname])
# Table exists, so don't bother checking for a view
continue
except KeyError:
pass
try:
columns.extend(meta['views'][schema][relname])
except KeyError:
pass
else:
# Schema not specified, so traverse the search path looking for
# a table or view that matches. Note that in order to get proper
# shadowing behavior, we need to check both views and tables for
# each schema before checking the next schema
for schema in self.search_path:
relname = self.escape_name(tbl.name)
if tbl.is_function:
try:
functions = meta['functions'][schema][relname]
except KeyError:
continue
for func in functions:
# func is a FunctionMetadata object
columns.extend(func.fieldnames())
else:
try:
columns.extend(meta['tables'][schema][relname])
break break
except KeyError:
pass
try:
columns.extend(meta['views'][schema][relname])
break
except KeyError:
pass
return columns return columns

View File

@ -322,4 +322,73 @@ def test_suggest_columns_from_aliased_set_returning_function(completer, complete
result = completer.get_completions(Document(text=sql, cursor_position=pos), result = completer.get_completions(Document(text=sql, cursor_position=pos),
complete_event) complete_event)
assert set(result) == set([ assert set(result) == set([
Completion(text='x', start_position=0, display_meta='column')]) Completion(text='x', start_position=0, display_meta='column')])
def test_wildcard_column_expansion(completer, complete_event):
sql = 'SELECT * FROM custom.set_returning_func()'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'x'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event):
sql = 'SELECT p.* FROM custom.products p'
pos = len('SELECT p.*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, p.product_name, p.price'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event):
sql = 'SELECT "select".* FROM public."select"'
pos = len('SELECT "select".*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, "select"."insert", "select"."ABC"'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT * FROM public."select" JOIN custom.users u ON true'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true'
pos = len('SELECT "select".*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, "select"."insert", "select"."ABC"'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions

View File

@ -553,3 +553,70 @@ def test_columns_before_keywords(completer, complete_event):
assert completions.index(column) < completions.index(keyword) assert completions.index(column) < completions.index(keyword)
def test_wildcard_column_expansion(completer, complete_event):
sql = 'SELECT * FROM users'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, email, first_name, last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event):
sql = 'SELECT u.* FROM users u'
pos = len('SELECT u.*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, u.email, u.first_name, u.last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event):
sql = 'SELECT users.* FROM users'
pos = len('SELECT users.*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, users.email, users.first_name, users.last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT * FROM "select" JOIN users u ON true'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT "select".* FROM "select" JOIN users u ON true'
pos = len('SELECT "select".*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, "select"."insert", "select"."ABC"'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions