1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-10-06 02:07:53 +03:00

Some changes after code review

Move *-expansion logic from find_matches to get_column_matches
Make the expansion always qualify column names if there's more than one table. For this I changed populate_scoped_columns to return a dict{TableReference:[column names]}; I took this code from my other branch joinconditions.
Add a bunch of tests
This commit is contained in:
koljonen 2016-05-14 18:19:55 +02:00
parent 16efe77433
commit b793ed5c00
No known key found for this signature in database
GPG Key ID: AF0327B5131CD164
4 changed files with 184 additions and 80 deletions

View File

@ -66,6 +66,7 @@ def last_word(text, include='alphanum_underscore'):
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
TableReference.ref = property(lambda self: self.alias or self.name)
# This code is borrowed from sqlparse example script.

View File

@ -3,7 +3,7 @@ import logging
import re
import itertools
import operator
from collections import namedtuple
from collections import namedtuple, defaultdict
from pgspecial.namedqueries import NamedQueries
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.contrib.completers import PathCompleter
@ -11,7 +11,7 @@ from prompt_toolkit.document import Document
from .packages.sqlcompletion import (
suggest_type, Special, Database, Schema, Table, Function, Column, View,
Keyword, NamedQuery, Datatype, Alias, Path)
from .packages.parseutils import last_word
from .packages.parseutils import last_word, TableReference
from .packages.pgliterals.main import get_literals
from .packages.prioritization import PrevalenceCounter
from .config import load_config, config_location
@ -182,7 +182,7 @@ class PGCompleter(Completer):
self.all_completions = set(self.keywords + self.functions)
def find_matches(self, text, collection, mode='fuzzy',
meta=None, meta_collection=None, parent=None):
meta=None, meta_collection=None):
"""Find completion matches for the given text.
Given the user's input text and a collection of available
@ -220,9 +220,8 @@ class PGCompleter(Completer):
# or None if the item doesn't match
# Note: higher priority values mean more important, so use negative
# signs to flip the direction of the tuple
searchtext = '' if text == '*' and meta == 'column' else text
if fuzzy:
regex = '.*?'.join(map(re.escape, searchtext))
regex = '.*?'.join(map(re.escape, text))
pat = re.compile('(%s)' % regex)
def _match(item):
@ -230,10 +229,10 @@ class PGCompleter(Completer):
if r:
return -len(r.group()), -r.start()
else:
match_end_limit = len(searchtext)
match_end_limit = len(text)
def _match(item):
match_point = item.lower().find(searchtext, 0, match_end_limit)
match_point = item.lower().find(text, 0, match_end_limit)
if match_point >= 0:
# Use negative infinity to force keywords to sort after all
# fuzzy matches
@ -270,12 +269,6 @@ class PGCompleter(Completer):
matches.append(Match(
completion=Completion(item, -text_len, display_meta=meta),
priority=priority))
if text == '*' and meta == 'column' and matches:
sep = ', ' + parent + '.' if parent else ', '
cols = (m.completion.text for m in matches)
collist = (sep).join([c for c in cols if c != '*'])
matches = [Match(completion=Completion(collist, -text_len,
display_meta='columns', display='*'), priority=(1,1,1))]
return matches
def get_completions(self, document, complete_event, smart_completion=None):
@ -314,15 +307,32 @@ class PGCompleter(Completer):
tables = suggestion.tables
_logger.debug("Completion column scope: %r", tables)
scoped_cols = self.populate_scoped_cols(tables)
colit = scoped_cols.items()
flat_cols = []
for cols in scoped_cols.values():
flat_cols.extend(cols)
if suggestion.drop_unique:
# drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in more than one table
scoped_cols = [col for (col, count)
in Counter(scoped_cols).items()
flat_cols = [col for (col, count)
in Counter(flat_cols).items()
if count > 1 and col != '*']
return self.find_matches(word_before_cursor, scoped_cols, meta='column', parent=suggestion.parent)
lastword = last_word(word_before_cursor, include='most_punctuations')
if lastword == '*':
if suggestion.parent:
# User typed x.*; replicate "x." for all columns
sep = ', ' + self.escape_name(suggestion.parent) + '.'
collist = (sep).join([c for c in flat_cols if c != '*'])
elif len(scoped_cols) > 1:
# Multiple tables; qualify all columns
collist = (', ').join([t.ref + '.' + c for t, cs in colit
for c in cs if c != '*'])
else:
# Plain columns
collist = (', ').join([c for c in flat_cols if c != '*'])
return [Match(completion=Completion(collist, -1,
display_meta='columns', display='*'), priority=(1,1,1))]
return self.find_matches(word_before_cursor, flat_cols, meta='column')
def get_function_matches(self, suggestion, word_before_cursor):
if suggestion.filter == 'is_set_returning':
@ -442,76 +452,33 @@ class PGCompleter(Completer):
def populate_scoped_cols(self, scoped_tbls):
""" Find all columns in a set of scoped_tables
:param scoped_tbls: list of TableReference namedtuples
:return: list of column names
:return: dict {TableReference:[list of column names]}
"""
columns = []
columns = defaultdict(lambda: [])
meta = self.dbmetadata
def addcols(schema, rel, alias, reltype, cols):
tbl = TableReference(schema, rel, alias, reltype == 'functions')
columns[tbl].extend(cols)
for tbl in scoped_tbls:
if tbl.schema:
# A fully qualified schema.relname reference
schema = self.escape_name(tbl.schema)
schemas = [tbl.schema] if tbl.schema else self.search_path
for schema in schemas:
relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function:
# Return column names from a set-returning function
try:
# Get an array of FunctionMetadata objects
functions = meta['functions'][schema][relname]
except KeyError:
# No such function name
continue
for func in functions:
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta['functions'].get(schema, {}).get(relname)
for func in (functions or []):
# func is a FunctionMetadata object
columns.extend(func.fieldnames())
cols = func.fieldnames()
addcols(schema, relname, tbl.alias, 'functions', cols)
else:
# We don't know if schema.relname is a table or view. Since
# tables and views cannot share the same name, we can check
# one at a time
try:
columns.extend(meta['tables'][schema][relname])
# Table exists, so don't bother checking for a view
continue
except KeyError:
pass
try:
columns.extend(meta['views'][schema][relname])
except KeyError:
pass
else:
# Schema not specified, so traverse the search path looking for
# a table or view that matches. Note that in order to get proper
# shadowing behavior, we need to check both views and tables for
# each schema before checking the next schema
for schema in self.search_path:
relname = self.escape_name(tbl.name)
if tbl.is_function:
try:
functions = meta['functions'][schema][relname]
except KeyError:
continue
for func in functions:
# func is a FunctionMetadata object
columns.extend(func.fieldnames())
else:
try:
columns.extend(meta['tables'][schema][relname])
for reltype in ('tables', 'views'):
cols = meta[reltype].get(schema, {}).get(relname)
if cols:
addcols(schema, relname, tbl.alias, reltype, cols)
break
except KeyError:
pass
try:
columns.extend(meta['views'][schema][relname])
break
except KeyError:
pass
return columns

View File

@ -322,4 +322,73 @@ def test_suggest_columns_from_aliased_set_returning_function(completer, complete
result = completer.get_completions(Document(text=sql, cursor_position=pos),
complete_event)
assert set(result) == set([
Completion(text='x', start_position=0, display_meta='column')])
Completion(text='x', start_position=0, display_meta='column')])
def test_wildcard_column_expansion(completer, complete_event):
sql = 'SELECT * FROM custom.set_returning_func()'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'x'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event):
sql = 'SELECT p.* FROM custom.products p'
pos = len('SELECT p.*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, p.product_name, p.price'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event):
sql = 'SELECT "select".* FROM public."select"'
pos = len('SELECT "select".*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, "select"."insert", "select"."ABC"'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT * FROM public."select" JOIN custom.users u ON true'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true'
pos = len('SELECT "select".*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, "select"."insert", "select"."ABC"'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions

View File

@ -553,3 +553,70 @@ def test_columns_before_keywords(completer, complete_event):
assert completions.index(column) < completions.index(keyword)
def test_wildcard_column_expansion(completer, complete_event):
sql = 'SELECT * FROM users'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, email, first_name, last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event):
sql = 'SELECT u.* FROM users u'
pos = len('SELECT u.*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, u.email, u.first_name, u.last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event):
sql = 'SELECT users.* FROM users'
pos = len('SELECT users.*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, users.email, users.first_name, users.last_name'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT * FROM "select" JOIN users u ON true'
pos = len('SELECT *')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
sql = 'SELECT "select".* FROM "select" JOIN users u ON true'
pos = len('SELECT "select".*')
completions = completer.get_completions(
Document(text=sql, cursor_position=pos), complete_event)
col_list = 'id, "select"."insert", "select"."ABC"'
expected = [Completion(text=col_list, start_position=-1,
display='*', display_meta='columns')]
assert expected == completions