mirror of
https://github.com/dbcli/pgcli.git
synced 2024-10-06 02:07:53 +03:00
Some changes after code review
Move *-expansion logic from find_matches to get_column_matches Make the expansion always qualify column names if there's more than one table. For this I changed populate_scoped_columns to return a dict{TableReference:[column names]}; I took this code from my other branch joinconditions. Add a bunch of tests
This commit is contained in:
parent
16efe77433
commit
b793ed5c00
@ -66,6 +66,7 @@ def last_word(text, include='alphanum_underscore'):
|
||||
|
||||
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
|
||||
'is_function'])
|
||||
TableReference.ref = property(lambda self: self.alias or self.name)
|
||||
|
||||
|
||||
# This code is borrowed from sqlparse example script.
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import re
|
||||
import itertools
|
||||
import operator
|
||||
from collections import namedtuple
|
||||
from collections import namedtuple, defaultdict
|
||||
from pgspecial.namedqueries import NamedQueries
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from prompt_toolkit.contrib.completers import PathCompleter
|
||||
@ -11,7 +11,7 @@ from prompt_toolkit.document import Document
|
||||
from .packages.sqlcompletion import (
|
||||
suggest_type, Special, Database, Schema, Table, Function, Column, View,
|
||||
Keyword, NamedQuery, Datatype, Alias, Path)
|
||||
from .packages.parseutils import last_word
|
||||
from .packages.parseutils import last_word, TableReference
|
||||
from .packages.pgliterals.main import get_literals
|
||||
from .packages.prioritization import PrevalenceCounter
|
||||
from .config import load_config, config_location
|
||||
@ -182,7 +182,7 @@ class PGCompleter(Completer):
|
||||
self.all_completions = set(self.keywords + self.functions)
|
||||
|
||||
def find_matches(self, text, collection, mode='fuzzy',
|
||||
meta=None, meta_collection=None, parent=None):
|
||||
meta=None, meta_collection=None):
|
||||
"""Find completion matches for the given text.
|
||||
|
||||
Given the user's input text and a collection of available
|
||||
@ -220,9 +220,8 @@ class PGCompleter(Completer):
|
||||
# or None if the item doesn't match
|
||||
# Note: higher priority values mean more important, so use negative
|
||||
# signs to flip the direction of the tuple
|
||||
searchtext = '' if text == '*' and meta == 'column' else text
|
||||
if fuzzy:
|
||||
regex = '.*?'.join(map(re.escape, searchtext))
|
||||
regex = '.*?'.join(map(re.escape, text))
|
||||
pat = re.compile('(%s)' % regex)
|
||||
|
||||
def _match(item):
|
||||
@ -230,10 +229,10 @@ class PGCompleter(Completer):
|
||||
if r:
|
||||
return -len(r.group()), -r.start()
|
||||
else:
|
||||
match_end_limit = len(searchtext)
|
||||
match_end_limit = len(text)
|
||||
|
||||
def _match(item):
|
||||
match_point = item.lower().find(searchtext, 0, match_end_limit)
|
||||
match_point = item.lower().find(text, 0, match_end_limit)
|
||||
if match_point >= 0:
|
||||
# Use negative infinity to force keywords to sort after all
|
||||
# fuzzy matches
|
||||
@ -270,12 +269,6 @@ class PGCompleter(Completer):
|
||||
matches.append(Match(
|
||||
completion=Completion(item, -text_len, display_meta=meta),
|
||||
priority=priority))
|
||||
if text == '*' and meta == 'column' and matches:
|
||||
sep = ', ' + parent + '.' if parent else ', '
|
||||
cols = (m.completion.text for m in matches)
|
||||
collist = (sep).join([c for c in cols if c != '*'])
|
||||
matches = [Match(completion=Completion(collist, -text_len,
|
||||
display_meta='columns', display='*'), priority=(1,1,1))]
|
||||
return matches
|
||||
|
||||
def get_completions(self, document, complete_event, smart_completion=None):
|
||||
@ -314,15 +307,32 @@ class PGCompleter(Completer):
|
||||
tables = suggestion.tables
|
||||
_logger.debug("Completion column scope: %r", tables)
|
||||
scoped_cols = self.populate_scoped_cols(tables)
|
||||
|
||||
colit = scoped_cols.items()
|
||||
flat_cols = []
|
||||
for cols in scoped_cols.values():
|
||||
flat_cols.extend(cols)
|
||||
if suggestion.drop_unique:
|
||||
# drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should
|
||||
# suggest only columns that appear in more than one table
|
||||
scoped_cols = [col for (col, count)
|
||||
in Counter(scoped_cols).items()
|
||||
flat_cols = [col for (col, count)
|
||||
in Counter(flat_cols).items()
|
||||
if count > 1 and col != '*']
|
||||
|
||||
return self.find_matches(word_before_cursor, scoped_cols, meta='column', parent=suggestion.parent)
|
||||
lastword = last_word(word_before_cursor, include='most_punctuations')
|
||||
if lastword == '*':
|
||||
if suggestion.parent:
|
||||
# User typed x.*; replicate "x." for all columns
|
||||
sep = ', ' + self.escape_name(suggestion.parent) + '.'
|
||||
collist = (sep).join([c for c in flat_cols if c != '*'])
|
||||
elif len(scoped_cols) > 1:
|
||||
# Multiple tables; qualify all columns
|
||||
collist = (', ').join([t.ref + '.' + c for t, cs in colit
|
||||
for c in cs if c != '*'])
|
||||
else:
|
||||
# Plain columns
|
||||
collist = (', ').join([c for c in flat_cols if c != '*'])
|
||||
return [Match(completion=Completion(collist, -1,
|
||||
display_meta='columns', display='*'), priority=(1,1,1))]
|
||||
return self.find_matches(word_before_cursor, flat_cols, meta='column')
|
||||
|
||||
def get_function_matches(self, suggestion, word_before_cursor):
|
||||
if suggestion.filter == 'is_set_returning':
|
||||
@ -442,76 +452,33 @@ class PGCompleter(Completer):
|
||||
def populate_scoped_cols(self, scoped_tbls):
|
||||
""" Find all columns in a set of scoped_tables
|
||||
:param scoped_tbls: list of TableReference namedtuples
|
||||
:return: list of column names
|
||||
:return: dict {TableReference:[list of column names]}
|
||||
"""
|
||||
|
||||
columns = []
|
||||
columns = defaultdict(lambda: [])
|
||||
meta = self.dbmetadata
|
||||
|
||||
def addcols(schema, rel, alias, reltype, cols):
|
||||
tbl = TableReference(schema, rel, alias, reltype == 'functions')
|
||||
columns[tbl].extend(cols)
|
||||
for tbl in scoped_tbls:
|
||||
if tbl.schema:
|
||||
# A fully qualified schema.relname reference
|
||||
schema = self.escape_name(tbl.schema)
|
||||
schemas = [tbl.schema] if tbl.schema else self.search_path
|
||||
for schema in schemas:
|
||||
relname = self.escape_name(tbl.name)
|
||||
|
||||
schema = self.escape_name(schema)
|
||||
if tbl.is_function:
|
||||
# Return column names from a set-returning function
|
||||
try:
|
||||
# Get an array of FunctionMetadata objects
|
||||
functions = meta['functions'][schema][relname]
|
||||
except KeyError:
|
||||
# No such function name
|
||||
continue
|
||||
|
||||
for func in functions:
|
||||
# Return column names from a set-returning function
|
||||
# Get an array of FunctionMetadata objects
|
||||
functions = meta['functions'].get(schema, {}).get(relname)
|
||||
for func in (functions or []):
|
||||
# func is a FunctionMetadata object
|
||||
columns.extend(func.fieldnames())
|
||||
cols = func.fieldnames()
|
||||
addcols(schema, relname, tbl.alias, 'functions', cols)
|
||||
else:
|
||||
# We don't know if schema.relname is a table or view. Since
|
||||
# tables and views cannot share the same name, we can check
|
||||
# one at a time
|
||||
try:
|
||||
columns.extend(meta['tables'][schema][relname])
|
||||
|
||||
# Table exists, so don't bother checking for a view
|
||||
continue
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
columns.extend(meta['views'][schema][relname])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
else:
|
||||
# Schema not specified, so traverse the search path looking for
|
||||
# a table or view that matches. Note that in order to get proper
|
||||
# shadowing behavior, we need to check both views and tables for
|
||||
# each schema before checking the next schema
|
||||
for schema in self.search_path:
|
||||
relname = self.escape_name(tbl.name)
|
||||
|
||||
if tbl.is_function:
|
||||
try:
|
||||
functions = meta['functions'][schema][relname]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
for func in functions:
|
||||
# func is a FunctionMetadata object
|
||||
columns.extend(func.fieldnames())
|
||||
else:
|
||||
try:
|
||||
columns.extend(meta['tables'][schema][relname])
|
||||
for reltype in ('tables', 'views'):
|
||||
cols = meta[reltype].get(schema, {}).get(relname)
|
||||
if cols:
|
||||
addcols(schema, relname, tbl.alias, reltype, cols)
|
||||
break
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
columns.extend(meta['views'][schema][relname])
|
||||
break
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return columns
|
||||
|
||||
|
@ -322,4 +322,73 @@ def test_suggest_columns_from_aliased_set_returning_function(completer, complete
|
||||
result = completer.get_completions(Document(text=sql, cursor_position=pos),
|
||||
complete_event)
|
||||
assert set(result) == set([
|
||||
Completion(text='x', start_position=0, display_meta='column')])
|
||||
Completion(text='x', start_position=0, display_meta='column')])
|
||||
|
||||
|
||||
def test_wildcard_column_expansion(completer, complete_event):
|
||||
sql = 'SELECT * FROM custom.set_returning_func()'
|
||||
pos = len('SELECT *')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'x'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
|
||||
def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event):
|
||||
sql = 'SELECT p.* FROM custom.products p'
|
||||
pos = len('SELECT p.*')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'id, p.product_name, p.price'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
|
||||
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event):
|
||||
sql = 'SELECT "select".* FROM public."select"'
|
||||
pos = len('SELECT "select".*')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'id, "select"."insert", "select"."ABC"'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
|
||||
sql = 'SELECT * FROM public."select" JOIN custom.users u ON true'
|
||||
pos = len('SELECT *')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
|
||||
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
|
||||
sql = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true'
|
||||
pos = len('SELECT "select".*')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'id, "select"."insert", "select"."ABC"'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
@ -553,3 +553,70 @@ def test_columns_before_keywords(completer, complete_event):
|
||||
assert completions.index(column) < completions.index(keyword)
|
||||
|
||||
|
||||
def test_wildcard_column_expansion(completer, complete_event):
|
||||
sql = 'SELECT * FROM users'
|
||||
pos = len('SELECT *')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'id, email, first_name, last_name'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
|
||||
def test_wildcard_column_expansion_with_alias_qualifier(completer, complete_event):
|
||||
sql = 'SELECT u.* FROM users u'
|
||||
pos = len('SELECT u.*')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'id, u.email, u.first_name, u.last_name'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
|
||||
def test_wildcard_column_expansion_with_table_qualifier(completer, complete_event):
|
||||
sql = 'SELECT users.* FROM users'
|
||||
pos = len('SELECT users.*')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'id, users.email, users.first_name, users.last_name'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
|
||||
sql = 'SELECT * FROM "select" JOIN users u ON true'
|
||||
pos = len('SELECT *')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = '"select".id, "select"."insert", "select"."ABC", u.id, u.phone_number'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
||||
|
||||
def test_wildcard_column_expansion_with_two_tables_and_parent(completer, complete_event):
|
||||
sql = 'SELECT "select".* FROM "select" JOIN users u ON true'
|
||||
pos = len('SELECT "select".*')
|
||||
|
||||
completions = completer.get_completions(
|
||||
Document(text=sql, cursor_position=pos), complete_event)
|
||||
|
||||
col_list = 'id, "select"."insert", "select"."ABC"'
|
||||
expected = [Completion(text=col_list, start_position=-1,
|
||||
display='*', display_meta='columns')]
|
||||
|
||||
assert expected == completions
|
||||
|
Loading…
Reference in New Issue
Block a user