1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-08-16 17:20:49 +03:00

black all the things. (#1049)

* added black to develop guide

* no need for pep8radius.

* changelog.

* Add pre-commit checkbox.

* Add pre-commit to dev reqs.

* Add pyproject.toml for black.

* Pre-commit config.

* Add black to travis and dev reqs.

* Install and run black in travis.

* Remove black from dev reqs.

* Lower black target version.

* Re-format with black.
This commit is contained in:
Irina Truong 2019-05-25 13:08:56 -07:00 committed by GitHub
parent a5e607b6fc
commit 8cb7009bcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
61 changed files with 4540 additions and 3689 deletions

View File

@ -7,3 +7,5 @@
<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. -->
- [ ] I've added this contribution to the `changelog.rst`.
- [ ] I've added my name to the `AUTHORS` file (or it's already there).
<!-- We would appreciate if you comply with our code style guidelines. -->
- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code.

7
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,7 @@
repos:
- repo: https://github.com/ambv/black
rev: stable
hooks:
- id: black
language_version: python3.6

View File

@ -29,8 +29,8 @@ script:
- cd ..
# check for changelog ReST compliance
- rst2html.py --halt=warning changelog.rst >/dev/null
# check for pep8 errors, only looking at branch vs master. If there are errors, show diff and return an error code.
- pep8radius master --docformatter --diff --error-status
# check for black code compliance, 3.6 only
- if [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then pip install black && black --check . ; else echo "Skipping black for $TRAVIS_PYTHON_VERSION"; fi
- set +e
after_success:

View File

@ -172,17 +172,7 @@ Troubleshooting the integration tests
- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information.
- Contact us on `gitter <https://gitter.im/dbcli/pgcli/>`_ or `file an issue <https://github.com/dbcli/pgcli/issues/new>`_.
PEP8 checks (lint)
------------------
Coding Style
------------
When you submit a PR, the changeset is checked for pep8 compliance using
`pep8radius <https://github.com/hayd/pep8radius>`_. If you see a build failing because
of these checks, install pep8radius and apply style fixes:
::
$ pip install pep8radius
$ pep8radius master --docformatter --diff # view a diff of proposed fixes
$ pep8radius master --docformatter --in-place # apply the fixes
Then commit and push the fixes.
``pgcli`` uses `black <https://github.com/ambv/black>`_ to format the source code. Make sure to install black.

View File

@ -16,6 +16,7 @@ Internal:
---------
* Add python 3.7 to travis build matrix. (Thanks: `Irina Truong`_)
* Apply `black` to code. (Thanks: `Irina Truong`_)
2.1.0
=====

View File

@ -1 +1 @@
__version__ = '2.1.0'
__version__ = "2.1.0"

View File

@ -14,8 +14,7 @@ class CompletionRefresher(object):
self._completer_thread = None
self._restart_refresh = threading.Event()
def refresh(self, executor, special, callbacks, history=None,
settings=None):
def refresh(self, executor, special, callbacks, history=None, settings=None):
"""
Creates a PGCompleter object and populates it with the relevant
completion suggestions in a background thread.
@ -30,27 +29,29 @@ class CompletionRefresher(object):
"""
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, 'Auto-completion refresh restarted.')]
return [(None, None, None, "Auto-completion refresh restarted.")]
else:
self._completer_thread = threading.Thread(
target=self._bg_refresh,
args=(executor, special, callbacks, history, settings),
name='completion_refresh')
name="completion_refresh",
)
self._completer_thread.setDaemon(True)
self._completer_thread.start()
return [(None, None, None,
'Auto-completion refresh started in the background.')]
return [
(None, None, None, "Auto-completion refresh started in the background.")
]
def is_refreshing(self):
return self._completer_thread and self._completer_thread.is_alive()
def _bg_refresh(self, pgexecute, special, callbacks, history=None,
settings=None):
def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None):
settings = settings or {}
completer = PGCompleter(smart_completion=True, pgspecial=special,
settings=settings)
completer = PGCompleter(
smart_completion=True, pgspecial=special, settings=settings
)
if settings.get('single_connection'):
if settings.get("single_connection"):
executor = pgexecute
else:
# Create a new pgexecute method to populate the completions.
@ -88,55 +89,58 @@ def refresher(name, refreshers=CompletionRefresher.refreshers):
"""Decorator to populate the dictionary of refreshers with the current
function.
"""
def wrapper(wrapped):
refreshers[name] = wrapped
return wrapped
return wrapper
@refresher('schemata')
@refresher("schemata")
def refresh_schemata(completer, executor):
completer.set_search_path(executor.search_path())
completer.extend_schemata(executor.schemata())
@refresher('tables')
@refresher("tables")
def refresh_tables(completer, executor):
completer.extend_relations(executor.tables(), kind='tables')
completer.extend_columns(executor.table_columns(), kind='tables')
completer.extend_relations(executor.tables(), kind="tables")
completer.extend_columns(executor.table_columns(), kind="tables")
completer.extend_foreignkeys(executor.foreignkeys())
@refresher('views')
@refresher("views")
def refresh_views(completer, executor):
completer.extend_relations(executor.views(), kind='views')
completer.extend_columns(executor.view_columns(), kind='views')
completer.extend_relations(executor.views(), kind="views")
completer.extend_columns(executor.view_columns(), kind="views")
@refresher('types')
@refresher("types")
def refresh_types(completer, executor):
completer.extend_datatypes(executor.datatypes())
@refresher('databases')
@refresher("databases")
def refresh_databases(completer, executor):
completer.extend_database_names(executor.databases())
@refresher('casing')
@refresher("casing")
def refresh_casing(completer, executor):
casing_file = completer.casing_file
if not casing_file:
return
generate_casing_file = completer.generate_casing_file
if generate_casing_file and not os.path.isfile(casing_file):
casing_prefs = '\n'.join(executor.casing())
with open(casing_file, 'w') as f:
casing_prefs = "\n".join(executor.casing())
with open(casing_file, "w") as f:
f.write(casing_prefs)
if os.path.isfile(casing_file):
with open(casing_file, 'r') as f:
with open(casing_file, "r") as f:
completer.extend_casing([line.strip() for line in f])
@refresher('functions')
@refresher("functions")
def refresh_functions(completer, executor):
completer.extend_functions(executor.functions())

View File

@ -7,18 +7,18 @@ from configobj import ConfigObj
def config_location():
if 'XDG_CONFIG_HOME' in os.environ:
return '%s/pgcli/' % expanduser(os.environ['XDG_CONFIG_HOME'])
elif platform.system() == 'Windows':
return os.getenv('USERPROFILE') + '\\AppData\\Local\\dbcli\\pgcli\\'
if "XDG_CONFIG_HOME" in os.environ:
return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"])
elif platform.system() == "Windows":
return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\"
else:
return expanduser('~/.config/pgcli/')
return expanduser("~/.config/pgcli/")
def load_config(usr_cfg, def_cfg=None):
cfg = ConfigObj()
cfg.merge(ConfigObj(def_cfg, interpolation=False))
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding='utf-8'))
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
cfg.filename = expanduser(usr_cfg)
return cfg
@ -51,18 +51,19 @@ def upgrade_config(config, def_config):
def get_config(pgclirc_file=None):
from pgcli import __file__ as package_root
package_root = os.path.dirname(package_root)
pgclirc_file = pgclirc_file or '%sconfig' % config_location()
pgclirc_file = pgclirc_file or "%sconfig" % config_location()
default_config = os.path.join(package_root, 'pgclirc')
default_config = os.path.join(package_root, "pgclirc")
write_default_config(default_config, pgclirc_file)
return load_config(pgclirc_file, default_config)
def get_casing_file(config):
casing_file = config['main']['casing_file']
if casing_file == 'default':
casing_file = config_location() + 'casing'
casing_file = config["main"]["casing_file"]
if casing_file == "default":
casing_file = config_location() + "casing"
return casing_file

View File

@ -13,7 +13,7 @@ def unicode2utf8(arg):
"""
if PY2 and isinstance(arg, unicode):
return arg.encode('utf-8')
return arg.encode("utf-8")
return arg
@ -24,5 +24,5 @@ def utf8tounicode(arg):
"""
if PY2 and isinstance(arg, str):
return arg.decode('utf-8')
return arg.decode("utf-8")
return arg

View File

@ -12,32 +12,32 @@ def pgcli_bindings(pgcli):
"""Custom key bindings for pgcli."""
kb = KeyBindings()
tab_insert_text = ' ' * 4
tab_insert_text = " " * 4
@kb.add('f2')
@kb.add("f2")
def _(event):
"""Enable/Disable SmartCompletion Mode."""
_logger.debug('Detected F2 key.')
_logger.debug("Detected F2 key.")
pgcli.completer.smart_completion = not pgcli.completer.smart_completion
@kb.add('f3')
@kb.add("f3")
def _(event):
"""Enable/Disable Multiline Mode."""
_logger.debug('Detected F3 key.')
_logger.debug("Detected F3 key.")
pgcli.multi_line = not pgcli.multi_line
@kb.add('f4')
@kb.add("f4")
def _(event):
"""Toggle between Vi and Emacs mode."""
_logger.debug('Detected F4 key.')
_logger.debug("Detected F4 key.")
pgcli.vi_mode = not pgcli.vi_mode
event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS
@kb.add('tab')
@kb.add("tab")
def _(event):
"""Force autocompletion at cursor on non-empty lines."""
_logger.debug('Detected <Tab> key.')
_logger.debug("Detected <Tab> key.")
buff = event.app.current_buffer
doc = buff.document
@ -50,17 +50,15 @@ def pgcli_bindings(pgcli):
else:
buff.insert_text(tab_insert_text, fire_event=False)
@kb.add('escape', filter=has_completions)
@kb.add("escape", filter=has_completions)
def _(event):
"""Force closing of autocompletion."""
_logger.debug('Detected <Esc> key.')
_logger.debug("Detected <Esc> key.")
event.current_buffer.complete_state = None
event.app.current_buffer.complete_state = None
@kb.add('c-space')
@kb.add("c-space")
def _(event):
"""
Initialize autocompletion at cursor.
@ -70,7 +68,7 @@ def pgcli_bindings(pgcli):
If the menu is showing, select the next completion.
"""
_logger.debug('Detected <C-Space> key.')
_logger.debug("Detected <C-Space> key.")
b = event.app.current_buffer
if b.complete_state:
@ -78,7 +76,7 @@ def pgcli_bindings(pgcli):
else:
b.start_completion(select_first=False)
@kb.add('enter', filter=completion_is_selected)
@kb.add("enter", filter=completion_is_selected)
def _(event):
"""Makes the enter key work as the tab key only when showing the menu.
@ -87,7 +85,7 @@ def pgcli_bindings(pgcli):
(accept current selection).
"""
_logger.debug('Detected enter key.')
_logger.debug("Detected enter key.")
event.current_buffer.complete_state = None
event.app.current_buffer.complete_state = None

View File

@ -10,40 +10,40 @@ def load_ipython_extension(ipython):
"""This is called via the ipython command '%load_ext pgcli.magic'"""
# first, load the sql magic if it isn't already loaded
if not ipython.find_line_magic('sql'):
ipython.run_line_magic('load_ext', 'sql')
if not ipython.find_line_magic("sql"):
ipython.run_line_magic("load_ext", "sql")
# register our own magic
ipython.register_magic_function(pgcli_line_magic, 'line', 'pgcli')
ipython.register_magic_function(pgcli_line_magic, "line", "pgcli")
def pgcli_line_magic(line):
_logger.debug('pgcli magic called: %r', line)
_logger.debug("pgcli magic called: %r", line)
parsed = sql.parse.parse(line, {})
# "get" was renamed to "set" in ipython-sql:
# https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43
if hasattr(sql.connection.Connection, 'get'):
conn = sql.connection.Connection.get(parsed['connection'])
if hasattr(sql.connection.Connection, "get"):
conn = sql.connection.Connection.get(parsed["connection"])
else:
conn = sql.connection.Connection.set(parsed['connection'])
conn = sql.connection.Connection.set(parsed["connection"])
try:
# A corresponding pgcli object already exists
pgcli = conn._pgcli
_logger.debug('Reusing existing pgcli')
_logger.debug("Reusing existing pgcli")
except AttributeError:
# I can't figure out how to get the underylying psycopg2 connection
# from the sqlalchemy connection, so just grab the url and make a
# new connection
pgcli = PGCli()
u = conn.session.engine.url
_logger.debug('New pgcli: %r', str(u))
_logger.debug("New pgcli: %r", str(u))
pgcli.connect(u.database, u.host, u.username, u.port, u.password)
conn._pgcli = pgcli
# For convenience, print the connection alias
print('Connected: {}'.format(conn.name))
print ("Connected: {}".format(conn.name))
try:
pgcli.run_cli()
@ -56,12 +56,12 @@ def pgcli_line_magic(line):
q = pgcli.query_history[-1]
if not q.successful:
_logger.debug('Unsuccessful query - ignoring')
_logger.debug("Unsuccessful query - ignoring")
return
if q.meta_changed or q.db_changed or q.path_changed:
_logger.debug('Dangerous query detected -- ignoring')
_logger.debug("Dangerous query detected -- ignoring")
return
ipython = get_ipython()
return ipython.run_cell_magic('sql', line, q.query)
return ipython.run_cell_magic("sql", line, q.query)

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,13 @@
import sqlparse
def query_starts_with(query, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
def queries_start_with(queries, prefixes):
"""Check if any queries start with any item from *prefixes*."""
for query in sqlparse.split(queries):
@ -13,7 +15,8 @@ def queries_start_with(queries, prefixes):
return True
return False
def is_destructive(queries):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
return queries_start_with(queries, keywords)

View File

@ -12,7 +12,7 @@ from .meta import TableMetadata, ColumnMetadata
# columns: list of column names
# start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple('TableExpression', 'name columns start stop')
TableExpression = namedtuple("TableExpression", "name columns start stop")
def isolate_query_ctes(full_text, text_before_cursor):
@ -32,8 +32,8 @@ def isolate_query_ctes(full_text, text_before_cursor):
for cte in ctes:
if cte.start < current_position < cte.stop:
# Currently editing a cte - treat its body as the current full_text
text_before_cursor = full_text[cte.start:current_position]
full_text = full_text[cte.start:cte.stop]
text_before_cursor = full_text[cte.start : current_position]
full_text = full_text[cte.start : cte.stop]
return full_text, text_before_cursor, meta
# Append this cte to the list of available table metadata
@ -41,8 +41,8 @@ def isolate_query_ctes(full_text, text_before_cursor):
meta.append(TableMetadata(cte.name, cols))
# Editing past the last cte (ie the main body of the query)
full_text = full_text[ctes[-1].stop:]
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
full_text = full_text[ctes[-1].stop :]
text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]
return full_text, text_before_cursor, tuple(meta)
@ -68,7 +68,7 @@ def extract_ctes(sql):
# Get the next (meaningful) token, which should be the first CTE
idx, tok = p.token_next(idx)
if not tok:
return ([], '')
return ([], "")
start_pos = token_start_pos(p.tokens, idx)
ctes = []
@ -89,7 +89,7 @@ def extract_ctes(sql):
idx = p.token_index(tok) + 1
# Collapse everything after the ctes into a remainder query
remainder = u''.join(str(tok) for tok in p.tokens[idx:])
remainder = "".join(str(tok) for tok in p.tokens[idx:])
return ctes, remainder
@ -118,10 +118,10 @@ def extract_column_names(parsed):
idx, tok = parsed.token_next_by(t=DML)
tok_val = tok and tok.value.lower()
if tok_val in ('insert', 'update', 'delete'):
if tok_val in ("insert", "update", "delete"):
# Jump ahead to the RETURNING clause where the list of column names is
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
elif not tok_val == 'select':
idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
elif not tok_val == "select":
# Must be invalid CTE
return ()

View File

@ -2,22 +2,26 @@ from __future__ import unicode_literals
from collections import namedtuple
_ColumnMetadata = namedtuple(
'ColumnMetadata',
['name', 'datatype', 'foreignkeys', 'default', 'has_default']
"ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
)
def ColumnMetadata(
name, datatype, foreignkeys=None, default=None, has_default=False
):
return _ColumnMetadata(
name, datatype, foreignkeys or [], default, has_default
)
def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
'parentcolumn', 'childschema', 'childtable', 'childcolumn'])
TableMetadata = namedtuple('TableMetadata', 'name columns')
ForeignKey = namedtuple(
"ForeignKey",
[
"parentschema",
"parenttable",
"parentcolumn",
"childschema",
"childtable",
"childcolumn",
],
)
TableMetadata = namedtuple("TableMetadata", "name columns")
def parse_defaults(defaults_string):
@ -25,34 +29,42 @@ def parse_defaults(defaults_string):
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
if not defaults_string:
return
current = ''
current = ""
in_quote = None
for char in defaults_string:
if current == '' and char == ' ':
if current == "" and char == " ":
# Skip space after comma separating default expressions
continue
if char == '"' or char == '\'':
if char == '"' or char == "'":
if in_quote and char == in_quote:
# End quote
in_quote = None
elif not in_quote:
# Begin quote
in_quote = char
elif char == ',' and not in_quote:
elif char == "," and not in_quote:
# End of expression
yield current
current = ''
current = ""
continue
current += char
yield current
class FunctionMetadata(object):
def __init__(
self, schema_name, func_name, arg_names, arg_types, arg_modes,
return_type, is_aggregate, is_window, is_set_returning, is_extension,
arg_defaults
self,
schema_name,
func_name,
arg_names,
arg_types,
arg_modes,
return_type,
is_aggregate,
is_window,
is_set_returning,
is_extension,
arg_defaults,
):
"""Class for describing a postgresql function"""
@ -81,20 +93,27 @@ class FunctionMetadata(object):
self.is_window = is_window
self.is_set_returning = is_set_returning
self.is_extension = bool(is_extension)
self.is_public = (self.schema_name and self.schema_name == 'public')
self.is_public = self.schema_name and self.schema_name == "public"
def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
def _signature(self):
return (
self.schema_name, self.func_name, self.arg_names,
self.arg_types, self.arg_modes, self.return_type, self.is_aggregate,
self.is_window, self.is_set_returning, self.is_extension, self.arg_defaults
self.schema_name,
self.func_name,
self.arg_names,
self.arg_types,
self.arg_modes,
self.return_type,
self.is_aggregate,
self.is_window,
self.is_set_returning,
self.is_extension,
self.arg_defaults,
)
def __hash__(self):
@ -102,25 +121,23 @@ class FunctionMetadata(object):
def __repr__(self):
return (
(
'%s(schema_name=%r, func_name=%r, arg_names=%r, '
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
'is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)'
) % ((self.__class__.__name__,) + self._signature())
)
"%s(schema_name=%r, func_name=%r, arg_names=%r, "
"arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
"is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
) % ((self.__class__.__name__,) + self._signature())
def has_variadic(self):
return self.arg_modes and any(arg_mode == 'v' for arg_mode in self.arg_modes)
return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
def args(self):
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
if not self.arg_names:
return []
modes = self.arg_modes or ['i'] * len(self.arg_names)
modes = self.arg_modes or ["i"] * len(self.arg_names)
args = [
(name, typ)
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
]
def arg(name, typ, num):
@ -128,7 +145,8 @@ class FunctionMetadata(object):
num_defaults = len(self.arg_defaults)
has_default = num + num_defaults >= num_args
default = (
self.arg_defaults[num - num_args + num_defaults] if has_default
self.arg_defaults[num - num_args + num_defaults]
if has_default
else None
)
return ColumnMetadata(name, typ, [], default, has_default)
@ -138,7 +156,7 @@ class FunctionMetadata(object):
def fields(self):
"""Returns a list of output-field ColumnMetadata namedtuples"""
if self.return_type.lower() == 'void':
if self.return_type.lower() == "void":
return []
elif not self.arg_modes:
# For functions without output parameters, the function name
@ -146,7 +164,8 @@ class FunctionMetadata(object):
# E.g. 'SELECT unnest FROM unnest(...);'
return [ColumnMetadata(self.func_name, self.return_type, [])]
return [ColumnMetadata(name, typ, [])
for name, typ, mode in zip(
self.arg_names, self.arg_types, self.arg_modes)
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE
return [
ColumnMetadata(name, typ, [])
for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
if mode in ("o", "b", "t")
] # OUT, INOUT, TABLE

View File

@ -5,11 +5,17 @@ from collections import namedtuple
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
TableReference.ref = property(lambda self: self.alias or (
self.name if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'))
TableReference = namedtuple(
"TableReference", ["schema", "name", "alias", "is_function"]
)
TableReference.ref = property(
lambda self: self.alias
or (
self.name
if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'
)
)
# This code is borrowed from sqlparse example script.
@ -18,8 +24,13 @@ def is_subselect(parsed):
if not parsed.is_group:
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
'UPDATE', 'CREATE', 'DELETE'):
if item.ttype is DML and item.value.upper() in (
"SELECT",
"INSERT",
"UPDATE",
"CREATE",
"DELETE",
):
return True
return False
@ -45,23 +56,29 @@ def extract_from_part(parsed, stop_at_punctuation=True):
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
elif (
item.ttype is Keyword
and (not item.value.upper() == "FROM")
and (not item.value.upper().endswith("JOIN"))
):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
item_val.endswith('JOIN')):
if item_val in (
"COPY",
"FROM",
"INTO",
"UPDATE",
"TABLE",
) or item_val.endswith("JOIN"):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
tbl_prefix_seen = True
break
@ -102,13 +119,15 @@ def extract_table_identifiers(token_stream, allow_functions=True):
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = (allow_functions and
_identifier_is_function(identifier))
is_function = allow_functions and _identifier_is_function(
identifier
)
except AttributeError:
continue
if real_name:
yield TableReference(schema_name, real_name,
identifier.get_alias(), is_function)
yield TableReference(
schema_name, real_name, identifier.get_alias(), is_function
)
elif isinstance(item, Identifier):
schema_name, real_name, alias = parse_identifier(item)
is_function = allow_functions and _identifier_is_function(item)
@ -136,7 +155,7 @@ def extract_tables(sql):
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
insert_stmt = parsed[0].token_first().value.lower() == "insert"
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as
@ -144,7 +163,6 @@ def extract_tables(sql):
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(stream,
allow_functions=not insert_stmt)
identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt)
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
return tuple(i for i in identifiers if i.name)

View File

@ -6,17 +6,17 @@ from sqlparse.tokens import Token, Error
cleanup_regex = {
# This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'),
"alphanum_underscore": re.compile(r"(\w+)$"),
# This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'),
"many_punctuations": re.compile(r"([^():,\s]+)$"),
# This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
"most_punctuations": re.compile(r"([^\.():,\s]+)$"),
# This matches everything except a space.
'all_punctuations': re.compile(r'([^\s]+)$'),
"all_punctuations": re.compile(r"([^\s]+)$"),
}
def last_word(text, include='alphanum_underscore'):
def last_word(text, include="alphanum_underscore"):
r"""
Find the last word in a sentence.
@ -50,18 +50,18 @@ def last_word(text, include='alphanum_underscore'):
'"foo*bar'
"""
if not text: # Empty string
return ''
if not text: # Empty string
return ""
if text[-1].isspace():
return ''
return ""
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ''
return ""
def find_prev_keyword(sql, n_skip=0):
@ -71,17 +71,18 @@ def find_prev_keyword(sql, n_skip=0):
everything after the last keyword stripped
"""
if not sql.strip():
return None, ''
return None, ""
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
flattened = flattened[:len(flattened)-n_skip]
flattened = flattened[: len(flattened) - n_skip]
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
logical_operators = ("AND", "OR", "NOT", "BETWEEN")
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
if t.value == "(" or (
t.is_keyword and (t.value.upper() not in logical_operators)
):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index thows an error
@ -94,14 +95,14 @@ def find_prev_keyword(sql, n_skip=0):
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx+1])
text = "".join(tok.value for tok in flattened[: idx + 1])
return t, text
return None, ''
return None, ""
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
def is_open_quote(sql):

View File

@ -2,7 +2,7 @@ import os
import json
root = os.path.dirname(__file__)
literal_file = os.path.join(root, 'pgliterals.json')
literal_file = os.path.join(root, "pgliterals.json")
with open(literal_file) as f:
literals = json.load(f)

View File

@ -7,17 +7,17 @@ from collections import defaultdict
from .pgliterals.main import get_literals
white_space_regex = re.compile('\\s+', re.MULTILINE)
white_space_regex = re.compile("\\s+", re.MULTILINE)
def _compile_regex(keyword):
# Surround the keyword with word boundaries and replace interior whitespace
# with whitespace wildcards
pattern = '\\b' + white_space_regex.sub(r'\\s+', keyword) + '\\b'
pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
keywords = get_literals('keywords')
keywords = get_literals("keywords")
keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)

View File

@ -15,8 +15,9 @@ def confirm_destructive_query(queries):
* False if the query is destructive and the user doesn't want to proceed.
"""
prompt_text = ("You're about to run a destructive command.\n"
"Do you want to proceed? (y/n)")
prompt_text = (
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
)
if is_destructive(queries) and sys.stdin.isatty():
return prompt(prompt_text, type=bool)

View File

@ -5,8 +5,7 @@ import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where
from .parseutils.utils import (
last_word, find_prev_keyword, parse_partial_identifier)
from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier
from .parseutils.tables import extract_tables
from .parseutils.ctes import isolate_query_ctes
from pgspecial.main import parse_special_command
@ -20,23 +19,23 @@ else:
string_types = basestring
Special = namedtuple('Special', [])
Database = namedtuple('Database', [])
Schema = namedtuple('Schema', ['quoted'])
Special = namedtuple("Special", [])
Database = namedtuple("Database", [])
Schema = namedtuple("Schema", ["quoted"])
Schema.__new__.__defaults__ = (False,)
# FromClauseItem is a table/view/function used in the FROM clause
# `table_refs` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
TableFormat = namedtuple('TableFormat', [])
View = namedtuple('View', ['schema', 'table_refs'])
FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
TableFormat = namedtuple("TableFormat", [])
View = namedtuple("View", ["schema", "table_refs"])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
Join = namedtuple('Join', ['table_refs', 'schema'])
Join = namedtuple("Join", ["table_refs", "schema"])
Function = namedtuple('Function', ['schema', 'table_refs', 'usage'])
Function = namedtuple("Function", ["schema", "table_refs", "usage"])
# For convenience, don't require the `usage` argument in Function constructor
Function.__new__.__defaults__ = (None, tuple(), None)
Table.__new__.__defaults__ = (None, tuple(), tuple())
@ -44,30 +43,32 @@ View.__new__.__defaults__ = (None, tuple())
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
Column = namedtuple(
'Column',
['table_refs', 'require_last_table', 'local_tables', 'qualifiable', 'context']
"Column",
["table_refs", "require_last_table", "local_tables", "qualifiable", "context"],
)
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
Keyword = namedtuple('Keyword', ['last_token'])
Keyword = namedtuple("Keyword", ["last_token"])
Keyword.__new__.__defaults__ = (None,)
NamedQuery = namedtuple('NamedQuery', [])
Datatype = namedtuple('Datatype', ['schema'])
Alias = namedtuple('Alias', ['aliases'])
NamedQuery = namedtuple("NamedQuery", [])
Datatype = namedtuple("Datatype", ["schema"])
Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple('Path', [])
Path = namedtuple("Path", [])
class SqlStatement(object):
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(
text_before_cursor, include='many_punctuations')
text_before_cursor, include="many_punctuations"
)
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
full_text, text_before_cursor, self.local_tables = \
isolate_query_ctes(full_text, text_before_cursor)
full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
full_text, text_before_cursor
)
self.text_before_cursor_including_last_word = text_before_cursor
@ -78,28 +79,29 @@ class SqlStatement(object):
# completion useless because it will always return the list of
# keywords as completion.
if self.word_before_cursor:
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
parsed = sqlparse.parse(text_before_cursor)
else:
text_before_cursor = text_before_cursor[:-len(word_before_cursor)]
text_before_cursor = text_before_cursor[: -len(word_before_cursor)]
parsed = sqlparse.parse(text_before_cursor)
self.identifier = parse_partial_identifier(word_before_cursor)
else:
parsed = sqlparse.parse(text_before_cursor)
full_text, text_before_cursor, parsed = \
_split_multiple_statements(full_text, text_before_cursor, parsed)
full_text, text_before_cursor, parsed = _split_multiple_statements(
full_text, text_before_cursor, parsed
)
self.full_text = full_text
self.text_before_cursor = text_before_cursor
self.parsed = parsed
self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or ''
self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or ""
def is_insert(self):
return self.parsed.token_first().value.lower() == 'insert'
return self.parsed.token_first().value.lower() == "insert"
def get_tables(self, scope='full'):
def get_tables(self, scope="full"):
""" Gets the tables available in the statement.
param `scope:` possible values: 'full', 'insert', 'before'
If 'insert', only the first table is returned.
@ -107,8 +109,9 @@ class SqlStatement(object):
If not 'insert' and the stmt is an insert, the first table is skipped.
"""
tables = extract_tables(
self.full_text if scope == 'full' else self.text_before_cursor)
if scope == 'insert':
self.full_text if scope == "full" else self.text_before_cursor
)
if scope == "insert":
tables = tables[:1]
elif self.is_insert():
tables = tables[1:]
@ -126,8 +129,9 @@ class SqlStatement(object):
return schema
def reduce_to_prev_keyword(self, n_skip=0):
prev_keyword, self.text_before_cursor = \
find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
prev_keyword, self.text_before_cursor = find_prev_keyword(
self.text_before_cursor, n_skip=n_skip
)
return prev_keyword
@ -139,7 +143,7 @@ def suggest_type(full_text, text_before_cursor):
A scope for a column category will be a list of tables.
"""
if full_text.startswith('\\i '):
if full_text.startswith("\\i "):
return (Path(),)
# This is a temporary hack; the exception handling
@ -154,14 +158,14 @@ def suggest_type(full_text, text_before_cursor):
# Be careful here because trivial whitespace is parsed as a
# statement, but the statement won't have a first token
tok1 = stmt.parsed.token_first()
if tok1 and tok1.value.startswith('\\'):
if tok1 and tok1.value.startswith("\\"):
text = stmt.text_before_cursor + stmt.word_before_cursor
return suggest_special(text)
return suggest_based_on_last_token(stmt.last_token, stmt)
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
def _strip_named_query(txt):
@ -172,11 +176,11 @@ def _strip_named_query(txt):
"""
if named_query_regex.match(txt):
txt = named_query_regex.sub('', txt)
txt = named_query_regex.sub("", txt)
return txt
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
def _find_function_body(text):
@ -222,12 +226,12 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
return full_text, text_before_cursor, None
token2 = None
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
token1 = statement.token_first()
if token1:
token1_idx = statement.token_index(token1)
token2 = statement.token_next(token1_idx)[1]
if token2 and token2.value.upper() == 'FUNCTION':
if token2 and token2.value.upper() == "FUNCTION":
full_text, text_before_cursor, statement = _statement_from_function(
full_text, text_before_cursor, statement
)
@ -235,11 +239,11 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
SPECIALS_SUGGESTION = {
'dT': Datatype,
'df': Function,
'dt': Table,
'dv': View,
'sf': Function,
"dT": Datatype,
"df": Function,
"dt": Table,
"dv": View,
"sf": Function,
}
@ -251,13 +255,13 @@ def suggest_special(text):
# Trying to complete the special command itself
return (Special(),)
if cmd in ('\\c', '\\connect'):
if cmd in ("\\c", "\\connect"):
return (Database(),)
if cmd == '\\T':
if cmd == "\\T":
return (TableFormat(),)
if cmd == '\\dn':
if cmd == "\\dn":
return (Schema(),)
if arg:
@ -272,27 +276,24 @@ def suggest_special(text):
else:
schema = None
if cmd[1:] == 'd':
if cmd[1:] == "d":
# \d can describe tables or views
if schema:
return (Table(schema=schema),
View(schema=schema),)
return (Table(schema=schema), View(schema=schema))
else:
return (Schema(),
Table(schema=None),
View(schema=None),)
return (Schema(), Table(schema=None), View(schema=None))
elif cmd[1:] in SPECIALS_SUGGESTION:
rel_type = SPECIALS_SUGGESTION[cmd[1:]]
if schema:
if rel_type == Function:
return (Function(schema=schema, usage='special'),)
return (Function(schema=schema, usage="special"),)
return (rel_type(schema=schema),)
else:
if rel_type == Function:
return (Schema(), Function(schema=None, usage='special'),)
return (Schema(), Function(schema=None, usage="special"))
return (Schema(), rel_type(schema=None))
if cmd in ['\\n', '\\ns', '\\nd']:
if cmd in ["\\n", "\\ns", "\\nd"]:
return (NamedQuery(),)
return (Keyword(), Special())
@ -327,9 +328,9 @@ def suggest_based_on_last_token(token, stmt):
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
if prev_keyword and prev_keyword.value == '(':
if prev_keyword and prev_keyword.value == "(":
# Suggest datatypes
return suggest_based_on_last_token('type', stmt)
return suggest_based_on_last_token("type", stmt)
else:
return (Keyword(),)
else:
@ -337,7 +338,7 @@ def suggest_based_on_last_token(token, stmt):
if not token:
return (Keyword(), Special())
elif token_v.endswith('('):
elif token_v.endswith("("):
p = sqlparse.parse(stmt.text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
@ -352,7 +353,7 @@ def suggest_based_on_last_token(token, stmt):
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token('where', stmt)
column_suggestions = suggest_based_on_last_token("where", stmt)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
@ -363,7 +364,7 @@ def suggest_based_on_last_token(token, stmt):
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == 'exists':
if prev_tok == "exists":
return (Keyword(),)
else:
return column_suggestions
@ -371,57 +372,71 @@ def suggest_based_on_last_token(token, stmt):
# Get the token before the parens
prev_tok = p.token_prev(len(p.tokens) - 1)[1]
if (prev_tok and prev_tok.value
and prev_tok.value.lower().split(' ')[-1] == 'using'):
if (
prev_tok
and prev_tok.value
and prev_tok.value.lower().split(" ")[-1] == "using"
):
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = stmt.get_tables('before')
tables = stmt.get_tables("before")
# suggest columns that are present in more than one table
return (Column(table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables),)
return (
Column(
table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables,
),
)
elif p.token_first().value.lower() == 'select':
elif p.token_first().value.lower() == "select":
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(stmt.text_before_cursor,
'all_punctuations').startswith('('):
if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("):
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
return (
Column(table_refs=stmt.get_tables('insert'), context='insert'),
)
if prev_prev_tok and prev_prev_tok.normalized == "INTO":
return (Column(table_refs=stmt.get_tables("insert"), context="insert"),)
# We're probably in a function argument list
return (Column(table_refs=extract_tables(stmt.full_text),
local_tables=stmt.local_tables, qualifiable=True),)
elif token_v == 'set':
return (Column(table_refs=stmt.get_tables(),
local_tables=stmt.local_tables),)
elif token_v in ('select', 'where', 'having', 'order by', 'distinct'):
return (
Column(
table_refs=extract_tables(stmt.full_text),
local_tables=stmt.local_tables,
qualifiable=True,
),
)
elif token_v == "set":
return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),)
elif token_v in ("select", "where", "having", "order by", "distinct"):
# Check for a table alias or schema qualification
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
return (
Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),
)
else:
return (Column(table_refs=tables, local_tables=stmt.local_tables,
qualifiable=True),
Function(schema=None),
Keyword(token_v.upper()),)
elif token_v == 'as':
return (
Column(
table_refs=tables, local_tables=stmt.local_tables, qualifiable=True
),
Function(schema=None),
Keyword(token_v.upper()),
)
elif token_v == "as":
# Don't suggest anything for aliases
return ()
elif (token_v.endswith('join') and token.is_keyword) or (token_v in
('copy', 'from', 'update', 'into', 'describe', 'truncate')):
elif (token_v.endswith("join") and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate")
):
schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith('join') and token.is_keyword
is_join = token_v.endswith("join") and token.is_keyword
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
@ -431,98 +446,101 @@ def suggest_based_on_last_token(token, stmt):
# Suggest schemas
suggest.insert(0, Schema())
if token_v == 'from' or is_join:
suggest.append(FromClauseItem(schema=schema,
table_refs=tables,
local_tables=stmt.local_tables))
elif token_v == 'truncate':
if token_v == "from" or is_join:
suggest.append(
FromClauseItem(
schema=schema, table_refs=tables, local_tables=stmt.local_tables
)
)
elif token_v == "truncate":
suggest.append(Table(schema))
else:
suggest.extend((Table(schema), View(schema)))
if is_join and _allow_join(stmt.parsed):
tables = stmt.get_tables('before')
tables = stmt.get_tables("before")
suggest.append(Join(table_refs=tables, schema=schema))
return tuple(suggest)
elif token_v == 'function':
elif token_v == "function":
schema = stmt.get_identifier_schema()
# stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:`
try:
prev = stmt.get_previous_token(token).value.lower()
if prev in('drop', 'alter', 'create', 'create or replace'):
return (Function(schema=schema, usage='signature'),)
if prev in ("drop", "alter", "create", "create or replace"):
return (Function(schema=schema, usage="signature"),)
except ValueError:
pass
return tuple()
elif token_v in ('table', 'view'):
elif token_v in ("table", "view"):
# E.g. 'ALTER TABLE <tablname>'
rel_type = {'table': Table, 'view': View, 'function': Function}[token_v]
rel_type = {"table": Table, "view": View, "function": Function}[token_v]
schema = stmt.get_identifier_schema()
if schema:
return (rel_type(schema=schema),)
else:
return (Schema(), rel_type(schema=schema))
elif token_v == 'column':
elif token_v == "column":
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
return (Column(table_refs=stmt.get_tables()),)
elif token_v == 'on':
tables = stmt.get_tables('before')
elif token_v == "on":
tables = stmt.get_tables("before")
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t))
sugs = [Column(table_refs=filteredtables,
local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent)]
sugs = [
Column(table_refs=filteredtables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),
]
if filteredtables and _allow_join_condition(stmt.parsed):
sugs.append(JoinCondition(table_refs=tables,
parent=filteredtables[-1]))
sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1]))
return tuple(sugs)
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.ref for t in tables)
if _allow_join_condition(stmt.parsed):
return (Alias(aliases=aliases), JoinCondition(
table_refs=tables, parent=None))
return (
Alias(aliases=aliases),
JoinCondition(table_refs=tables, parent=None),
)
else:
return (Alias(aliases=aliases),)
elif token_v in ('c', 'use', 'database', 'template'):
elif token_v in ("c", "use", "database", "template"):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return (Database(),)
elif token_v == 'schema':
elif token_v == "schema":
# DROP SCHEMA schema_name, SET SCHEMA schema name
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
quoted = prev_keyword and prev_keyword.value.lower() == 'set'
quoted = prev_keyword and prev_keyword.value.lower() == "set"
return (Schema(quoted),)
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
elif token_v.endswith(",") or token_v in ("=", "and", "or"):
prev_keyword = stmt.reduce_to_prev_keyword()
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
else:
return ()
elif token_v in ('type', '::'):
elif token_v in ("type", "::"):
# ALTER TABLE foo SET DATA TYPE bar
# SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so
# they're suggested here as well
schema = stmt.get_identifier_schema()
suggestions = [Datatype(schema=schema),
Table(schema=schema)]
suggestions = [Datatype(schema=schema), Table(schema=schema)]
if not schema:
suggestions.append(Schema())
return tuple(suggestions)
elif token_v in {'alter', 'create', 'drop'}:
elif token_v in {"alter", "create", "drop"}:
return (Keyword(token_v.upper()),)
elif token.is_keyword:
# token is a keyword we haven't implemented any special handling for
@ -539,8 +557,11 @@ def suggest_based_on_last_token(token, stmt):
def identifies(id, ref):
"""Returns true if string `id` matches TableReference `ref`"""
return id == ref.alias or id == ref.name or (
ref.schema and (id == ref.schema + '.' + ref.name))
return (
id == ref.alias
or id == ref.name
or (ref.schema and (id == ref.schema + "." + ref.name))
)
def _allow_join_condition(statement):
@ -560,7 +581,7 @@ def _allow_join_condition(statement):
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower() in ('on', 'and', 'or')
return last_tok.value.lower() in ("on", "and", "or")
def _allow_join(statement):
@ -579,5 +600,7 @@ def _allow_join(statement):
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return (last_tok.value.lower().endswith('join')
and last_tok.value.lower() not in('cross join', 'natural join'))
return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in (
"cross join",
"natural join",
)

View File

@ -13,10 +13,11 @@ def pg_is_multiline(pgcli):
if not pgcli.multi_line:
return False
if pgcli.multiline_mode == 'safe':
if pgcli.multiline_mode == "safe":
return True
else:
return not _multiline_exception(doc.text)
return cond
@ -24,17 +25,19 @@ def _is_complete(sql):
# A complete command is an sql statement that ends with a semicolon, unless
# there's an open quote surrounding it, as is common when writing a
# CREATE FUNCTION command
return sql.endswith(';') and not is_open_quote(sql)
return sql.endswith(";") and not is_open_quote(sql)
def _multiline_exception(text):
text = text.strip()
return (
text.startswith('\\') or # Special Command
text.endswith(r'\e') or # Ended with \e which should launch the editor
_is_complete(text) or # A complete SQL command
(text == 'exit') or # Exit doesn't need semi-colon
(text == 'quit') or # Quit doesn't need semi-colon
(text == ':q') or # To all the vim fans out there
(text == '') # Just a plain enter without any text
text.startswith("\\")
or text.endswith(r"\e") # Special Command
or _is_complete(text) # Ended with \e which should launch the editor
or (text == "exit") # A complete SQL command
or (text == "quit") # Exit doesn't need semi-colon
or (text == ":q") # Quit doesn't need semi-colon
or ( # To all the vim fans out there
text == ""
) # Just a plain enter without any text
)

View File

@ -9,9 +9,24 @@ from pgspecial.namedqueries import NamedQueries
from prompt_toolkit.completion import Completer, Completion, PathCompleter
from prompt_toolkit.document import Document
from .packages.sqlcompletion import (
FromClauseItem, suggest_type, Special, Database, Schema, Table,
TableFormat, Function, Column, View, Keyword, NamedQuery,
Datatype, Alias, Path, JoinCondition, Join)
FromClauseItem,
suggest_type,
Special,
Database,
Schema,
Table,
TableFormat,
Function,
Column,
View,
Keyword,
NamedQuery,
Datatype,
Alias,
Path,
JoinCondition,
Join,
)
from .packages.parseutils.meta import ColumnMetadata, ForeignKey
from .packages.parseutils.utils import last_word
from .packages.parseutils.tables import TableReference
@ -21,34 +36,30 @@ from .config import load_config, config_location
_logger = logging.getLogger(__name__)
Match = namedtuple('Match', ['completion', 'priority'])
Match = namedtuple("Match", ["completion", "priority"])
_SchemaObject = namedtuple('SchemaObject', 'name schema meta')
_SchemaObject = namedtuple("SchemaObject", "name schema meta")
def SchemaObject(name, schema=None, meta=None):
return _SchemaObject(name, schema, meta)
_Candidate = namedtuple(
'Candidate', 'completion prio meta synonyms prio2 display'
)
_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display")
def Candidate(
completion, prio=None, meta=None, synonyms=None, prio2=None,
display=None
completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
):
return _Candidate(
completion, prio, meta, synonyms or [completion], prio2,
display or completion
completion, prio, meta, synonyms or [completion], prio2, display or completion
)
# Used to strip trailing '::some_type' from default-value expressions
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$')
arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
def generate_alias(tbl):
@ -57,18 +68,20 @@ def generate_alias(tbl):
all letters preceded by _
param tbl - unescaped name of the table to alias
"""
return ''.join([l for l in tbl if l.isupper()] or
[l for l, prev in zip(tbl, '_' + tbl) if prev == '_' and l != '_'])
return "".join(
[l for l in tbl if l.isupper()]
or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]
)
class PGCompleter(Completer):
# keywords_tree: A dict mapping keywords to well known following keywords.
# e.g. 'CREATE': ['TABLE', 'USER', ...],
keywords_tree = get_literals('keywords', type_=dict)
keywords_tree = get_literals("keywords", type_=dict)
keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values())))
functions = get_literals('functions')
datatypes = get_literals('datatypes')
reserved_words = set(get_literals('reserved'))
functions = get_literals("functions")
datatypes = get_literals("datatypes")
reserved_words = set(get_literals("reserved"))
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
super(PGCompleter, self).__init__()
@ -77,48 +90,49 @@ class PGCompleter(Completer):
self.prioritizer = PrevalenceCounter()
settings = settings or {}
self.signature_arg_style = settings.get(
'signature_arg_style', '{arg_name} {arg_type}'
"signature_arg_style", "{arg_name} {arg_type}"
)
self.call_arg_style = settings.get(
'call_arg_style', '{arg_name: <{max_arg_len}} := {arg_default}'
"call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}"
)
self.call_arg_display_style = settings.get(
'call_arg_display_style', '{arg_name}'
"call_arg_display_style", "{arg_name}"
)
self.call_arg_oneliner_max = settings.get('call_arg_oneliner_max', 2)
self.search_path_filter = settings.get('search_path_filter')
self.generate_aliases = settings.get('generate_aliases')
self.casing_file = settings.get('casing_file')
self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2)
self.search_path_filter = settings.get("search_path_filter")
self.generate_aliases = settings.get("generate_aliases")
self.casing_file = settings.get("casing_file")
self.insert_col_skip_patterns = [
re.compile(pattern) for pattern in settings.get(
'insert_col_skip_patterns',
[r'^now\(\)$', r'^nextval\(']
re.compile(pattern)
for pattern in settings.get(
"insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("]
)
]
self.generate_casing_file = settings.get('generate_casing_file')
self.qualify_columns = settings.get(
'qualify_columns', 'if_more_than_one_table')
self.generate_casing_file = settings.get("generate_casing_file")
self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table")
self.asterisk_column_order = settings.get(
'asterisk_column_order', 'table_order')
"asterisk_column_order", "table_order"
)
keyword_casing = settings.get('keyword_casing', 'upper').lower()
if keyword_casing not in ('upper', 'lower', 'auto'):
keyword_casing = 'upper'
keyword_casing = settings.get("keyword_casing", "upper").lower()
if keyword_casing not in ("upper", "lower", "auto"):
keyword_casing = "upper"
self.keyword_casing = keyword_casing
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
self.databases = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
'datatypes': {}}
self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
self.search_path = []
self.casing = {}
self.all_completions = set(self.keywords + self.functions)
def escape_name(self, name):
if name and ((not self.name_pattern.match(name))
or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)):
if name and (
(not self.name_pattern.match(name))
or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)
):
name = '"%s"' % name
return name
@ -147,7 +161,7 @@ class PGCompleter(Completer):
# schemata is a list of schema names
schemata = self.escaped_names(schemata)
metadata = self.dbmetadata['tables']
metadata = self.dbmetadata["tables"]
for schema in schemata:
metadata[schema] = {}
@ -185,8 +199,9 @@ class PGCompleter(Completer):
try:
metadata[schema][relname] = OrderedDict()
except KeyError:
_logger.error('%r %r listed in unrecognized schema %r',
kind, relname, schema)
_logger.error(
"%r %r listed in unrecognized schema %r", kind, relname, schema
)
self.all_completions.add(relname)
def extend_columns(self, column_data, kind):
@ -201,13 +216,12 @@ class PGCompleter(Completer):
"""
metadata = self.dbmetadata[kind]
for schema, relname, colname, datatype, has_default, default in column_data:
(schema, relname, colname) = self.escaped_names(
[schema, relname, colname])
(schema, relname, colname) = self.escaped_names([schema, relname, colname])
column = ColumnMetadata(
name=colname,
datatype=datatype,
has_default=has_default,
default=default
default=default,
)
metadata[schema][relname][colname] = column
self.all_completions.add(colname)
@ -218,7 +232,7 @@ class PGCompleter(Completer):
# dbmetadata['schema_name']['functions']['function_name'] should return
# the function metadata namedtuple for the corresponding function
metadata = self.dbmetadata['functions']
metadata = self.dbmetadata["functions"]
for f in func_data:
schema, func = self.escaped_names([f.schema_name, f.func_name])
@ -239,11 +253,11 @@ class PGCompleter(Completer):
self._arg_list_cache = {
usage: {
meta: self._arg_list(meta, usage)
for sch, funcs in self.dbmetadata['functions'].items()
for sch, funcs in self.dbmetadata["functions"].items()
for func, metas in funcs.items()
for meta in metas
}
for usage in ('call', 'call_display', 'signature')
for usage in ("call", "call_display", "signature")
}
def extend_foreignkeys(self, fk_data):
@ -254,17 +268,18 @@ class PGCompleter(Completer):
# These are added as a list of ForeignKey namedtuples to the
# ColumnMetadata namedtuple for both the child and parent
meta = self.dbmetadata['tables']
meta = self.dbmetadata["tables"]
for fk in fk_data:
e = self.escaped_names
parentschema, childschema = e([fk.parentschema, fk.childschema])
parenttable, childtable = e([fk.parenttable, fk.childtable])
childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
childcolmeta = meta[childschema][childtable][childcol]
parcolmeta = meta[parentschema][parenttable][parcol]
fk = ForeignKey(parentschema, parenttable, parcol,
childschema, childtable, childcol)
childcolmeta = meta[childschema][childtable][childcol]
parcolmeta = meta[parentschema][parenttable][parcol]
fk = ForeignKey(
parentschema, parenttable, parcol, childschema, childtable, childcol
)
childcolmeta.foreignkeys.append((fk))
parcolmeta.foreignkeys.append((fk))
@ -273,7 +288,7 @@ class PGCompleter(Completer):
# dbmetadata['datatypes'][schema_name][type_name] should store type
# metadata, such as composite type field names. Currently, we're not
# storing any metadata beyond typename, so just store None
meta = self.dbmetadata['datatypes']
meta = self.dbmetadata["datatypes"]
for t in type_data:
schema, type_name = self.escaped_names(t)
@ -295,11 +310,10 @@ class PGCompleter(Completer):
self.databases = []
self.special_commands = []
self.search_path = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
'datatypes': {}}
self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
self.all_completions = set(self.keywords + self.functions)
def find_matches(self, text, collection, mode='fuzzy', meta=None):
def find_matches(self, text, collection, mode="fuzzy", meta=None):
"""Find completion matches for the given text.
Given the user's input text and a collection of available
@ -319,12 +333,22 @@ class PGCompleter(Completer):
if not collection:
return []
prio_order = [
'keyword', 'function', 'view', 'table', 'datatype', 'database',
'schema', 'column', 'table alias', 'join', 'name join', 'fk join',
'table format'
"keyword",
"function",
"view",
"table",
"datatype",
"database",
"schema",
"column",
"table alias",
"join",
"name join",
"fk join",
"table format",
]
type_priority = prio_order.index(meta) if meta in prio_order else -1
text = last_word(text, include='most_punctuations').lower()
text = last_word(text, include="most_punctuations").lower()
text_len = len(text)
if text and text[0] == '"':
@ -334,7 +358,7 @@ class PGCompleter(Completer):
# Completion.position value is correct
text = text[1:]
if mode == 'fuzzy':
if mode == "fuzzy":
fuzzy = True
priority_func = self.prioritizer.name_count
else:
@ -347,19 +371,20 @@ class PGCompleter(Completer):
# Note: higher priority values mean more important, so use negative
# signs to flip the direction of the tuple
if fuzzy:
regex = '.*?'.join(map(re.escape, text))
pat = re.compile('(%s)' % regex)
regex = ".*?".join(map(re.escape, text))
pat = re.compile("(%s)" % regex)
def _match(item):
if item.lower()[:len(text) + 1] in (text, text + ' '):
if item.lower()[: len(text) + 1] in (text, text + " "):
# Exact match of first word in suggestion
# This is to get exact alias matches to the top
# E.g. for input `e`, 'Entries E' should be on top
# (before e.g. `EndUsers EU`)
return float('Infinity'), -1
return float("Infinity"), -1
r = pat.search(self.unescape_name(item.lower()))
if r:
return -len(r.group()), -r.start()
else:
match_end_limit = len(text)
@ -368,7 +393,7 @@ class PGCompleter(Completer):
if match_point >= 0:
# Use negative infinity to force keywords to sort after all
# fuzzy matches
return -float('Infinity'), -match_point
return -float("Infinity"), -match_point
matches = []
for cand in collection:
@ -387,7 +412,7 @@ class PGCompleter(Completer):
if sort_key:
if display_meta and len(display_meta) > 50:
# Truncate meta-text to 50 characters, if necessary
display_meta = display_meta[:47] + u'...'
display_meta = display_meta[:47] + "..."
# Lexical order of items in the collection, used for
# tiebreaking items with the same match group length and start
@ -398,15 +423,24 @@ class PGCompleter(Completer):
# case-sensitive one as a tie breaker.
# We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names.
lexical_priority = (tuple(0 if c in(' _') else -ord(c)
for c in self.unescape_name(item.lower())) + (1,)
+ tuple(c for c in item))
lexical_priority = (
tuple(
0 if c in (" _") else -ord(c)
for c in self.unescape_name(item.lower())
)
+ (1,)
+ tuple(c for c in item)
)
item = self.case(item)
display = self.case(display)
priority = (
sort_key, type_priority, prio, priority_func(item),
prio2, lexical_priority
sort_key,
type_priority,
prio,
priority_func(item),
prio2,
lexical_priority,
)
matches.append(
Match(
@ -414,9 +448,9 @@ class PGCompleter(Completer):
text=item,
start_position=-text_len,
display_meta=display_meta,
display=display
display=display,
),
priority=priority
priority=priority,
)
)
return matches
@ -433,17 +467,18 @@ class PGCompleter(Completer):
# If smart_completion is off then match any word that starts with
# 'word_before_cursor'.
if not smart_completion:
matches = self.find_matches(word_before_cursor, self.all_completions,
mode='strict')
matches = self.find_matches(
word_before_cursor, self.all_completions, mode="strict"
)
completions = [m.completion for m in matches]
return sorted(completions, key=operator.attrgetter('text'))
return sorted(completions, key=operator.attrgetter("text"))
matches = []
suggestions = suggest_type(document.text, document.text_before_cursor)
for suggestion in suggestions:
suggestion_type = type(suggestion)
_logger.debug('Suggestion type: %r', suggestion_type)
_logger.debug("Suggestion type: %r", suggestion_type)
# Map suggestion type to method
# e.g. 'table' -> self.get_table_matches
@ -451,76 +486,91 @@ class PGCompleter(Completer):
matches.extend(matcher(self, suggestion, word_before_cursor))
# Sort matches so highest priorities are first
matches = sorted(matches, key=operator.attrgetter('priority'),
reverse=True)
matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True)
return [m.completion for m in matches]
def get_column_matches(self, suggestion, word_before_cursor):
tables = suggestion.table_refs
do_qualify = suggestion.qualifiable and {'always': True, 'never': False,
'if_more_than_one_table': len(tables) > 1}[self.qualify_columns]
do_qualify = (
suggestion.qualifiable
and {
"always": True,
"never": False,
"if_more_than_one_table": len(tables) > 1,
}[self.qualify_columns]
)
qualify = lambda col, tbl: (
(tbl + '.' + self.case(col)) if do_qualify else self.case(col))
(tbl + "." + self.case(col)) if do_qualify else self.case(col)
)
_logger.debug("Completion column scope: %r", tables)
scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
def make_cand(name, ref):
synonyms = (name, generate_alias(self.case(name)))
return Candidate(qualify(name, ref), 0, 'column', synonyms)
return Candidate(qualify(name, ref), 0, "column", synonyms)
def flat_cols():
return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() for c in cols]
return [
make_cand(c.name, t.ref)
for t, cols in scoped_cols.items()
for c in cols
]
if suggestion.require_last_table:
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref
other_tbl_cols = set(
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs)
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
)
scoped_cols = {
t: [col for col in cols if col.name in other_tbl_cols]
for t, cols in scoped_cols.items()
if t.ref == ltbl
}
lastword = last_word(word_before_cursor, include='most_punctuations')
if lastword == '*':
if suggestion.context == 'insert':
lastword = last_word(word_before_cursor, include="most_punctuations")
if lastword == "*":
if suggestion.context == "insert":
def filter(col):
if not col.has_default:
return True
return not any(
p.match(col.default)
for p in self.insert_col_skip_patterns
p.match(col.default) for p in self.insert_col_skip_patterns
)
scoped_cols = {
t: [col for col in cols if filter(col)] for t, cols in scoped_cols.items()
t: [col for col in cols if filter(col)]
for t, cols in scoped_cols.items()
}
if self.asterisk_column_order == 'alphabetic':
if self.asterisk_column_order == "alphabetic":
for cols in scoped_cols.values():
cols.sort(key=operator.attrgetter('name'))
if (lastword != word_before_cursor and len(tables) == 1
and word_before_cursor[-len(lastword) - 1] == '.'):
cols.sort(key=operator.attrgetter("name"))
if (
lastword != word_before_cursor
and len(tables) == 1
and word_before_cursor[-len(lastword) - 1] == "."
):
# User typed x.*; replicate "x." for all columns except the
# first, which gets the original (as we only replace the "*"")
sep = ', ' + word_before_cursor[:-1]
collist = sep.join(self.case(c.completion)
for c in flat_cols())
sep = ", " + word_before_cursor[:-1]
collist = sep.join(self.case(c.completion) for c in flat_cols())
else:
collist = ', '.join(qualify(c.name, t.ref)
for t, cs in scoped_cols.items() for c in cs)
collist = ", ".join(
qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs
)
return [Match(
completion=Completion(
collist,
-1,
display_meta='columns',
display='*'
),
priority=(1, 1, 1)
)]
return [
Match(
completion=Completion(
collist, -1, display_meta="columns", display="*"
),
priority=(1, 1, 1),
)
]
return self.find_matches(word_before_cursor, flat_cols(),
meta='column')
return self.find_matches(word_before_cursor, flat_cols(), meta="column")
def alias(self, tbl, tbls):
""" Generate a unique table alias
@ -546,13 +596,16 @@ class PGCompleter(Completer):
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
refs = set(normalize_ref(t.ref) for t in tbls)
other_tbls = set((t.schema, t.name)
for t in list(cols)[:-1])
other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
joins = []
# Iterate over FKs in existing tables to find potential joins
fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items()
for rcol in rcols for fk in rcol.foreignkeys)
col = namedtuple('col', 'schema tbl col')
fks = (
(fk, rtbl, rcol)
for rtbl, rcols in cols.items()
for rcol in rcols
for fk in rcol.foreignkeys
)
col = namedtuple("col", "schema tbl col")
for fk, rtbl, rcol in fks:
right = col(rtbl.schema, rtbl.name, rcol.name)
child = col(fk.childschema, fk.childtable, fk.childcolumn)
@ -563,57 +616,66 @@ class PGCompleter(Completer):
c = self.case
if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.table_refs)
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref)
join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref
)
else:
join = '{0} ON {0}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col))
join = "{0} ON {0}.{1} = {2}.{3}".format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col)
)
alias = generate_alias(self.case(left.tbl))
synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format(
alias, c(left.col), rtbl.ref, c(right.col))]
synonyms = [
join,
"{0} ON {0}.{1} = {2}.{3}".format(
alias, c(left.col), rtbl.ref, c(right.col)
),
]
# Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public
if not suggestion.schema and (qualified[normalize_ref(rtbl.ref)]
if not suggestion.schema and (
qualified[normalize_ref(rtbl.ref)]
and left.schema == right.schema
or left.schema not in(right.schema, 'public')):
join = left.schema + '.' + join
or left.schema not in (right.schema, "public")
):
join = left.schema + "." + join
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
0 if (left.schema, left.tbl) in other_tbls else 1)
joins.append(Candidate(join, prio, 'join', synonyms=synonyms))
0 if (left.schema, left.tbl) in other_tbls else 1
)
joins.append(Candidate(join, prio, "join", synonyms=synonyms))
return self.find_matches(word_before_cursor, joins, meta='join')
return self.find_matches(word_before_cursor, joins, meta="join")
def get_join_condition_matches(self, suggestion, word_before_cursor):
col = namedtuple('col', 'schema tbl col')
col = namedtuple("col", "schema tbl col")
tbls = self.populate_scoped_cols(suggestion.table_refs).items
cols = [(t, c) for t, cs in tbls() for c in cs]
try:
lref = (suggestion.parent or suggestion.table_refs[-1]).ref
ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
except IndexError: # The user typed an incorrect table qualifier
except IndexError: # The user typed an incorrect table qualifier
return []
conds, found_conds = [], set()
def add_cond(lcol, rcol, rref, prio, meta):
prefix = '' if suggestion.parent else ltbl.ref + '.'
prefix = "" if suggestion.parent else ltbl.ref + "."
case = self.case
cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol)
cond = prefix + case(lcol) + " = " + rref + "." + case(rcol)
if cond not in found_conds:
found_conds.add(cond)
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
d = defaultdict(list)
for pair in pairs:
d[pair[0]].append(pair[1])
return d
# Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num) for num, tbl
in enumerate(suggestion.table_refs))
ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
# Map (schema, table, col) to tables
coldict = list_dict(((t.schema, t.name, c.name), t)
for t, c in cols if t.ref != lref)
coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
)
# For each fk from the left table, generate a join condition if
# the other table is also in the scope
fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
@ -623,81 +685,80 @@ class PGCompleter(Completer):
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
left, right = (child, par) if left == child else (par, child)
for rtbl in coldict[right]:
add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join')
add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
# For name matching, use a {(colname, coltype): TableReference} dict
coltyp = namedtuple('coltyp', 'name datatype')
coltyp = namedtuple("coltyp", "name datatype")
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
# Find all name-match join conditions
for c in (coltyp(c.name, c.datatype) for c in lcols):
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
prio = 1000 if c.datatype in (
'integer', 'bigint', 'smallint') else 0
add_cond(c.name, c.name, rtbl.ref, prio, 'name join')
prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0
add_cond(c.name, c.name, rtbl.ref, prio, "name join")
return self.find_matches(word_before_cursor, conds, meta='join')
return self.find_matches(word_before_cursor, conds, meta="join")
def get_function_matches(self, suggestion, word_before_cursor, alias=False):
if suggestion.usage == 'from':
if suggestion.usage == "from":
# Only suggest functions allowed in FROM clause
def filt(f):
return (not f.is_aggregate and
not f.is_window and
not f.is_extension and
(f.is_public or f.schema_name == suggestion.schema))
return (
not f.is_aggregate
and not f.is_window
and not f.is_extension
and (f.is_public or f.schema_name == suggestion.schema)
)
else:
alias = False
def filt(f):
return (not f.is_extension and
(f.is_public or f.schema_name == suggestion.schema))
return not f.is_extension and (
f.is_public or f.schema_name == suggestion.schema
)
arg_mode = {
'signature': 'signature',
'special': None,
}.get(suggestion.usage, 'call')
arg_mode = {"signature": "signature", "special": None}.get(
suggestion.usage, "call"
)
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt)
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode)
for f in all_functions
self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
)
matches = self.find_matches(word_before_cursor, funcs, meta='function')
matches = self.find_matches(word_before_cursor, funcs, meta="function")
if not suggestion.schema and not suggestion.usage:
# also suggest hardcoded functions using startswith matching
predefined_funcs = self.find_matches(
word_before_cursor, self.functions, mode='strict',
meta='function')
word_before_cursor, self.functions, mode="strict", meta="function"
)
matches.extend(predefined_funcs)
return matches
def get_schema_matches(self, suggestion, word_before_cursor):
schema_names = self.dbmetadata['tables'].keys()
schema_names = self.dbmetadata["tables"].keys()
# Unless we're sure the user really wants them, hide schema names
# starting with pg_, which are mostly temporary schemas
if not word_before_cursor.startswith('pg_'):
schema_names = [s
for s in schema_names
if not s.startswith('pg_')]
if not word_before_cursor.startswith("pg_"):
schema_names = [s for s in schema_names if not s.startswith("pg_")]
if suggestion.quoted:
schema_names = [self.escape_schema(s) for s in schema_names]
return self.find_matches(word_before_cursor, schema_names, meta='schema')
return self.find_matches(word_before_cursor, schema_names, meta="schema")
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
alias = self.generate_aliases
s = suggestion
t_sug = Table(s.schema, s.table_refs, s.local_tables)
v_sug = View(s.schema, s.table_refs)
f_sug = Function(s.schema, s.table_refs, usage='from')
f_sug = Function(s.schema, s.table_refs, usage="from")
return (
self.get_table_matches(t_sug, word_before_cursor, alias)
+ self.get_view_matches(v_sug, word_before_cursor, alias)
@ -712,43 +773,43 @@ class PGCompleter(Completer):
"""
template = {
'call': self.call_arg_style,
'call_display': self.call_arg_display_style,
'signature': self.signature_arg_style
"call": self.call_arg_style,
"call_display": self.call_arg_display_style,
"signature": self.signature_arg_style,
}[usage]
args = func.args()
if not template:
return '()'
elif usage == 'call' and len(args) < 2:
return '()'
elif usage == 'call' and func.has_variadic():
return '()'
multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max
return "()"
elif usage == "call" and len(args) < 2:
return "()"
elif usage == "call" and func.has_variadic():
return "()"
multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
max_arg_len = max(len(a.name) for a in args) if multiline else 0
args = (
self._format_arg(template, arg, arg_num + 1, max_arg_len)
for arg_num, arg in enumerate(args)
)
if multiline:
return '(' + ','.join('\n ' + a for a in args if a) + '\n)'
return "(" + ",".join("\n " + a for a in args if a) + "\n)"
else:
return '(' + ', '.join(a for a in args if a) + ')'
return "(" + ", ".join(a for a in args if a) + ")"
def _format_arg(self, template, arg, arg_num, max_arg_len):
if not template:
return None
if arg.has_default:
arg_default = 'NULL' if arg.default is None else arg.default
arg_default = "NULL" if arg.default is None else arg.default
# Remove trailing ::(schema.)type
arg_default = arg_default_type_strip_regex.sub('', arg_default)
arg_default = arg_default_type_strip_regex.sub("", arg_default)
else:
arg_default = ''
arg_default = ""
return template.format(
max_arg_len=max_arg_len,
arg_name=self.case(arg.name),
arg_num=arg_num,
arg_type=arg.datatype,
arg_default=arg_default
arg_default=arg_default,
)
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
@ -763,53 +824,49 @@ class PGCompleter(Completer):
if do_alias:
alias = self.alias(cased_tbl, suggestion.table_refs)
synonyms = (cased_tbl, generate_alias(cased_tbl))
maybe_alias = (' ' + alias) if do_alias else ''
maybe_schema = (self.case(tbl.schema) + '.') if tbl.schema else ''
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ''
if arg_mode == 'call':
display_suffix = self._arg_list_cache['call_display'][tbl.meta]
elif arg_mode == 'signature':
display_suffix = self._arg_list_cache['signature'][tbl.meta]
maybe_alias = (" " + alias) if do_alias else ""
maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else ""
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
if arg_mode == "call":
display_suffix = self._arg_list_cache["call_display"][tbl.meta]
elif arg_mode == "signature":
display_suffix = self._arg_list_cache["signature"][tbl.meta]
else:
display_suffix = ''
display_suffix = ""
item = maybe_schema + cased_tbl + suffix + maybe_alias
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
prio2 = 0 if tbl.schema else 1
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
tables = self.populate_schema_objects(suggestion.schema, "tables")
tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)
# Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.name.startswith('pg_')]
if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
tables = [t for t in tables if not t.name.startswith("pg_")]
tables = [self._make_cand(t, alias, suggestion) for t in tables]
return self.find_matches(word_before_cursor, tables, meta='table')
return self.find_matches(word_before_cursor, tables, meta="table")
def get_table_formats(self, _, word_before_cursor):
formats = TabularOutputFormatter().supported_formats
return self.find_matches(word_before_cursor, formats, meta='table format')
return self.find_matches(word_before_cursor, formats, meta="table format")
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
views = self.populate_schema_objects(suggestion.schema, 'views')
views = self.populate_schema_objects(suggestion.schema, "views")
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.name.startswith('pg_')]
if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
views = [v for v in views if not v.name.startswith("pg_")]
views = [self._make_cand(v, alias, suggestion) for v in views]
return self.find_matches(word_before_cursor, views, meta='view')
return self.find_matches(word_before_cursor, views, meta="view")
def get_alias_matches(self, suggestion, word_before_cursor):
aliases = suggestion.aliases
return self.find_matches(word_before_cursor, aliases,
meta='table alias')
return self.find_matches(word_before_cursor, aliases, meta="table alias")
def get_database_matches(self, _, word_before_cursor):
return self.find_matches(word_before_cursor, self.databases,
meta='database')
return self.find_matches(word_before_cursor, self.databases, meta="database")
def get_keyword_matches(self, suggestion, word_before_cursor):
keywords = self.keywords_tree.keys()
@ -820,24 +877,26 @@ class PGCompleter(Completer):
keywords = next_keywords
casing = self.keyword_casing
if casing == 'auto':
if casing == "auto":
if word_before_cursor and word_before_cursor[-1].islower():
casing = 'lower'
casing = "lower"
else:
casing = 'upper'
casing = "upper"
if casing == 'upper':
if casing == "upper":
keywords = [k.upper() for k in keywords]
else:
keywords = [k.lower() for k in keywords]
return self.find_matches(word_before_cursor, keywords,
mode='strict', meta='keyword')
return self.find_matches(
word_before_cursor, keywords, mode="strict", meta="keyword"
)
def get_path_matches(self, _, word_before_cursor):
completer = PathCompleter(expanduser=True)
document = Document(text=word_before_cursor,
cursor_position=len(word_before_cursor))
document = Document(
text=word_before_cursor, cursor_position=len(word_before_cursor)
)
for c in completer.get_completions(document, None):
yield Match(completion=c, priority=(0,))
@ -848,24 +907,28 @@ class PGCompleter(Completer):
commands = self.pgspecial.commands
cmds = commands.keys()
cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds]
return self.find_matches(word_before_cursor, cmds, mode='strict')
return self.find_matches(word_before_cursor, cmds, mode="strict")
def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
types = self.populate_schema_objects(suggestion.schema, "datatypes")
types = [self._make_cand(t, False, suggestion) for t in types]
matches = self.find_matches(word_before_cursor, types, meta='datatype')
matches = self.find_matches(word_before_cursor, types, meta="datatype")
if not suggestion.schema:
# Also suggest hardcoded types
matches.extend(self.find_matches(word_before_cursor, self.datatypes,
mode='strict', meta='datatype'))
matches.extend(
self.find_matches(
word_before_cursor, self.datatypes, mode="strict", meta="datatype"
)
)
return matches
def get_namedquery_matches(self, _, word_before_cursor):
return self.find_matches(
word_before_cursor, NamedQueries.instance.list(), meta='named query')
word_before_cursor, NamedQueries.instance.list(), meta="named query"
)
suggestion_matchers = {
FromClauseItem: get_from_clause_item_matches,
@ -899,7 +962,7 @@ class PGCompleter(Completer):
meta = self.dbmetadata
def addcols(schema, rel, alias, reltype, cols):
tbl = TableReference(schema, rel, alias, reltype == 'functions')
tbl = TableReference(schema, rel, alias, reltype == "functions")
if tbl not in columns:
columns[tbl] = []
columns[tbl].extend(cols)
@ -908,22 +971,22 @@ class PGCompleter(Completer):
# Local tables should shadow database tables
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
cols = ctes[normalize_ref(tbl.name)]
addcols(None, tbl.name, 'CTE', tbl.alias, cols)
addcols(None, tbl.name, "CTE", tbl.alias, cols)
continue
schemas = [tbl.schema] if tbl.schema else self.search_path
for schema in schemas:
relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function:
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta['functions'].get(schema, {}).get(relname)
for func in (functions or []):
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta["functions"].get(schema, {}).get(relname)
for func in functions or []:
# func is a FunctionMetadata object
cols = func.fields()
addcols(schema, relname, tbl.alias, 'functions', cols)
addcols(schema, relname, tbl.alias, "functions", cols)
else:
for reltype in ('tables', 'views'):
for reltype in ("tables", "views"):
cols = meta[reltype].get(schema, {}).get(relname)
if cols:
cols = cols.values()
@ -956,8 +1019,7 @@ class PGCompleter(Completer):
return [
SchemaObject(
name=obj,
schema=(self._maybe_schema(schema=sch, parent=schema))
name=obj, schema=(self._maybe_schema(schema=sch, parent=schema))
)
for sch in self._get_schemas(obj_type, schema)
for obj in self.dbmetadata[obj_type][sch].keys()
@ -979,10 +1041,10 @@ class PGCompleter(Completer):
SchemaObject(
name=func,
schema=(self._maybe_schema(schema=sch, parent=schema)),
meta=meta
meta=meta,
)
for sch in self._get_schemas('functions', schema)
for (func, metas) in self.dbmetadata['functions'][sch].items()
for sch in self._get_schemas("functions", schema)
for (func, metas) in self.dbmetadata["functions"][sch].items()
for meta in metas
if filter_func(meta)
]

View File

@ -24,7 +24,7 @@ ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
# Cast bytea fields to text. By default, this will render as hex strings with
# Postgres 9+ and as escaped binary in earlier versions.
ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING))
ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
@ -55,6 +55,7 @@ def _wait_select(conn):
if errno != 4:
raise
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
@ -66,17 +67,19 @@ def register_date_typecasters(connection):
Casts date and timestamp values to string, resolves issues with out of
range dates (e.g. BC) which psycopg2 can't handle
"""
def cast_date(value, cursor):
return value
cursor = connection.cursor()
cursor.execute('SELECT NULL::date')
cursor.execute("SELECT NULL::date")
date_oid = cursor.description[0][1]
cursor.execute('SELECT NULL::timestamp')
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
cursor.execute('SELECT NULL::timestamp with time zone')
cursor.execute("SELECT NULL::timestamp with time zone")
timestamptz_oid = cursor.description[0][1]
oids = (date_oid, timestamp_oid, timestamptz_oid)
new_type = psycopg2.extensions.new_type(oids, 'DATE', cast_date)
new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date)
psycopg2.extensions.register_type(new_type)
@ -97,7 +100,7 @@ def register_json_typecasters(conn, loads_fn):
"""
available = set()
for name in ['json', 'jsonb']:
for name in ["json", "jsonb"]:
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
@ -117,7 +120,8 @@ def register_hstore_typecaster(conn):
with conn.cursor() as cur:
try:
cur.execute(
"select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined")
"select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined"
)
oid = cur.fetchone()[0]
ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
except Exception:
@ -128,29 +132,29 @@ class PGExecute(object):
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
search_path_query = '''
SELECT * FROM unnest(current_schemas(true))'''
search_path_query = """
SELECT * FROM unnest(current_schemas(true))"""
schemata_query = '''
schemata_query = """
SELECT nspname
FROM pg_catalog.pg_namespace
ORDER BY 1 '''
ORDER BY 1 """
tables_query = '''
tables_query = """
SELECT n.nspname schema_name,
c.relname table_name
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
WHERE c.relkind = ANY(%s)
ORDER BY 1,2;'''
ORDER BY 1,2;"""
databases_query = '''
databases_query = """
SELECT d.datname
FROM pg_catalog.pg_database d
ORDER BY 1'''
ORDER BY 1"""
full_databases_query = '''
full_databases_query = """
SELECT d.datname as "Name",
pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding",
@ -158,15 +162,15 @@ class PGExecute(object):
d.datctype as "Ctype",
pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges"
FROM pg_catalog.pg_database d
ORDER BY 1'''
ORDER BY 1"""
socket_directory_query = '''
socket_directory_query = """
SELECT setting
FROM pg_settings
WHERE name = 'unix_socket_directories'
'''
"""
view_definition_query = '''
view_definition_query = """
WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid)
SELECT nspname, relname, relkind,
pg_catalog.pg_get_viewdef(c.oid, true),
@ -179,18 +183,26 @@ class PGExecute(object):
END AS checkoption
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid)
JOIN v ON (c.oid = v.v_oid)'''
JOIN v ON (c.oid = v.v_oid)"""
function_definition_query = '''
function_definition_query = """
WITH f AS
(SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid)
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f'''
FROM f"""
version_query = "SELECT version();"
def __init__(self, database=None, user=None, password=None, host=None,
port=None, dsn=None, **kwargs):
def __init__(
self,
database=None,
user=None,
password=None,
host=None,
port=None,
dsn=None,
**kwargs
):
self._conn_params = {}
self.conn = None
self.dbname = None
@ -207,10 +219,10 @@ class PGExecute(object):
return self.__class__(**self._conn_params)
def get_server_version(self, cursor):
_logger.debug('Version Query. sql: %r', self.version_query)
_logger.debug("Version Query. sql: %r", self.version_query)
cursor.execute(self.version_query)
result = cursor.fetchone()
server_version = ''
server_version = ""
if result:
# full version string looks like this:
# PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa
@ -219,38 +231,42 @@ class PGExecute(object):
server_version = version_parts[1]
return server_version
def connect(self, database=None, user=None, password=None, host=None,
port=None, dsn=None, **kwargs):
def connect(
self,
database=None,
user=None,
password=None,
host=None,
port=None,
dsn=None,
**kwargs
):
conn_params = self._conn_params.copy()
new_params = {
'database': database,
'user': user,
'password': password,
'host': host,
'port': port,
'dsn': dsn,
"database": database,
"user": user,
"password": password,
"host": host,
"port": port,
"dsn": dsn,
}
new_params.update(kwargs)
if new_params['dsn']:
new_params = {
'dsn': new_params['dsn'],
'password': new_params['password']
}
if new_params["dsn"]:
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
if new_params['password']:
new_params['dsn'] = make_dsn(
new_params['dsn'], password=new_params.pop('password'))
if new_params["password"]:
new_params["dsn"] = make_dsn(
new_params["dsn"], password=new_params.pop("password")
)
conn_params.update({
k: unicode2utf8(v) for k, v in new_params.items() if v
})
conn_params.update({k: unicode2utf8(v) for k, v in new_params.items() if v})
conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()
conn.set_client_encoding('utf8')
conn.set_client_encoding("utf8")
self._conn_params = conn_params
if self.conn:
@ -264,11 +280,11 @@ class PGExecute(object):
# TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa
dsn_parameters = conn.get_dsn_parameters()
self.dbname = dsn_parameters.get('dbname')
self.user = dsn_parameters.get('user')
self.dbname = dsn_parameters.get("dbname")
self.user = dsn_parameters.get("user")
self.password = password
self.host = dsn_parameters.get('host')
self.port = dsn_parameters.get('port')
self.host = dsn_parameters.get("host")
self.port = dsn_parameters.get("port")
self.extra_args = kwargs
if not self.host:
@ -277,9 +293,9 @@ class PGExecute(object):
cursor.execute("SHOW ALL")
db_parameters = dict(name_val_desc[:2] for name_val_desc in cursor.fetchall())
pid = self._select_one(cursor, 'select pg_backend_pid()')[0]
pid = self._select_one(cursor, "select pg_backend_pid()")[0]
self.pid = pid
self.superuser = db_parameters.get('is_superuser') == '1'
self.superuser = db_parameters.get("is_superuser") == "1"
self.server_version = self.get_server_version(cursor)
@ -289,11 +305,11 @@ class PGExecute(object):
@property
def short_host(self):
if ',' in self.host:
host, _, _ = self.host.partition(',')
if "," in self.host:
host, _, _ = self.host.partition(",")
else:
host = self.host
short_host, _, _ = host.partition('.')
short_host, _, _ = host.partition(".")
return short_host
def _select_one(self, cur, sql):
@ -326,11 +342,14 @@ class PGExecute(object):
def valid_transaction(self):
status = self.conn.get_transaction_status()
return (status == ext.TRANSACTION_STATUS_ACTIVE or
status == ext.TRANSACTION_STATUS_INTRANS)
return (
status == ext.TRANSACTION_STATUS_ACTIVE
or status == ext.TRANSACTION_STATUS_INTRANS
)
def run(self, statement, pgspecial=None, exception_formatter=None,
on_error_resume=False):
def run(
self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False
):
"""Execute the sql in the database and return the results.
:param statement: A string containing one or more sql statements
@ -355,12 +374,12 @@ class PGExecute(object):
# Split the sql into separate queries and run each one.
for sql in sqlparse.split(statement):
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(';')
sql = sql.rstrip(";")
try:
if pgspecial:
# First try to run each query as special
_logger.debug('Trying a pgspecial command. sql: %r', sql)
_logger.debug("Trying a pgspecial command. sql: %r", sql)
try:
cur = self.conn.cursor()
except psycopg2.InterfaceError:
@ -385,8 +404,7 @@ class PGExecute(object):
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
if (self._must_raise(e)
or not exception_formatter):
if self._must_raise(e) or not exception_formatter:
raise
yield None, None, None, exception_formatter(e), sql, False, False
@ -410,12 +428,12 @@ class PGExecute(object):
def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
_logger.debug('Regular sql statement. sql: %r', split_sql)
_logger.debug("Regular sql statement. sql: %r", split_sql)
cur = self.conn.cursor()
cur.execute(split_sql)
# conn.notices persist between queies, we use pop to clear out the list
title = ''
title = ""
while len(self.conn.notices) > 0:
title = utf8tounicode(self.conn.notices.pop()) + title
@ -425,7 +443,7 @@ class PGExecute(object):
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
else:
_logger.debug('No rows in result.')
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
def search_path(self):
@ -433,33 +451,32 @@ class PGExecute(object):
try:
with self.conn.cursor() as cur:
_logger.debug('Search path query. sql: %r', self.search_path_query)
_logger.debug("Search path query. sql: %r", self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
except psycopg2.ProgrammingError:
fallback = 'SELECT * FROM current_schemas(true)'
fallback = "SELECT * FROM current_schemas(true)"
with self.conn.cursor() as cur:
_logger.debug('Search path query. sql: %r', fallback)
_logger.debug("Search path query. sql: %r", fallback)
cur.execute(fallback)
return cur.fetchone()[0]
def view_definition(self, spec):
"""Returns the SQL defining views described by `spec`"""
template = 'CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}'
template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}"
# 2: relkind, v or m (materialized)
# 4: reloptions, null
# 5: checkoption: local or cascaded
with self.conn.cursor() as cur:
sql = self.view_definition_query
_logger.debug('View Definition Query. sql: %r\nspec: %r',
sql, spec)
_logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec, ))
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
raise RuntimeError('View {} does not exist.'.format(spec))
raise RuntimeError("View {} does not exist.".format(spec))
result = cur.fetchone()
view_type = 'MATERIALIZED' if result[2] == 'm' else ''
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
def function_definition(self, spec):
@ -467,24 +484,23 @@ class PGExecute(object):
with self.conn.cursor() as cur:
sql = self.function_definition_query
_logger.debug('Function Definition Query. sql: %r\nspec: %r',
sql, spec)
_logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec,))
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
raise RuntimeError('Function {} does not exist.'.format(spec))
raise RuntimeError("Function {} does not exist.".format(spec))
def schemata(self):
"""Returns a list of schema names in the database"""
with self.conn.cursor() as cur:
_logger.debug('Schemata Query. sql: %r', self.schemata_query)
_logger.debug("Schemata Query. sql: %r", self.schemata_query)
cur.execute(self.schemata_query)
return [x[0] for x in cur.fetchall()]
def _relations(self, kinds=('r', 'v', 'm')):
def _relations(self, kinds=("r", "v", "m")):
"""Get table or view name metadata
:param kinds: list of postgres relkind filters:
@ -496,14 +512,14 @@ class PGExecute(object):
with self.conn.cursor() as cur:
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug('Tables Query. sql: %r', sql)
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
def tables(self):
"""Yields (schema_name, table_name) tuples"""
for row in self._relations(kinds=['r']):
for row in self._relations(kinds=["r"]):
yield row
def views(self):
@ -511,10 +527,10 @@ class PGExecute(object):
Includes both views and and materialized views
"""
for row in self._relations(kinds=['v', 'm']):
for row in self._relations(kinds=["v", "m"]):
yield row
def _columns(self, kinds=('r', 'v', 'm')):
def _columns(self, kinds=("r", "v", "m")):
"""Get column metadata for tables and views
:param kinds: kinds: list of postgres relkind filters:
@ -525,7 +541,7 @@ class PGExecute(object):
"""
if self.conn.server_version >= 80400:
columns_query = '''
columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
@ -543,9 +559,9 @@ class PGExecute(object):
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, att.attnum'''
ORDER BY 1, 2, att.attnum"""
else:
columns_query = '''
columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
@ -562,44 +578,44 @@ class PGExecute(object):
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, att.attnum'''
ORDER BY 1, 2, att.attnum"""
with self.conn.cursor() as cur:
sql = cur.mogrify(columns_query, [kinds])
_logger.debug('Columns Query. sql: %r', sql)
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
def table_columns(self):
for row in self._columns(kinds=['r']):
for row in self._columns(kinds=["r"]):
yield row
def view_columns(self):
for row in self._columns(kinds=['v', 'm']):
for row in self._columns(kinds=["v", "m"]):
yield row
def databases(self):
with self.conn.cursor() as cur:
_logger.debug('Databases Query. sql: %r', self.databases_query)
_logger.debug("Databases Query. sql: %r", self.databases_query)
cur.execute(self.databases_query)
return [x[0] for x in cur.fetchall()]
def full_databases(self):
with self.conn.cursor() as cur:
_logger.debug('Databases Query. sql: %r',
self.full_databases_query)
_logger.debug("Databases Query. sql: %r", self.full_databases_query)
cur.execute(self.full_databases_query)
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug('Socket directory Query. sql: %r',
self.socket_directory_query)
_logger.debug(
"Socket directory Query. sql: %r", self.socket_directory_query
)
cur.execute(self.socket_directory_query)
result = cur.fetchone()
return result[0] if result else ''
return result[0] if result else ""
def foreignkeys(self):
"""Yields ForeignKey named tuples"""
@ -608,7 +624,7 @@ class PGExecute(object):
return
with self.conn.cursor() as cur:
query = '''
query = """
SELECT s_p.nspname AS parentschema,
t_p.relname AS parenttable,
unnest((
@ -635,8 +651,8 @@ class PGExecute(object):
JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
WHERE fk.contype = 'f';
'''
_logger.debug('Functions Query. sql: %r', query)
"""
_logger.debug("Functions Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield ForeignKey(*row)
@ -645,7 +661,7 @@ class PGExecute(object):
"""Yields FunctionMetadata named tuples"""
if self.conn.server_version >= 110000:
query = '''
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@ -663,9 +679,9 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
"""
elif self.conn.server_version > 90000:
query = '''
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@ -683,9 +699,9 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
"""
elif self.conn.server_version >= 80400:
query = '''
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@ -703,9 +719,9 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
"""
else:
query = '''
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
@ -723,20 +739,20 @@ class PGExecute(object):
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
"""
with self.conn.cursor() as cur:
_logger.debug('Functions Query. sql: %r', query)
_logger.debug("Functions Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield FunctionMetadata(*row)
yield FunctionMetadata(*row)
def datatypes(self):
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
if self.conn.server_version > 90000:
query = '''
query = """
SELECT n.nspname schema_name,
t.typname type_name
FROM pg_catalog.pg_type t
@ -757,9 +773,9 @@ class PGExecute(object):
AND n.nspname <> 'pg_catalog'
AND n.nspname <> 'information_schema'
ORDER BY 1, 2;
'''
"""
else:
query = '''
query = """
SELECT n.nspname schema_name,
pg_catalog.format_type(t.oid, NULL) type_name
FROM pg_catalog.pg_type t
@ -770,8 +786,8 @@ class PGExecute(object):
AND n.nspname <> 'information_schema'
AND pg_catalog.pg_type_is_visible(t.oid)
ORDER BY 1, 2;
'''
_logger.debug('Datatypes Query. sql: %r', query)
"""
_logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row
@ -779,7 +795,7 @@ class PGExecute(object):
def casing(self):
"""Yields the most common casing for names used in db functions"""
with self.conn.cursor() as cur:
query = r'''
query = r"""
WITH Words AS (
SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1)
FROM pg_catalog.pg_proc P
@ -819,8 +835,8 @@ class PGExecute(object):
FROM OrderWords
WHERE LOWER(Word) IN (SELECT Name FROM Names)
AND Row_Number = 1;
'''
_logger.debug('Casing Query. sql: %r', query)
"""
_logger.debug("Casing Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row[0]

View File

@ -13,35 +13,33 @@ logger = logging.getLogger(__name__)
# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
TOKEN_TO_PROMPT_STYLE = {
Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current',
Token.Menu.Completions.Completion: 'completion-menu.completion',
Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current',
Token.Menu.Completions.Meta: 'completion-menu.meta.completion',
Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta',
Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess
Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess
Token.SelectedText: 'selected',
Token.SearchMatch: 'search',
Token.SearchMatch.Current: 'search.current',
Token.Toolbar: 'bottom-toolbar',
Token.Toolbar.Off: 'bottom-toolbar.off',
Token.Toolbar.On: 'bottom-toolbar.on',
Token.Toolbar.Search: 'search-toolbar',
Token.Toolbar.Search.Text: 'search-toolbar.text',
Token.Toolbar.System: 'system-toolbar',
Token.Toolbar.Arg: 'arg-toolbar',
Token.Toolbar.Arg.Text: 'arg-toolbar.text',
Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid',
Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed',
Token.Output.Header: 'output.header',
Token.Output.OddRow: 'output.odd-row',
Token.Output.EvenRow: 'output.even-row',
Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
Token.Menu.Completions.Completion: "completion-menu.completion",
Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
Token.Menu.Completions.Meta: "completion-menu.meta.completion",
Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
Token.SelectedText: "selected",
Token.SearchMatch: "search",
Token.SearchMatch.Current: "search.current",
Token.Toolbar: "bottom-toolbar",
Token.Toolbar.Off: "bottom-toolbar.off",
Token.Toolbar.On: "bottom-toolbar.on",
Token.Toolbar.Search: "search-toolbar",
Token.Toolbar.Search.Text: "search-toolbar.text",
Token.Toolbar.System: "system-toolbar",
Token.Toolbar.Arg: "arg-toolbar",
Token.Toolbar.Arg.Text: "arg-toolbar.text",
Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
Token.Output.Header: "output.header",
Token.Output.OddRow: "output.odd-row",
Token.Output.EvenRow: "output.even-row",
}
# reverse dict for cli_helpers, because they still expect Pygments tokens.
PROMPT_STYLE_TO_TOKEN = {
v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()
}
PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
def parse_pygments_style(token_name, style_object, style_dict):
@ -64,52 +62,48 @@ def style_factory(name, cli_style):
try:
style = pygments.styles.get_style_by_name(name)
except ClassNotFound:
style = pygments.styles.get_style_by_name('native')
style = pygments.styles.get_style_by_name("native")
prompt_styles = []
# prompt-toolkit used pygments tokens for styling before, switched to style
# names in 2.0. Convert old token types to new style names, for backwards compatibility.
for token in cli_style:
if token.startswith('Token.'):
if token.startswith("Token."):
# treat as pygments token (1.0)
token_type, style_value = parse_pygments_style(
token, style, cli_style)
token_type, style_value = parse_pygments_style(token, style, cli_style)
if token_type in TOKEN_TO_PROMPT_STYLE:
prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
prompt_styles.append((prompt_style, style_value))
else:
# we don't want to support tokens anymore
logger.error('Unhandled style / class name: %s', token)
logger.error("Unhandled style / class name: %s", token)
else:
# treat as prompt style name (2.0). See default style names here:
# https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
prompt_styles.append((token, cli_style[token]))
override_style = Style([('bottom-toolbar', 'noreverse')])
return merge_styles([
style_from_pygments_cls(style),
override_style,
Style(prompt_styles)
])
override_style = Style([("bottom-toolbar", "noreverse")])
return merge_styles(
[style_from_pygments_cls(style), override_style, Style(prompt_styles)]
)
def style_factory_output(name, cli_style):
try:
style = pygments.styles.get_style_by_name(name).styles
except ClassNotFound:
style = pygments.styles.get_style_by_name('native').styles
style = pygments.styles.get_style_by_name("native").styles
for token in cli_style:
if token.startswith('Token.'):
token_type, style_value = parse_pygments_style(
token, style, cli_style)
if token.startswith("Token."):
token_type, style_value = parse_pygments_style(token, style, cli_style)
style.update({token_type: style_value})
elif token in PROMPT_STYLE_TO_TOKEN:
token_type = PROMPT_STYLE_TO_TOKEN[token]
style.update({token_type: cli_style[token]})
else:
# TODO: cli helpers will have to switch to ptk.Style
logger.error('Unhandled style / class name: %s', token)
logger.error("Unhandled style / class name: %s", token)
class OutputStyle(PygmentsStyle):
default_style = ""

View File

@ -6,57 +6,58 @@ from prompt_toolkit.application import get_app
def _get_vi_mode():
return {
InputMode.INSERT: 'I',
InputMode.NAVIGATION: 'N',
InputMode.REPLACE: 'R',
InputMode.INSERT_MULTIPLE: 'M',
InputMode.INSERT: "I",
InputMode.NAVIGATION: "N",
InputMode.REPLACE: "R",
InputMode.INSERT_MULTIPLE: "M",
}[get_app().vi_state.input_mode]
def create_toolbar_tokens_func(pgcli):
"""Return a function that generates the toolbar tokens."""
def get_toolbar_tokens():
result = []
result.append(('class:bottom-toolbar', ' '))
result.append(("class:bottom-toolbar", " "))
if pgcli.completer.smart_completion:
result.append(('class:bottom-toolbar.on',
'[F2] Smart Completion: ON '))
result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON "))
else:
result.append(('class:bottom-toolbar.off',
'[F2] Smart Completion: OFF '))
result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF "))
if pgcli.multi_line:
result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON '))
result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
else:
result.append(('class:bottom-toolbar.off',
'[F3] Multiline: OFF '))
result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
if pgcli.multi_line:
if pgcli.multiline_mode == 'safe':
result.append(
('class:bottom-toolbar', ' ([Esc] [Enter] to execute]) '))
if pgcli.multiline_mode == "safe":
result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) "))
else:
result.append(
('class:bottom-toolbar', ' (Semi-colon [;] will end the line) '))
("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")
)
if pgcli.vi_mode:
result.append(
('class:bottom-toolbar', '[F4] Vi-mode (' + _get_vi_mode() + ')'))
("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ")")
)
else:
result.append(('class:bottom-toolbar', '[F4] Emacs-mode'))
result.append(("class:bottom-toolbar", "[F4] Emacs-mode"))
if pgcli.pgexecute.failed_transaction():
result.append(('class:bottom-toolbar.transaction.failed',
' Failed transaction'))
result.append(
("class:bottom-toolbar.transaction.failed", " Failed transaction")
)
if pgcli.pgexecute.valid_transaction():
result.append(
('class:bottom-toolbar.transaction.valid', ' Transaction'))
("class:bottom-toolbar.transaction.valid", " Transaction")
)
if pgcli.completion_refresher.is_refreshing():
result.append(
('class:bottom-toolbar', ' Refreshing completions...'))
result.append(("class:bottom-toolbar", " Refreshing completions..."))
return result
return get_toolbar_tokens

22
pyproject.toml Normal file
View File

@ -0,0 +1,22 @@
[tool.black]
line-length = 88
target-version = ['py27']
include = '\.pyi?$'
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| \.cache
| \.pytest_cache
| _build
| buck-out
| build
| dist
| tests/data
)/
'''

View File

@ -23,7 +23,7 @@ def skip_step():
global CONFIRM_STEPS
if CONFIRM_STEPS:
return not click.confirm('--- Run this step?', default=True)
return not click.confirm("--- Run this step?", default=True)
return False
@ -36,90 +36,100 @@ def run_step(*args):
global DRY_RUN
cmd = args
print(' '.join(cmd))
print (" ".join(cmd))
if skip_step():
print('--- Skipping...')
print ("--- Skipping...")
elif DRY_RUN:
print('--- Pretending to run...')
print ("--- Pretending to run...")
else:
subprocess.check_output(cmd)
def version(version_file):
_version_re = re.compile(
r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)')
r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)'
)
with io.open(version_file, encoding='utf-8') as f:
ver = _version_re.search(f.read()).group('version')
with io.open(version_file, encoding="utf-8") as f:
ver = _version_re.search(f.read()).group("version")
return ver
def commit_for_release(version_file, ver):
run_step('git', 'reset')
run_step('git', 'add', version_file)
run_step('git', 'commit', '--message',
'Releasing version {}'.format(ver))
run_step("git", "reset")
run_step("git", "add", version_file)
run_step("git", "commit", "--message", "Releasing version {}".format(ver))
def create_git_tag(tag_name):
run_step('git', 'tag', tag_name)
run_step("git", "tag", tag_name)
def create_distribution_files():
run_step('python', 'setup.py', 'clean', '--all', 'sdist', 'bdist_wheel')
run_step("python", "setup.py", "clean", "--all", "sdist", "bdist_wheel")
def upload_distribution_files():
run_step('twine', 'upload', 'dist/*')
run_step("twine", "upload", "dist/*")
def push_to_github():
run_step('git', 'push', 'origin', 'master')
run_step("git", "push", "origin", "master")
def push_tags_to_github():
run_step('git', 'push', '--tags', 'origin')
run_step("git", "push", "--tags", "origin")
def checklist(questions):
for question in questions:
if not click.confirm('--- {}'.format(question), default=False):
if not click.confirm("--- {}".format(question), default=False):
sys.exit(1)
if __name__ == '__main__':
if __name__ == "__main__":
if DEBUG:
subprocess.check_output = lambda x: x
checks = ['Have you updated the AUTHORS file?',
'Have you updated the `Usage` section of the README?',
]
checks = [
"Have you updated the AUTHORS file?",
"Have you updated the `Usage` section of the README?",
]
checklist(checks)
ver = version('pgcli/__init__.py')
print('Releasing Version:', ver)
ver = version("pgcli/__init__.py")
print ("Releasing Version:", ver)
parser = OptionParser()
parser.add_option(
"-c", "--confirm-steps", action="store_true", dest="confirm_steps",
default=False, help=("Confirm every step. If the step is not "
"confirmed, it will be skipped.")
"-c",
"--confirm-steps",
action="store_true",
dest="confirm_steps",
default=False,
help=(
"Confirm every step. If the step is not " "confirmed, it will be skipped."
),
)
parser.add_option(
"-d", "--dry-run", action="store_true", dest="dry_run",
default=False, help="Print out, but not actually run any steps."
"-d",
"--dry-run",
action="store_true",
dest="dry_run",
default=False,
help="Print out, but not actually run any steps.",
)
popts, pargs = parser.parse_args()
CONFIRM_STEPS = popts.confirm_steps
DRY_RUN = popts.dry_run
if not click.confirm('Are you sure?', default=False):
if not click.confirm("Are you sure?", default=False):
sys.exit(1)
commit_for_release('pgcli/__init__.py', ver)
create_git_tag('v{}'.format(ver))
commit_for_release("pgcli/__init__.py", ver)
create_git_tag("v{}".format(ver))
create_distribution_files()
push_to_github()
push_tags_to_github()

View File

@ -4,11 +4,10 @@ mock>=1.0.1
tox>=1.9.2
behave>=1.2.4
pexpect==3.3
pre-commit>=1.16.0
coverage==4.3.4
codecov>=1.5.1
docutils>=0.13.1
autopep8==1.3.3
# we want the latest possible version of pep8radius
git+https://github.com/hayd/pep8radius.git
click==6.7
twine==1.11.0

View File

@ -3,24 +3,25 @@ import ast
import platform
from setuptools import setup, find_packages
_version_re = re.compile(r'__version__\s+=\s+(.*)')
_version_re = re.compile(r"__version__\s+=\s+(.*)")
with open('pgcli/__init__.py', 'rb') as f:
version = str(ast.literal_eval(_version_re.search(
f.read().decode('utf-8')).group(1)))
with open("pgcli/__init__.py", "rb") as f:
version = str(
ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))
)
description = 'CLI for Postgres Database. With auto-completion and syntax highlighting.'
description = "CLI for Postgres Database. With auto-completion and syntax highlighting."
install_requirements = [
'pgspecial>=1.11.5',
'click >= 4.1',
'Pygments >= 2.0', # Pygments has to be Capitalcased. WTF?
'prompt_toolkit>=2.0.6,<2.1.0',
'psycopg2 >= 2.7.4,<2.8',
'sqlparse >=0.3.0,<0.4',
'configobj >= 5.0.6',
'humanize >= 0.5.1',
'cli_helpers[styles] >= 1.2.0',
"pgspecial>=1.11.5",
"click >= 4.1",
"Pygments >= 2.0", # Pygments has to be Capitalcased. WTF?
"prompt_toolkit>=2.0.6,<2.1.0",
"psycopg2 >= 2.7.4,<2.8",
"sqlparse >=0.3.0,<0.4",
"configobj >= 5.0.6",
"humanize >= 0.5.1",
"cli_helpers[styles] >= 1.2.0",
]
@ -28,45 +29,42 @@ install_requirements = [
# But this is not necessary in Windows since the password is never shown in the
# task manager. Also setproctitle is a hard dependency to install in Windows,
# so we'll only install it if we're not in Windows.
if platform.system() != 'Windows' and not platform.system().startswith("CYGWIN"):
install_requirements.append('setproctitle >= 1.1.9')
if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"):
install_requirements.append("setproctitle >= 1.1.9")
setup(
name='pgcli',
author='Pgcli Core Team',
author_email='pgcli-dev@googlegroups.com',
name="pgcli",
author="Pgcli Core Team",
author_email="pgcli-dev@googlegroups.com",
version=version,
license='BSD',
url='http://pgcli.com',
license="BSD",
url="http://pgcli.com",
packages=find_packages(),
package_data={'pgcli': ['pgclirc',
'packages/pgliterals/pgliterals.json']},
package_data={"pgcli": ["pgclirc", "packages/pgliterals/pgliterals.json"]},
description=description,
long_description=open('README.rst').read(),
long_description=open("README.rst").read(),
install_requires=install_requirements,
extras_require={
'keyring': ['keyring >= 12.2.0'],
},
entry_points='''
extras_require={"keyring": ["keyring >= 12.2.0"]},
entry_points="""
[console_scripts]
pgcli=pgcli.main:cli
''',
""",
classifiers=[
'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License',
'Operating System :: Unix',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: SQL',
'Topic :: Database',
'Topic :: Database :: Front-Ends',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries :: Python Modules',
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
"Programming Language :: Python",
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: SQL",
"Topic :: Database",
"Topic :: Database :: Front-Ends",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)

View File

@ -2,15 +2,22 @@ from __future__ import print_function
import os
import pytest
from utils import (POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD, create_db, db_connection,
drop_tables)
from utils import (
POSTGRES_HOST,
POSTGRES_PORT,
POSTGRES_USER,
POSTGRES_PASSWORD,
create_db,
db_connection,
drop_tables,
)
import pgcli.pgexecute
@pytest.yield_fixture(scope="function")
def connection():
create_db('_test_db')
connection = db_connection('_test_db')
create_db("_test_db")
connection = db_connection("_test_db")
yield connection
drop_tables(connection)
@ -25,8 +32,14 @@ def cursor(connection):
@pytest.fixture
def executor(connection):
return pgcli.pgexecute.PGExecute(database='_test_db', user=POSTGRES_USER, host=POSTGRES_HOST,
password=POSTGRES_PASSWORD, port=POSTGRES_PORT, dsn=None)
return pgcli.pgexecute.PGExecute(
database="_test_db",
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
dsn=None,
)
@pytest.fixture
@ -38,4 +51,4 @@ def exception_formatter():
def temp_config(tmpdir_factory):
# this function runs on start of test session.
# use temporary directory for config home so user config will not be used
os.environ['XDG_CONFIG_HOME'] = str(tmpdir_factory.mktemp('data'))
os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data"))

View File

@ -6,7 +6,9 @@ from psycopg2 import connect
from psycopg2.extensions import AsIs
def create_db(hostname='localhost', username=None, password=None, dbname=None, port=None):
def create_db(
hostname="localhost", username=None, password=None, dbname=None, port=None
):
"""Create test database.
:param hostname: string
@ -17,15 +19,15 @@ def create_db(hostname='localhost', username=None, password=None, dbname=None, p
:return:
"""
cn = create_cn(hostname, password, username, 'postgres', port)
cn = create_cn(hostname, password, username, "postgres", port)
# ISOLATION_LEVEL_AUTOCOMMIT = 0
# Needed for DB creation.
cn.set_isolation_level(0)
with cn.cursor() as cr:
cr.execute('drop database if exists %s', (AsIs(dbname),))
cr.execute('create database %s', (AsIs(dbname),))
cr.execute("drop database if exists %s", (AsIs(dbname),))
cr.execute("create database %s", (AsIs(dbname),))
cn.close()
@ -42,15 +44,15 @@ def create_cn(hostname, password, username, dbname, port):
:param dbname: string
:return: psycopg2.connection
"""
cn = connect(host=hostname, user=username, database=dbname,
password=password, port=port)
cn = connect(
host=hostname, user=username, database=dbname, password=password, port=port
)
print('Created connection: {0}.'.format(cn.dsn))
print ("Created connection: {0}.".format(cn.dsn))
return cn
def drop_db(hostname='localhost', username=None, password=None,
dbname=None, port=None):
def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None):
"""
Drop database.
:param hostname: string
@ -58,14 +60,14 @@ def drop_db(hostname='localhost', username=None, password=None,
:param password: string
:param dbname: string
"""
cn = create_cn(hostname, password, username, 'postgres', port)
cn = create_cn(hostname, password, username, "postgres", port)
# ISOLATION_LEVEL_AUTOCOMMIT = 0
# Needed for DB drop.
cn.set_isolation_level(0)
with cn.cursor() as cr:
cr.execute('drop database if exists %s', (AsIs(dbname),))
cr.execute("drop database if exists %s", (AsIs(dbname),))
close_cn(cn)
@ -77,4 +79,4 @@ def close_cn(cn=None):
"""
if cn:
cn.close()
print('Closed connection: {0}.'.format(cn.dsn))
print ("Closed connection: {0}.".format(cn.dsn))

View File

@ -18,113 +18,119 @@ from steps import wrappers
def before_all(context):
"""Set env parameters."""
env_old = copy.deepcopy(dict(os.environ))
os.environ['LINES'] = "100"
os.environ['COLUMNS'] = "100"
os.environ['PAGER'] = 'cat'
os.environ['EDITOR'] = 'ex'
os.environ['VISUAL'] = 'ex'
os.environ["LINES"] = "100"
os.environ["COLUMNS"] = "100"
os.environ["PAGER"] = "cat"
os.environ["EDITOR"] = "ex"
os.environ["VISUAL"] = "ex"
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
fixture_dir = os.path.join(
context.package_root, 'tests/features/fixture_data')
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
)
fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data")
print('package root:', context.package_root)
print('fixture dir:', fixture_dir)
print ("package root:", context.package_root)
print ("fixture dir:", fixture_dir)
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
'.coveragerc')
os.environ["COVERAGE_PROCESS_START"] = os.path.join(
context.package_root, ".coveragerc"
)
context.exit_sent = False
vi = '_'.join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get('pg_test_db', 'pgcli_behave_tests')
db_name_full = '{0}_{1}'.format(db_name, vi)
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
db_name_full = "{0}_{1}".format(db_name, vi)
# Store get params from config.
context.conf = {
'host': context.config.userdata.get(
'pg_test_host',
os.getenv('PGHOST', 'localhost')
"host": context.config.userdata.get(
"pg_test_host", os.getenv("PGHOST", "localhost")
),
'user': context.config.userdata.get(
'pg_test_user',
os.getenv('PGUSER', 'postgres')
"user": context.config.userdata.get(
"pg_test_user", os.getenv("PGUSER", "postgres")
),
'pass': context.config.userdata.get(
'pg_test_pass',
os.getenv('PGPASSWORD', None)
"pass": context.config.userdata.get(
"pg_test_pass", os.getenv("PGPASSWORD", None)
),
'port': context.config.userdata.get(
'pg_test_port',
os.getenv('PGPORT', '5432')
"port": context.config.userdata.get(
"pg_test_port", os.getenv("PGPORT", "5432")
),
'cli_command': (
context.config.userdata.get('pg_cli_command', None) or
'{python} -c "{startup}"'.format(
"cli_command": (
context.config.userdata.get("pg_cli_command", None)
or '{python} -c "{startup}"'.format(
python=sys.executable,
startup='; '.join([
"import coverage",
"coverage.process_startup()",
"import pgcli.main",
"pgcli.main.cli()"]))),
'dbname': db_name_full,
'dbname_tmp': db_name_full + '_tmp',
'vi': vi,
'pager_boundary': '---boundary---',
startup="; ".join(
[
"import coverage",
"coverage.process_startup()",
"import pgcli.main",
"pgcli.main.cli()",
]
),
)
),
"dbname": db_name_full,
"dbname_tmp": db_name_full + "_tmp",
"vi": vi,
"pager_boundary": "---boundary---",
}
os.environ['PAGER'] = "{0} {1} {2}".format(
os.environ["PAGER"] = "{0} {1} {2}".format(
sys.executable,
os.path.join(context.package_root, "tests/features/wrappager.py"),
context.conf['pager_boundary'])
context.conf["pager_boundary"],
)
# Store old env vars.
context.pgenv = {
'PGDATABASE': os.environ.get('PGDATABASE', None),
'PGUSER': os.environ.get('PGUSER', None),
'PGHOST': os.environ.get('PGHOST', None),
'PGPASSWORD': os.environ.get('PGPASSWORD', None),
'PGPORT': os.environ.get('PGPORT', None),
'XDG_CONFIG_HOME': os.environ.get('XDG_CONFIG_HOME', None),
'PGSERVICEFILE': os.environ.get('PGSERVICEFILE', None),
"PGDATABASE": os.environ.get("PGDATABASE", None),
"PGUSER": os.environ.get("PGUSER", None),
"PGHOST": os.environ.get("PGHOST", None),
"PGPASSWORD": os.environ.get("PGPASSWORD", None),
"PGPORT": os.environ.get("PGPORT", None),
"XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None),
"PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None),
}
# Set new env vars.
os.environ['PGDATABASE'] = context.conf['dbname']
os.environ['PGUSER'] = context.conf['user']
os.environ['PGHOST'] = context.conf['host']
os.environ['PGPORT'] = context.conf['port']
os.environ['PGSERVICEFILE'] = os.path.join(
fixture_dir, 'mock_pg_service.conf')
os.environ["PGDATABASE"] = context.conf["dbname"]
os.environ["PGUSER"] = context.conf["user"]
os.environ["PGHOST"] = context.conf["host"]
os.environ["PGPORT"] = context.conf["port"]
os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf")
if context.conf['pass']:
os.environ['PGPASSWORD'] = context.conf['pass']
if context.conf["pass"]:
os.environ["PGPASSWORD"] = context.conf["pass"]
else:
if 'PGPASSWORD' in os.environ:
del os.environ['PGPASSWORD']
if "PGPASSWORD" in os.environ:
del os.environ["PGPASSWORD"]
context.cn = dbutils.create_db(context.conf['host'], context.conf['user'],
context.conf['pass'], context.conf['dbname'],
context.conf['port'])
context.cn = dbutils.create_db(
context.conf["host"],
context.conf["user"],
context.conf["pass"],
context.conf["dbname"],
context.conf["port"],
)
context.fixture_data = fixutils.read_fixture_files()
# use temporary directory as config home
context.env_config_home = tempfile.mkdtemp(prefix='pgcli_home_')
os.environ['XDG_CONFIG_HOME'] = context.env_config_home
context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_")
os.environ["XDG_CONFIG_HOME"] = context.env_config_home
show_env_changes(env_old, dict(os.environ))
def show_env_changes(env_old, env_new):
"""Print out all test-specific env values."""
print('--- os.environ changed values: ---')
print ("--- os.environ changed values: ---")
all_keys = set(list(env_old.keys()) + list(env_new.keys()))
for k in sorted(all_keys):
old_value = env_old.get(k, '')
new_value = env_new.get(k, '')
old_value = env_old.get(k, "")
new_value = env_new.get(k, "")
if new_value and old_value != new_value:
print('{}="{}"'.format(k, new_value))
print('-' * 20)
print ('{}="{}"'.format(k, new_value))
print ("-" * 20)
def after_all(context):
@ -132,9 +138,13 @@ def after_all(context):
Unset env parameters.
"""
dbutils.close_cn(context.cn)
dbutils.drop_db(context.conf['host'], context.conf['user'],
context.conf['pass'], context.conf['dbname'],
context.conf['port'])
dbutils.drop_db(
context.conf["host"],
context.conf["user"],
context.conf["pass"],
context.conf["dbname"],
context.conf["port"],
)
# Remove temp config direcotry
shutil.rmtree(context.env_config_home)
@ -152,7 +162,7 @@ def before_step(context, _):
def before_scenario(context, scenario):
if scenario.name == 'list databases':
if scenario.name == "list databases":
# not using the cli for that
return
wrappers.run_cli(context)
@ -161,19 +171,19 @@ def before_scenario(context, scenario):
def after_scenario(context, scenario):
"""Cleans up after each scenario completes."""
if hasattr(context, 'cli') and context.cli and not context.exit_sent:
if hasattr(context, "cli") and context.cli and not context.exit_sent:
# Quit nicely.
if not context.atprompt:
dbname = context.currentdb
context.cli.expect_exact('{0}> '.format(dbname), timeout=15)
context.cli.sendcontrol('c')
context.cli.sendcontrol('d')
context.cli.expect_exact("{0}> ".format(dbname), timeout=15)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
try:
context.cli.expect_exact(pexpect.EOF, timeout=15)
except pexpect.TIMEOUT:
print('--- after_scenario {}: kill cli'.format(scenario.name))
print ("--- after_scenario {}: kill cli".format(scenario.name))
context.cli.kill(signal.SIGKILL)
if hasattr(context, 'tmpfile_sql_help') and context.tmpfile_sql_help:
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
context.tmpfile_sql_help.close()
context.tmpfile_sql_help = None

View File

@ -13,7 +13,7 @@ def read_fixture_lines(filename):
:return: list of strings
"""
lines = []
for line in codecs.open(filename, 'rb', encoding='utf-8'):
for line in codecs.open(filename, "rb", encoding="utf-8"):
lines.append(line.strip())
return lines
@ -21,11 +21,11 @@ def read_fixture_lines(filename):
def read_fixture_files():
"""Read all files inside fixture_data directory."""
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, 'fixture_data/')
print('reading fixture data: {}'.format(fixture_dir))
fixture_dir = os.path.join(current_dir, "fixture_data/")
print ("reading fixture data: {}".format(fixture_dir))
fixture_dict = {}
for filename in os.listdir(fixture_dir):
if filename not in ['.', '..']:
if filename not in [".", ".."]:
fullname = os.path.join(fixture_dir, filename)
fixture_dict[filename] = read_fixture_lines(fullname)

View File

@ -6,37 +6,45 @@ from behave import then, when
import wrappers
@when('we run dbcli with {arg}')
@when("we run dbcli with {arg}")
def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=arg.split('='))
wrappers.run_cli(context, run_args=arg.split("="))
@when('we execute a small query')
@when("we execute a small query")
def step_execute_small_query(context):
context.cli.sendline('select 1')
context.cli.sendline("select 1")
@when('we execute a large query')
@when("we execute a large query")
def step_execute_large_query(context):
context.cli.sendline(
'select {}'.format(','.join([str(n) for n in range(1, 50)])))
context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
@then('we see small results in horizontal format')
@then("we see small results in horizontal format")
def step_see_small_results(context):
wrappers.expect_pager(context, dedent("""\
wrappers.expect_pager(
context,
dedent(
"""\
+------------+\r
| ?column? |\r
|------------|\r
| 1 |\r
+------------+\r
SELECT 1\r
"""), timeout=5)
"""
),
timeout=5,
)
@then('we see large results in vertical format')
@then("we see large results in vertical format")
def step_see_large_results(context):
wrappers.expect_pager(context, dedent("""\
wrappers.expect_pager(
context,
dedent(
"""\
-[ RECORD 1 ]-------------------------\r
?column? | 1\r
?column? | 2\r
@ -88,4 +96,7 @@ def step_see_large_results(context):
?column? | 48\r
?column? | 49\r
SELECT 1\r
"""), timeout=5)
"""
),
timeout=5,
)

View File

@ -15,47 +15,48 @@ from textwrap import dedent
import wrappers
@when('we list databases')
@when("we list databases")
def step_list_databases(context):
cmd = ['pgcli', '--list']
cmd = ["pgcli", "--list"]
context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root)
@then('we see list of databases')
@then("we see list of databases")
def step_see_list_databases(context):
assert b'List of databases' in context.cmd_output
assert b'postgres' in context.cmd_output
assert b"List of databases" in context.cmd_output
assert b"postgres" in context.cmd_output
context.cmd_output = None
@when('we run dbcli')
@when("we run dbcli")
def step_run_cli(context):
wrappers.run_cli(context)
@when('we launch dbcli using {arg}')
@when("we launch dbcli using {arg}")
def step_run_cli_using_arg(context, arg):
prompt_check = False
currentdb = None
if arg == '--username':
arg = '--username={}'.format(context.conf['user'])
if arg == '--user':
arg = '--user={}'.format(context.conf['user'])
if arg == '--port':
arg = '--port={}'.format(context.conf['port'])
if arg == '--password':
arg = '--password'
if arg == "--username":
arg = "--username={}".format(context.conf["user"])
if arg == "--user":
arg = "--user={}".format(context.conf["user"])
if arg == "--port":
arg = "--port={}".format(context.conf["port"])
if arg == "--password":
arg = "--password"
prompt_check = False
# This uses the mock_pg_service.conf file in fixtures folder.
if arg == 'dsn_password':
arg = 'service=mock_postgres --password'
if arg == "dsn_password":
arg = "service=mock_postgres --password"
prompt_check = False
currentdb = "postgres"
wrappers.run_cli(context, run_args=[
arg], prompt_check=prompt_check, currentdb=currentdb)
wrappers.run_cli(
context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb
)
@when('we wait for prompt')
@when("we wait for prompt")
def step_wait_prompt(context):
wrappers.wait_prompt(context)
@ -66,9 +67,9 @@ def step_ctrl_d(context):
Send Ctrl + D to hopefully exit.
"""
# turn off pager before exiting
context.cli.sendline('\pset pager off')
context.cli.sendline("\pset pager off")
wrappers.wait_prompt(context)
context.cli.sendcontrol('d')
context.cli.sendcontrol("d")
context.cli.expect_exact(pexpect.EOF, timeout=15)
context.exit_sent = True
@ -78,51 +79,58 @@ def step_send_help(context):
"""
Send \? to see help.
"""
context.cli.sendline('\?')
context.cli.sendline("\?")
@when(u'we send source command')
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix='pgcli_')
context.tmpfile_sql_help.write(b'\?')
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
context.tmpfile_sql_help.write(b"\?")
context.tmpfile_sql_help.flush()
context.cli.sendline('\i {0}'.format(context.tmpfile_sql_help.name))
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@when(u'we run query to check application_name')
@when("we run query to check application_name")
def step_check_application_name(context):
context.cli.sendline(
"SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;"
)
@then(u'we see found')
@then("we see found")
def step_see_found(context):
wrappers.expect_exact(
context,
context.conf['pager_boundary'] + '\r' + dedent('''
context.conf["pager_boundary"]
+ "\r"
+ dedent(
"""
+------------+\r
| ?column? |\r
|------------|\r
| found |\r
+------------+\r
SELECT 1\r
''') + context.conf['pager_boundary'],
timeout=5
"""
)
+ context.conf["pager_boundary"],
timeout=5,
)
@then(u'we confirm the destructive warning')
@then("we confirm the destructive warning")
def step_confirm_destructive_command(context):
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
context.cli.sendline('y')
context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):",
timeout=2,
)
context.cli.sendline("y")
@then(u'we send password')
@then("we send password")
def step_send_password(context):
wrappers.expect_exact(context, 'Password for', timeout=5)
context.cli.sendline(context.conf['pass'] or 'DOES NOT MATTER')
wrappers.expect_exact(context, "Password for", timeout=5)
context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER")

View File

@ -12,47 +12,43 @@ from behave import when, then
import wrappers
@when('we create database')
@when("we create database")
def step_db_create(context):
"""
Send create database.
"""
context.cli.sendline('create database {0};'.format(
context.conf['dbname_tmp']))
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
context.response = {
'database_name': context.conf['dbname_tmp']
}
context.response = {"database_name": context.conf["dbname_tmp"]}
@when('we drop database')
@when("we drop database")
def step_db_drop(context):
"""
Send drop database.
"""
context.cli.sendline('drop database {0};'.format(
context.conf['dbname_tmp']))
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
@when('we connect to test database')
@when("we connect to test database")
def step_db_connect_test(context):
"""
Send connect to database.
"""
db_name = context.conf['dbname']
context.cli.sendline('\\connect {0}'.format(db_name))
db_name = context.conf["dbname"]
context.cli.sendline("\\connect {0}".format(db_name))
@when('we connect to dbserver')
@when("we connect to dbserver")
def step_db_connect_dbserver(context):
"""
Send connect to database.
"""
context.cli.sendline('\\connect postgres')
context.currentdb = 'postgres'
context.cli.sendline("\\connect postgres")
context.currentdb = "postgres"
@then('dbcli exits')
@then("dbcli exits")
def step_wait_exit(context):
"""
Make sure the cli exits.
@ -60,41 +56,41 @@ def step_wait_exit(context):
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
@then('we see dbcli prompt')
@then("we see dbcli prompt")
def step_see_prompt(context):
"""
Wait to see the prompt.
"""
db_name = getattr(context, "currentdb", context.conf['dbname'])
wrappers.expect_exact(context, '{0}> '.format(db_name), timeout=5)
db_name = getattr(context, "currentdb", context.conf["dbname"])
wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5)
context.atprompt = True
@then('we see help output')
@then("we see help output")
def step_see_help(context):
for expected_line in context.fixture_data['help_commands.txt']:
for expected_line in context.fixture_data["help_commands.txt"]:
wrappers.expect_exact(context, expected_line, timeout=2)
@then('we see database created')
@then("we see database created")
def step_see_db_created(context):
"""
Wait to see create database output.
"""
wrappers.expect_pager(context, 'CREATE DATABASE\r\n', timeout=5)
wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5)
@then('we see database dropped')
@then("we see database dropped")
def step_see_db_dropped(context):
"""
Wait to see drop database output.
"""
wrappers.expect_pager(context, 'DROP DATABASE\r\n', timeout=2)
wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2)
@then('we see database connected')
@then("we see database connected")
def step_see_db_connected(context):
"""
Wait to see drop database output.
"""
wrappers.expect_exact(context, 'You are now connected to database', timeout=2)
wrappers.expect_exact(context, "You are now connected to database", timeout=2)

View File

@ -11,107 +11,110 @@ from textwrap import dedent
import wrappers
@when('we create table')
@when("we create table")
def step_create_table(context):
"""
Send create table.
"""
context.cli.sendline('create table a(x text);')
context.cli.sendline("create table a(x text);")
@when('we insert into table')
@when("we insert into table")
def step_insert_into_table(context):
"""
Send insert into table.
"""
context.cli.sendline('''insert into a(x) values('xxx');''')
context.cli.sendline("""insert into a(x) values('xxx');""")
@when('we update table')
@when("we update table")
def step_update_table(context):
"""
Send insert into table.
"""
context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""")
@when('we select from table')
@when("we select from table")
def step_select_from_table(context):
"""
Send select from table.
"""
context.cli.sendline('select * from a;')
context.cli.sendline("select * from a;")
@when('we delete from table')
@when("we delete from table")
def step_delete_from_table(context):
"""
Send deete from table.
"""
context.cli.sendline('''delete from a where x = 'yyy';''')
context.cli.sendline("""delete from a where x = 'yyy';""")
@when('we drop table')
@when("we drop table")
def step_drop_table(context):
"""
Send drop table.
"""
context.cli.sendline('drop table a;')
context.cli.sendline("drop table a;")
@then('we see table created')
@then("we see table created")
def step_see_table_created(context):
"""
Wait to see create table output.
"""
wrappers.expect_pager(context, 'CREATE TABLE\r\n', timeout=2)
wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
@then('we see record inserted')
@then("we see record inserted")
def step_see_record_inserted(context):
"""
Wait to see insert output.
"""
wrappers.expect_pager(context, 'INSERT 0 1\r\n', timeout=2)
wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
@then('we see record updated')
@then("we see record updated")
def step_see_record_updated(context):
"""
Wait to see update output.
"""
wrappers.expect_pager(context, 'UPDATE 1\r\n', timeout=2)
wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2)
@then('we see data selected')
@then("we see data selected")
def step_see_data_selected(context):
"""
Wait to see select output.
"""
wrappers.expect_pager(
context,
dedent('''\
dedent(
"""\
+-----+\r
| x |\r
|-----|\r
| yyy |\r
+-----+\r
SELECT 1\r
'''),
timeout=1)
"""
),
timeout=1,
)
@then('we see record deleted')
@then("we see record deleted")
def step_see_data_deleted(context):
"""
Wait to see delete output.
"""
wrappers.expect_pager(context, 'DELETE 1\r\n', timeout=2)
wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2)
@then('we see table dropped')
@then("we see table dropped")
def step_see_table_dropped(context):
"""
Wait to see drop output.
"""
wrappers.expect_pager(context, 'DROP TABLE\r\n', timeout=2)
wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2)

View File

@ -12,53 +12,61 @@ from textwrap import dedent
import wrappers
@when('we prepare the test data')
@when("we prepare the test data")
def step_prepare_data(context):
"""Create table, insert a record."""
context.cli.sendline('drop table if exists a;')
context.cli.sendline("drop table if exists a;")
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
context.cli.sendline('y')
context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):",
timeout=2,
)
context.cli.sendline("y")
wrappers.wait_prompt(context)
context.cli.sendline(
'create table a(x integer, y real, z numeric(10, 4));')
wrappers.expect_pager(context, 'CREATE TABLE\r\n', timeout=2)
context.cli.sendline('''insert into a(x, y, z) values(1, 1.0, 1.0);''')
wrappers.expect_pager(context, 'INSERT 0 1\r\n', timeout=2)
context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));")
wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""")
wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
@when('we set expanded {mode}')
@when("we set expanded {mode}")
def step_set_expanded(context, mode):
"""Set expanded to mode."""
context.cli.sendline('\\' + 'x {}'.format(mode))
wrappers.expect_exact(context, 'Expanded display is', timeout=2)
context.cli.sendline("\\" + "x {}".format(mode))
wrappers.expect_exact(context, "Expanded display is", timeout=2)
wrappers.wait_prompt(context)
@then('we see {which} data selected')
@then("we see {which} data selected")
def step_see_data(context, which):
"""Select data from expanded test table."""
if which == 'expanded':
if which == "expanded":
wrappers.expect_pager(
context,
dedent('''\
dedent(
"""\
-[ RECORD 1 ]-------------------------\r
x | 1\r
y | 1.0\r
z | 1.0000\r
SELECT 1\r
'''),
timeout=1)
"""
),
timeout=1,
)
else:
wrappers.expect_pager(
context,
dedent('''\
dedent(
"""\
+-----+-----+--------+\r
| x | y | z |\r
|-----+-----+--------|\r
| 1 | 1.0 | 1.0000 |\r
+-----+-----+--------+\r
SELECT 1\r
'''),
timeout=1)
"""
),
timeout=1,
)

View File

@ -7,77 +7,76 @@ from behave import when, then
import wrappers
@when('we start external editor providing a file name')
@when("we start external editor providing a file name")
def step_edit_file(context):
"""Edit file with external editor."""
context.editor_file_name = os.path.join(
context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
context.package_root, "test_file_{0}.sql".format(context.conf["vi"])
)
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline('\e {0}'.format(
os.path.basename(context.editor_file_name)))
context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
wrappers.expect_exact(context, ':', timeout=2)
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
)
wrappers.expect_exact(context, ":", timeout=2)
@when('we type sql in the editor')
@when("we type sql in the editor")
def step_edit_type_sql(context):
context.cli.sendline('i')
context.cli.sendline('select * from abc')
context.cli.sendline('.')
wrappers.expect_exact(context, ':', timeout=2)
context.cli.sendline("i")
context.cli.sendline("select * from abc")
context.cli.sendline(".")
wrappers.expect_exact(context, ":", timeout=2)
@when('we exit the editor')
@when("we exit the editor")
def step_edit_quit(context):
context.cli.sendline('x')
context.cli.sendline("x")
wrappers.expect_exact(context, "written", timeout=2)
@then('we see the sql in prompt')
@then("we see the sql in prompt")
def step_edit_done_sql(context):
for match in 'select * from abc'.split(' '):
for match in "select * from abc".split(" "):
wrappers.expect_exact(context, match, timeout=1)
# Cleanup the command line.
context.cli.sendcontrol('c')
context.cli.sendcontrol("c")
# Cleanup the edited file.
if context.editor_file_name and os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.atprompt = True
@when(u'we tee output')
@when("we tee output")
def step_tee_ouptut(context):
context.tee_file_name = os.path.join(
context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])
)
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline('\o {0}'.format(
os.path.basename(context.tee_file_name)))
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name)))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Writing to file", timeout=5)
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Time", timeout=5)
@when(u'we query "select 123456"')
@when('we query "select 123456"')
def step_query_select_123456(context):
context.cli.sendline('select 123456')
context.cli.sendline("select 123456")
@when(u'we stop teeing output')
@when("we stop teeing output")
def step_notee_output(context):
context.cli.sendline('\o')
context.cli.sendline("\o")
wrappers.expect_exact(context, "Time", timeout=5)
@then(u'we see 123456 in tee output')
@then("we see 123456 in tee output")
def step_see_123456_in_ouput(context):
with open(context.tee_file_name) as f:
assert '123456' in f.read()
assert "123456" in f.read()
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.atprompt = True

View File

@ -10,50 +10,50 @@ from behave import when, then
import wrappers
@when('we save a named query')
@when("we save a named query")
def step_save_named_query(context):
"""
Send \ns command
"""
context.cli.sendline('\\ns foo SELECT 12345')
context.cli.sendline("\\ns foo SELECT 12345")
@when('we use a named query')
@when("we use a named query")
def step_use_named_query(context):
"""
Send \n command
"""
context.cli.sendline('\\n foo')
context.cli.sendline("\\n foo")
@when('we delete a named query')
@when("we delete a named query")
def step_delete_named_query(context):
"""
Send \nd command
"""
context.cli.sendline('\\nd foo')
context.cli.sendline("\\nd foo")
@then('we see the named query saved')
@then("we see the named query saved")
def step_see_named_query_saved(context):
"""
Wait to see query saved.
"""
wrappers.expect_exact(context, 'Saved.', timeout=2)
wrappers.expect_exact(context, "Saved.", timeout=2)
@then('we see the named query executed')
@then("we see the named query executed")
def step_see_named_query_executed(context):
"""
Wait to see select output.
"""
wrappers.expect_exact(context, '12345', timeout=1)
wrappers.expect_exact(context, 'SELECT 1', timeout=1)
wrappers.expect_exact(context, "12345", timeout=1)
wrappers.expect_exact(context, "SELECT 1", timeout=1)
@then('we see the named query deleted')
@then("we see the named query deleted")
def step_see_named_query_deleted(context):
"""
Wait to see query deleted.
"""
wrappers.expect_pager(context, 'foo: Deleted\r\n', timeout=1)
wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1)

View File

@ -10,18 +10,19 @@ from behave import when, then
import wrappers
@when('we refresh completions')
@when("we refresh completions")
def step_refresh_completions(context):
"""
Send refresh command.
"""
context.cli.sendline('\\refresh')
context.cli.sendline("\\refresh")
@then('we see completions refresh started')
@then("we see completions refresh started")
def step_see_refresh_started(context):
"""
Wait to see refresh output.
"""
wrappers.expect_pager(
context, 'Auto-completion refresh started in the background.\r\n', timeout=2)
context, "Auto-completion refresh started in the background.\r\n", timeout=2
)

View File

@ -20,10 +20,10 @@ def expect_exact(context, expected, timeout):
timedout = True
if timedout:
# Strip color codes out of the output.
actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
'', context.cli.before)
actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
raise Exception(
textwrap.dedent('''\
textwrap.dedent(
"""\
Expected:
---
{0!r}
@ -36,35 +36,35 @@ def expect_exact(context, expected, timeout):
---
{2!r}
---
''').format(
expected,
actual,
context.logfile.getvalue()
)
"""
).format(expected, actual, context.logfile.getvalue())
)
def expect_pager(context, expected, timeout):
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
context.conf['pager_boundary'], expected), timeout=timeout)
expect_exact(
context,
"{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
timeout=timeout,
)
def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
"""Run the process using pexpect."""
run_args = run_args or []
cli_cmd = context.conf.get('cli_command')
cli_cmd = context.conf.get("cli_command")
cmd_parts = [cli_cmd] + run_args
cmd = ' '.join(cmd_parts)
cmd = " ".join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = currentdb or context.conf['dbname']
context.cli.sendline('\pset pager always')
context.currentdb = currentdb or context.conf["dbname"]
context.cli.sendline("\pset pager always")
if prompt_check:
wait_prompt(context)
def wait_prompt(context):
"""Make sure prompt is displayed."""
expect_exact(context, '{0}> '.format(context.conf['dbname']), timeout=5)
expect_exact(context, "{0}> ".format(context.conf["dbname"]), timeout=5)

View File

@ -3,13 +3,13 @@ import sys
def wrappager(boundary):
print(boundary)
print (boundary)
while 1:
buf = sys.stdin.read(2048)
if not buf:
break
sys.stdout.write(buf)
print(boundary)
print (boundary)
if __name__ == "__main__":

View File

@ -11,12 +11,12 @@ import pytest
parametrize = pytest.mark.parametrize
qual = ['if_more_than_one_table', 'always']
no_qual = ['if_more_than_one_table', 'never']
qual = ["if_more_than_one_table", "always"]
no_qual = ["if_more_than_one_table", "never"]
def escape(name):
if not name.islower() or name in ('select', 'localtimestamp'):
if not name.islower() or name in ("select", "localtimestamp"):
return '"' + name + '"'
return name
@ -27,10 +27,7 @@ def completion(display_meta, text, pos=0):
def function(text, pos=0, display=None):
return Completion(
text,
display=display or text,
start_position=pos,
display_meta='function'
text, display=display or text, start_position=pos, display_meta="function"
)
@ -49,21 +46,20 @@ def result_set(completer, text, position=None):
# def schema(text, pos=0):
# return completion('schema', text, pos)
# and so on
schema = partial(completion, 'schema')
table = partial(completion, 'table')
view = partial(completion, 'view')
column = partial(completion, 'column')
keyword = partial(completion, 'keyword')
datatype = partial(completion, 'datatype')
alias = partial(completion, 'table alias')
name_join = partial(completion, 'name join')
fk_join = partial(completion, 'fk join')
join = partial(completion, 'join')
schema = partial(completion, "schema")
table = partial(completion, "table")
view = partial(completion, "view")
column = partial(completion, "column")
keyword = partial(completion, "keyword")
datatype = partial(completion, "datatype")
alias = partial(completion, "table alias")
name_join = partial(completion, "name join")
fk_join = partial(completion, "fk join")
join = partial(completion, "join")
def wildcard_expansion(cols, pos=-1):
return Completion(
cols, start_position=pos, display_meta='columns', display='*')
return Completion(cols, start_position=pos, display_meta="columns", display="*")
class MetaData(object):
@ -80,78 +76,88 @@ class MetaData(object):
return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()]
def specials(self, pos=0):
return [Completion(text=k, start_position=pos, display_meta=v.description) for k, v in iteritems(self.completer.pgspecial.commands)]
return [
Completion(text=k, start_position=pos, display_meta=v.description)
for k, v in iteritems(self.completer.pgspecial.commands)
]
def columns(self, tbl, parent='public', typ='tables', pos=0):
if typ == 'functions':
def columns(self, tbl, parent="public", typ="tables", pos=0):
if typ == "functions":
fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0]
cols = fun[1]
else:
cols = self.metadata[typ][parent][tbl]
return [column(escape(col), pos) for col in cols]
def datatypes(self, parent='public', pos=0):
def datatypes(self, parent="public", pos=0):
return [
datatype(escape(x), pos)
for x in self.metadata.get('datatypes', {}).get(parent, [])]
for x in self.metadata.get("datatypes", {}).get(parent, [])
]
def tables(self, parent='public', pos=0):
def tables(self, parent="public", pos=0):
return [
table(escape(x), pos)
for x in self.metadata.get('tables', {}).get(parent, [])]
for x in self.metadata.get("tables", {}).get(parent, [])
]
def views(self, parent='public', pos=0):
def views(self, parent="public", pos=0):
return [
view(escape(x), pos)
for x in self.metadata.get('views', {}).get(parent, [])]
view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, [])
]
def functions(self, parent='public', pos=0):
def functions(self, parent="public", pos=0):
return [
function(
escape(x[0]) + '(' + ', '.join(
arg_name + ' := '
escape(x[0])
+ "("
+ ", ".join(
arg_name + " := "
for (arg_name, arg_mode) in zip(x[1], x[3])
if arg_mode in ('b', 'i')
) + ')',
if arg_mode in ("b", "i")
)
+ ")",
pos,
escape(x[0]) + '(' + ', '.join(
escape(x[0])
+ "("
+ ", ".join(
arg_name
for (arg_name, arg_mode) in zip(x[1], x[3])
if arg_mode in ('b', 'i')
) + ')'
if arg_mode in ("b", "i")
)
+ ")",
)
for x in self.metadata.get('functions', {}).get(parent, [])
for x in self.metadata.get("functions", {}).get(parent, [])
]
def schemas(self, pos=0):
schemas = set(sch for schs in self.metadata.values() for sch in schs)
return [schema(escape(s), pos=pos) for s in schemas]
def functions_and_keywords(self, parent='public', pos=0):
def functions_and_keywords(self, parent="public", pos=0):
return (
self.functions(parent, pos) + self.builtin_functions(pos) +
self.keywords(pos)
self.functions(parent, pos)
+ self.builtin_functions(pos)
+ self.keywords(pos)
)
# Note that the filtering parameters here only apply to the columns
def columns_functions_and_keywords(
self, tbl, parent='public', typ='tables', pos=0
):
return (
self.functions_and_keywords(pos=pos) +
self.columns(tbl, parent, typ, pos)
def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0):
return self.functions_and_keywords(pos=pos) + self.columns(
tbl, parent, typ, pos
)
def from_clause_items(self, parent='public', pos=0):
def from_clause_items(self, parent="public", pos=0):
return (
self.functions(parent, pos) + self.views(parent, pos) +
self.tables(parent, pos)
self.functions(parent, pos)
+ self.views(parent, pos)
+ self.tables(parent, pos)
)
def schemas_and_from_clause_items(self, parent='public', pos=0):
def schemas_and_from_clause_items(self, parent="public", pos=0):
return self.from_clause_items(parent, pos) + self.schemas(pos)
def types(self, parent='public', pos=0):
def types(self, parent="public", pos=0):
return self.datatypes(parent, pos) + self.tables(parent, pos)
@property
@ -170,85 +176,83 @@ class MetaData(object):
casing, without `search_path` filtering of objects, with table
aliasing, and without column qualification.
"""
def _cfg(_casing, filtr, aliasing, qualify):
cfg = {'settings': {}}
cfg = {"settings": {}}
if _casing:
cfg['casing'] = casing
cfg['settings']['search_path_filter'] = filtr
cfg['settings']['generate_aliases'] = aliasing
cfg['settings']['qualify_columns'] = qualify
cfg["casing"] = casing
cfg["settings"]["search_path_filter"] = filtr
cfg["settings"]["generate_aliases"] = aliasing
cfg["settings"]["qualify_columns"] = qualify
return cfg
def _cfgs(casing, filtr, aliasing, qualify):
casings = [True, False] if casing is None else [casing]
filtrs = [True, False] if filtr is None else [filtr]
aliases = [True, False] if aliasing is None else [aliasing]
qualifys = qualify or ['always', 'if_more_than_one_table', 'never']
return [
_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)
]
qualifys = qualify or ["always", "if_more_than_one_table", "never"]
return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)]
def completers(casing=None, filtr=None, aliasing=None, qualify=None):
get_comp = self.get_completer
return [
get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)
]
return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)]
return completers
def _make_col(self, sch, tbl, col):
defaults = self.metadata.get('defaults', {}).get(sch, {})
return (sch, tbl, col, 'text', (tbl, col) in defaults, defaults.get((tbl, col)))
defaults = self.metadata.get("defaults", {}).get(sch, {})
return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col)))
def get_completer(self, settings=None, casing=None):
metadata = self.metadata
from pgcli.pgcompleter import PGCompleter
from pgspecial import PGSpecial
comp = PGCompleter(smart_completion=True,
settings=settings, pgspecial=PGSpecial())
comp = PGCompleter(
smart_completion=True, settings=settings, pgspecial=PGSpecial()
)
schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []
for sch, tbls in metadata['tables'].items():
for sch, tbls in metadata["tables"].items():
schemata.append(sch)
for tbl, cols in tbls.items():
tables.append((sch, tbl))
# Let all columns be text columns
tbl_cols.extend([self._make_col(sch, tbl, col)
for col in cols])
tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols])
for sch, tbls in metadata.get('views', {}).items():
for sch, tbls in metadata.get("views", {}).items():
for tbl, cols in tbls.items():
views.append((sch, tbl))
# Let all columns be text columns
view_cols.extend([self._make_col(sch, tbl, col)
for col in cols])
view_cols.extend([self._make_col(sch, tbl, col) for col in cols])
functions = [
FunctionMetadata(sch, *func_meta, arg_defaults=None)
for sch, funcs in metadata['functions'].items()
for func_meta in funcs]
for sch, funcs in metadata["functions"].items()
for func_meta in funcs
]
datatypes = [
(sch, typ)
for sch, datatypes in metadata['datatypes'].items()
for typ in datatypes]
for sch, datatypes in metadata["datatypes"].items()
for typ in datatypes
]
foreignkeys = [
ForeignKey(*fk) for fks in metadata['foreignkeys'].values()
for fk in fks]
ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks
]
comp.extend_schemata(schemata)
comp.extend_relations(tables, kind='tables')
comp.extend_relations(views, kind='views')
comp.extend_columns(tbl_cols, kind='tables')
comp.extend_columns(view_cols, kind='views')
comp.extend_relations(tables, kind="tables")
comp.extend_relations(views, kind="views")
comp.extend_columns(tbl_cols, kind="tables")
comp.extend_columns(view_cols, kind="views")
comp.extend_functions(functions)
comp.extend_datatypes(datatypes)
comp.extend_foreignkeys(foreignkeys)
comp.set_search_path(['public'])
comp.set_search_path(["public"])
comp.extend_casing(casing or [])
return comp

View File

@ -1,8 +1,10 @@
import pytest
from sqlparse import parse
from pgcli.packages.parseutils.ctes import (
token_start_pos, extract_ctes,
extract_column_names as _extract_column_names)
token_start_pos,
extract_ctes,
extract_column_names as _extract_column_names,
)
def extract_column_names(sql):
@ -11,109 +13,125 @@ def extract_column_names(sql):
def test_token_str_pos():
sql = 'SELECT * FROM xxx'
sql = "SELECT * FROM xxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len('SELECT * FROM ')
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")
sql = 'SELECT * FROM \nxxx'
sql = "SELECT * FROM \nxxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len('SELECT * FROM \n')
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")
def test_single_column_name_extraction():
sql = 'SELECT abc FROM xxx'
assert extract_column_names(sql) == ('abc',)
sql = "SELECT abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_aliased_single_column_name_extraction():
sql = 'SELECT abc def FROM xxx'
assert extract_column_names(sql) == ('def',)
sql = "SELECT abc def FROM xxx"
assert extract_column_names(sql) == ("def",)
def test_aliased_expression_name_extraction():
sql = 'SELECT 99 abc FROM xxx'
assert extract_column_names(sql) == ('abc',)
sql = "SELECT 99 abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_multiple_column_name_extraction():
sql = 'SELECT abc, def FROM xxx'
assert extract_column_names(sql) == ('abc', 'def')
sql = "SELECT abc, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_missing_column_name_handled_gracefully():
sql = 'SELECT abc, 99 FROM xxx'
assert extract_column_names(sql) == ('abc',)
sql = "SELECT abc, 99 FROM xxx"
assert extract_column_names(sql) == ("abc",)
sql = 'SELECT abc, 99, def FROM xxx'
assert extract_column_names(sql) == ('abc', 'def')
sql = "SELECT abc, 99, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_aliased_multiple_column_name_extraction():
sql = 'SELECT abc def, ghi jkl FROM xxx'
assert extract_column_names(sql) == ('def', 'jkl')
sql = "SELECT abc def, ghi jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
def test_table_qualified_column_name_extraction():
sql = 'SELECT abc.def, ghi.jkl FROM xxx'
assert extract_column_names(sql) == ('def', 'jkl')
sql = "SELECT abc.def, ghi.jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
@pytest.mark.parametrize('sql', [
'INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y',
'DELETE FROM foo WHERE x > y RETURNING x, y',
'UPDATE foo SET x = 9 RETURNING x, y',
])
@pytest.mark.parametrize(
"sql",
[
"INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
"DELETE FROM foo WHERE x > y RETURNING x, y",
"UPDATE foo SET x = 9 RETURNING x, y",
],
)
def test_extract_column_names_from_returning_clause(sql):
assert extract_column_names(sql) == ('x', 'y')
assert extract_column_names(sql) == ("x", "y")
def test_simple_cte_extraction():
sql = 'WITH a AS (SELECT abc FROM xxx) SELECT * FROM a'
start_pos = len('WITH a AS ')
stop_pos = len('WITH a AS (SELECT abc FROM xxx)')
sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
start_pos = len("WITH a AS ")
stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (('a', ('abc',), start_pos, stop_pos),)
assert remainder.strip() == 'SELECT * FROM a'
assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_cte_extraction_around_comments():
sql = '''--blah blah blah
sql = """--blah blah blah
WITH a AS (SELECT abc def FROM x)
SELECT * FROM a'''
start_pos = len('''--blah blah blah
WITH a AS ''')
stop_pos = len('''--blah blah blah
WITH a AS (SELECT abc def FROM x)''')
SELECT * FROM a"""
start_pos = len(
"""--blah blah blah
WITH a AS """
)
stop_pos = len(
"""--blah blah blah
WITH a AS (SELECT abc def FROM x)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (('a', ('def',), start_pos, stop_pos),)
assert remainder.strip() == 'SELECT * FROM a'
assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_multiple_cte_extraction():
sql = '''WITH
sql = """WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)
SELECT * FROM a, b'''
SELECT * FROM a, b"""
start1 = len('''WITH
x AS ''')
start1 = len(
"""WITH
x AS """
)
stop1 = len('''WITH
x AS (SELECT abc, def FROM x)''')
stop1 = len(
"""WITH
x AS (SELECT abc, def FROM x)"""
)
start2 = len('''WITH
start2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS ''')
y AS """
)
stop2 = len('''WITH
stop2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)''')
y AS (SELECT ghi, jkl FROM y)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (
('x', ('abc', 'def'), start1, stop1),
('y', ('ghi', 'jkl'), start2, stop2))
("x", ("abc", "def"), start1, stop1),
("y", ("ghi", "jkl"), start2, stop2),
)

View File

@ -3,16 +3,13 @@ from pgcli.packages.parseutils.meta import FunctionMetadata
def test_function_metadata_eq():
f1 = FunctionMetadata(
's', 'f', ['x'], ['integer'], [
], 'int', False, False, False, False, None
"s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
)
f2 = FunctionMetadata(
's', 'f', ['x'], ['integer'], [
], 'int', False, False, False, False, None
"s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
)
f3 = FunctionMetadata(
's', 'g', ['x'], ['integer'], [
], 'int', False, False, False, False, None
"s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None
)
assert f1 == f2
assert f1 != f3

View File

@ -4,104 +4,98 @@ from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
def test_empty_string():
tables = extract_tables('')
tables = extract_tables("")
assert tables == ()
def test_simple_select_single_table():
tables = extract_tables('select * from abc')
assert tables == ((None, 'abc', None, False),)
tables = extract_tables("select * from abc")
assert tables == ((None, "abc", None, False),)
@pytest.mark.parametrize('sql', [
'select * from "abc"."def"',
'select * from abc."def"',
])
@pytest.mark.parametrize(
"sql", ['select * from "abc"."def"', 'select * from abc."def"']
)
def test_simple_select_single_table_schema_qualified_quoted_table(sql):
tables = extract_tables(sql)
assert tables == (('abc', 'def', '"def"', False),)
assert tables == (("abc", "def", '"def"', False),)
@pytest.mark.parametrize('sql', [
'select * from abc.def',
'select * from "abc".def',
])
@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def'])
def test_simple_select_single_table_schema_qualified(sql):
tables = extract_tables(sql)
assert tables == (('abc', 'def', None, False),)
assert tables == (("abc", "def", None, False),)
def test_simple_select_single_table_double_quoted():
tables = extract_tables('select * from "Abc"')
assert tables == ((None, 'Abc', None, False),)
assert tables == ((None, "Abc", None, False),)
def test_simple_select_multiple_tables():
tables = extract_tables('select * from abc, def')
assert set(tables) == set([(None, 'abc', None, False),
(None, 'def', None, False)])
tables = extract_tables("select * from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_simple_select_multiple_tables_double_quoted():
tables = extract_tables('select * from "Abc", "Def"')
assert set(tables) == set([(None, 'Abc', None, False),
(None, 'Def', None, False)])
assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)])
def test_simple_select_single_table_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a')
assert tables == ((None, 'Abc', 'a', False),)
assert tables == ((None, "Abc", "a", False),)
def test_simple_select_multiple_tables_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a, "Def" d')
assert set(tables) == set([(None, 'Abc', 'a', False),
(None, 'Def', 'd', False)])
assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)])
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables('select * from abc.def, ghi.jkl')
assert set(tables) == set([('abc', 'def', None, False),
('ghi', 'jkl', None, False)])
tables = extract_tables("select * from abc.def, ghi.jkl")
assert set(tables) == set(
[("abc", "def", None, False), ("ghi", "jkl", None, False)]
)
def test_simple_select_with_cols_single_table():
tables = extract_tables('select a,b from abc')
assert tables == ((None, 'abc', None, False),)
tables = extract_tables("select a,b from abc")
assert tables == ((None, "abc", None, False),)
def test_simple_select_with_cols_single_table_schema_qualified():
tables = extract_tables('select a,b from abc.def')
assert tables == (('abc', 'def', None, False),)
tables = extract_tables("select a,b from abc.def")
assert tables == (("abc", "def", None, False),)
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables('select a,b from abc, def')
assert set(tables) == set([(None, 'abc', None, False),
(None, 'def', None, False)])
tables = extract_tables("select a,b from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_simple_select_with_cols_multiple_qualified_tables():
tables = extract_tables('select a,b from abc.def, def.ghi')
assert set(tables) == set([('abc', 'def', None, False),
('def', 'ghi', None, False)])
tables = extract_tables("select a,b from abc.def, def.ghi")
assert set(tables) == set(
[("abc", "def", None, False), ("def", "ghi", None, False)]
)
def test_select_with_hanging_comma_single_table():
tables = extract_tables('select a, from abc')
assert tables == ((None, 'abc', None, False),)
tables = extract_tables("select a, from abc")
assert tables == ((None, "abc", None, False),)
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables('select a, from abc, def')
assert set(tables) == set([(None, 'abc', None, False),
(None, 'def', None, False)])
tables = extract_tables("select a, from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
assert set(tables) == set([(None, 'tabl1', 't1', False),
(None, 'tabl2', 't2', False)])
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
assert set(tables) == set(
[(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)]
)
def test_simple_insert_single_table():
@ -111,157 +105,165 @@ def test_simple_insert_single_table():
# AND mistakenly identifies the field list as
# assert tables == ((None, 'abc', 'abc', False),)
assert tables == ((None, 'abc', 'abc', False),)
assert tables == ((None, "abc", "abc", False),)
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
assert tables == (('abc', 'def', None, False),)
assert tables == (("abc", "def", None, False),)
def test_simple_update_table_no_schema():
tables = extract_tables('update abc set id = 1')
assert tables == ((None, 'abc', None, False),)
tables = extract_tables("update abc set id = 1")
assert tables == ((None, "abc", None, False),)
def test_simple_update_table_with_schema():
tables = extract_tables('update abc.def set id = 1')
assert tables == (('abc', 'def', None, False),)
tables = extract_tables("update abc.def set id = 1")
assert tables == (("abc", "def", None, False),)
@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
def test_join_table(join_type):
sql = 'SELECT * FROM abc a {0} JOIN def d ON a.id = d.num'.format(join_type)
sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type)
tables = extract_tables(sql)
assert set(tables) == set([(None, 'abc', 'a', False),
(None, 'def', 'd', False)])
assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)])
def test_join_table_schema_qualified():
tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
assert set(tables) == set([('abc', 'def', 'x', False),
('ghi', 'jkl', 'y', False)])
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)])
def test_incomplete_join_clause():
sql = '''select a.x, b.y
sql = """select a.x, b.y
from abc a join bcd b
on a.id = '''
on a.id = """
tables = extract_tables(sql)
assert tables == ((None, 'abc', 'a', False),
(None, 'bcd', 'b', False))
assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False))
def test_join_as_table():
tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
assert tables == ((None, 'my_table', 'm', False),)
tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
assert tables == ((None, "my_table", "m", False),)
def test_multiple_joins():
sql = '''select * from t1
sql = """select * from t1
inner join t2 ON
t1.id = t2.t1_id
inner join t3 ON
t2.id = t3.'''
t2.id = t3."""
tables = extract_tables(sql)
assert tables == (
(None, 't1', None, False),
(None, 't2', None, False),
(None, 't3', None, False))
(None, "t1", None, False),
(None, "t2", None, False),
(None, "t3", None, False),
)
def test_subselect_tables():
sql = 'SELECT * FROM (SELECT FROM abc'
sql = "SELECT * FROM (SELECT FROM abc"
tables = extract_tables(sql)
assert tables == ((None, 'abc', None, False),)
assert tables == ((None, "abc", None, False),)
@pytest.mark.parametrize('text', ['SELECT * FROM foo.', 'SELECT 123 AS foo'])
@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"])
def test_extract_no_tables(text):
tables = extract_tables(text)
assert tables == tuple()
@pytest.mark.parametrize('arg_list', ['', 'arg1', 'arg1, arg2, arg3'])
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_function_as_table(arg_list):
tables = extract_tables('SELECT * FROM foo({0})'.format(arg_list))
assert tables == ((None, 'foo', None, True),)
tables = extract_tables("SELECT * FROM foo({0})".format(arg_list))
assert tables == ((None, "foo", None, True),)
@pytest.mark.parametrize('arg_list', ['', 'arg1', 'arg1, arg2, arg3'])
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_schema_qualified_function_as_table(arg_list):
tables = extract_tables('SELECT * FROM foo.bar({0})'.format(arg_list))
assert tables == (('foo', 'bar', None, True),)
tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list))
assert tables == (("foo", "bar", None, True),)
@pytest.mark.parametrize('arg_list', ['', 'arg1', 'arg1, arg2, arg3'])
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_aliased_function_as_table(arg_list):
tables = extract_tables('SELECT * FROM foo({0}) bar'.format(arg_list))
assert tables == ((None, 'foo', 'bar', True),)
tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list))
assert tables == ((None, "foo", "bar", True),)
def test_simple_table_and_function():
tables = extract_tables('SELECT * FROM foo JOIN bar()')
assert set(tables) == set([(None, 'foo', None, False),
(None, 'bar', None, True)])
tables = extract_tables("SELECT * FROM foo JOIN bar()")
assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)])
def test_complex_table_and_function():
tables = extract_tables('''SELECT * FROM foo.bar baz
JOIN bar.qux(x, y, z) quux''')
assert set(tables) == set([('foo', 'bar', 'baz', False),
('bar', 'qux', 'quux', True)])
tables = extract_tables(
"""SELECT * FROM foo.bar baz
JOIN bar.qux(x, y, z) quux"""
)
assert set(tables) == set(
[("foo", "bar", "baz", False), ("bar", "qux", "quux", True)]
)
def test_find_prev_keyword_using():
q = 'select * from tbl1 inner join tbl2 using (col1, '
q = "select * from tbl1 inner join tbl2 using (col1, "
kw, q2 = find_prev_keyword(q)
assert kw.value == '(' and q2 == 'select * from tbl1 inner join tbl2 using ('
assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using ("
@pytest.mark.parametrize('sql', [
'select * from foo where bar',
'select * from foo where bar = 1 and baz or ',
'select * from foo where bar = 1 and baz between qux and ',
])
@pytest.mark.parametrize(
"sql",
[
"select * from foo where bar",
"select * from foo where bar = 1 and baz or ",
"select * from foo where bar = 1 and baz between qux and ",
],
)
def test_find_prev_keyword_where(sql):
kw, stripped = find_prev_keyword(sql)
assert kw.value == 'where' and stripped == 'select * from foo where'
assert kw.value == "where" and stripped == "select * from foo where"
@pytest.mark.parametrize('sql', [
'create table foo (bar int, baz ',
'select * from foo() as bar (baz '
])
@pytest.mark.parametrize(
"sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "]
)
def test_find_prev_keyword_open_parens(sql):
kw, _ = find_prev_keyword(sql)
assert kw.value == '('
assert kw.value == "("
@pytest.mark.parametrize('sql', [
'',
'$$ foo $$',
"$$ 'foo' $$",
'$$ "foo" $$',
'$$ $a$ $$',
'$a$ $$ $a$',
'foo bar $$ baz $$',
])
@pytest.mark.parametrize(
"sql",
[
"",
"$$ foo $$",
"$$ 'foo' $$",
'$$ "foo" $$',
"$$ $a$ $$",
"$a$ $$ $a$",
"foo bar $$ baz $$",
],
)
def test_is_open_quote__closed(sql):
assert not is_open_quote(sql)
@pytest.mark.parametrize('sql', [
'$$',
';;;$$',
'foo $$ bar $$; foo $$',
'$$ foo $a$',
"foo 'bar baz",
"$a$ foo ",
'$$ "foo" ',
'$$ $a$ ',
'foo bar $$ baz',
])
@pytest.mark.parametrize(
"sql",
[
"$$",
";;;$$",
"foo $$ bar $$; foo $$",
"$$ foo $a$",
"foo 'bar baz",
"$a$ foo ",
'$$ "foo" ',
"$$ $a$ ",
"foo bar $$ baz",
],
)
def test_is_open_quote__open(sql):
assert is_open_quote(sql)

View File

@ -6,6 +6,7 @@ from mock import Mock, patch
@pytest.fixture
def refresher():
from pgcli.completion_refresher import CompletionRefresher
return CompletionRefresher()
@ -17,8 +18,15 @@ def test_ctor(refresher):
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = ['schemata', 'tables', 'views',
'types', 'databases', 'casing', 'functions']
expected_handlers = [
"schemata",
"tables",
"views",
"types",
"databases",
"casing",
"functions",
]
assert expected_handlers == actual_handlers
@ -32,14 +40,13 @@ def test_refresh_called_once(refresher):
pgexecute = Mock()
special = Mock()
with patch.object(refresher, '_bg_refresh') as bg_refresh:
with patch.object(refresher, "_bg_refresh") as bg_refresh:
actual = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual) == 1
assert len(actual[0]) == 4
assert actual[0][3] == 'Auto-completion refresh started in the background.'
bg_refresh.assert_called_with(pgexecute, special, callbacks, None,
None)
assert actual[0][3] == "Auto-completion refresh started in the background."
bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None)
def test_refresh_called_twice(refresher):
@ -62,13 +69,13 @@ def test_refresh_called_twice(refresher):
time.sleep(1) # Wait for the thread to work.
assert len(actual1) == 1
assert len(actual1[0]) == 4
assert actual1[0][3] == 'Auto-completion refresh started in the background.'
assert actual1[0][3] == "Auto-completion refresh started in the background."
actual2 = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual2) == 1
assert len(actual2[0]) == 4
assert actual2[0][3] == 'Auto-completion refresh restarted.'
assert actual2[0][3] == "Auto-completion refresh restarted."
def test_refresh_with_callbacks(refresher):
@ -82,9 +89,9 @@ def test_refresh_with_callbacks(refresher):
pgexecute.extra_args = {}
special = Mock()
with patch('pgcli.completion_refresher.PGExecute', pgexecute_class):
with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert (callbacks[0].call_count == 1)
assert callbacks[0].call_count == 1

View File

@ -5,6 +5,7 @@ import pytest
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
return pgcompleter.PGCompleter()
@ -22,8 +23,8 @@ def test_ranking_ignores_identifier_quotes(completer):
"""
text = 'user'
collection = ['user_action', '"user"']
text = "user"
collection = ["user_action", '"user"']
matches = completer.find_matches(text, collection)
assert len(matches) == 2
@ -42,18 +43,17 @@ def test_ranking_based_on_shortest_match(completer):
"""
text = 'user'
collection = ['api_user', 'user_group']
text = "user"
collection = ["api_user", "user_group"]
matches = completer.find_matches(text, collection)
assert matches[1].priority > matches[0].priority
@pytest.mark.parametrize('collection', [
['user_action', 'user'],
['user_group', 'user'],
['user_group', 'user_action'],
])
@pytest.mark.parametrize(
"collection",
[["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]],
)
def test_should_break_ties_using_lexical_order(completer, collection):
"""Fuzzy result rank should use lexical order to break ties.
@ -68,7 +68,7 @@ def test_should_break_ties_using_lexical_order(completer, collection):
"""
text = 'user'
text = "user"
matches = completer.find_matches(text, collection)
assert matches[1].priority > matches[0].priority
@ -81,8 +81,8 @@ def test_matching_should_be_case_insensitive(completer):
are still matched.
"""
text = 'foo'
collection = ['Foo', 'FOO', 'fOO']
text = "foo"
collection = ["Foo", "FOO", "fOO"]
matches = completer.find_matches(text, collection)
assert len(matches) == 3

View File

@ -5,24 +5,27 @@ import platform
import mock
import pytest
try:
import setproctitle
except ImportError:
setproctitle = None
from pgcli.main import (
obfuscate_process_password, format_output, PGCli, OutputSettings, COLOR_CODE_REGEX
obfuscate_process_password,
format_output,
PGCli,
OutputSettings,
COLOR_CODE_REGEX,
)
from pgcli.pgexecute import PGExecute
from pgspecial.main import (PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS)
from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS
from utils import dbtest, run
from collections import namedtuple
@pytest.mark.skipif(platform.system() == 'Windows',
reason='Not applicable in windows')
@pytest.mark.skipif(not setproctitle,
reason='setproctitle not available')
@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows")
@pytest.mark.skipif(not setproctitle, reason="setproctitle not available")
def test_obfuscate_process_password():
original_title = setproctitle.getproctitle()
@ -54,24 +57,25 @@ def test_obfuscate_process_password():
def test_format_output():
settings = OutputSettings(table_format='psql', dcmlfmt='d', floatfmt='g')
results = format_output('Title', [('abc', 'def')], ['head1', 'head2'],
'test status', settings)
settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g")
results = format_output(
"Title", [("abc", "def")], ["head1", "head2"], "test status", settings
)
expected = [
'Title',
'+---------+---------+',
'| head1 | head2 |',
'|---------+---------|',
'| abc | def |',
'+---------+---------+',
'test status'
"Title",
"+---------+---------+",
"| head1 | head2 |",
"|---------+---------|",
"| abc | def |",
"+---------+---------+",
"test status",
]
assert list(results) == expected
@dbtest
def test_format_array_output(executor):
statement = u"""
statement = """
SELECT
array[1, 2, 3]::bigint[] as bigint_array,
'{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
@ -81,20 +85,20 @@ def test_format_array_output(executor):
"""
results = run(executor, statement)
expected = [
'+----------------+------------------------+--------------+',
'| bigint_array | nested_numeric_array | 配列 |',
'|----------------+------------------------+--------------|',
'| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |',
'| {} | <null> | {<null>} |',
'+----------------+------------------------+--------------+',
'SELECT 2'
"+----------------+------------------------+--------------+",
"| bigint_array | nested_numeric_array | 配列 |",
"|----------------+------------------------+--------------|",
"| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |",
"| {} | <null> | {<null>} |",
"+----------------+------------------------+--------------+",
"SELECT 2",
]
assert list(results) == expected
@dbtest
def test_format_array_output_expanded(executor):
statement = u"""
statement = """
SELECT
array[1, 2, 3]::bigint[] as bigint_array,
'{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
@ -104,110 +108,111 @@ def test_format_array_output_expanded(executor):
"""
results = run(executor, statement, expanded=True)
expected = [
'-[ RECORD 1 ]-------------------------',
'bigint_array | {1,2,3}',
'nested_numeric_array | {{1,2},{3,4}}',
'配列 | {å,魚,текст}',
'-[ RECORD 2 ]-------------------------',
'bigint_array | {}',
'nested_numeric_array | <null>',
'配列 | {<null>}',
'SELECT 2'
"-[ RECORD 1 ]-------------------------",
"bigint_array | {1,2,3}",
"nested_numeric_array | {{1,2},{3,4}}",
"配列 | {å,魚,текст}",
"-[ RECORD 2 ]-------------------------",
"bigint_array | {}",
"nested_numeric_array | <null>",
"配列 | {<null>}",
"SELECT 2",
]
assert '\n'.join(results) == '\n'.join(expected)
assert "\n".join(results) == "\n".join(expected)
def test_format_output_auto_expand():
settings = OutputSettings(
table_format='psql', dcmlfmt='d', floatfmt='g', max_width=100)
table_results = format_output('Title', [('abc', 'def')],
['head1', 'head2'], 'test status', settings)
table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100
)
table_results = format_output(
"Title", [("abc", "def")], ["head1", "head2"], "test status", settings
)
table = [
'Title',
'+---------+---------+',
'| head1 | head2 |',
'|---------+---------|',
'| abc | def |',
'+---------+---------+',
'test status'
"Title",
"+---------+---------+",
"| head1 | head2 |",
"|---------+---------|",
"| abc | def |",
"+---------+---------+",
"test status",
]
assert list(table_results) == table
expanded_results = format_output(
'Title',
[('abc', 'def')],
['head1', 'head2'],
'test status',
settings._replace(max_width=1)
"Title",
[("abc", "def")],
["head1", "head2"],
"test status",
settings._replace(max_width=1),
)
expanded = [
'Title',
'-[ RECORD 1 ]-------------------------',
'head1 | abc',
'head2 | def',
'test status'
"Title",
"-[ RECORD 1 ]-------------------------",
"head1 | abc",
"head2 | def",
"test status",
]
assert '\n'.join(expanded_results) == '\n'.join(expanded)
assert "\n".join(expanded_results) == "\n".join(expanded)
termsize = namedtuple('termsize', ['rows', 'columns'])
test_line = '-' * 10
termsize = namedtuple("termsize", ["rows", "columns"])
test_line = "-" * 10
test_data = [
(10, 10, '\n'.join([test_line] * 7)),
(10, 10, '\n'.join([test_line] * 6)),
(10, 10, '\n'.join([test_line] * 5)),
(10, 10, '-' * 11),
(10, 10, '-' * 10),
(10, 10, '-' * 9),
(10, 10, "\n".join([test_line] * 7)),
(10, 10, "\n".join([test_line] * 6)),
(10, 10, "\n".join([test_line] * 5)),
(10, 10, "-" * 11),
(10, 10, "-" * 10),
(10, 10, "-" * 9),
]
# 4 lines are reserved at the bottom of the terminal for pgcli's prompt
use_pager_when_on = [True,
True,
False,
True,
False,
False]
use_pager_when_on = [True, True, False, True, False, False]
# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL
test_ids = ["Output longer than terminal height",
"Output equal to terminal height",
"Output shorter than terminal height",
"Output longer than terminal width",
"Output equal to terminal width",
"Output shorter than terminal width"]
test_ids = [
"Output longer than terminal height",
"Output equal to terminal height",
"Output shorter than terminal height",
"Output longer than terminal width",
"Output equal to terminal width",
"Output shorter than terminal width",
]
@pytest.fixture
def pset_pager_mocks():
cli = PGCli()
cli.watch_command = None
with mock.patch('pgcli.main.click.echo') as mock_echo, \
mock.patch('pgcli.main.click.echo_via_pager') as mock_echo_via_pager, \
mock.patch.object(cli, 'prompt_app') as mock_app:
with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch(
"pgcli.main.click.echo_via_pager"
) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app:
yield cli, mock_echo, mock_echo_via_pager, mock_app
@pytest.mark.parametrize('term_height,term_width,text', test_data, ids=test_ids)
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width)
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, 'pager_config', PAGER_OFF):
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF):
cli.echo_via_pager(text)
mock_echo.assert_called()
mock_echo_via_pager.assert_not_called()
@pytest.mark.parametrize('term_height,term_width,text', test_data, ids=test_ids)
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width)
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, 'pager_config', PAGER_ALWAYS):
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS):
cli.echo_via_pager(text)
mock_echo.assert_not_called()
@ -217,13 +222,16 @@ def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)]
@pytest.mark.parametrize('term_height,term_width,text,use_pager', pager_on_test_data, ids=test_ids)
@pytest.mark.parametrize(
"term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids
)
def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width)
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, 'pager_config', PAGER_LONG_OUTPUT):
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT):
cli.echo_via_pager(text)
if use_pager:
@ -234,24 +242,28 @@ def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mock
mock_echo.assert_called()
@pytest.mark.parametrize('text,expected_length', [
(u"22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", 78),
(u"=\u001b[m=", 2),
(u"-\u001b]23\u0007-", 2),
])
@pytest.mark.parametrize(
"text,expected_length",
[
(
"22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s",
78,
),
("=\u001b[m=", 2),
("-\u001b]23\u0007-", 2),
],
)
def test_color_pattern(text, expected_length, pset_pager_mocks):
cli = pset_pager_mocks[0]
assert len(COLOR_CODE_REGEX.sub('', text)) == expected_length
assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length
@dbtest
def test_i_works(tmpdir, executor):
sqlfile = tmpdir.join("test.sql")
sqlfile.write("SELECT NOW()")
rcfile = str(tmpdir.join("rcfile"))
cli = PGCli(
pgexecute=executor,
pgclirc_file=rcfile,
)
cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
statement = r"\i {0}".format(sqlfile)
run(executor, statement, pgspecial=cli.pgspecial)
@ -264,58 +276,62 @@ def test_missing_rc_dir(tmpdir):
def test_quoted_db_uri(tmpdir):
with mock.patch.object(PGCli, 'connect') as mock_connect:
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri('postgres://bar%5E:%5Dfoo@baz.com/testdb%5B')
mock_connect.assert_called_with(database='testdb[',
host='baz.com',
user='bar^',
passwd=']foo')
cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B")
mock_connect.assert_called_with(
database="testdb[", host="baz.com", user="bar^", passwd="]foo"
)
def test_ssl_db_uri(tmpdir):
with mock.patch.object(PGCli, 'connect') as mock_connect:
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
'postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?'
'sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem')
mock_connect.assert_called_with(database='testdb[',
host='baz.com',
user='bar^',
passwd=']foo',
sslmode='verify-full',
sslcert='my.pem',
sslkey='my-key.pem',
sslrootcert='ca.pem')
"postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?"
"sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem"
)
mock_connect.assert_called_with(
database="testdb[",
host="baz.com",
user="bar^",
passwd="]foo",
sslmode="verify-full",
sslcert="my.pem",
sslkey="my-key.pem",
sslrootcert="ca.pem",
)
def test_port_db_uri(tmpdir):
with mock.patch.object(PGCli, 'connect') as mock_connect:
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri('postgres://bar:foo@baz.com:2543/testdb')
mock_connect.assert_called_with(database='testdb',
host='baz.com',
user='bar',
passwd='foo',
port='2543')
cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb")
mock_connect.assert_called_with(
database="testdb", host="baz.com", user="bar", passwd="foo", port="2543"
)
def test_multihost_db_uri(tmpdir):
with mock.patch.object(PGCli, 'connect') as mock_connect:
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
'postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb')
mock_connect.assert_called_with(database='testdb',
host='baz1.com,baz2.com,baz3.com',
user='bar',
passwd='foo',
port='2543,2543,2543')
"postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb"
)
mock_connect.assert_called_with(
database="testdb",
host="baz1.com,baz2.com,baz3.com",
user="bar",
passwd="foo",
port="2543,2543,2543",
)
def test_application_name_db_uri(tmpdir):
with mock.patch.object(PGExecute, '__init__') as mock_pgexecute:
with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
mock_pgexecute.return_value = None
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri('postgres://bar@baz.com/?application_name=cow')
mock_pgexecute.assert_called_with('bar', 'bar', '', 'baz.com', '', '',
application_name='cow')
cli.connect_uri("postgres://bar@baz.com/?application_name=cow")
mock_pgexecute.assert_called_with(
"bar", "bar", "", "baz.com", "", "", application_name="cow"
)

View File

@ -8,77 +8,97 @@ from utils import completions_to_set
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
return pgcompleter.PGCompleter(smart_completion=False)
@pytest.fixture
def complete_event():
from mock import Mock
return Mock()
def test_empty_string_completion(completer, complete_event):
text = ''
text = ""
position = 0
result = completions_to_set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == completions_to_set(
map(Completion, completer.all_completions))
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(map(Completion, completer.all_completions))
def test_select_keyword_completion(completer, complete_event):
text = 'SEL'
position = len('SEL')
result = completions_to_set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == completions_to_set(
[Completion(text='SELECT', start_position=-3)])
text = "SEL"
position = len("SEL")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set([Completion(text="SELECT", start_position=-3)])
def test_function_name_completion(completer, complete_event):
text = 'SELECT MA'
position = len('SELECT MA')
result = completions_to_set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == completions_to_set([
Completion(text='MATERIALIZED VIEW', start_position=-2),
Completion(text='MAX', start_position=-2),
Completion(text='MAXEXTENTS', start_position=-2)])
text = "SELECT MA"
position = len("SELECT MA")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(
[
Completion(text="MATERIALIZED VIEW", start_position=-2),
Completion(text="MAX", start_position=-2),
Completion(text="MAXEXTENTS", start_position=-2),
]
)
def test_column_name_completion(completer, complete_event):
text = 'SELECT FROM users'
position = len('SELECT ')
result = completions_to_set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == completions_to_set(
map(Completion, completer.all_completions))
text = "SELECT FROM users"
position = len("SELECT ")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(map(Completion, completer.all_completions))
def test_alter_well_known_keywords_completion(completer, complete_event):
text = 'ALTER '
text = "ALTER "
position = len(text)
result = completions_to_set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event,
smart_completion=True))
assert result > completions_to_set([
Completion(text="DATABASE", display_meta='keyword'),
Completion(text="TABLE", display_meta='keyword'),
Completion(text="SYSTEM", display_meta='keyword'),
])
assert completions_to_set(
[Completion(text="CREATE", display_meta="keyword")]) not in result
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position),
complete_event,
smart_completion=True,
)
)
assert result > completions_to_set(
[
Completion(text="DATABASE", display_meta="keyword"),
Completion(text="TABLE", display_meta="keyword"),
Completion(text="SYSTEM", display_meta="keyword"),
]
)
assert (
completions_to_set([Completion(text="CREATE", display_meta="keyword")])
not in result
)
def test_special_name_completion(completer, complete_event):
text = '\\'
position = len('\\')
result = completions_to_set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
text = "\\"
position = len("\\")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
# Special commands will NOT be suggested during naive completion mode.
assert result == completions_to_set([])

View File

@ -14,54 +14,77 @@ from pgcli.packages.parseutils.meta import FunctionMetadata
def function_meta_data(
func_name, schema_name='public', arg_names=None, arg_types=None,
arg_modes=None, return_type=None, is_aggregate=False, is_window=False,
is_set_returning=False, is_extension=False, arg_defaults=None
func_name,
schema_name="public",
arg_names=None,
arg_types=None,
arg_modes=None,
return_type=None,
is_aggregate=False,
is_window=False,
is_set_returning=False,
is_extension=False,
arg_defaults=None,
):
return FunctionMetadata(
schema_name, func_name, arg_names, arg_types, arg_modes, return_type,
is_aggregate, is_window, is_set_returning, is_extension, arg_defaults
schema_name,
func_name,
arg_names,
arg_types,
arg_modes,
return_type,
is_aggregate,
is_window,
is_set_returning,
is_extension,
arg_defaults,
)
@dbtest
def test_conn(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc')''')
assert run(executor, '''select * from test''', join=True) == dedent("""\
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc')""")
assert run(executor, """select * from test""", join=True) == dedent(
"""\
+-----+
| a |
|-----|
| abc |
+-----+
SELECT 1""")
SELECT 1"""
)
@dbtest
def test_copy(executor):
executor_copy = executor.copy()
run(executor_copy, '''create table test(a text)''')
run(executor_copy, '''insert into test values('abc')''')
assert run(executor_copy, '''select * from test''', join=True) == dedent("""\
run(executor_copy, """create table test(a text)""")
run(executor_copy, """insert into test values('abc')""")
assert run(executor_copy, """select * from test""", join=True) == dedent(
"""\
+-----+
| a |
|-----|
| abc |
+-----+
SELECT 1""")
SELECT 1"""
)
@dbtest
def test_bools_are_treated_as_strings(executor):
run(executor, '''create table test(a boolean)''')
run(executor, '''insert into test values(True)''')
assert run(executor, '''select * from test''', join=True) == dedent("""\
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
assert run(executor, """select * from test""", join=True) == dedent(
"""\
+------+
| a |
|------|
| True |
+------+
SELECT 1""")
SELECT 1"""
)
@dbtest
@ -76,27 +99,31 @@ def test_schemata_table_views_and_columns_query(executor):
# schemata
# don't enforce all members of the schemas since they may include postgres
# temporary schemas
assert set(executor.schemata()) >= set([
'public', 'pg_catalog', 'information_schema', 'schema1', 'schema2'])
assert executor.search_path() == ['pg_catalog', 'public']
assert set(executor.schemata()) >= set(
["public", "pg_catalog", "information_schema", "schema1", "schema2"]
)
assert executor.search_path() == ["pg_catalog", "public"]
# tables
assert set(executor.tables()) >= set([
('public', 'a'), ('public', 'b'), ('schema1', 'c')])
assert set(executor.tables()) >= set(
[("public", "a"), ("public", "b"), ("schema1", "c")]
)
assert set(executor.table_columns()) >= set([
('public', 'a', 'x', 'text', False, None),
('public', 'a', 'y', 'text', False, None),
('public', 'b', 'z', 'text', False, None),
('schema1', 'c', 'w', 'text', True, "'meow'::text"),
])
assert set(executor.table_columns()) >= set(
[
("public", "a", "x", "text", False, None),
("public", "a", "y", "text", False, None),
("public", "b", "z", "text", False, None),
("schema1", "c", "w", "text", True, "'meow'::text"),
]
)
# views
assert set(executor.views()) >= set([
('public', 'd')])
assert set(executor.views()) >= set([("public", "d")])
assert set(executor.view_columns()) >= set([
('public', 'd', 'e', 'integer', False, None)])
assert set(executor.view_columns()) >= set(
[("public", "d", "e", "integer", False, None)]
)
@dbtest
@ -104,82 +131,95 @@ def test_foreign_key_query(executor):
run(executor, "create schema schema1")
run(executor, "create schema schema2")
run(executor, "create table schema1.parent(parentid int PRIMARY KEY)")
run(executor, "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)")
run(
executor,
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
)
assert set(executor.foreignkeys()) >= set([
('schema1', 'parent', 'parentid', 'schema2', 'child', 'motherid')])
assert set(executor.foreignkeys()) >= set(
[("schema1", "parent", "parentid", "schema2", "child", "motherid")]
)
@dbtest
def test_functions_query(executor):
run(executor, '''create function func1() returns int
language sql as $$select 1$$''')
run(executor, 'create schema schema1')
run(executor, '''create function schema1.func2() returns int
language sql as $$select 2$$''')
run(
executor,
"""create function func1() returns int
language sql as $$select 1$$""",
)
run(executor, "create schema schema1")
run(
executor,
"""create function schema1.func2() returns int
language sql as $$select 2$$""",
)
run(executor, '''create function func3()
run(
executor,
"""create function func3()
returns table(x int, y int) language sql
as $$select 1, 2 from generate_series(1,5)$$;''')
as $$select 1, 2 from generate_series(1,5)$$;""",
)
run(executor, '''create function func4(x int) returns setof int language sql
as $$select generate_series(1,5)$$;''')
run(
executor,
"""create function func4(x int) returns setof int language sql
as $$select generate_series(1,5)$$;""",
)
funcs = set(executor.functions())
assert funcs >= set([
function_meta_data(
func_name='func1',
return_type='integer'
),
function_meta_data(
func_name='func3',
arg_names=['x', 'y'],
arg_types=['integer', 'integer'],
arg_modes=['t', 't'],
return_type='record',
is_set_returning=True
),
function_meta_data(
schema_name='public',
func_name='func4',
arg_names=('x',),
arg_types=('integer',),
return_type='integer',
is_set_returning=True
),
function_meta_data(
schema_name='schema1',
func_name='func2',
return_type='integer'
),
])
assert funcs >= set(
[
function_meta_data(func_name="func1", return_type="integer"),
function_meta_data(
func_name="func3",
arg_names=["x", "y"],
arg_types=["integer", "integer"],
arg_modes=["t", "t"],
return_type="record",
is_set_returning=True,
),
function_meta_data(
schema_name="public",
func_name="func4",
arg_names=("x",),
arg_types=("integer",),
return_type="integer",
is_set_returning=True,
),
function_meta_data(
schema_name="schema1", func_name="func2", return_type="integer"
),
]
)
@dbtest
def test_datatypes_query(executor):
run(executor, 'create type foo AS (a int, b text)')
run(executor, "create type foo AS (a int, b text)")
types = list(executor.datatypes())
assert types == [('public', 'foo')]
assert types == [("public", "foo")]
@dbtest
def test_database_list(executor):
databases = executor.databases()
assert '_test_db' in databases
assert "_test_db" in databases
@dbtest
def test_invalid_syntax(executor, exception_formatter):
result = run(executor, 'invalid syntax!',
exception_formatter=exception_formatter)
result = run(executor, "invalid syntax!", exception_formatter=exception_formatter)
assert 'syntax error at or near "invalid"' in result[0]
@dbtest
def test_invalid_column_name(executor, exception_formatter):
result = run(executor, 'select invalid command',
exception_formatter=exception_formatter)
result = run(
executor, "select invalid command", exception_formatter=exception_formatter
)
assert 'column "invalid" does not exist' in result[0]
@ -194,14 +234,15 @@ def test_unicode_support_in_output(executor, expanded):
run(executor, u"insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
assert u'é' in run(executor, "select * from unicodechars",
join=True, expanded=expanded)
assert u"é" in run(
executor, "select * from unicodechars", join=True, expanded=expanded
)
@dbtest
def test_not_is_special(executor, pgspecial):
"""is_special is set to false for database queries."""
query = 'select 1'
query = "select 1"
result = list(executor.run(query, pgspecial=pgspecial))
success, is_special = result[0][5:]
assert success == True
@ -213,22 +254,22 @@ def test_execute_from_file_no_arg(executor, pgspecial):
"""\i without a filename returns an error."""
result = list(executor.run("\i", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert 'missing required argument' in status
assert "missing required argument" in status
assert success == False
assert is_special == True
@dbtest
@patch('pgcli.main.os')
@patch("pgcli.main.os")
def test_execute_from_file_io_error(os, executor, pgspecial):
"""\i with an io_error returns an error."""
# Inject an IOError.
os.path.expanduser.side_effect = IOError('test')
os.path.expanduser.side_effect = IOError("test")
# Check the result.
result = list(executor.run("\i test", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert status == 'test'
assert status == "test"
assert success == False
assert is_special == True
@ -252,9 +293,12 @@ def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
@dbtest
def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter):
result = run(executor, u"select 'fooé'; invalid syntax é",
exception_formatter=exception_formatter)
assert u'fooé' in result[3]
result = run(
executor,
u"select 'fooé'; invalid syntax é",
exception_formatter=exception_formatter,
)
assert u"fooé" in result[3]
assert 'syntax error at or near "invalid"' in result[-1]
@ -265,23 +309,22 @@ def pgspecial():
@dbtest
def test_special_command_help(executor, pgspecial):
result = run(executor, '\\?', pgspecial=pgspecial)[1].split('|')
assert u'Command' in result[1]
assert u'Description' in result[2]
result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|")
assert u"Command" in result[1]
assert u"Description" in result[2]
@dbtest
def test_bytea_field_support_in_output(executor):
run(executor, "create table binarydata(c bytea)")
run(executor,
"insert into binarydata (c) values (decode('DEADBEEF', 'hex'))")
run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))")
assert u'\\xdeadbeef' in run(executor, "select * from binarydata", join=True)
assert u"\\xdeadbeef" in run(executor, "select * from binarydata", join=True)
@dbtest
def test_unicode_support_in_unknown_type(executor):
assert u'日本語' in run(executor, u"SELECT '日本語' AS japanese;", join=True)
assert u"日本語" in run(executor, u"SELECT '日本語' AS japanese;", join=True)
@dbtest
@ -289,15 +332,16 @@ def test_unicode_support_in_enum_type(executor):
run(executor, u"CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')")
run(executor, u"CREATE TABLE person (name TEXT, current_mood mood)")
run(executor, u"INSERT INTO person VALUES ('Moe', '日本語')")
assert u'日本語' in run(executor, u"SELECT * FROM person", join=True)
assert u"日本語" in run(executor, u"SELECT * FROM person", join=True)
@requires_json
def test_json_renders_without_u_prefix(executor, expanded):
run(executor, u"create table jsontest(d json)")
run(executor, u"""insert into jsontest (d) values ('{"name": "Éowyn"}')""")
result = run(executor, u"SELECT d FROM jsontest LIMIT 1",
join=True, expanded=expanded)
result = run(
executor, u"SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded
)
assert u'{"name": "Éowyn"}' in result
@ -306,8 +350,9 @@ def test_json_renders_without_u_prefix(executor, expanded):
def test_jsonb_renders_without_u_prefix(executor, expanded):
run(executor, "create table jsonbtest(d jsonb)")
run(executor, u"""insert into jsonbtest (d) values ('{"name": "Éowyn"}')""")
result = run(executor, "SELECT d FROM jsonbtest LIMIT 1",
join=True, expanded=expanded)
result = run(
executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded
)
assert u'{"name": "Éowyn"}' in result
@ -315,45 +360,65 @@ def test_jsonb_renders_without_u_prefix(executor, expanded):
@dbtest
def test_date_time_types(executor):
run(executor, "SET TIME ZONE UTC")
assert run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] \
== "| 00:00:00 |"
assert run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split("\n")[3] \
assert (
run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3]
== "| 00:00:00 |"
)
assert (
run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split(
"\n"
)[3]
== "| 00:00:00+14:59 |"
assert run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[3] \
== "| 4713-01-01 BC |"
assert run(executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True).split("\n")[3] \
== "| 4713-01-01 00:00:00 BC |"
assert run(executor, "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))", join=True).split("\n")[3] \
== "| 4713-01-01 00:00:00+00 BC |"
assert run(executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True).split("\n")[3] \
== "| -123456789 days, 12:23:56 |"
)
assert (
run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[
3
]
== "| 4713-01-01 BC |"
)
assert (
run(
executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True
).split("\n")[3]
== "| 4713-01-01 00:00:00 BC |"
)
assert (
run(
executor,
"SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))",
join=True,
).split("\n")[3]
== "| 4713-01-01 00:00:00+00 BC |"
)
assert (
run(
executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True
).split("\n")[3]
== "| -123456789 days, 12:23:56 |"
)
@dbtest
@pytest.mark.parametrize('value', ['10000000', '10000000.0', '10000000000000'])
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
def test_large_numbers_render_directly(executor, value):
run(executor, "create table numbertest(a numeric)")
run(executor,
"insert into numbertest (a) values ({0})".format(value))
run(executor, "insert into numbertest (a) values ({0})".format(value))
assert value in run(executor, "select * from numbertest", join=True)
@dbtest
@pytest.mark.parametrize('command', ['di', 'dv', 'ds', 'df', 'dT'])
@pytest.mark.parametrize('verbose', ['', '+'])
@pytest.mark.parametrize('pattern', ['', 'x', '*.*', 'x.y', 'x.*', '*.y'])
@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"])
@pytest.mark.parametrize("verbose", ["", "+"])
@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"])
def test_describe_special(executor, command, verbose, pattern, pgspecial):
# We don't have any tests for the output of any of the special commands,
# but we can at least make sure they run without error
sql = r'\{command}{verbose} {pattern}'.format(**locals())
sql = r"\{command}{verbose} {pattern}".format(**locals())
list(executor.run(sql, pgspecial=pgspecial))
@dbtest
@pytest.mark.parametrize('sql', [
'invalid sql',
'SELECT 1; select error;',
])
@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"])
def test_raises_with_no_formatter(executor, sql):
with pytest.raises(psycopg2.ProgrammingError):
list(executor.run(sql))
@ -361,19 +426,24 @@ def test_raises_with_no_formatter(executor, sql):
@dbtest
def test_on_error_resume(executor, exception_formatter):
sql = 'select 1; error; select 1;'
result = list(executor.run(sql, on_error_resume=True,
exception_formatter=exception_formatter))
sql = "select 1; error; select 1;"
result = list(
executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter)
)
assert len(result) == 3
@dbtest
def test_on_error_stop(executor, exception_formatter):
sql = 'select 1; error; select 1;'
result = list(executor.run(sql, on_error_resume=False,
exception_formatter=exception_formatter))
sql = "select 1; error; select 1;"
result = list(
executor.run(
sql, on_error_resume=False, exception_formatter=exception_formatter
)
)
assert len(result) == 2
# @dbtest
# def test_unicode_notices(executor):
# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;"
@ -384,50 +454,55 @@ def test_on_error_stop(executor, exception_formatter):
@dbtest
def test_nonexistent_function_definition(executor):
with pytest.raises(RuntimeError):
result = executor.view_definition('there_is_no_such_function')
result = executor.view_definition("there_is_no_such_function")
@dbtest
def test_function_definition(executor):
run(executor, '''
run(
executor,
"""
CREATE OR REPLACE FUNCTION public.the_number_three()
RETURNS int
LANGUAGE sql
AS $function$
select 3;
$function$
''')
result = executor.function_definition('the_number_three')
""",
)
result = executor.function_definition("the_number_three")
@dbtest
def test_view_definition(executor):
run(executor, 'create table tbl1 (a text, b numeric)')
run(executor, 'create view vw1 AS SELECT * FROM tbl1')
run(executor, 'create materialized view mvw1 AS SELECT * FROM tbl1')
result = executor.view_definition('vw1')
assert 'FROM tbl1' in result
run(executor, "create table tbl1 (a text, b numeric)")
run(executor, "create view vw1 AS SELECT * FROM tbl1")
run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1")
result = executor.view_definition("vw1")
assert "FROM tbl1" in result
# import pytest; pytest.set_trace()
result = executor.view_definition('mvw1')
assert 'MATERIALIZED VIEW' in result
result = executor.view_definition("mvw1")
assert "MATERIALIZED VIEW" in result
@dbtest
def test_nonexistent_view_definition(executor):
with pytest.raises(RuntimeError):
result = executor.view_definition('there_is_no_such_view')
result = executor.view_definition("there_is_no_such_view")
with pytest.raises(RuntimeError):
result = executor.view_definition('mvw1')
result = executor.view_definition("mvw1")
@dbtest
def test_short_host(executor):
with patch.object(executor, 'host', 'localhost'):
assert executor.short_host == 'localhost'
with patch.object(executor, 'host', 'localhost.example.org'):
assert executor.short_host == 'localhost'
with patch.object(executor, 'host', 'localhost1.example.org,localhost2.example.org'):
assert executor.short_host == 'localhost1'
with patch.object(executor, "host", "localhost"):
assert executor.short_host == "localhost"
with patch.object(executor, "host", "localhost.example.org"):
assert executor.short_host == "localhost"
with patch.object(
executor, "host", "localhost1.example.org,localhost2.example.org"
):
assert executor.short_host == "localhost1"
class BrokenConnection(object):
@ -441,8 +516,15 @@ class BrokenConnection(object):
def test_exit_without_active_connection(executor):
quit_handler = MagicMock()
pgspecial = PGSpecial()
pgspecial.register(quit_handler, '\\q', '\\q', 'Quit pgcli.',
arg_type=NO_QUERY, case_sensitive=True, aliases=(':q',))
pgspecial.register(
quit_handler,
"\\q",
"\\q",
"Quit pgcli.",
arg_type=NO_QUERY,
case_sensitive=True,
aliases=(":q",),
)
with patch.object(executor, "conn", BrokenConnection()):
# we should be able to quit the app, even without active connection

View File

@ -1,92 +1,78 @@
import pytest
from pgcli.packages.sqlcompletion import (
suggest_type, Special, Database, Schema, Table, View, Function, Datatype)
suggest_type,
Special,
Database,
Schema,
Table,
View,
Function,
Datatype,
)
def test_slash_suggests_special():
suggestions = suggest_type('\\', '\\')
assert set(suggestions) == set(
[Special()])
suggestions = suggest_type("\\", "\\")
assert set(suggestions) == set([Special()])
def test_slash_d_suggests_special():
suggestions = suggest_type('\\d', '\\d')
assert set(suggestions) == set(
[Special()])
suggestions = suggest_type("\\d", "\\d")
assert set(suggestions) == set([Special()])
def test_dn_suggests_schemata():
suggestions = suggest_type('\\dn ', '\\dn ')
suggestions = suggest_type("\\dn ", "\\dn ")
assert suggestions == (Schema(),)
suggestions = suggest_type('\\dn xxx', '\\dn xxx')
suggestions = suggest_type("\\dn xxx", "\\dn xxx")
assert suggestions == (Schema(),)
def test_d_suggests_tables_views_and_schemas():
suggestions = suggest_type('\d ', '\d ')
assert set(suggestions) == set([
Schema(),
Table(schema=None),
View(schema=None),
])
suggestions = suggest_type("\d ", "\d ")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
suggestions = suggest_type('\d xxx', '\d xxx')
assert set(suggestions) == set([
Schema(),
Table(schema=None),
View(schema=None),
])
suggestions = suggest_type("\d xxx", "\d xxx")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
def test_d_dot_suggests_schema_qualified_tables_or_views():
suggestions = suggest_type('\d myschema.', '\d myschema.')
assert set(suggestions) == set([
Table(schema='myschema'),
View(schema='myschema'),
])
suggestions = suggest_type("\d myschema.", "\d myschema.")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
suggestions = suggest_type('\d myschema.xxx', '\d myschema.xxx')
assert set(suggestions) == set([
Table(schema='myschema'),
View(schema='myschema'),
])
suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
def test_df_suggests_schema_or_function():
suggestions = suggest_type('\\df xxx', '\\df xxx')
assert set(suggestions) == set([
Function(schema=None, usage='special'),
Schema(),
])
suggestions = suggest_type("\\df xxx", "\\df xxx")
assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()])
suggestions = suggest_type('\\df myschema.xxx', '\\df myschema.xxx')
assert suggestions == (Function(schema='myschema', usage='special'),)
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
assert suggestions == (Function(schema="myschema", usage="special"),)
def test_leading_whitespace_ok():
cmd = '\\dn '
whitespace = ' '
cmd = "\\dn "
whitespace = " "
suggestions = suggest_type(whitespace + cmd, whitespace + cmd)
assert suggestions == suggest_type(cmd, cmd)
def test_dT_suggests_schema_or_datatypes():
text = '\\dT '
text = "\\dT "
suggestions = suggest_type(text, text)
assert set(suggestions) == set([
Schema(),
Datatype(schema=None),
])
assert set(suggestions) == set([Schema(), Datatype(schema=None)])
def test_schema_qualified_dT_suggests_datatypes():
text = '\\dT foo.'
text = "\\dT foo."
suggestions = suggest_type(text, text)
assert suggestions == (Datatype(schema='foo'),)
assert suggestions == (Datatype(schema="foo"),)
@pytest.mark.parametrize('command', ['\\c ', '\\connect '])
@pytest.mark.parametrize("command", ["\\c ", "\\connect "])
def test_c_suggests_databases(command):
suggestions = suggest_type(command, command)
assert suggestions == (Database(),)

View File

@ -3,18 +3,18 @@ from pgcli.packages.prioritization import PrevalenceCounter
def test_prevalence_counter():
counter = PrevalenceCounter()
sql = '''SELECT * FROM foo WHERE bar GROUP BY baz;
sql = """SELECT * FROM foo WHERE bar GROUP BY baz;
select * from foo;
SELECT * FROM foo WHERE bar GROUP
BY baz'''
BY baz"""
counter.update(sql)
keywords = ['SELECT', 'FROM', 'GROUP BY']
keywords = ["SELECT", "FROM", "GROUP BY"]
expected = [3, 3, 2]
kw_counts = [counter.keyword_count(x) for x in keywords]
assert kw_counts == expected
assert counter.keyword_count('NOSUCHKEYWORD') == 0
assert counter.keyword_count("NOSUCHKEYWORD") == 0
names = ['foo', 'bar', 'baz']
names = ["foo", "bar", "baz"]
name_counts = [counter.name_count(x) for x in names]
assert name_counts == [3, 2, 2]

View File

@ -7,7 +7,7 @@ from pgcli.packages.prompt_utils import confirm_destructive_query
def test_confirm_destructive_query_notty():
stdin = click.get_text_stream('stdin')
stdin = click.get_text_stream("stdin")
if not stdin.isatty():
sql = 'drop database foo;'
sql = "drop database foo;"
assert confirm_destructive_query(sql) is None

View File

@ -6,6 +6,7 @@ import pytest
# We need this fixtures beacause we need PGCli object to be created
# after test collection so it has config loaded from temp directory
@pytest.fixture(scope="module")
def default_pgcli_obj():
return PGCli()
@ -24,9 +25,7 @@ def LIMIT(DEFAULT):
@pytest.fixture(scope="module")
def over_default(DEFAULT):
over_default_cursor = Mock()
over_default_cursor.configure_mock(
rowcount=DEFAULT + 10
)
over_default_cursor.configure_mock(rowcount=DEFAULT + 10)
return over_default_cursor

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -5,18 +5,20 @@ from pgcli.main import format_output, OutputSettings
from pgcli.pgexecute import register_json_typecasters
from os import getenv
POSTGRES_USER = getenv('PGUSER', 'postgres')
POSTGRES_HOST = getenv('PGHOST', 'localhost')
POSTGRES_PORT = getenv('PGPORT', 5432)
POSTGRES_PASSWORD = getenv('PGPASSWORD', '')
POSTGRES_USER = getenv("PGUSER", "postgres")
POSTGRES_HOST = getenv("PGHOST", "localhost")
POSTGRES_PORT = getenv("PGPORT", 5432)
POSTGRES_PASSWORD = getenv("PGPASSWORD", "")
def db_connection(dbname=None):
conn = psycopg2.connect(user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
database=dbname)
conn = psycopg2.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
database=dbname,
)
conn.autocommit = True
return conn
@ -26,8 +28,8 @@ try:
CAN_CONNECT_TO_DB = True
SERVER_VERSION = conn.server_version
json_types = register_json_typecasters(conn, lambda x: x)
JSON_AVAILABLE = 'json' in json_types
JSONB_AVAILABLE = 'jsonb' in json_types
JSON_AVAILABLE = "json" in json_types
JSONB_AVAILABLE = "jsonb" in json_types
except:
CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
SERVER_VERSION = 0
@ -35,51 +37,59 @@ except:
dbtest = pytest.mark.skipif(
not CAN_CONNECT_TO_DB,
reason="Need a postgres instance at localhost accessible by user 'postgres'")
reason="Need a postgres instance at localhost accessible by user 'postgres'",
)
requires_json = pytest.mark.skipif(
not JSON_AVAILABLE,
reason='Postgres server unavailable or json type not defined')
not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
)
requires_jsonb = pytest.mark.skipif(
not JSONB_AVAILABLE,
reason='Postgres server unavailable or jsonb type not defined')
not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
)
def create_db(dbname):
with db_connection().cursor() as cur:
try:
cur.execute('''CREATE DATABASE _test_db''')
cur.execute("""CREATE DATABASE _test_db""")
except:
pass
def drop_tables(conn):
with conn.cursor() as cur:
cur.execute('''
cur.execute(
"""
DROP SCHEMA public CASCADE;
CREATE SCHEMA public;
DROP SCHEMA IF EXISTS schema1 CASCADE;
DROP SCHEMA IF EXISTS schema2 CASCADE''')
DROP SCHEMA IF EXISTS schema2 CASCADE"""
)
def run(executor, sql, join=False, expanded=False, pgspecial=None,
exception_formatter=None):
def run(
executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
):
" Return string output for the sql to be run "
results = executor.run(sql, pgspecial, exception_formatter)
formatted = []
settings = OutputSettings(table_format='psql', dcmlfmt='d', floatfmt='g',
expanded=expanded)
settings = OutputSettings(
table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
)
for title, rows, headers, status, sql, success, is_special in results:
formatted.extend(format_output(title, rows, headers, status, settings))
if join:
formatted = '\n'.join(formatted)
formatted = "\n".join(formatted)
return formatted
def completions_to_set(completions):
return set((completion.display_text, completion.display_meta_text) for completion in completions)
return set(
(completion.display_text, completion.display_meta_text)
for completion in completions
)