From 8cb7009bcd0f0062942932c853706a36178f566c Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Sat, 25 May 2019 13:08:56 -0700 Subject: [PATCH] black all the things. (#1049) * added black to develop guide * no need for pep8radius. * changelog. * Add pre-commit checkbox. * Add pre-commit to dev reqs. * Add pyproject.toml for black. * Pre-commit config. * Add black to travis and dev reqs. * Install and run black in travis. * Remove black from dev reqs. * Lower black target version. * Re-format with black. --- .github/PULL_REQUEST_TEMPLATE.md | 2 + .pre-commit-config.yaml | 7 + .travis.yml | 4 +- DEVELOP.rst | 16 +- changelog.rst | 1 + pgcli/__init__.py | 2 +- pgcli/completion_refresher.py | 54 +- pgcli/config.py | 23 +- pgcli/encodingutils.py | 4 +- pgcli/key_bindings.py | 32 +- pgcli/magic.py | 26 +- pgcli/main.py | 1011 ++++++++------ pgcli/packages/parseutils/__init__.py | 5 +- pgcli/packages/parseutils/ctes.py | 20 +- pgcli/packages/parseutils/meta.py | 101 +- pgcli/packages/parseutils/tables.py | 60 +- pgcli/packages/parseutils/utils.py | 35 +- pgcli/packages/pgliterals/main.py | 2 +- pgcli/packages/prioritization.py | 6 +- pgcli/packages/prompt_utils.py | 5 +- pgcli/packages/sqlcompletion.py | 281 ++-- pgcli/pgbuffer.py | 21 +- pgcli/pgcompleter.py | 534 +++---- pgcli/pgexecute.py | 254 ++-- pgcli/pgstyle.py | 78 +- pgcli/pgtoolbar.py | 47 +- pyproject.toml | 22 + release.py | 72 +- requirements-dev.txt | 3 +- setup.py | 88 +- tests/conftest.py | 27 +- tests/features/db_utils.py | 26 +- tests/features/environment.py | 166 +-- tests/features/fixture_utils.py | 8 +- tests/features/steps/auto_vertical.py | 37 +- tests/features/steps/basic_commands.py | 88 +- tests/features/steps/crud_database.py | 50 +- tests/features/steps/crud_table.py | 55 +- tests/features/steps/expanded.py | 48 +- tests/features/steps/iocommands.py | 59 +- tests/features/steps/named_queries.py | 26 +- tests/features/steps/specials.py | 9 +- tests/features/steps/wrappers.py | 30 +- tests/features/wrappager.py | 4 +- tests/metadata.py | 176 +-- tests/parseutils/test_ctes.py | 124 +- tests/parseutils/test_function_metadata.py | 9 +- tests/parseutils/test_parseutils.py | 240 ++-- tests/test_completion_refresher.py | 27 +- tests/test_fuzzy_completion.py | 24 +- tests/test_main.py | 282 ++-- tests/test_naive_completion.py | 112 +- tests/test_pgexecute.py | 382 ++++-- tests/test_pgspecial.py | 82 +- tests/test_prioritization.py | 10 +- tests/test_prompt_utils.py | 4 +- tests/test_rowlimit.py | 5 +- ...test_smart_completion_multiple_schemata.py | 845 +++++++----- ...est_smart_completion_public_schema_only.py | 1221 +++++++++-------- tests/test_sqlcompletion.py | 1177 ++++++++-------- tests/utils.py | 60 +- 61 files changed, 4540 insertions(+), 3689 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d02bf3b6..f655bc70 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,3 +7,5 @@ - [ ] I've added this contribution to the `changelog.rst`. - [ ] I've added my name to the `AUTHORS` file (or it's already there). + +- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..52ea4ebf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: +- repo: https://github.com/ambv/black + rev: stable + hooks: + - id: black + language_version: python3.6 + diff --git a/.travis.yml b/.travis.yml index 2bb738e5..bc0e671d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,8 +29,8 @@ script: - cd .. # check for changelog ReST compliance - rst2html.py --halt=warning changelog.rst >/dev/null - # check for pep8 errors, only looking at branch vs master. If there are errors, show diff and return an error code. - - pep8radius master --docformatter --diff --error-status + # check for black code compliance, 3.6 only + - if [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then pip install black && black --check . ; else echo "Skipping black for $TRAVIS_PYTHON_VERSION"; fi - set +e after_success: diff --git a/DEVELOP.rst b/DEVELOP.rst index 005c9f3f..774dd8d0 100644 --- a/DEVELOP.rst +++ b/DEVELOP.rst @@ -172,17 +172,7 @@ Troubleshooting the integration tests - Check `this issue `_ for relevant information. - Contact us on `gitter `_ or `file an issue `_. -PEP8 checks (lint) ------------------- +Coding Style +------------ -When you submit a PR, the changeset is checked for pep8 compliance using -`pep8radius `_. If you see a build failing because -of these checks, install pep8radius and apply style fixes: - -:: - - $ pip install pep8radius - $ pep8radius master --docformatter --diff # view a diff of proposed fixes - $ pep8radius master --docformatter --in-place # apply the fixes - -Then commit and push the fixes. +``pgcli`` uses `black `_ to format the source code. Make sure to install black. diff --git a/changelog.rst b/changelog.rst index ed69bcb9..1fb0b90a 100644 --- a/changelog.rst +++ b/changelog.rst @@ -16,6 +16,7 @@ Internal: --------- * Add python 3.7 to travis build matrix. (Thanks: `Irina Truong`_) +* Apply `black` to code. (Thanks: `Irina Truong`_) 2.1.0 ===== diff --git a/pgcli/__init__.py b/pgcli/__init__.py index a33997dd..9aa3f903 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = '2.1.0' +__version__ = "2.1.0" diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index 388bb295..a70d2f2b 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -14,8 +14,7 @@ class CompletionRefresher(object): self._completer_thread = None self._restart_refresh = threading.Event() - def refresh(self, executor, special, callbacks, history=None, - settings=None): + def refresh(self, executor, special, callbacks, history=None, settings=None): """ Creates a PGCompleter object and populates it with the relevant completion suggestions in a background thread. @@ -30,27 +29,29 @@ class CompletionRefresher(object): """ if self.is_refreshing(): self._restart_refresh.set() - return [(None, None, None, 'Auto-completion refresh restarted.')] + return [(None, None, None, "Auto-completion refresh restarted.")] else: self._completer_thread = threading.Thread( target=self._bg_refresh, args=(executor, special, callbacks, history, settings), - name='completion_refresh') + name="completion_refresh", + ) self._completer_thread.setDaemon(True) self._completer_thread.start() - return [(None, None, None, - 'Auto-completion refresh started in the background.')] + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] def is_refreshing(self): return self._completer_thread and self._completer_thread.is_alive() - def _bg_refresh(self, pgexecute, special, callbacks, history=None, - settings=None): + def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None): settings = settings or {} - completer = PGCompleter(smart_completion=True, pgspecial=special, - settings=settings) + completer = PGCompleter( + smart_completion=True, pgspecial=special, settings=settings + ) - if settings.get('single_connection'): + if settings.get("single_connection"): executor = pgexecute else: # Create a new pgexecute method to populate the completions. @@ -88,55 +89,58 @@ def refresher(name, refreshers=CompletionRefresher.refreshers): """Decorator to populate the dictionary of refreshers with the current function. """ + def wrapper(wrapped): refreshers[name] = wrapped return wrapped + return wrapper -@refresher('schemata') +@refresher("schemata") def refresh_schemata(completer, executor): completer.set_search_path(executor.search_path()) completer.extend_schemata(executor.schemata()) -@refresher('tables') +@refresher("tables") def refresh_tables(completer, executor): - completer.extend_relations(executor.tables(), kind='tables') - completer.extend_columns(executor.table_columns(), kind='tables') + completer.extend_relations(executor.tables(), kind="tables") + completer.extend_columns(executor.table_columns(), kind="tables") completer.extend_foreignkeys(executor.foreignkeys()) -@refresher('views') +@refresher("views") def refresh_views(completer, executor): - completer.extend_relations(executor.views(), kind='views') - completer.extend_columns(executor.view_columns(), kind='views') + completer.extend_relations(executor.views(), kind="views") + completer.extend_columns(executor.view_columns(), kind="views") -@refresher('types') + +@refresher("types") def refresh_types(completer, executor): completer.extend_datatypes(executor.datatypes()) -@refresher('databases') +@refresher("databases") def refresh_databases(completer, executor): completer.extend_database_names(executor.databases()) -@refresher('casing') +@refresher("casing") def refresh_casing(completer, executor): casing_file = completer.casing_file if not casing_file: return generate_casing_file = completer.generate_casing_file if generate_casing_file and not os.path.isfile(casing_file): - casing_prefs = '\n'.join(executor.casing()) - with open(casing_file, 'w') as f: + casing_prefs = "\n".join(executor.casing()) + with open(casing_file, "w") as f: f.write(casing_prefs) if os.path.isfile(casing_file): - with open(casing_file, 'r') as f: + with open(casing_file, "r") as f: completer.extend_casing([line.strip() for line in f]) -@refresher('functions') +@refresher("functions") def refresh_functions(completer, executor): completer.extend_functions(executor.functions()) diff --git a/pgcli/config.py b/pgcli/config.py index 1303eed2..a2c0b0b2 100644 --- a/pgcli/config.py +++ b/pgcli/config.py @@ -7,18 +7,18 @@ from configobj import ConfigObj def config_location(): - if 'XDG_CONFIG_HOME' in os.environ: - return '%s/pgcli/' % expanduser(os.environ['XDG_CONFIG_HOME']) - elif platform.system() == 'Windows': - return os.getenv('USERPROFILE') + '\\AppData\\Local\\dbcli\\pgcli\\' + if "XDG_CONFIG_HOME" in os.environ: + return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) + elif platform.system() == "Windows": + return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\" else: - return expanduser('~/.config/pgcli/') + return expanduser("~/.config/pgcli/") def load_config(usr_cfg, def_cfg=None): cfg = ConfigObj() cfg.merge(ConfigObj(def_cfg, interpolation=False)) - cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding='utf-8')) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) cfg.filename = expanduser(usr_cfg) return cfg @@ -51,18 +51,19 @@ def upgrade_config(config, def_config): def get_config(pgclirc_file=None): from pgcli import __file__ as package_root + package_root = os.path.dirname(package_root) - pgclirc_file = pgclirc_file or '%sconfig' % config_location() + pgclirc_file = pgclirc_file or "%sconfig" % config_location() - default_config = os.path.join(package_root, 'pgclirc') + default_config = os.path.join(package_root, "pgclirc") write_default_config(default_config, pgclirc_file) return load_config(pgclirc_file, default_config) def get_casing_file(config): - casing_file = config['main']['casing_file'] - if casing_file == 'default': - casing_file = config_location() + 'casing' + casing_file = config["main"]["casing_file"] + if casing_file == "default": + casing_file = config_location() + "casing" return casing_file diff --git a/pgcli/encodingutils.py b/pgcli/encodingutils.py index 5de97b7e..279da3a8 100644 --- a/pgcli/encodingutils.py +++ b/pgcli/encodingutils.py @@ -13,7 +13,7 @@ def unicode2utf8(arg): """ if PY2 and isinstance(arg, unicode): - return arg.encode('utf-8') + return arg.encode("utf-8") return arg @@ -24,5 +24,5 @@ def utf8tounicode(arg): """ if PY2 and isinstance(arg, str): - return arg.decode('utf-8') + return arg.decode("utf-8") return arg diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py index d7750432..cc078977 100644 --- a/pgcli/key_bindings.py +++ b/pgcli/key_bindings.py @@ -12,32 +12,32 @@ def pgcli_bindings(pgcli): """Custom key bindings for pgcli.""" kb = KeyBindings() - tab_insert_text = ' ' * 4 + tab_insert_text = " " * 4 - @kb.add('f2') + @kb.add("f2") def _(event): """Enable/Disable SmartCompletion Mode.""" - _logger.debug('Detected F2 key.') + _logger.debug("Detected F2 key.") pgcli.completer.smart_completion = not pgcli.completer.smart_completion - @kb.add('f3') + @kb.add("f3") def _(event): """Enable/Disable Multiline Mode.""" - _logger.debug('Detected F3 key.') + _logger.debug("Detected F3 key.") pgcli.multi_line = not pgcli.multi_line - @kb.add('f4') + @kb.add("f4") def _(event): """Toggle between Vi and Emacs mode.""" - _logger.debug('Detected F4 key.') + _logger.debug("Detected F4 key.") pgcli.vi_mode = not pgcli.vi_mode event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS - @kb.add('tab') + @kb.add("tab") def _(event): """Force autocompletion at cursor on non-empty lines.""" - _logger.debug('Detected key.') + _logger.debug("Detected key.") buff = event.app.current_buffer doc = buff.document @@ -50,17 +50,15 @@ def pgcli_bindings(pgcli): else: buff.insert_text(tab_insert_text, fire_event=False) - @kb.add('escape', filter=has_completions) + @kb.add("escape", filter=has_completions) def _(event): """Force closing of autocompletion.""" - _logger.debug('Detected key.') + _logger.debug("Detected key.") event.current_buffer.complete_state = None event.app.current_buffer.complete_state = None - - - @kb.add('c-space') + @kb.add("c-space") def _(event): """ Initialize autocompletion at cursor. @@ -70,7 +68,7 @@ def pgcli_bindings(pgcli): If the menu is showing, select the next completion. """ - _logger.debug('Detected key.') + _logger.debug("Detected key.") b = event.app.current_buffer if b.complete_state: @@ -78,7 +76,7 @@ def pgcli_bindings(pgcli): else: b.start_completion(select_first=False) - @kb.add('enter', filter=completion_is_selected) + @kb.add("enter", filter=completion_is_selected) def _(event): """Makes the enter key work as the tab key only when showing the menu. @@ -87,7 +85,7 @@ def pgcli_bindings(pgcli): (accept current selection). """ - _logger.debug('Detected enter key.') + _logger.debug("Detected enter key.") event.current_buffer.complete_state = None event.app.current_buffer.complete_state = None diff --git a/pgcli/magic.py b/pgcli/magic.py index 899fe733..3e9a4021 100644 --- a/pgcli/magic.py +++ b/pgcli/magic.py @@ -10,40 +10,40 @@ def load_ipython_extension(ipython): """This is called via the ipython command '%load_ext pgcli.magic'""" # first, load the sql magic if it isn't already loaded - if not ipython.find_line_magic('sql'): - ipython.run_line_magic('load_ext', 'sql') + if not ipython.find_line_magic("sql"): + ipython.run_line_magic("load_ext", "sql") # register our own magic - ipython.register_magic_function(pgcli_line_magic, 'line', 'pgcli') + ipython.register_magic_function(pgcli_line_magic, "line", "pgcli") def pgcli_line_magic(line): - _logger.debug('pgcli magic called: %r', line) + _logger.debug("pgcli magic called: %r", line) parsed = sql.parse.parse(line, {}) # "get" was renamed to "set" in ipython-sql: # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43 - if hasattr(sql.connection.Connection, 'get'): - conn = sql.connection.Connection.get(parsed['connection']) + if hasattr(sql.connection.Connection, "get"): + conn = sql.connection.Connection.get(parsed["connection"]) else: - conn = sql.connection.Connection.set(parsed['connection']) + conn = sql.connection.Connection.set(parsed["connection"]) try: # A corresponding pgcli object already exists pgcli = conn._pgcli - _logger.debug('Reusing existing pgcli') + _logger.debug("Reusing existing pgcli") except AttributeError: # I can't figure out how to get the underylying psycopg2 connection # from the sqlalchemy connection, so just grab the url and make a # new connection pgcli = PGCli() u = conn.session.engine.url - _logger.debug('New pgcli: %r', str(u)) + _logger.debug("New pgcli: %r", str(u)) pgcli.connect(u.database, u.host, u.username, u.port, u.password) conn._pgcli = pgcli # For convenience, print the connection alias - print('Connected: {}'.format(conn.name)) + print ("Connected: {}".format(conn.name)) try: pgcli.run_cli() @@ -56,12 +56,12 @@ def pgcli_line_magic(line): q = pgcli.query_history[-1] if not q.successful: - _logger.debug('Unsuccessful query - ignoring') + _logger.debug("Unsuccessful query - ignoring") return if q.meta_changed or q.db_changed or q.path_changed: - _logger.debug('Dangerous query detected -- ignoring') + _logger.debug("Dangerous query detected -- ignoring") return ipython = get_ipython() - return ipython.run_cell_magic('sql', line, q.query) + return ipython.run_cell_magic("sql", line, q.query) diff --git a/pgcli/main.py b/pgcli/main.py index 2a000cd9..fdbbc15c 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -5,7 +5,7 @@ import warnings from pgspecial.namedqueries import NamedQueries -warnings.filterwarnings("ignore", category=UserWarning, module='psycopg2') +warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") import os import re @@ -20,12 +20,13 @@ import datetime as dt import itertools from time import time, sleep from codecs import open + keyring = None # keyring will be loaded later from cli_helpers.tabular_output import TabularOutputFormatter -from cli_helpers.tabular_output.preprocessors import (align_decimals, - format_numbers) +from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers import click + try: import setproctitle except ImportError: @@ -36,14 +37,16 @@ from prompt_toolkit.shortcuts import PromptSession, CompleteStyle from prompt_toolkit.document import Document from prompt_toolkit.filters import HasFocus, IsDone from prompt_toolkit.lexers import PygmentsLexer -from prompt_toolkit.layout.processors import (ConditionalProcessor, - HighlightMatchingBracketProcessor, - TabsProcessor) +from prompt_toolkit.layout.processors import ( + ConditionalProcessor, + HighlightMatchingBracketProcessor, + TabsProcessor, +) from prompt_toolkit.history import FileHistory from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from pygments.lexers.sql import PostgresLexer -from pgspecial.main import (PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT) +from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT import pgspecial as special from .pgcompleter import PGCompleter @@ -52,8 +55,13 @@ from .pgstyle import style_factory, style_factory_output from .pgexecute import PGExecute from .pgbuffer import pg_is_multiline from .completion_refresher import CompletionRefresher -from .config import (get_casing_file, - load_config, config_location, ensure_dir_exists, get_config) +from .config import ( + get_casing_file, + load_config, + config_location, + ensure_dir_exists, + get_config, +) from .key_bindings import pgcli_bindings from .encodingutils import utf8tounicode from .encodingutils import text_type @@ -76,30 +84,38 @@ from collections import namedtuple from textwrap import dedent # Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output -COLOR_CODE_REGEX = re.compile(r'\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))') +COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") # Query tuples are used for maintaining history MetaQuery = namedtuple( - 'Query', + "Query", [ - 'query', # The entire text of the command - 'successful', # True If all subqueries were successful - 'total_time', # Time elapsed executing the query and formatting results - 'execution_time', # Time elapsed executing the query - 'meta_changed', # True if any subquery executed create/alter/drop - 'db_changed', # True if any subquery changed the database - 'path_changed', # True if any subquery changed the search path - 'mutated', # True if any subquery executed insert/update/delete - 'is_special', # True if the query is a special command - ]) -MetaQuery.__new__.__defaults__ = ('', False, 0, 0, False, False, False, False) + "query", # The entire text of the command + "successful", # True If all subqueries were successful + "total_time", # Time elapsed executing the query and formatting results + "execution_time", # Time elapsed executing the query + "meta_changed", # True if any subquery executed create/alter/drop + "db_changed", # True if any subquery changed the database + "path_changed", # True if any subquery changed the search path + "mutated", # True if any subquery executed insert/update/delete + "is_special", # True if the query is a special command + ], +) +MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False) OutputSettings = namedtuple( - 'OutputSettings', - 'table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output' + "OutputSettings", + "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output", ) OutputSettings.__new__.__defaults__ = ( - None, None, None, '', False, None, lambda x: x, None + None, + None, + None, + "", + False, + None, + lambda x: x, + None, ) @@ -109,34 +125,49 @@ class PgCliQuitError(Exception): class PGCli(object): - default_prompt = '\\u@\\h:\\d> ' + default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 def set_default_pager(self, config): - configured_pager = config['main'].get('pager') - os_environ_pager = os.environ.get('PAGER') + configured_pager = config["main"].get("pager") + os_environ_pager = os.environ.get("PAGER") if configured_pager: self.logger.info( - 'Default pager found in config file: "{}"'.format(configured_pager)) - os.environ['PAGER'] = configured_pager + 'Default pager found in config file: "{}"'.format(configured_pager) + ) + os.environ["PAGER"] = configured_pager elif os_environ_pager: - self.logger.info('Default pager found in PAGER environment variable: "{}"'.format( - os_environ_pager)) - os.environ['PAGER'] = os_environ_pager + self.logger.info( + 'Default pager found in PAGER environment variable: "{}"'.format( + os_environ_pager + ) + ) + os.environ["PAGER"] = os_environ_pager else: self.logger.info( - 'No default pager found in environment. Using os default pager') + "No default pager found in environment. Using os default pager" + ) # Set default set of less recommended options, if they are not already set. # They are ignored if pager is different than less. - if not os.environ.get('LESS'): - os.environ['LESS'] = '-SRXF' + if not os.environ.get("LESS"): + os.environ["LESS"] = "-SRXF" - def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False, - pgexecute=None, pgclirc_file=None, row_limit=None, - single_connection=False, less_chatty=None, prompt=None, prompt_dsn=None, - auto_vertical_output=False, warn=None): + def __init__( + self, + force_passwd_prompt=False, + never_passwd_prompt=False, + pgexecute=None, + pgclirc_file=None, + row_limit=None, + single_connection=False, + less_chatty=None, + prompt=None, + prompt_dsn=None, + auto_vertical_output=False, + warn=None, + ): self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt @@ -156,40 +187,43 @@ class PGCli(object): self.output_file = None self.pgspecial = PGSpecial() - self.multi_line = c['main'].as_bool('multi_line') - self.multiline_mode = c['main'].get('multi_line_mode', 'psql') - self.vi_mode = c['main'].as_bool('vi') - self.auto_expand = auto_vertical_output or c['main'].as_bool( - 'auto_expand') - self.expanded_output = c['main'].as_bool('expand') - self.pgspecial.timing_enabled = c['main'].as_bool('timing') + self.multi_line = c["main"].as_bool("multi_line") + self.multiline_mode = c["main"].get("multi_line_mode", "psql") + self.vi_mode = c["main"].as_bool("vi") + self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.expanded_output = c["main"].as_bool("expand") + self.pgspecial.timing_enabled = c["main"].as_bool("timing") if row_limit is not None: self.row_limit = row_limit else: - self.row_limit = c['main'].as_int('row_limit') + self.row_limit = c["main"].as_int("row_limit") - self.min_num_menu_lines = c['main'].as_int('min_num_menu_lines') - self.multiline_continuation_char = c['main']['multiline_continuation_char'] - self.table_format = c['main']['table_format'] - self.syntax_style = c['main']['syntax_style'] - self.cli_style = c['colors'] - self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') - c_dest_warning = c['main'].as_bool('destructive_warning') + self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") + self.multiline_continuation_char = c["main"]["multiline_continuation_char"] + self.table_format = c["main"]["table_format"] + self.syntax_style = c["main"]["syntax_style"] + self.cli_style = c["colors"] + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + c_dest_warning = c["main"].as_bool("destructive_warning") self.destructive_warning = c_dest_warning if warn is None else warn - self.less_chatty = bool(less_chatty) or c['main'].as_bool('less_chatty') - self.null_string = c['main'].get('null_string', '') - self.prompt_format = prompt if prompt is not None else c['main'].get('prompt', self.default_prompt) + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.null_string = c["main"].get("null_string", "") + self.prompt_format = ( + prompt + if prompt is not None + else c["main"].get("prompt", self.default_prompt) + ) self.prompt_dsn_format = prompt_dsn - self.on_error = c['main']['on_error'].upper() - self.decimal_format = c['data_formats']['decimal'] - self.float_format = c['data_formats']['float'] + self.on_error = c["main"]["on_error"].upper() + self.decimal_format = c["data_formats"]["decimal"] + self.float_format = c["data_formats"]["float"] self.initialize_keyring() - self.pgspecial.pset_pager(self.config['main'].as_bool( - 'enable_pager') and "on" or "off") + self.pgspecial.pset_pager( + self.config["main"].as_bool("enable_pager") and "on" or "off" + ) - self.style_output = style_factory_output( - self.syntax_style, c['colors']) + self.style_output = style_factory_output(self.syntax_style, c["colors"]) self.now = dt.datetime.today() @@ -198,23 +232,24 @@ class PGCli(object): self.query_history = [] # Initialize completer - smart_completion = c['main'].as_bool('smart_completion') - keyword_casing = c['main']['keyword_casing'] + smart_completion = c["main"].as_bool("smart_completion") + keyword_casing = c["main"]["keyword_casing"] self.settings = { - 'casing_file': get_casing_file(c), - 'generate_casing_file': c['main'].as_bool('generate_casing_file'), - 'generate_aliases': c['main'].as_bool('generate_aliases'), - 'asterisk_column_order': c['main']['asterisk_column_order'], - 'qualify_columns': c['main']['qualify_columns'], - 'case_column_headers': c['main'].as_bool('case_column_headers'), - 'search_path_filter': c['main'].as_bool('search_path_filter'), - 'single_connection': single_connection, - 'less_chatty': less_chatty, - 'keyword_casing': keyword_casing, + "casing_file": get_casing_file(c), + "generate_casing_file": c["main"].as_bool("generate_casing_file"), + "generate_aliases": c["main"].as_bool("generate_aliases"), + "asterisk_column_order": c["main"]["asterisk_column_order"], + "qualify_columns": c["main"]["qualify_columns"], + "case_column_headers": c["main"].as_bool("case_column_headers"), + "search_path_filter": c["main"].as_bool("search_path_filter"), + "single_connection": single_connection, + "less_chatty": less_chatty, + "keyword_casing": keyword_casing, } - completer = PGCompleter(smart_completion, pgspecial=self.pgspecial, - settings=self.settings) + completer = PGCompleter( + smart_completion, pgspecial=self.pgspecial, settings=self.settings + ) self.completer = completer self._completer_lock = threading.Lock() self.register_special_commands() @@ -227,57 +262,93 @@ class PGCli(object): def register_special_commands(self): self.pgspecial.register( - self.change_db, '\\c', '\\c[onnect] database_name', - 'Change to a new database.', aliases=('use', '\\connect', 'USE')) + self.change_db, + "\\c", + "\\c[onnect] database_name", + "Change to a new database.", + aliases=("use", "\\connect", "USE"), + ) - refresh_callback = lambda: self.refresh_completions( - persist_priorities='all') + refresh_callback = lambda: self.refresh_completions(persist_priorities="all") - self.pgspecial.register(self.quit, '\\q', '\\q', - 'Quit pgcli.', arg_type=NO_QUERY, case_sensitive=True, - aliases=(':q',)) - self.pgspecial.register(self.quit, 'quit', 'quit', - 'Quit pgcli.', arg_type=NO_QUERY, case_sensitive=False, - aliases=('exit',)) - self.pgspecial.register(refresh_callback, '\\#', '\\#', - 'Refresh auto-completions.', arg_type=NO_QUERY) - self.pgspecial.register(refresh_callback, '\\refresh', '\\refresh', - 'Refresh auto-completions.', arg_type=NO_QUERY) - self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename', - 'Execute commands from file.') - self.pgspecial.register(self.write_to_file, '\\o', '\\o [filename]', - 'Send all query results to file.') - self.pgspecial.register(self.info_connection, '\\conninfo', - '\\conninfo', 'Get connection details') - self.pgspecial.register(self.change_table_format, '\\T', '\\T [format]', - 'Change the table format used to output results') + self.pgspecial.register( + self.quit, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + self.pgspecial.register( + self.quit, + "quit", + "quit", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=("exit",), + ) + self.pgspecial.register( + refresh_callback, + "\\#", + "\\#", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + refresh_callback, + "\\refresh", + "\\refresh", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + self.execute_from_file, "\\i", "\\i filename", "Execute commands from file." + ) + self.pgspecial.register( + self.write_to_file, + "\\o", + "\\o [filename]", + "Send all query results to file.", + ) + self.pgspecial.register( + self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" + ) + self.pgspecial.register( + self.change_table_format, + "\\T", + "\\T [format]", + "Change the table format used to output results", + ) def change_table_format(self, pattern, **_): try: if pattern not in TabularOutputFormatter().supported_formats: raise ValueError() self.table_format = pattern - yield (None, None, None, - 'Changed table format to {}'.format(pattern)) + yield (None, None, None, "Changed table format to {}".format(pattern)) except ValueError: - msg = 'Table format {} not recognized. Allowed formats:'.format( - pattern) + msg = "Table format {} not recognized. Allowed formats:".format(pattern) for table_type in TabularOutputFormatter().supported_formats: msg += "\n\t{}".format(table_type) - msg += '\nCurrently set to: %s' % self.table_format + msg += "\nCurrently set to: %s" % self.table_format yield (None, None, None, msg) def info_connection(self, **_): - if self.pgexecute.host.startswith('/'): + if self.pgexecute.host.startswith("/"): host = 'socket "%s"' % self.pgexecute.host else: host = 'host "%s"' % self.pgexecute.host - yield (None, None, None, 'You are connected to database "%s" as user ' - '"%s" on %s at port "%s".' % (self.pgexecute.dbname, - self.pgexecute.user, - host, - self.pgexecute.port)) + yield ( + None, + None, + None, + 'You are connected to database "%s" as user ' + '"%s" on %s at port "%s".' + % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port), + ) def change_db(self, pattern, **_): if pattern: @@ -289,33 +360,42 @@ class PGCli(object): infos.extend([None] * (4 - len(infos))) db, user, host, port = infos try: - self.pgexecute.connect(database=db, user=user, host=host, - port=port, **self.pgexecute.extra_args) + self.pgexecute.connect( + database=db, + user=user, + host=host, + port=port, + **self.pgexecute.extra_args + ) except OperationalError as e: - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") click.echo("Previous connection kept") else: self.pgexecute.connect() - yield (None, None, None, 'You are now connected to database "%s" as ' - 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user)) + yield ( + None, + None, + None, + 'You are now connected to database "%s" as ' + 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user), + ) def execute_from_file(self, pattern, **_): if not pattern: - message = '\\i: missing required argument' - return [(None, None, None, message, '', False, True)] + message = "\\i: missing required argument" + return [(None, None, None, message, "", False, True)] try: - with open(os.path.expanduser(pattern), encoding='utf-8') as f: + with open(os.path.expanduser(pattern), encoding="utf-8") as f: query = f.read() except IOError as e: - return [(None, None, None, str(e), '', False, True)] + return [(None, None, None, str(e), "", False, True)] - if (self.destructive_warning and - confirm_destructive_query(query) is False): - message = 'Wise choice. Command execution stopped.' + if self.destructive_warning and confirm_destructive_query(query) is False: + message = "Wise choice. Command execution stopped." return [(None, None, None, message)] - on_error_resume = (self.on_error == 'RESUME') + on_error_resume = self.on_error == "RESUME" return self.pgexecute.run( query, self.pgspecial, on_error_resume=on_error_resume ) @@ -323,59 +403,61 @@ class PGCli(object): def write_to_file(self, pattern, **_): if not pattern: self.output_file = None - message = 'File output disabled' - return [(None, None, None, message, '', True, True)] + message = "File output disabled" + return [(None, None, None, message, "", True, True)] filename = os.path.abspath(os.path.expanduser(pattern)) if not os.path.isfile(filename): try: - open(filename, 'w').close() + open(filename, "w").close() except IOError as e: self.output_file = None - message = str(e) + '\nFile output disabled' - return [(None, None, None, message, '', False, True)] + message = str(e) + "\nFile output disabled" + return [(None, None, None, message, "", False, True)] self.output_file = filename message = 'Writing to file "%s"' % self.output_file - return [(None, None, None, message, '', True, True)] + return [(None, None, None, message, "", True, True)] def initialize_logging(self): - log_file = self.config['main']['log_file'] - if log_file == 'default': - log_file = config_location() + 'log' + log_file = self.config["main"]["log_file"] + if log_file == "default": + log_file = config_location() + "log" ensure_dir_exists(log_file) - log_level = self.config['main']['log_level'] + log_level = self.config["main"]["log_level"] # Disable logging if value is NONE by switching to a no-op handler. # Set log level to a high value so it doesn't even waste cycles getting called. - if log_level.upper() == 'NONE': + if log_level.upper() == "NONE": handler = logging.NullHandler() else: handler = logging.FileHandler(os.path.expanduser(log_file)) - level_map = {'CRITICAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARNING': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG, - 'NONE': logging.CRITICAL - } + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "NONE": logging.CRITICAL, + } log_level = level_map[log_level.upper()] formatter = logging.Formatter( - '%(asctime)s (%(process)d/%(threadName)s) ' - '%(name)s %(levelname)s - %(message)s') + "%(asctime)s (%(process)d/%(threadName)s) " + "%(name)s %(levelname)s - %(message)s" + ) handler.setFormatter(formatter) - root_logger = logging.getLogger('pgcli') + root_logger = logging.getLogger("pgcli") root_logger.addHandler(handler) root_logger.setLevel(log_level) - root_logger.debug('Initializing pgcli logging.') - root_logger.debug('Log file %r.', log_file) + root_logger.debug("Initializing pgcli logging.") + root_logger.debug("Log file %r.", log_file) - pgspecial_logger = logging.getLogger('pgspecial') + pgspecial_logger = logging.getLogger("pgspecial") pgspecial_logger.addHandler(handler) pgspecial_logger.setLevel(log_level) @@ -386,25 +468,24 @@ class PGCli(object): if keyring_enabled: # Try best to load keyring (issue #1041). import importlib + try: - keyring = importlib.import_module('keyring') + keyring = importlib.import_module("keyring") except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 - self.logger.warning('import keyring failed: %r.', e) + self.logger.warning("import keyring failed: %r.", e) def connect_dsn(self, dsn, **kwargs): self.connect(dsn=dsn, **kwargs) def connect_uri(self, uri): kwargs = psycopg2.extensions.parse_dsn(uri) - remap = { - 'dbname': 'database', - 'password': 'passwd', - } + remap = {"dbname": "database", "password": "passwd"} kwargs = {remap.get(k, k): v for k, v in kwargs.items()} self.connect(**kwargs) - def connect(self, database='', host='', user='', port='', passwd='', - dsn='', **kwargs): + def connect( + self, database="", host="", user="", port="", passwd="", dsn="", **kwargs + ): # Connect to the database. if not user: @@ -413,37 +494,35 @@ class PGCli(object): if not database: database = user - kwargs.setdefault('application_name', 'pgcli') + kwargs.setdefault("application_name", "pgcli") # If password prompt is not forced but no password is provided, try # getting it from environment variable. if not self.force_passwd_prompt and not passwd: - passwd = os.environ.get('PGPASSWORD', '') + passwd = os.environ.get("PGPASSWORD", "") # Find password from store - key = '%s@%s' % (user, host) - keyring_error_message = dedent("""\ + key = "%s@%s" % (user, host) + keyring_error_message = dedent( + """\ {} {} To remove this message do one of the following: - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/ - uninstall keyring: pip uninstall keyring - - disable keyring in our configuration: add keyring = False to [main]""") + - disable keyring in our configuration: add keyring = False to [main]""" + ) if not passwd and keyring: try: - passwd = keyring.get_password('pgcli', key) - except ( - RuntimeError, - keyring.errors.InitError - ) as e: + passwd = keyring.get_password("pgcli", key) + except (RuntimeError, keyring.errors.InitError) as e: click.secho( keyring_error_message.format( - "Load your password from keyring returned:", - str(e) + "Load your password from keyring returned:", str(e) ), err=True, - fg='red' + fg="red", ) # Prompt for a password immediately if requested via the -W flag. This @@ -452,8 +531,9 @@ class PGCli(object): # If we successfully parsed a password from a URI, there's no need to # prompt for it, even with the -W flag if self.force_passwd_prompt and not passwd: - passwd = click.prompt('Password for %s' % user, hide_input=True, - show_default=False, type=str) + passwd = click.prompt( + "Password for %s" % user, hide_input=True, show_default=False, type=str + ) def should_ask_for_password(exc): # Prompt for a password after 1st attempt to connect @@ -473,37 +553,36 @@ class PGCli(object): # prompt for a password (no -w flag), prompt for a passwd and try again. try: try: - pgexecute = PGExecute(database, user, passwd, host, port, dsn, - **kwargs) + pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) except (OperationalError, InterfaceError) as e: if should_ask_for_password(e): - passwd = click.prompt('Password for %s' % user, - hide_input=True, show_default=False, - type=str) - pgexecute = PGExecute(database, user, passwd, host, port, - dsn, **kwargs) + passwd = click.prompt( + "Password for %s" % user, + hide_input=True, + show_default=False, + type=str, + ) + pgexecute = PGExecute( + database, user, passwd, host, port, dsn, **kwargs + ) else: raise e if passwd and keyring: try: - keyring.set_password('pgcli', key, passwd) - except ( - RuntimeError, - keyring.errors.KeyringError, - ) as e: + keyring.set_password("pgcli", key, passwd) + except (RuntimeError, keyring.errors.KeyringError) as e: click.secho( keyring_error_message.format( - "Set password in keyring returned:", - str(e) + "Set password in keyring returned:", str(e) ), err=True, - fg='red' + fg="red", ) except Exception as e: # Connecting to a database could fail. - self.logger.debug('Database connection failed: %r.', e) + self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc()) - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") exit(1) self.pgexecute = pgexecute @@ -522,19 +601,17 @@ class PGCli(object): """ editor_command = special.editor_command(text) while editor_command: - if editor_command == '\\e': + if editor_command == "\\e": filename = special.get_filename(text) - query = special.get_editor_query( - text) or self.get_last_query() + query = special.get_editor_query(text) or self.get_last_query() else: # \ev or \ef filename = None spec = text.split()[1] - if editor_command == '\\ev': + if editor_command == "\\ev": query = self.pgexecute.view_definition(spec) - elif editor_command == '\\ef': + elif editor_command == "\\ef": query = self.pgexecute.function_definition(spec) - sql, message = special.open_external_editor( - filename, sql=query) + sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) @@ -554,21 +631,21 @@ class PGCli(object): query = MetaQuery(query=text, successful=False) try: - if (self.destructive_warning): + if self.destructive_warning: destroy = confirm = confirm_destructive_query(text) if destroy is False: - click.secho('Wise choice!') + click.secho("Wise choice!") raise KeyboardInterrupt elif destroy: - click.secho('Your call!') + click.secho("Your call!") output, query = self._evaluate_command(text) except KeyboardInterrupt: # Restart connection to the database self.pgexecute.connect() logger.debug("cancelled query, sql: %r", text) - click.secho("cancelled query", err=True, fg='red') + click.secho("cancelled query", err=True, fg="red") except NotImplementedError: - click.secho('Not Yet Implemented.', fg="yellow") + click.secho("Not Yet Implemented.", fg="yellow") except OperationalError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) @@ -578,68 +655,69 @@ class PGCli(object): except Exception as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") else: try: - if self.output_file and not text.startswith(('\\o ', '\\? ')): + if self.output_file and not text.startswith(("\\o ", "\\? ")): try: - with open(self.output_file, 'a', encoding='utf-8') as f: + with open(self.output_file, "a", encoding="utf-8") as f: click.echo(text, file=f) - click.echo('\n'.join(output), file=f) - click.echo('', file=f) # extra newline + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline except IOError as e: - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") else: - self.echo_via_pager('\n'.join(output)) + self.echo_via_pager("\n".join(output)) except KeyboardInterrupt: pass if self.pgspecial.timing_enabled: # Only add humanized time display if > 1 second if query.total_time > 1: - print('Time: %0.03fs (%s), executed in: %0.03fs (%s)' % (query.total_time, - humanize.time.naturaldelta( - query.total_time), - query.execution_time, - humanize.time.naturaldelta(query.execution_time))) + print ( + "Time: %0.03fs (%s), executed in: %0.03fs (%s)" + % ( + query.total_time, + humanize.time.naturaldelta(query.total_time), + query.execution_time, + humanize.time.naturaldelta(query.execution_time), + ) + ) else: - print('Time: %0.03fs' % query.total_time) + print ("Time: %0.03fs" % query.total_time) # Check if we need to update completions, in order of most # to least drastic changes if query.db_changed: with self._completer_lock: self.completer.reset_completions() - self.refresh_completions(persist_priorities='keywords') + self.refresh_completions(persist_priorities="keywords") elif query.meta_changed: - self.refresh_completions(persist_priorities='all') + self.refresh_completions(persist_priorities="all") elif query.path_changed: - logger.debug('Refreshing search path') + logger.debug("Refreshing search path") with self._completer_lock: - self.completer.set_search_path( - self.pgexecute.search_path()) - logger.debug('Search path: %r', - self.completer.search_path) + self.completer.set_search_path(self.pgexecute.search_path()) + logger.debug("Search path: %r", self.completer.search_path) return query def run_cli(self): logger = self.logger - history_file = self.config['main']['history_file'] - if history_file == 'default': - history_file = config_location() + 'history' + history_file = self.config["main"]["history_file"] + if history_file == "default": + history_file = config_location() + "history" history = FileHistory(os.path.expanduser(history_file)) - self.refresh_completions(history=history, - persist_priorities='none') + self.refresh_completions(history=history, persist_priorities="none") self.prompt_app = self._build_cli(history) if not self.less_chatty: - print('Server: PostgreSQL', self.pgexecute.server_version) - print('Version:', __version__) - print('Chat: https://gitter.im/dbcli/pgcli') - print('Mail: https://groups.google.com/forum/#!forum/pgcli') - print('Home: http://pgcli.com') + print ("Server: PostgreSQL", self.pgexecute.server_version) + print ("Version:", __version__) + print ("Chat: https://gitter.im/dbcli/pgcli") + print ("Mail: https://groups.google.com/forum/#!forum/pgcli") + print ("Home: http://pgcli.com") try: while True: @@ -653,7 +731,7 @@ class PGCli(object): except RuntimeError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") continue # Initialize default metaquery in case execution fails @@ -663,8 +741,10 @@ class PGCli(object): try: query = self.execute_command(self.watch_command) click.echo( - 'Waiting for {0} seconds before repeating' - .format(timing)) + "Waiting for {0} seconds before repeating".format( + timing + ) + ) sleep(timing) except KeyboardInterrupt: self.watch_command = None @@ -681,7 +761,7 @@ class PGCli(object): except (PgCliQuitError, EOFError): if not self.less_chatty: - print ('Goodbye!') + print ("Goodbye!") def _build_cli(self, history): key_bindings = pgcli_bindings(self) @@ -694,15 +774,17 @@ class PGCli(object): prompt = self.get_prompt(prompt_format) - if (prompt_format == self.default_prompt and - len(prompt) > self.max_len_prompt): - prompt = self.get_prompt('\\d> ') + if ( + prompt_format == self.default_prompt + and len(prompt) > self.max_len_prompt + ): + prompt = self.get_prompt("\\d> ") - return [('class:prompt', prompt)] + return [("class:prompt", prompt)] def get_continuation(width, line_number, is_soft_wrap): - continuation = self.multiline_continuation_char * (width - 1) + ' ' - return [('class:continuation', continuation)] + continuation = self.multiline_continuation_char * (width - 1) + " " + return [("class:continuation", continuation)] get_toolbar_tokens = create_toolbar_tokens_func(self) @@ -722,17 +804,17 @@ class PGCli(object): input_processors=[ # Highlight matching brackets while editing. ConditionalProcessor( - processor=HighlightMatchingBracketProcessor( - chars='[](){}'), - filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()), + processor=HighlightMatchingBracketProcessor(chars="[](){}"), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ), # Render \t as 4 spaces instead of "^I" - TabsProcessor(char1=' ', char2=' ')], + TabsProcessor(char1=" ", char2=" "), + ], auto_suggest=AutoSuggestFromHistory(), - tempfile_suffix='.sql', + tempfile_suffix=".sql", multiline=pg_is_multiline(self), history=history, - completer=ThreadedCompleter( - DynamicCompleter(lambda: self.completer)), + completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)), complete_while_typing=True, style=style_factory(self.syntax_style, self.cli_style), include_default_pygments_style=False, @@ -741,7 +823,8 @@ class PGCli(object): enable_system_prompt=True, enable_suspend=True, editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS, - search_ignore_case=True) + search_ignore_case=True, + ) return prompt_app @@ -758,7 +841,7 @@ class PGCli(object): returns (results, MetaQuery) """ logger = self.logger - logger.debug('sql: %r', text) + logger.debug("sql: %r", text) all_success = True meta_changed = False # CREATE, ALTER, DROP, etc @@ -771,9 +854,10 @@ class PGCli(object): # Run the query. start = time() - on_error_resume = self.on_error == 'RESUME' - res = self.pgexecute.run(text, self.pgspecial, - exception_formatter, on_error_resume) + on_error_resume = self.on_error == "RESUME" + res = self.pgexecute.run( + text, self.pgspecial, exception_formatter, on_error_resume + ) for title, cur, headers, status, sql, success, is_special in res: logger.debug("headers: %r", headers) @@ -781,10 +865,11 @@ class PGCli(object): logger.debug("status: %r", status) threshold = self.row_limit if self._should_show_limit_prompt(status, cur): - click.secho('The result set has more than %s rows.' - % threshold, fg='red') - if not click.confirm('Do you want to continue?'): - click.secho("Aborted!", err=True, fg='red') + click.secho( + "The result set has more than %s rows." % threshold, fg="red" + ) + if not click.confirm("Do you want to continue?"): + click.secho("Aborted!", err=True, fg="red") break if self.pgspecial.auto_expand or self.auto_expand: @@ -801,10 +886,11 @@ class PGCli(object): expanded=expanded, max_width=max_width, case_function=( - self.completer.case if self.settings['case_column_headers'] + self.completer.case + if self.settings["case_column_headers"] else lambda x: x ), - style_output=self.style_output + style_output=self.style_output, ) execution = time() - start formatted = format_output(title, cur, headers, status, settings) @@ -822,23 +908,32 @@ class PGCli(object): else: all_success = False - meta_query = MetaQuery(text, all_success, total, execution, meta_changed, - db_changed, path_changed, mutated, is_special) + meta_query = MetaQuery( + text, + all_success, + total, + execution, + meta_changed, + db_changed, + path_changed, + mutated, + is_special, + ) return output, meta_query def _handle_server_closed_connection(self, text): """Used during CLI execution.""" try: - click.secho('Reconnecting...', fg='green') + click.secho("Reconnecting...", fg="green") self.pgexecute.connect() - click.secho('Reconnected!', fg='green') + click.secho("Reconnected!", fg="green") self.execute_command(text) except OperationalError as e: - click.secho('Reconnect Failed', fg='red') - click.secho(str(e), err=True, fg='red') + click.secho("Reconnect Failed", fg="red") + click.secho(str(e), err=True, fg="red") - def refresh_completions(self, history=None, persist_priorities='all'): + def refresh_completions(self, history=None, persist_priorities="all"): """ Refresh outdated completions :param history: A prompt_toolkit.history.FileHistory object. Used to @@ -847,12 +942,19 @@ class PGCli(object): :param persist_priorities: 'all' or 'keywords' """ - callback = functools.partial(self._on_completions_refreshed, - persist_priorities=persist_priorities) - self.completion_refresher.refresh(self.pgexecute, self.pgspecial, - callback, history=history, settings=self.settings) - return [(None, None, None, - 'Auto-completion refresh started in the background.')] + callback = functools.partial( + self._on_completions_refreshed, persist_priorities=persist_priorities + ) + self.completion_refresher.refresh( + self.pgexecute, + self.pgspecial, + callback, + history=history, + settings=self.settings, + ) + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] def _on_completions_refreshed(self, new_completer, persist_priorities): self._swap_completer_objects(new_completer, persist_priorities) @@ -881,15 +983,15 @@ class PGCli(object): old_completer = self.completer self.completer = new_completer - if persist_priorities == 'all': + if persist_priorities == "all": # Just swap over the entire prioritizer new_completer.prioritizer = old_completer.prioritizer - elif persist_priorities == 'keywords': + elif persist_priorities == "keywords": # Swap over the entire prioritizer, but clear name priorities, # leaving learned keyword priorities alone new_completer.prioritizer = old_completer.prioritizer new_completer.prioritizer.clear_names() - elif persist_priorities == 'none': + elif persist_priorities == "none": # Leave the new prioritizer as is pass self.completer = new_completer @@ -897,21 +999,24 @@ class PGCli(object): def get_completions(self, text, cursor_positition): with self._completer_lock: return self.completer.get_completions( - Document(text=text, cursor_position=cursor_positition), None) + Document(text=text, cursor_position=cursor_positition), None + ) def get_prompt(self, string): # should be before replacing \\d - string = string.replace('\\dsn_alias', self.dsn_alias or '') - string = string.replace('\\t', self.now.strftime('%x %X')) - string = string.replace('\\u', self.pgexecute.user or '(none)') - string = string.replace('\\H', self.pgexecute.host or '(none)') - string = string.replace('\\h', self.pgexecute.short_host or '(none)') - string = string.replace('\\d', self.pgexecute.dbname or '(none)') - string = string.replace('\\p', str( - self.pgexecute.port) if self.pgexecute.port is not None else '5432') - string = string.replace('\\i', str(self.pgexecute.pid) or '(none)') - string = string.replace('\\#', "#" if (self.pgexecute.superuser) else ">") - string = string.replace('\\n', "\n") + string = string.replace("\\dsn_alias", self.dsn_alias or "") + string = string.replace("\\t", self.now.strftime("%x %X")) + string = string.replace("\\u", self.pgexecute.user or "(none)") + string = string.replace("\\H", self.pgexecute.host or "(none)") + string = string.replace("\\h", self.pgexecute.short_host or "(none)") + string = string.replace("\\d", self.pgexecute.dbname or "(none)") + string = string.replace( + "\\p", + str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", + ) + string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") + string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">") + string = string.replace("\\n", "\n") return string def get_last_query(self): @@ -922,7 +1027,10 @@ class PGCli(object): """Will this line be too wide to fit into terminal?""" if not self.prompt_app: return False - return len(COLOR_CODE_REGEX.sub('', line)) > self.prompt_app.output.get_size().columns + return ( + len(COLOR_CODE_REGEX.sub("", line)) + > self.prompt_app.output.get_size().columns + ) def is_too_tall(self, lines): """Are there too many lines to fit into terminal?""" @@ -934,7 +1042,7 @@ class PGCli(object): if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: click.echo(text, color=color) elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT: - lines = text.split('\n') + lines = text.split("\n") # The last 4 lines are reserved for the pgcli menu and padding if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines): @@ -944,51 +1052,139 @@ class PGCli(object): else: click.echo_via_pager(text, color) + @click.command() # Default host is '' so psycopg2 can default to either localhost or unix socket -@click.option('-h', '--host', default='', envvar='PGHOST', - help='Host address of the postgres database.') -@click.option('-p', '--port', default=5432, help='Port number at which the ' - 'postgres instance is listening.', envvar='PGPORT', type=click.INT) -@click.option('-U', '--username', 'username_opt', help='Username to connect to the postgres database.') -@click.option('-u', '--user', 'username_opt', help='Username to connect to the postgres database.') -@click.option('-W', '--password', 'prompt_passwd', is_flag=True, default=False, - help='Force password prompt.') -@click.option('-w', '--no-password', 'never_prompt', is_flag=True, - default=False, help='Never prompt for password.') -@click.option('--single-connection', 'single_connection', is_flag=True, - default=False, - help='Do not use a separate connection for completions.') -@click.option('-v', '--version', is_flag=True, help='Version of pgcli.') -@click.option('-d', '--dbname', 'dbname_opt', help='database name to connect to.') -@click.option('--pgclirc', default=config_location() + 'config', - envvar='PGCLIRC', help='Location of pgclirc file.', type=click.Path(dir_okay=False)) -@click.option('-D', '--dsn', default='', envvar='DSN', - help='Use DSN configured into the [alias_dsn] section of pgclirc file.') -@click.option('--list-dsn', 'list_dsn', is_flag=True, - help='list of DSN configured into the [alias_dsn] section of pgclirc file.') -@click.option('--row-limit', default=None, envvar='PGROWLIMIT', type=click.INT, - help='Set threshold for row limit prompt. Use 0 to disable prompt.') -@click.option('--less-chatty', 'less_chatty', is_flag=True, - default=False, - help='Skip intro on startup and goodbye on exit.') -@click.option('--prompt', help='Prompt format (Default: "\\u@\\h:\\d> ").') -@click.option('--prompt-dsn', help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").') -@click.option('-l', '--list', 'list_databases', is_flag=True, help='list ' - 'available databases, then exit.') -@click.option('--auto-vertical-output', is_flag=True, - help='Automatically switch to vertical output mode if the result is wider than the terminal width.') -@click.option('--warn/--no-warn', default=None, - help='Warn before running a destructive query.') -@click.argument('dbname', default=lambda: None, envvar='PGDATABASE', nargs=1) -@click.argument('username', default=lambda: None, envvar='PGUSER', nargs=1) -def cli(dbname, username_opt, host, port, prompt_passwd, never_prompt, - single_connection, dbname_opt, username, version, pgclirc, dsn, row_limit, - less_chatty, prompt, prompt_dsn, list_databases, auto_vertical_output, - list_dsn, warn): +@click.option( + "-h", + "--host", + default="", + envvar="PGHOST", + help="Host address of the postgres database.", +) +@click.option( + "-p", + "--port", + default=5432, + help="Port number at which the " "postgres instance is listening.", + envvar="PGPORT", + type=click.INT, +) +@click.option( + "-U", + "--username", + "username_opt", + help="Username to connect to the postgres database.", +) +@click.option( + "-u", "--user", "username_opt", help="Username to connect to the postgres database." +) +@click.option( + "-W", + "--password", + "prompt_passwd", + is_flag=True, + default=False, + help="Force password prompt.", +) +@click.option( + "-w", + "--no-password", + "never_prompt", + is_flag=True, + default=False, + help="Never prompt for password.", +) +@click.option( + "--single-connection", + "single_connection", + is_flag=True, + default=False, + help="Do not use a separate connection for completions.", +) +@click.option("-v", "--version", is_flag=True, help="Version of pgcli.") +@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.") +@click.option( + "--pgclirc", + default=config_location() + "config", + envvar="PGCLIRC", + help="Location of pgclirc file.", + type=click.Path(dir_okay=False), +) +@click.option( + "-D", + "--dsn", + default="", + envvar="DSN", + help="Use DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--list-dsn", + "list_dsn", + is_flag=True, + help="list of DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--row-limit", + default=None, + envvar="PGROWLIMIT", + type=click.INT, + help="Set threshold for row limit prompt. Use 0 to disable prompt.", +) +@click.option( + "--less-chatty", + "less_chatty", + is_flag=True, + default=False, + help="Skip intro on startup and goodbye on exit.", +) +@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").') +@click.option( + "--prompt-dsn", + help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").', +) +@click.option( + "-l", + "--list", + "list_databases", + is_flag=True, + help="list " "available databases, then exit.", +) +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option( + "--warn/--no-warn", default=None, help="Warn before running a destructive query." +) +@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) +@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) +def cli( + dbname, + username_opt, + host, + port, + prompt_passwd, + never_prompt, + single_connection, + dbname_opt, + username, + version, + pgclirc, + dsn, + row_limit, + less_chatty, + prompt, + prompt_dsn, + list_databases, + auto_vertical_output, + list_dsn, + warn, +): if version: - print('Version:', __version__) + print ("Version:", __version__) sys.exit(0) config_dir = os.path.dirname(config_location()) @@ -996,82 +1192,96 @@ def cli(dbname, username_opt, host, port, prompt_passwd, never_prompt, os.makedirs(config_dir) # Migrate the config file from old location. - config_full_path = config_location() + 'config' - if os.path.exists(os.path.expanduser('~/.pgclirc')): + config_full_path = config_location() + "config" + if os.path.exists(os.path.expanduser("~/.pgclirc")): if not os.path.exists(config_full_path): - shutil.move(os.path.expanduser('~/.pgclirc'), config_full_path) - print ('Config file (~/.pgclirc) moved to new location', - config_full_path) + shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path) + print ("Config file (~/.pgclirc) moved to new location", config_full_path) else: - print ('Config file is now located at', config_full_path) - print ('Please move the existing config file ~/.pgclirc to', - config_full_path) + print ("Config file is now located at", config_full_path) + print ( + "Please move the existing config file ~/.pgclirc to", + config_full_path, + ) if list_dsn: try: cfg = load_config(pgclirc, config_full_path) - for alias in cfg['alias_dsn']: - click.secho(alias + " : " + cfg['alias_dsn'][alias]) + for alias in cfg["alias_dsn"]: + click.secho(alias + " : " + cfg["alias_dsn"][alias]) sys.exit(0) except Exception as err: - click.secho('Invalid DSNs found in the config file. ' - 'Please check the "[alias_dsn]" section in pgclirc.', - err=True, fg='red') + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) exit(1) - pgcli = PGCli(prompt_passwd, never_prompt, pgclirc_file=pgclirc, - row_limit=row_limit, single_connection=single_connection, - less_chatty=less_chatty, prompt=prompt, prompt_dsn=prompt_dsn, - auto_vertical_output=auto_vertical_output, warn=warn) + pgcli = PGCli( + prompt_passwd, + never_prompt, + pgclirc_file=pgclirc, + row_limit=row_limit, + single_connection=single_connection, + less_chatty=less_chatty, + prompt=prompt, + prompt_dsn=prompt_dsn, + auto_vertical_output=auto_vertical_output, + warn=warn, + ) # Choose which ever one has a valid value. if dbname_opt and dbname: # work as psql: when database is given as option and argument use the argument as user username = dbname - database = dbname_opt or dbname or '' + database = dbname_opt or dbname or "" user = username_opt or username # because option --list or -l are not supposed to have a db name if list_databases: - database = 'postgres' + database = "postgres" - if dsn is not '': + if dsn is not "": try: cfg = load_config(pgclirc, config_full_path) - dsn_config = cfg['alias_dsn'][dsn] + dsn_config = cfg["alias_dsn"][dsn] except: - click.secho('Invalid DSNs found in the config file. '\ + click.secho( + "Invalid DSNs found in the config file. " 'Please check the "[alias_dsn]" section in pgclirc.', - err=True, fg='red') + err=True, + fg="red", + ) exit(1) pgcli.connect_uri(dsn_config) pgcli.dsn_alias = dsn - elif '://' in database: + elif "://" in database: pgcli.connect_uri(database) elif "=" in database: pgcli.connect_dsn(database, user=user) - elif os.environ.get('PGSERVICE', None): - pgcli.connect_dsn('service={0}'.format(os.environ['PGSERVICE'])) + elif os.environ.get("PGSERVICE", None): + pgcli.connect_dsn("service={0}".format(os.environ["PGSERVICE"])) else: pgcli.connect(database, host, user, port) if list_databases: cur, headers, status = pgcli.pgexecute.full_databases() - title = 'List of databases' - settings = OutputSettings( - table_format='ascii', - missingval='' - ) + title = "List of databases" + settings = OutputSettings(table_format="ascii", missingval="") formatted = format_output(title, cur, headers, status, settings) - pgcli.echo_via_pager('\n'.join(formatted)) + pgcli.echo_via_pager("\n".join(formatted)) sys.exit(0) - pgcli.logger.debug('Launch Params: \n' - '\tdatabase: %r' - '\tuser: %r' - '\thost: %r' - '\tport: %r', database, user, host, port) + pgcli.logger.debug( + "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", + database, + user, + host, + port, + ) if setproctitle: obfuscate_process_password() @@ -1081,10 +1291,12 @@ def cli(dbname, username_opt, host, port, prompt_passwd, never_prompt, def obfuscate_process_password(): process_title = setproctitle.getproctitle() - if '://' in process_title: + if "://" in process_title: process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title) elif "=" in process_title: - process_title = re.sub(r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title) + process_title = re.sub( + r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title + ) setproctitle.setproctitle(process_title) @@ -1094,7 +1306,7 @@ def has_meta_cmd(query): statement is an alter, create, drop, commit or rollback.""" try: first_token = query.split()[0] - if first_token.lower() in ('alter', 'create', 'drop', 'commit', 'rollback'): + if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"): return True except Exception: return False @@ -1106,7 +1318,7 @@ def has_change_db_cmd(query): """Determines if the statement is a database switch such as 'use' or '\\c'""" try: first_token = query.split()[0] - if first_token.lower() in ('use', '\\c', '\\connect'): + if first_token.lower() in ("use", "\\c", "\\connect"): return True except Exception: return False @@ -1117,7 +1329,7 @@ def has_change_db_cmd(query): def has_change_path_cmd(sql): """Determines if the search_path should be refreshed by checking if the sql has 'set search_path'.""" - return 'set search_path' in sql.lower() + return "set search_path" in sql.lower() def is_mutating(status): @@ -1125,7 +1337,7 @@ def is_mutating(status): if not status: return False - mutating = set(['insert', 'update', 'delete']) + mutating = set(["insert", "update", "delete"]) return status.split(None, 1)[0].lower() in mutating @@ -1133,18 +1345,17 @@ def is_select(status): """Returns true if the first word in status is 'select'.""" if not status: return False - return status.split(None, 1)[0].lower() == 'select' + return status.split(None, 1)[0].lower() == "select" def exception_formatter(e): - return click.style(utf8tounicode(str(e)), fg='red') + return click.style(utf8tounicode(str(e)), fg="red") def format_output(title, cur, headers, status, settings): output = [] - expanded = (settings.expanded or settings.table_format == 'vertical') - table_format = ('vertical' if settings.expanded else - settings.table_format) + expanded = settings.expanded or settings.table_format == "vertical" + table_format = "vertical" if settings.expanded else settings.table_format max_width = settings.max_width case_function = settings.case_function formatter = TabularOutputFormatter(format_name=table_format) @@ -1154,32 +1365,31 @@ def format_output(title, cur, headers, status, settings): return settings.missingval if not isinstance(val, list): return val - return '{' + ','.join(text_type(format_array(e)) for e in val) + '}' + return "{" + ",".join(text_type(format_array(e)) for e in val) + "}" def format_arrays(data, headers, **_): data = list(data) for row in data: row[:] = [ - format_array(val) if isinstance(val, list) else val - for val in row + format_array(val) if isinstance(val, list) else val for val in row ] return data, headers output_kwargs = { - 'sep_title': 'RECORD {n}', - 'sep_character': '-', - 'sep_length': (1, 25), - 'missing_value': settings.missingval, - 'integer_format': settings.dcmlfmt, - 'float_format': settings.floatfmt, - 'preprocessors': (format_numbers, format_arrays), - 'disable_numparse': True, - 'preserve_whitespace': True, - 'style': settings.style_output + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": settings.missingval, + "integer_format": settings.dcmlfmt, + "float_format": settings.floatfmt, + "preprocessors": (format_numbers, format_arrays), + "disable_numparse": True, + "preserve_whitespace": True, + "style": settings.style_output, } if not settings.floatfmt: - output_kwargs['preprocessors'] = (align_decimals, ) + output_kwargs["preprocessors"] = (align_decimals,) if title: # Only print the title if it's not None. output.append(title) @@ -1189,14 +1399,18 @@ def format_output(title, cur, headers, status, settings): if max_width is not None: cur = list(cur) column_types = None - if hasattr(cur, 'description'): + if hasattr(cur, "description"): column_types = [] for d in cur.description: - if d[1] in psycopg2.extensions.DECIMAL.values or \ - d[1] in psycopg2.extensions.FLOAT.values: + if ( + d[1] in psycopg2.extensions.DECIMAL.values + or d[1] in psycopg2.extensions.FLOAT.values + ): column_types.append(float) - if d[1] == psycopg2.extensions.INTEGER.values or \ - d[1] in psycopg2.extensions.LONGINTEGER.values: + if ( + d[1] == psycopg2.extensions.INTEGER.values + or d[1] in psycopg2.extensions.LONGINTEGER.values + ): column_types.append(int) else: column_types.append(text_type) @@ -1208,7 +1422,8 @@ def format_output(title, cur, headers, status, settings): if not expanded and max_width and len(first_line) > max_width and headers: formatted = formatter.format_output( - cur, headers, format_name='vertical', column_types=None, **output_kwargs) + cur, headers, format_name="vertical", column_types=None, **output_kwargs + ) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py index 818af9a7..9bfff73b 100644 --- a/pgcli/packages/parseutils/__init__.py +++ b/pgcli/packages/parseutils/__init__.py @@ -1,11 +1,13 @@ import sqlparse + def query_starts_with(query, prefixes): """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] formatted_sql = sqlparse.format(query.lower(), strip_comments=True) return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + def queries_start_with(queries, prefixes): """Check if any queries start with any item from *prefixes*.""" for query in sqlparse.split(queries): @@ -13,7 +15,8 @@ def queries_start_with(queries, prefixes): return True return False + def is_destructive(queries): """Returns if any of the queries in *queries* is destructive.""" - keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') + keywords = ("drop", "shutdown", "delete", "truncate", "alter") return queries_start_with(queries, keywords) diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py index 55088c2a..4b8786dc 100644 --- a/pgcli/packages/parseutils/ctes.py +++ b/pgcli/packages/parseutils/ctes.py @@ -12,7 +12,7 @@ from .meta import TableMetadata, ColumnMetadata # columns: list of column names # start: index into the original string of the left parens starting the CTE # stop: index into the original string of the right parens ending the CTE -TableExpression = namedtuple('TableExpression', 'name columns start stop') +TableExpression = namedtuple("TableExpression", "name columns start stop") def isolate_query_ctes(full_text, text_before_cursor): @@ -32,8 +32,8 @@ def isolate_query_ctes(full_text, text_before_cursor): for cte in ctes: if cte.start < current_position < cte.stop: # Currently editing a cte - treat its body as the current full_text - text_before_cursor = full_text[cte.start:current_position] - full_text = full_text[cte.start:cte.stop] + text_before_cursor = full_text[cte.start : current_position] + full_text = full_text[cte.start : cte.stop] return full_text, text_before_cursor, meta # Append this cte to the list of available table metadata @@ -41,8 +41,8 @@ def isolate_query_ctes(full_text, text_before_cursor): meta.append(TableMetadata(cte.name, cols)) # Editing past the last cte (ie the main body of the query) - full_text = full_text[ctes[-1].stop:] - text_before_cursor = text_before_cursor[ctes[-1].stop:current_position] + full_text = full_text[ctes[-1].stop :] + text_before_cursor = text_before_cursor[ctes[-1].stop : current_position] return full_text, text_before_cursor, tuple(meta) @@ -68,7 +68,7 @@ def extract_ctes(sql): # Get the next (meaningful) token, which should be the first CTE idx, tok = p.token_next(idx) if not tok: - return ([], '') + return ([], "") start_pos = token_start_pos(p.tokens, idx) ctes = [] @@ -89,7 +89,7 @@ def extract_ctes(sql): idx = p.token_index(tok) + 1 # Collapse everything after the ctes into a remainder query - remainder = u''.join(str(tok) for tok in p.tokens[idx:]) + remainder = "".join(str(tok) for tok in p.tokens[idx:]) return ctes, remainder @@ -118,10 +118,10 @@ def extract_column_names(parsed): idx, tok = parsed.token_next_by(t=DML) tok_val = tok and tok.value.lower() - if tok_val in ('insert', 'update', 'delete'): + if tok_val in ("insert", "update", "delete"): # Jump ahead to the RETURNING clause where the list of column names is - idx, tok = parsed.token_next_by(idx, (Keyword, 'returning')) - elif not tok_val == 'select': + idx, tok = parsed.token_next_by(idx, (Keyword, "returning")) + elif not tok_val == "select": # Must be invalid CTE return () diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py index 32dd0aff..a892b880 100644 --- a/pgcli/packages/parseutils/meta.py +++ b/pgcli/packages/parseutils/meta.py @@ -2,22 +2,26 @@ from __future__ import unicode_literals from collections import namedtuple _ColumnMetadata = namedtuple( - 'ColumnMetadata', - ['name', 'datatype', 'foreignkeys', 'default', 'has_default'] + "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"] ) -def ColumnMetadata( - name, datatype, foreignkeys=None, default=None, has_default=False -): - return _ColumnMetadata( - name, datatype, foreignkeys or [], default, has_default - ) +def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False): + return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default) -ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable', - 'parentcolumn', 'childschema', 'childtable', 'childcolumn']) -TableMetadata = namedtuple('TableMetadata', 'name columns') +ForeignKey = namedtuple( + "ForeignKey", + [ + "parentschema", + "parenttable", + "parentcolumn", + "childschema", + "childtable", + "childcolumn", + ], +) +TableMetadata = namedtuple("TableMetadata", "name columns") def parse_defaults(defaults_string): @@ -25,34 +29,42 @@ def parse_defaults(defaults_string): pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" if not defaults_string: return - current = '' + current = "" in_quote = None for char in defaults_string: - if current == '' and char == ' ': + if current == "" and char == " ": # Skip space after comma separating default expressions continue - if char == '"' or char == '\'': + if char == '"' or char == "'": if in_quote and char == in_quote: # End quote in_quote = None elif not in_quote: # Begin quote in_quote = char - elif char == ',' and not in_quote: + elif char == "," and not in_quote: # End of expression yield current - current = '' + current = "" continue current += char yield current class FunctionMetadata(object): - def __init__( - self, schema_name, func_name, arg_names, arg_types, arg_modes, - return_type, is_aggregate, is_window, is_set_returning, is_extension, - arg_defaults + self, + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, ): """Class for describing a postgresql function""" @@ -81,20 +93,27 @@ class FunctionMetadata(object): self.is_window = is_window self.is_set_returning = is_set_returning self.is_extension = bool(is_extension) - self.is_public = (self.schema_name and self.schema_name == 'public') + self.is_public = self.schema_name and self.schema_name == "public" def __eq__(self, other): - return (isinstance(other, self.__class__) - and self.__dict__ == other.__dict__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) def _signature(self): return ( - self.schema_name, self.func_name, self.arg_names, - self.arg_types, self.arg_modes, self.return_type, self.is_aggregate, - self.is_window, self.is_set_returning, self.is_extension, self.arg_defaults + self.schema_name, + self.func_name, + self.arg_names, + self.arg_types, + self.arg_modes, + self.return_type, + self.is_aggregate, + self.is_window, + self.is_set_returning, + self.is_extension, + self.arg_defaults, ) def __hash__(self): @@ -102,25 +121,23 @@ class FunctionMetadata(object): def __repr__(self): return ( - ( - '%s(schema_name=%r, func_name=%r, arg_names=%r, ' - 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' - 'is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)' - ) % ((self.__class__.__name__,) + self._signature()) - ) + "%s(schema_name=%r, func_name=%r, arg_names=%r, " + "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, " + "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)" + ) % ((self.__class__.__name__,) + self._signature()) def has_variadic(self): - return self.arg_modes and any(arg_mode == 'v' for arg_mode in self.arg_modes) + return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes) def args(self): """Returns a list of input-parameter ColumnMetadata namedtuples.""" if not self.arg_names: return [] - modes = self.arg_modes or ['i'] * len(self.arg_names) + modes = self.arg_modes or ["i"] * len(self.arg_names) args = [ (name, typ) for name, typ, mode in zip(self.arg_names, self.arg_types, modes) - if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC + if mode in ("i", "b", "v") # IN, INOUT, VARIADIC ] def arg(name, typ, num): @@ -128,7 +145,8 @@ class FunctionMetadata(object): num_defaults = len(self.arg_defaults) has_default = num + num_defaults >= num_args default = ( - self.arg_defaults[num - num_args + num_defaults] if has_default + self.arg_defaults[num - num_args + num_defaults] + if has_default else None ) return ColumnMetadata(name, typ, [], default, has_default) @@ -138,7 +156,7 @@ class FunctionMetadata(object): def fields(self): """Returns a list of output-field ColumnMetadata namedtuples""" - if self.return_type.lower() == 'void': + if self.return_type.lower() == "void": return [] elif not self.arg_modes: # For functions without output parameters, the function name @@ -146,7 +164,8 @@ class FunctionMetadata(object): # E.g. 'SELECT unnest FROM unnest(...);' return [ColumnMetadata(self.func_name, self.return_type, [])] - return [ColumnMetadata(name, typ, []) - for name, typ, mode in zip( - self.arg_names, self.arg_types, self.arg_modes) - if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE + return [ + ColumnMetadata(name, typ, []) + for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes) + if mode in ("o", "b", "t") + ] # OUT, INOUT, TABLE diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index 607b5009..ac4c1b9d 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -5,11 +5,17 @@ from collections import namedtuple from sqlparse.sql import IdentifierList, Identifier, Function from sqlparse.tokens import Keyword, DML, Punctuation -TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', - 'is_function']) -TableReference.ref = property(lambda self: self.alias or ( - self.name if self.name.islower() or self.name[0] == '"' - else '"' + self.name + '"')) +TableReference = namedtuple( + "TableReference", ["schema", "name", "alias", "is_function"] +) +TableReference.ref = property( + lambda self: self.alias + or ( + self.name + if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"' + ) +) # This code is borrowed from sqlparse example script. @@ -18,8 +24,13 @@ def is_subselect(parsed): if not parsed.is_group: return False for item in parsed.tokens: - if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', - 'UPDATE', 'CREATE', 'DELETE'): + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", + ): return True return False @@ -45,23 +56,29 @@ def extract_from_part(parsed, stop_at_punctuation=True): # Also 'SELECT * FROM abc JOIN def' will trigger this elif # condition. So we need to ignore the keyword JOIN and its variants # INNER JOIN, FULL OUTER JOIN, etc. - elif item.ttype is Keyword and ( - not item.value.upper() == 'FROM') and ( - not item.value.upper().endswith('JOIN')): + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) + ): tbl_prefix_seen = False else: yield item elif item.ttype is Keyword or item.ttype is Keyword.DML: item_val = item.value.upper() - if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or - item_val.endswith('JOIN')): + if item_val in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) or item_val.endswith("JOIN"): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. elif isinstance(item, IdentifierList): for identifier in item.get_identifiers(): - if (identifier.ttype is Keyword and - identifier.value.upper() == 'FROM'): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": tbl_prefix_seen = True break @@ -102,13 +119,15 @@ def extract_table_identifiers(token_stream, allow_functions=True): try: schema_name = identifier.get_parent_name() real_name = identifier.get_real_name() - is_function = (allow_functions and - _identifier_is_function(identifier)) + is_function = allow_functions and _identifier_is_function( + identifier + ) except AttributeError: continue if real_name: - yield TableReference(schema_name, real_name, - identifier.get_alias(), is_function) + yield TableReference( + schema_name, real_name, identifier.get_alias(), is_function + ) elif isinstance(item, Identifier): schema_name, real_name, alias = parse_identifier(item) is_function = allow_functions and _identifier_is_function(item) @@ -136,7 +155,7 @@ def extract_tables(sql): # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) # abc is the table name, but if we don't stop at the first lparen, then # we'll identify abc, col1 and col2 as table names. - insert_stmt = parsed[0].token_first().value.lower() == 'insert' + insert_stmt = parsed[0].token_first().value.lower() == "insert" stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) # Kludge: sqlparse mistakenly identifies insert statements as @@ -144,7 +163,6 @@ def extract_tables(sql): # "insert into foo (bar, baz)" as a function call to foo with arguments # (bar, baz). So don't allow any identifiers in insert statements # to have is_function=True - identifiers = extract_table_identifiers(stream, - allow_functions=not insert_stmt) + identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt) # In the case 'sche.', we get an empty TableReference; remove that return tuple(i for i in identifiers if i.name) diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py index e1569bff..998f71a3 100644 --- a/pgcli/packages/parseutils/utils.py +++ b/pgcli/packages/parseutils/utils.py @@ -6,17 +6,17 @@ from sqlparse.tokens import Token, Error cleanup_regex = { # This matches only alphanumerics and underscores. - 'alphanum_underscore': re.compile(r'(\w+)$'), + "alphanum_underscore": re.compile(r"(\w+)$"), # This matches everything except spaces, parens, colon, and comma - 'many_punctuations': re.compile(r'([^():,\s]+)$'), + "many_punctuations": re.compile(r"([^():,\s]+)$"), # This matches everything except spaces, parens, colon, comma, and period - 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), # This matches everything except a space. - 'all_punctuations': re.compile(r'([^\s]+)$'), + "all_punctuations": re.compile(r"([^\s]+)$"), } -def last_word(text, include='alphanum_underscore'): +def last_word(text, include="alphanum_underscore"): r""" Find the last word in a sentence. @@ -50,18 +50,18 @@ def last_word(text, include='alphanum_underscore'): '"foo*bar' """ - if not text: # Empty string - return '' + if not text: # Empty string + return "" if text[-1].isspace(): - return '' + return "" else: regex = cleanup_regex[include] matches = regex.search(text) if matches: return matches.group(0) else: - return '' + return "" def find_prev_keyword(sql, n_skip=0): @@ -71,17 +71,18 @@ def find_prev_keyword(sql, n_skip=0): everything after the last keyword stripped """ if not sql.strip(): - return None, '' + return None, "" parsed = sqlparse.parse(sql)[0] flattened = list(parsed.flatten()) - flattened = flattened[:len(flattened)-n_skip] + flattened = flattened[: len(flattened) - n_skip] - logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') + logical_operators = ("AND", "OR", "NOT", "BETWEEN") for t in reversed(flattened): - if t.value == '(' or (t.is_keyword and ( - t.value.upper() not in logical_operators)): + if t.value == "(" or ( + t.is_keyword and (t.value.upper() not in logical_operators) + ): # Find the location of token t in the original parsed statement # We can't use parsed.token_index(t) because t may be a child token # inside a TokenList, in which case token_index thows an error @@ -94,14 +95,14 @@ def find_prev_keyword(sql, n_skip=0): # Combine the string values of all tokens in the original list # up to and including the target keyword token t, to produce a # query string with everything after the keyword token removed - text = ''.join(tok.value for tok in flattened[:idx+1]) + text = "".join(tok.value for tok in flattened[: idx + 1]) return t, text - return None, '' + return None, "" # Postgresql dollar quote signs look like `$$` or `$tag$` -dollar_quote_regex = re.compile(r'^\$[^$]*\$$') +dollar_quote_regex = re.compile(r"^\$[^$]*\$$") def is_open_quote(sql): diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py index 874153b4..5c39296d 100644 --- a/pgcli/packages/pgliterals/main.py +++ b/pgcli/packages/pgliterals/main.py @@ -2,7 +2,7 @@ import os import json root = os.path.dirname(__file__) -literal_file = os.path.join(root, 'pgliterals.json') +literal_file = os.path.join(root, "pgliterals.json") with open(literal_file) as f: literals = json.load(f) diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py index cfd142cd..b692b75a 100644 --- a/pgcli/packages/prioritization.py +++ b/pgcli/packages/prioritization.py @@ -7,17 +7,17 @@ from collections import defaultdict from .pgliterals.main import get_literals -white_space_regex = re.compile('\\s+', re.MULTILINE) +white_space_regex = re.compile("\\s+", re.MULTILINE) def _compile_regex(keyword): # Surround the keyword with word boundaries and replace interior whitespace # with whitespace wildcards - pattern = '\\b' + white_space_regex.sub(r'\\s+', keyword) + '\\b' + pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b" return re.compile(pattern, re.MULTILINE | re.IGNORECASE) -keywords = get_literals('keywords') +keywords = get_literals("keywords") keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords) diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py index 8c64297e..63b5e059 100644 --- a/pgcli/packages/prompt_utils.py +++ b/pgcli/packages/prompt_utils.py @@ -15,8 +15,9 @@ def confirm_destructive_query(queries): * False if the query is destructive and the user doesn't want to proceed. """ - prompt_text = ("You're about to run a destructive command.\n" - "Do you want to proceed? (y/n)") + prompt_text = ( + "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" + ) if is_destructive(queries) and sys.stdin.isatty(): return prompt(prompt_text, type=bool) diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 8ba282e8..85b3bb10 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -5,8 +5,7 @@ import re import sqlparse from collections import namedtuple from sqlparse.sql import Comparison, Identifier, Where -from .parseutils.utils import ( - last_word, find_prev_keyword, parse_partial_identifier) +from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier from .parseutils.tables import extract_tables from .parseutils.ctes import isolate_query_ctes from pgspecial.main import parse_special_command @@ -20,23 +19,23 @@ else: string_types = basestring -Special = namedtuple('Special', []) -Database = namedtuple('Database', []) -Schema = namedtuple('Schema', ['quoted']) +Special = namedtuple("Special", []) +Database = namedtuple("Database", []) +Schema = namedtuple("Schema", ["quoted"]) Schema.__new__.__defaults__ = (False,) # FromClauseItem is a table/view/function used in the FROM clause # `table_refs` contains the list of tables/... already in the statement, # used to ensure that the alias we suggest is unique -FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables') -Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables']) -TableFormat = namedtuple('TableFormat', []) -View = namedtuple('View', ['schema', 'table_refs']) +FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables") +Table = namedtuple("Table", ["schema", "table_refs", "local_tables"]) +TableFormat = namedtuple("TableFormat", []) +View = namedtuple("View", ["schema", "table_refs"]) # JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' -JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent']) +JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"]) # Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' -Join = namedtuple('Join', ['table_refs', 'schema']) +Join = namedtuple("Join", ["table_refs", "schema"]) -Function = namedtuple('Function', ['schema', 'table_refs', 'usage']) +Function = namedtuple("Function", ["schema", "table_refs", "usage"]) # For convenience, don't require the `usage` argument in Function constructor Function.__new__.__defaults__ = (None, tuple(), None) Table.__new__.__defaults__ = (None, tuple(), tuple()) @@ -44,30 +43,32 @@ View.__new__.__defaults__ = (None, tuple()) FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) Column = namedtuple( - 'Column', - ['table_refs', 'require_last_table', 'local_tables', 'qualifiable', 'context'] + "Column", + ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], ) Column.__new__.__defaults__ = (None, None, tuple(), False, None) -Keyword = namedtuple('Keyword', ['last_token']) +Keyword = namedtuple("Keyword", ["last_token"]) Keyword.__new__.__defaults__ = (None,) -NamedQuery = namedtuple('NamedQuery', []) -Datatype = namedtuple('Datatype', ['schema']) -Alias = namedtuple('Alias', ['aliases']) +NamedQuery = namedtuple("NamedQuery", []) +Datatype = namedtuple("Datatype", ["schema"]) +Alias = namedtuple("Alias", ["aliases"]) -Path = namedtuple('Path', []) +Path = namedtuple("Path", []) class SqlStatement(object): def __init__(self, full_text, text_before_cursor): self.identifier = None self.word_before_cursor = word_before_cursor = last_word( - text_before_cursor, include='many_punctuations') + text_before_cursor, include="many_punctuations" + ) full_text = _strip_named_query(full_text) text_before_cursor = _strip_named_query(text_before_cursor) - full_text, text_before_cursor, self.local_tables = \ - isolate_query_ctes(full_text, text_before_cursor) + full_text, text_before_cursor, self.local_tables = isolate_query_ctes( + full_text, text_before_cursor + ) self.text_before_cursor_including_last_word = text_before_cursor @@ -78,28 +79,29 @@ class SqlStatement(object): # completion useless because it will always return the list of # keywords as completion. if self.word_before_cursor: - if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': + if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\": parsed = sqlparse.parse(text_before_cursor) else: - text_before_cursor = text_before_cursor[:-len(word_before_cursor)] + text_before_cursor = text_before_cursor[: -len(word_before_cursor)] parsed = sqlparse.parse(text_before_cursor) self.identifier = parse_partial_identifier(word_before_cursor) else: parsed = sqlparse.parse(text_before_cursor) - full_text, text_before_cursor, parsed = \ - _split_multiple_statements(full_text, text_before_cursor, parsed) + full_text, text_before_cursor, parsed = _split_multiple_statements( + full_text, text_before_cursor, parsed + ) self.full_text = full_text self.text_before_cursor = text_before_cursor self.parsed = parsed - self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or '' + self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or "" def is_insert(self): - return self.parsed.token_first().value.lower() == 'insert' + return self.parsed.token_first().value.lower() == "insert" - def get_tables(self, scope='full'): + def get_tables(self, scope="full"): """ Gets the tables available in the statement. param `scope:` possible values: 'full', 'insert', 'before' If 'insert', only the first table is returned. @@ -107,8 +109,9 @@ class SqlStatement(object): If not 'insert' and the stmt is an insert, the first table is skipped. """ tables = extract_tables( - self.full_text if scope == 'full' else self.text_before_cursor) - if scope == 'insert': + self.full_text if scope == "full" else self.text_before_cursor + ) + if scope == "insert": tables = tables[:1] elif self.is_insert(): tables = tables[1:] @@ -126,8 +129,9 @@ class SqlStatement(object): return schema def reduce_to_prev_keyword(self, n_skip=0): - prev_keyword, self.text_before_cursor = \ - find_prev_keyword(self.text_before_cursor, n_skip=n_skip) + prev_keyword, self.text_before_cursor = find_prev_keyword( + self.text_before_cursor, n_skip=n_skip + ) return prev_keyword @@ -139,7 +143,7 @@ def suggest_type(full_text, text_before_cursor): A scope for a column category will be a list of tables. """ - if full_text.startswith('\\i '): + if full_text.startswith("\\i "): return (Path(),) # This is a temporary hack; the exception handling @@ -154,14 +158,14 @@ def suggest_type(full_text, text_before_cursor): # Be careful here because trivial whitespace is parsed as a # statement, but the statement won't have a first token tok1 = stmt.parsed.token_first() - if tok1 and tok1.value.startswith('\\'): + if tok1 and tok1.value.startswith("\\"): text = stmt.text_before_cursor + stmt.word_before_cursor return suggest_special(text) return suggest_based_on_last_token(stmt.last_token, stmt) -named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+') +named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+") def _strip_named_query(txt): @@ -172,11 +176,11 @@ def _strip_named_query(txt): """ if named_query_regex.match(txt): - txt = named_query_regex.sub('', txt) + txt = named_query_regex.sub("", txt) return txt -function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M) +function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M) def _find_function_body(text): @@ -222,12 +226,12 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed): return full_text, text_before_cursor, None token2 = None - if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'): + if statement.get_type() in ("CREATE", "CREATE OR REPLACE"): token1 = statement.token_first() if token1: token1_idx = statement.token_index(token1) token2 = statement.token_next(token1_idx)[1] - if token2 and token2.value.upper() == 'FUNCTION': + if token2 and token2.value.upper() == "FUNCTION": full_text, text_before_cursor, statement = _statement_from_function( full_text, text_before_cursor, statement ) @@ -235,11 +239,11 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed): SPECIALS_SUGGESTION = { - 'dT': Datatype, - 'df': Function, - 'dt': Table, - 'dv': View, - 'sf': Function, + "dT": Datatype, + "df": Function, + "dt": Table, + "dv": View, + "sf": Function, } @@ -251,13 +255,13 @@ def suggest_special(text): # Trying to complete the special command itself return (Special(),) - if cmd in ('\\c', '\\connect'): + if cmd in ("\\c", "\\connect"): return (Database(),) - if cmd == '\\T': + if cmd == "\\T": return (TableFormat(),) - if cmd == '\\dn': + if cmd == "\\dn": return (Schema(),) if arg: @@ -272,27 +276,24 @@ def suggest_special(text): else: schema = None - if cmd[1:] == 'd': + if cmd[1:] == "d": # \d can describe tables or views if schema: - return (Table(schema=schema), - View(schema=schema),) + return (Table(schema=schema), View(schema=schema)) else: - return (Schema(), - Table(schema=None), - View(schema=None),) + return (Schema(), Table(schema=None), View(schema=None)) elif cmd[1:] in SPECIALS_SUGGESTION: rel_type = SPECIALS_SUGGESTION[cmd[1:]] if schema: if rel_type == Function: - return (Function(schema=schema, usage='special'),) + return (Function(schema=schema, usage="special"),) return (rel_type(schema=schema),) else: if rel_type == Function: - return (Schema(), Function(schema=None, usage='special'),) + return (Schema(), Function(schema=None, usage="special")) return (Schema(), rel_type(schema=None)) - if cmd in ['\\n', '\\ns', '\\nd']: + if cmd in ["\\n", "\\ns", "\\nd"]: return (NamedQuery(),) return (Keyword(), Special()) @@ -327,9 +328,9 @@ def suggest_based_on_last_token(token, stmt): # SELECT Identifier # SELECT foo FROM Identifier prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) - if prev_keyword and prev_keyword.value == '(': + if prev_keyword and prev_keyword.value == "(": # Suggest datatypes - return suggest_based_on_last_token('type', stmt) + return suggest_based_on_last_token("type", stmt) else: return (Keyword(),) else: @@ -337,7 +338,7 @@ def suggest_based_on_last_token(token, stmt): if not token: return (Keyword(), Special()) - elif token_v.endswith('('): + elif token_v.endswith("("): p = sqlparse.parse(stmt.text_before_cursor)[0] if p.tokens and isinstance(p.tokens[-1], Where): @@ -352,7 +353,7 @@ def suggest_based_on_last_token(token, stmt): # Suggest columns/functions AND keywords. (If we wanted to be # really fancy, we could suggest only array-typed columns) - column_suggestions = suggest_based_on_last_token('where', stmt) + column_suggestions = suggest_based_on_last_token("where", stmt) # Check for a subquery expression (cases 3 & 4) where = p.tokens[-1] @@ -363,7 +364,7 @@ def suggest_based_on_last_token(token, stmt): prev_tok = prev_tok.tokens[-1] prev_tok = prev_tok.value.lower() - if prev_tok == 'exists': + if prev_tok == "exists": return (Keyword(),) else: return column_suggestions @@ -371,57 +372,71 @@ def suggest_based_on_last_token(token, stmt): # Get the token before the parens prev_tok = p.token_prev(len(p.tokens) - 1)[1] - if (prev_tok and prev_tok.value - and prev_tok.value.lower().split(' ')[-1] == 'using'): + if ( + prev_tok + and prev_tok.value + and prev_tok.value.lower().split(" ")[-1] == "using" + ): # tbl1 INNER JOIN tbl2 USING (col1, col2) - tables = stmt.get_tables('before') + tables = stmt.get_tables("before") # suggest columns that are present in more than one table - return (Column(table_refs=tables, - require_last_table=True, - local_tables=stmt.local_tables),) + return ( + Column( + table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables, + ), + ) - elif p.token_first().value.lower() == 'select': + elif p.token_first().value.lower() == "select": # If the lparen is preceeded by a space chances are we're about to # do a sub-select. - if last_word(stmt.text_before_cursor, - 'all_punctuations').startswith('('): + if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): return (Keyword(),) prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] - if prev_prev_tok and prev_prev_tok.normalized == 'INTO': - return ( - Column(table_refs=stmt.get_tables('insert'), context='insert'), - ) + if prev_prev_tok and prev_prev_tok.normalized == "INTO": + return (Column(table_refs=stmt.get_tables("insert"), context="insert"),) # We're probably in a function argument list - return (Column(table_refs=extract_tables(stmt.full_text), - local_tables=stmt.local_tables, qualifiable=True),) - elif token_v == 'set': - return (Column(table_refs=stmt.get_tables(), - local_tables=stmt.local_tables),) - elif token_v in ('select', 'where', 'having', 'order by', 'distinct'): + return ( + Column( + table_refs=extract_tables(stmt.full_text), + local_tables=stmt.local_tables, + qualifiable=True, + ), + ) + elif token_v == "set": + return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) + elif token_v in ("select", "where", "having", "order by", "distinct"): # Check for a table alias or schema qualification parent = (stmt.identifier and stmt.identifier.get_parent_name()) or [] tables = stmt.get_tables() if parent: tables = tuple(t for t in tables if identifies(parent, t)) - return (Column(table_refs=tables, local_tables=stmt.local_tables), - Table(schema=parent), - View(schema=parent), - Function(schema=parent),) + return ( + Column(table_refs=tables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ) else: - return (Column(table_refs=tables, local_tables=stmt.local_tables, - qualifiable=True), - Function(schema=None), - Keyword(token_v.upper()),) - elif token_v == 'as': + return ( + Column( + table_refs=tables, local_tables=stmt.local_tables, qualifiable=True + ), + Function(schema=None), + Keyword(token_v.upper()), + ) + elif token_v == "as": # Don't suggest anything for aliases return () - elif (token_v.endswith('join') and token.is_keyword) or (token_v in - ('copy', 'from', 'update', 'into', 'describe', 'truncate')): + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate") + ): schema = stmt.get_identifier_schema() tables = extract_tables(stmt.text_before_cursor) - is_join = token_v.endswith('join') and token.is_keyword + is_join = token_v.endswith("join") and token.is_keyword # Suggest tables from either the currently-selected schema or the # public schema if no schema has been specified @@ -431,98 +446,101 @@ def suggest_based_on_last_token(token, stmt): # Suggest schemas suggest.insert(0, Schema()) - if token_v == 'from' or is_join: - suggest.append(FromClauseItem(schema=schema, - table_refs=tables, - local_tables=stmt.local_tables)) - elif token_v == 'truncate': + if token_v == "from" or is_join: + suggest.append( + FromClauseItem( + schema=schema, table_refs=tables, local_tables=stmt.local_tables + ) + ) + elif token_v == "truncate": suggest.append(Table(schema)) else: suggest.extend((Table(schema), View(schema))) if is_join and _allow_join(stmt.parsed): - tables = stmt.get_tables('before') + tables = stmt.get_tables("before") suggest.append(Join(table_refs=tables, schema=schema)) return tuple(suggest) - elif token_v == 'function': + elif token_v == "function": schema = stmt.get_identifier_schema() # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` try: prev = stmt.get_previous_token(token).value.lower() - if prev in('drop', 'alter', 'create', 'create or replace'): - return (Function(schema=schema, usage='signature'),) + if prev in ("drop", "alter", "create", "create or replace"): + return (Function(schema=schema, usage="signature"),) except ValueError: pass return tuple() - elif token_v in ('table', 'view'): + elif token_v in ("table", "view"): # E.g. 'ALTER TABLE ' - rel_type = {'table': Table, 'view': View, 'function': Function}[token_v] + rel_type = {"table": Table, "view": View, "function": Function}[token_v] schema = stmt.get_identifier_schema() if schema: return (rel_type(schema=schema),) else: return (Schema(), rel_type(schema=schema)) - elif token_v == 'column': + elif token_v == "column": # E.g. 'ALTER TABLE foo ALTER COLUMN bar return (Column(table_refs=stmt.get_tables()),) - elif token_v == 'on': - tables = stmt.get_tables('before') + elif token_v == "on": + tables = stmt.get_tables("before") parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None if parent: # "ON parent." # parent can be either a schema name or table alias filteredtables = tuple(t for t in tables if identifies(parent, t)) - sugs = [Column(table_refs=filteredtables, - local_tables=stmt.local_tables), - Table(schema=parent), - View(schema=parent), - Function(schema=parent)] + sugs = [ + Column(table_refs=filteredtables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ] if filteredtables and _allow_join_condition(stmt.parsed): - sugs.append(JoinCondition(table_refs=tables, - parent=filteredtables[-1])) + sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) return tuple(sugs) else: # ON # Use table alias if there is one, otherwise the table name aliases = tuple(t.ref for t in tables) if _allow_join_condition(stmt.parsed): - return (Alias(aliases=aliases), JoinCondition( - table_refs=tables, parent=None)) + return ( + Alias(aliases=aliases), + JoinCondition(table_refs=tables, parent=None), + ) else: return (Alias(aliases=aliases),) - elif token_v in ('c', 'use', 'database', 'template'): + elif token_v in ("c", "use", "database", "template"): # "\c ", "DROP DATABASE ", # "CREATE DATABASE WITH TEMPLATE " return (Database(),) - elif token_v == 'schema': + elif token_v == "schema": # DROP SCHEMA schema_name, SET SCHEMA schema name prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) - quoted = prev_keyword and prev_keyword.value.lower() == 'set' + quoted = prev_keyword and prev_keyword.value.lower() == "set" return (Schema(quoted),) - elif token_v.endswith(',') or token_v in ('=', 'and', 'or'): + elif token_v.endswith(",") or token_v in ("=", "and", "or"): prev_keyword = stmt.reduce_to_prev_keyword() if prev_keyword: return suggest_based_on_last_token(prev_keyword, stmt) else: return () - elif token_v in ('type', '::'): + elif token_v in ("type", "::"): # ALTER TABLE foo SET DATA TYPE bar # SELECT foo::bar # Note that tables are a form of composite type in postgresql, so # they're suggested here as well schema = stmt.get_identifier_schema() - suggestions = [Datatype(schema=schema), - Table(schema=schema)] + suggestions = [Datatype(schema=schema), Table(schema=schema)] if not schema: suggestions.append(Schema()) return tuple(suggestions) - elif token_v in {'alter', 'create', 'drop'}: + elif token_v in {"alter", "create", "drop"}: return (Keyword(token_v.upper()),) elif token.is_keyword: # token is a keyword we haven't implemented any special handling for @@ -539,8 +557,11 @@ def suggest_based_on_last_token(token, stmt): def identifies(id, ref): """Returns true if string `id` matches TableReference `ref`""" - return id == ref.alias or id == ref.name or ( - ref.schema and (id == ref.schema + '.' + ref.name)) + return ( + id == ref.alias + or id == ref.name + or (ref.schema and (id == ref.schema + "." + ref.name)) + ) def _allow_join_condition(statement): @@ -560,7 +581,7 @@ def _allow_join_condition(statement): return False last_tok = statement.token_prev(len(statement.tokens))[1] - return last_tok.value.lower() in ('on', 'and', 'or') + return last_tok.value.lower() in ("on", "and", "or") def _allow_join(statement): @@ -579,5 +600,7 @@ def _allow_join(statement): return False last_tok = statement.token_prev(len(statement.tokens))[1] - return (last_tok.value.lower().endswith('join') - and last_tok.value.lower() not in('cross join', 'natural join')) + return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in ( + "cross join", + "natural join", + ) diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py index 432d1479..bf1a10eb 100644 --- a/pgcli/pgbuffer.py +++ b/pgcli/pgbuffer.py @@ -13,10 +13,11 @@ def pg_is_multiline(pgcli): if not pgcli.multi_line: return False - if pgcli.multiline_mode == 'safe': + if pgcli.multiline_mode == "safe": return True else: return not _multiline_exception(doc.text) + return cond @@ -24,17 +25,19 @@ def _is_complete(sql): # A complete command is an sql statement that ends with a semicolon, unless # there's an open quote surrounding it, as is common when writing a # CREATE FUNCTION command - return sql.endswith(';') and not is_open_quote(sql) + return sql.endswith(";") and not is_open_quote(sql) def _multiline_exception(text): text = text.strip() return ( - text.startswith('\\') or # Special Command - text.endswith(r'\e') or # Ended with \e which should launch the editor - _is_complete(text) or # A complete SQL command - (text == 'exit') or # Exit doesn't need semi-colon - (text == 'quit') or # Quit doesn't need semi-colon - (text == ':q') or # To all the vim fans out there - (text == '') # Just a plain enter without any text + text.startswith("\\") + or text.endswith(r"\e") # Special Command + or _is_complete(text) # Ended with \e which should launch the editor + or (text == "exit") # A complete SQL command + or (text == "quit") # Exit doesn't need semi-colon + or (text == ":q") # Quit doesn't need semi-colon + or ( # To all the vim fans out there + text == "" + ) # Just a plain enter without any text ) diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index a4c3724d..14ae4fa2 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -9,9 +9,24 @@ from pgspecial.namedqueries import NamedQueries from prompt_toolkit.completion import Completer, Completion, PathCompleter from prompt_toolkit.document import Document from .packages.sqlcompletion import ( - FromClauseItem, suggest_type, Special, Database, Schema, Table, - TableFormat, Function, Column, View, Keyword, NamedQuery, - Datatype, Alias, Path, JoinCondition, Join) + FromClauseItem, + suggest_type, + Special, + Database, + Schema, + Table, + TableFormat, + Function, + Column, + View, + Keyword, + NamedQuery, + Datatype, + Alias, + Path, + JoinCondition, + Join, +) from .packages.parseutils.meta import ColumnMetadata, ForeignKey from .packages.parseutils.utils import last_word from .packages.parseutils.tables import TableReference @@ -21,34 +36,30 @@ from .config import load_config, config_location _logger = logging.getLogger(__name__) -Match = namedtuple('Match', ['completion', 'priority']) +Match = namedtuple("Match", ["completion", "priority"]) -_SchemaObject = namedtuple('SchemaObject', 'name schema meta') +_SchemaObject = namedtuple("SchemaObject", "name schema meta") def SchemaObject(name, schema=None, meta=None): return _SchemaObject(name, schema, meta) -_Candidate = namedtuple( - 'Candidate', 'completion prio meta synonyms prio2 display' -) +_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display") def Candidate( - completion, prio=None, meta=None, synonyms=None, prio2=None, - display=None + completion, prio=None, meta=None, synonyms=None, prio2=None, display=None ): return _Candidate( - completion, prio, meta, synonyms or [completion], prio2, - display or completion + completion, prio, meta, synonyms or [completion], prio2, display or completion ) # Used to strip trailing '::some_type' from default-value expressions -arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$') +arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") -normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' +normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' def generate_alias(tbl): @@ -57,18 +68,20 @@ def generate_alias(tbl): all letters preceded by _ param tbl - unescaped name of the table to alias """ - return ''.join([l for l in tbl if l.isupper()] or - [l for l, prev in zip(tbl, '_' + tbl) if prev == '_' and l != '_']) + return "".join( + [l for l in tbl if l.isupper()] + or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] + ) class PGCompleter(Completer): # keywords_tree: A dict mapping keywords to well known following keywords. # e.g. 'CREATE': ['TABLE', 'USER', ...], - keywords_tree = get_literals('keywords', type_=dict) + keywords_tree = get_literals("keywords", type_=dict) keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values()))) - functions = get_literals('functions') - datatypes = get_literals('datatypes') - reserved_words = set(get_literals('reserved')) + functions = get_literals("functions") + datatypes = get_literals("datatypes") + reserved_words = set(get_literals("reserved")) def __init__(self, smart_completion=True, pgspecial=None, settings=None): super(PGCompleter, self).__init__() @@ -77,48 +90,49 @@ class PGCompleter(Completer): self.prioritizer = PrevalenceCounter() settings = settings or {} self.signature_arg_style = settings.get( - 'signature_arg_style', '{arg_name} {arg_type}' + "signature_arg_style", "{arg_name} {arg_type}" ) self.call_arg_style = settings.get( - 'call_arg_style', '{arg_name: <{max_arg_len}} := {arg_default}' + "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}" ) self.call_arg_display_style = settings.get( - 'call_arg_display_style', '{arg_name}' + "call_arg_display_style", "{arg_name}" ) - self.call_arg_oneliner_max = settings.get('call_arg_oneliner_max', 2) - self.search_path_filter = settings.get('search_path_filter') - self.generate_aliases = settings.get('generate_aliases') - self.casing_file = settings.get('casing_file') + self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) + self.search_path_filter = settings.get("search_path_filter") + self.generate_aliases = settings.get("generate_aliases") + self.casing_file = settings.get("casing_file") self.insert_col_skip_patterns = [ - re.compile(pattern) for pattern in settings.get( - 'insert_col_skip_patterns', - [r'^now\(\)$', r'^nextval\('] + re.compile(pattern) + for pattern in settings.get( + "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("] ) ] - self.generate_casing_file = settings.get('generate_casing_file') - self.qualify_columns = settings.get( - 'qualify_columns', 'if_more_than_one_table') + self.generate_casing_file = settings.get("generate_casing_file") + self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table") self.asterisk_column_order = settings.get( - 'asterisk_column_order', 'table_order') + "asterisk_column_order", "table_order" + ) - keyword_casing = settings.get('keyword_casing', 'upper').lower() - if keyword_casing not in ('upper', 'lower', 'auto'): - keyword_casing = 'upper' + keyword_casing = settings.get("keyword_casing", "upper").lower() + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "upper" self.keyword_casing = keyword_casing self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$") self.databases = [] - self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, - 'datatypes': {}} + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} self.search_path = [] self.casing = {} self.all_completions = set(self.keywords + self.functions) def escape_name(self, name): - if name and ((not self.name_pattern.match(name)) - or (name.upper() in self.reserved_words) - or (name.upper() in self.functions)): + if name and ( + (not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions) + ): name = '"%s"' % name return name @@ -147,7 +161,7 @@ class PGCompleter(Completer): # schemata is a list of schema names schemata = self.escaped_names(schemata) - metadata = self.dbmetadata['tables'] + metadata = self.dbmetadata["tables"] for schema in schemata: metadata[schema] = {} @@ -185,8 +199,9 @@ class PGCompleter(Completer): try: metadata[schema][relname] = OrderedDict() except KeyError: - _logger.error('%r %r listed in unrecognized schema %r', - kind, relname, schema) + _logger.error( + "%r %r listed in unrecognized schema %r", kind, relname, schema + ) self.all_completions.add(relname) def extend_columns(self, column_data, kind): @@ -201,13 +216,12 @@ class PGCompleter(Completer): """ metadata = self.dbmetadata[kind] for schema, relname, colname, datatype, has_default, default in column_data: - (schema, relname, colname) = self.escaped_names( - [schema, relname, colname]) + (schema, relname, colname) = self.escaped_names([schema, relname, colname]) column = ColumnMetadata( name=colname, datatype=datatype, has_default=has_default, - default=default + default=default, ) metadata[schema][relname][colname] = column self.all_completions.add(colname) @@ -218,7 +232,7 @@ class PGCompleter(Completer): # dbmetadata['schema_name']['functions']['function_name'] should return # the function metadata namedtuple for the corresponding function - metadata = self.dbmetadata['functions'] + metadata = self.dbmetadata["functions"] for f in func_data: schema, func = self.escaped_names([f.schema_name, f.func_name]) @@ -239,11 +253,11 @@ class PGCompleter(Completer): self._arg_list_cache = { usage: { meta: self._arg_list(meta, usage) - for sch, funcs in self.dbmetadata['functions'].items() + for sch, funcs in self.dbmetadata["functions"].items() for func, metas in funcs.items() for meta in metas } - for usage in ('call', 'call_display', 'signature') + for usage in ("call", "call_display", "signature") } def extend_foreignkeys(self, fk_data): @@ -254,17 +268,18 @@ class PGCompleter(Completer): # These are added as a list of ForeignKey namedtuples to the # ColumnMetadata namedtuple for both the child and parent - meta = self.dbmetadata['tables'] + meta = self.dbmetadata["tables"] for fk in fk_data: e = self.escaped_names parentschema, childschema = e([fk.parentschema, fk.childschema]) parenttable, childtable = e([fk.parenttable, fk.childtable]) childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) - childcolmeta = meta[childschema][childtable][childcol] - parcolmeta = meta[parentschema][parenttable][parcol] - fk = ForeignKey(parentschema, parenttable, parcol, - childschema, childtable, childcol) + childcolmeta = meta[childschema][childtable][childcol] + parcolmeta = meta[parentschema][parenttable][parcol] + fk = ForeignKey( + parentschema, parenttable, parcol, childschema, childtable, childcol + ) childcolmeta.foreignkeys.append((fk)) parcolmeta.foreignkeys.append((fk)) @@ -273,7 +288,7 @@ class PGCompleter(Completer): # dbmetadata['datatypes'][schema_name][type_name] should store type # metadata, such as composite type field names. Currently, we're not # storing any metadata beyond typename, so just store None - meta = self.dbmetadata['datatypes'] + meta = self.dbmetadata["datatypes"] for t in type_data: schema, type_name = self.escaped_names(t) @@ -295,11 +310,10 @@ class PGCompleter(Completer): self.databases = [] self.special_commands = [] self.search_path = [] - self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, - 'datatypes': {}} + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} self.all_completions = set(self.keywords + self.functions) - def find_matches(self, text, collection, mode='fuzzy', meta=None): + def find_matches(self, text, collection, mode="fuzzy", meta=None): """Find completion matches for the given text. Given the user's input text and a collection of available @@ -319,12 +333,22 @@ class PGCompleter(Completer): if not collection: return [] prio_order = [ - 'keyword', 'function', 'view', 'table', 'datatype', 'database', - 'schema', 'column', 'table alias', 'join', 'name join', 'fk join', - 'table format' + "keyword", + "function", + "view", + "table", + "datatype", + "database", + "schema", + "column", + "table alias", + "join", + "name join", + "fk join", + "table format", ] type_priority = prio_order.index(meta) if meta in prio_order else -1 - text = last_word(text, include='most_punctuations').lower() + text = last_word(text, include="most_punctuations").lower() text_len = len(text) if text and text[0] == '"': @@ -334,7 +358,7 @@ class PGCompleter(Completer): # Completion.position value is correct text = text[1:] - if mode == 'fuzzy': + if mode == "fuzzy": fuzzy = True priority_func = self.prioritizer.name_count else: @@ -347,19 +371,20 @@ class PGCompleter(Completer): # Note: higher priority values mean more important, so use negative # signs to flip the direction of the tuple if fuzzy: - regex = '.*?'.join(map(re.escape, text)) - pat = re.compile('(%s)' % regex) + regex = ".*?".join(map(re.escape, text)) + pat = re.compile("(%s)" % regex) def _match(item): - if item.lower()[:len(text) + 1] in (text, text + ' '): + if item.lower()[: len(text) + 1] in (text, text + " "): # Exact match of first word in suggestion # This is to get exact alias matches to the top # E.g. for input `e`, 'Entries E' should be on top # (before e.g. `EndUsers EU`) - return float('Infinity'), -1 + return float("Infinity"), -1 r = pat.search(self.unescape_name(item.lower())) if r: return -len(r.group()), -r.start() + else: match_end_limit = len(text) @@ -368,7 +393,7 @@ class PGCompleter(Completer): if match_point >= 0: # Use negative infinity to force keywords to sort after all # fuzzy matches - return -float('Infinity'), -match_point + return -float("Infinity"), -match_point matches = [] for cand in collection: @@ -387,7 +412,7 @@ class PGCompleter(Completer): if sort_key: if display_meta and len(display_meta) > 50: # Truncate meta-text to 50 characters, if necessary - display_meta = display_meta[:47] + u'...' + display_meta = display_meta[:47] + "..." # Lexical order of items in the collection, used for # tiebreaking items with the same match group length and start @@ -398,15 +423,24 @@ class PGCompleter(Completer): # case-sensitive one as a tie breaker. # We also use the unescape_name to make sure quoted names have # the same priority as unquoted names. - lexical_priority = (tuple(0 if c in(' _') else -ord(c) - for c in self.unescape_name(item.lower())) + (1,) - + tuple(c for c in item)) + lexical_priority = ( + tuple( + 0 if c in (" _") else -ord(c) + for c in self.unescape_name(item.lower()) + ) + + (1,) + + tuple(c for c in item) + ) item = self.case(item) display = self.case(display) priority = ( - sort_key, type_priority, prio, priority_func(item), - prio2, lexical_priority + sort_key, + type_priority, + prio, + priority_func(item), + prio2, + lexical_priority, ) matches.append( Match( @@ -414,9 +448,9 @@ class PGCompleter(Completer): text=item, start_position=-text_len, display_meta=display_meta, - display=display + display=display, ), - priority=priority + priority=priority, ) ) return matches @@ -433,17 +467,18 @@ class PGCompleter(Completer): # If smart_completion is off then match any word that starts with # 'word_before_cursor'. if not smart_completion: - matches = self.find_matches(word_before_cursor, self.all_completions, - mode='strict') + matches = self.find_matches( + word_before_cursor, self.all_completions, mode="strict" + ) completions = [m.completion for m in matches] - return sorted(completions, key=operator.attrgetter('text')) + return sorted(completions, key=operator.attrgetter("text")) matches = [] suggestions = suggest_type(document.text, document.text_before_cursor) for suggestion in suggestions: suggestion_type = type(suggestion) - _logger.debug('Suggestion type: %r', suggestion_type) + _logger.debug("Suggestion type: %r", suggestion_type) # Map suggestion type to method # e.g. 'table' -> self.get_table_matches @@ -451,76 +486,91 @@ class PGCompleter(Completer): matches.extend(matcher(self, suggestion, word_before_cursor)) # Sort matches so highest priorities are first - matches = sorted(matches, key=operator.attrgetter('priority'), - reverse=True) + matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True) return [m.completion for m in matches] def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs - do_qualify = suggestion.qualifiable and {'always': True, 'never': False, - 'if_more_than_one_table': len(tables) > 1}[self.qualify_columns] + do_qualify = ( + suggestion.qualifiable + and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] + ) qualify = lambda col, tbl: ( - (tbl + '.' + self.case(col)) if do_qualify else self.case(col)) + (tbl + "." + self.case(col)) if do_qualify else self.case(col) + ) _logger.debug("Completion column scope: %r", tables) scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) def make_cand(name, ref): synonyms = (name, generate_alias(self.case(name))) - return Candidate(qualify(name, ref), 0, 'column', synonyms) + return Candidate(qualify(name, ref), 0, "column", synonyms) def flat_cols(): - return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() for c in cols] + return [ + make_cand(c.name, t.ref) + for t, cols in scoped_cols.items() + for c in cols + ] + if suggestion.require_last_table: # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should # suggest only columns that appear in the last table and one more ltbl = tables[-1].ref other_tbl_cols = set( - c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs) + c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs + ) scoped_cols = { t: [col for col in cols if col.name in other_tbl_cols] for t, cols in scoped_cols.items() if t.ref == ltbl } - lastword = last_word(word_before_cursor, include='most_punctuations') - if lastword == '*': - if suggestion.context == 'insert': + lastword = last_word(word_before_cursor, include="most_punctuations") + if lastword == "*": + if suggestion.context == "insert": + def filter(col): if not col.has_default: return True return not any( - p.match(col.default) - for p in self.insert_col_skip_patterns + p.match(col.default) for p in self.insert_col_skip_patterns ) + scoped_cols = { - t: [col for col in cols if filter(col)] for t, cols in scoped_cols.items() + t: [col for col in cols if filter(col)] + for t, cols in scoped_cols.items() } - if self.asterisk_column_order == 'alphabetic': + if self.asterisk_column_order == "alphabetic": for cols in scoped_cols.values(): - cols.sort(key=operator.attrgetter('name')) - if (lastword != word_before_cursor and len(tables) == 1 - and word_before_cursor[-len(lastword) - 1] == '.'): + cols.sort(key=operator.attrgetter("name")) + if ( + lastword != word_before_cursor + and len(tables) == 1 + and word_before_cursor[-len(lastword) - 1] == "." + ): # User typed x.*; replicate "x." for all columns except the # first, which gets the original (as we only replace the "*"") - sep = ', ' + word_before_cursor[:-1] - collist = sep.join(self.case(c.completion) - for c in flat_cols()) + sep = ", " + word_before_cursor[:-1] + collist = sep.join(self.case(c.completion) for c in flat_cols()) else: - collist = ', '.join(qualify(c.name, t.ref) - for t, cs in scoped_cols.items() for c in cs) + collist = ", ".join( + qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs + ) - return [Match( - completion=Completion( - collist, - -1, - display_meta='columns', - display='*' - ), - priority=(1, 1, 1) - )] + return [ + Match( + completion=Completion( + collist, -1, display_meta="columns", display="*" + ), + priority=(1, 1, 1), + ) + ] - return self.find_matches(word_before_cursor, flat_cols(), - meta='column') + return self.find_matches(word_before_cursor, flat_cols(), meta="column") def alias(self, tbl, tbls): """ Generate a unique table alias @@ -546,13 +596,16 @@ class PGCompleter(Completer): qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) refs = set(normalize_ref(t.ref) for t in tbls) - other_tbls = set((t.schema, t.name) - for t in list(cols)[:-1]) + other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) joins = [] # Iterate over FKs in existing tables to find potential joins - fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items() - for rcol in rcols for fk in rcol.foreignkeys) - col = namedtuple('col', 'schema tbl col') + fks = ( + (fk, rtbl, rcol) + for rtbl, rcols in cols.items() + for rcol in rcols + for fk in rcol.foreignkeys + ) + col = namedtuple("col", "schema tbl col") for fk, rtbl, rcol in fks: right = col(rtbl.schema, rtbl.name, rcol.name) child = col(fk.childschema, fk.childtable, fk.childcolumn) @@ -563,57 +616,66 @@ class PGCompleter(Completer): c = self.case if self.generate_aliases or normalize_ref(left.tbl) in refs: lref = self.alias(left.tbl, suggestion.table_refs) - join = '{0} {4} ON {4}.{1} = {2}.{3}'.format( - c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref) + join = "{0} {4} ON {4}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref + ) else: - join = '{0} ON {0}.{1} = {2}.{3}'.format( - c(left.tbl), c(left.col), rtbl.ref, c(right.col)) + join = "{0} ON {0}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col) + ) alias = generate_alias(self.case(left.tbl)) - synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format( - alias, c(left.col), rtbl.ref, c(right.col))] + synonyms = [ + join, + "{0} ON {0}.{1} = {2}.{3}".format( + alias, c(left.col), rtbl.ref, c(right.col) + ), + ] # Schema-qualify if (1) new table in same schema as old, and old # is schema-qualified, or (2) new in other schema, except public - if not suggestion.schema and (qualified[normalize_ref(rtbl.ref)] + if not suggestion.schema and ( + qualified[normalize_ref(rtbl.ref)] and left.schema == right.schema - or left.schema not in(right.schema, 'public')): - join = left.schema + '.' + join + or left.schema not in (right.schema, "public") + ): + join = left.schema + "." + join prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( - 0 if (left.schema, left.tbl) in other_tbls else 1) - joins.append(Candidate(join, prio, 'join', synonyms=synonyms)) + 0 if (left.schema, left.tbl) in other_tbls else 1 + ) + joins.append(Candidate(join, prio, "join", synonyms=synonyms)) - return self.find_matches(word_before_cursor, joins, meta='join') + return self.find_matches(word_before_cursor, joins, meta="join") def get_join_condition_matches(self, suggestion, word_before_cursor): - col = namedtuple('col', 'schema tbl col') + col = namedtuple("col", "schema tbl col") tbls = self.populate_scoped_cols(suggestion.table_refs).items cols = [(t, c) for t, cs in tbls() for c in cs] try: lref = (suggestion.parent or suggestion.table_refs[-1]).ref ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] - except IndexError: # The user typed an incorrect table qualifier + except IndexError: # The user typed an incorrect table qualifier return [] conds, found_conds = [], set() def add_cond(lcol, rcol, rref, prio, meta): - prefix = '' if suggestion.parent else ltbl.ref + '.' + prefix = "" if suggestion.parent else ltbl.ref + "." case = self.case - cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol) + cond = prefix + case(lcol) + " = " + rref + "." + case(rcol) if cond not in found_conds: found_conds.add(cond) conds.append(Candidate(cond, prio + ref_prio[rref], meta)) - def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} + def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} d = defaultdict(list) for pair in pairs: d[pair[0]].append(pair[1]) return d # Tables that are closer to the cursor get higher prio - ref_prio = dict((tbl.ref, num) for num, tbl - in enumerate(suggestion.table_refs)) + ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs)) # Map (schema, table, col) to tables - coldict = list_dict(((t.schema, t.name, c.name), t) - for t, c in cols if t.ref != lref) + coldict = list_dict( + ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref + ) # For each fk from the left table, generate a join condition if # the other table is also in the scope fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) @@ -623,81 +685,80 @@ class PGCompleter(Completer): par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) left, right = (child, par) if left == child else (par, child) for rtbl in coldict[right]: - add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join') + add_cond(left.col, right.col, rtbl.ref, 2000, "fk join") # For name matching, use a {(colname, coltype): TableReference} dict - coltyp = namedtuple('coltyp', 'name datatype') + coltyp = namedtuple("coltyp", "name datatype") col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) # Find all name-match join conditions for c in (coltyp(c.name, c.datatype) for c in lcols): for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): - prio = 1000 if c.datatype in ( - 'integer', 'bigint', 'smallint') else 0 - add_cond(c.name, c.name, rtbl.ref, prio, 'name join') + prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0 + add_cond(c.name, c.name, rtbl.ref, prio, "name join") - return self.find_matches(word_before_cursor, conds, meta='join') + return self.find_matches(word_before_cursor, conds, meta="join") def get_function_matches(self, suggestion, word_before_cursor, alias=False): - if suggestion.usage == 'from': + if suggestion.usage == "from": # Only suggest functions allowed in FROM clause def filt(f): - return (not f.is_aggregate and - not f.is_window and - not f.is_extension and - (f.is_public or f.schema_name == suggestion.schema)) + return ( + not f.is_aggregate + and not f.is_window + and not f.is_extension + and (f.is_public or f.schema_name == suggestion.schema) + ) + else: alias = False def filt(f): - return (not f.is_extension and - (f.is_public or f.schema_name == suggestion.schema)) + return not f.is_extension and ( + f.is_public or f.schema_name == suggestion.schema + ) - arg_mode = { - 'signature': 'signature', - 'special': None, - }.get(suggestion.usage, 'call') + arg_mode = {"signature": "signature", "special": None}.get( + suggestion.usage, "call" + ) # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only all_functions = self.populate_functions(suggestion.schema, filt) funcs = set( - self._make_cand(f, alias, suggestion, arg_mode) - for f in all_functions + self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions ) - matches = self.find_matches(word_before_cursor, funcs, meta='function') + matches = self.find_matches(word_before_cursor, funcs, meta="function") if not suggestion.schema and not suggestion.usage: # also suggest hardcoded functions using startswith matching predefined_funcs = self.find_matches( - word_before_cursor, self.functions, mode='strict', - meta='function') + word_before_cursor, self.functions, mode="strict", meta="function" + ) matches.extend(predefined_funcs) return matches def get_schema_matches(self, suggestion, word_before_cursor): - schema_names = self.dbmetadata['tables'].keys() + schema_names = self.dbmetadata["tables"].keys() # Unless we're sure the user really wants them, hide schema names # starting with pg_, which are mostly temporary schemas - if not word_before_cursor.startswith('pg_'): - schema_names = [s - for s in schema_names - if not s.startswith('pg_')] + if not word_before_cursor.startswith("pg_"): + schema_names = [s for s in schema_names if not s.startswith("pg_")] if suggestion.quoted: schema_names = [self.escape_schema(s) for s in schema_names] - return self.find_matches(word_before_cursor, schema_names, meta='schema') + return self.find_matches(word_before_cursor, schema_names, meta="schema") def get_from_clause_item_matches(self, suggestion, word_before_cursor): alias = self.generate_aliases s = suggestion t_sug = Table(s.schema, s.table_refs, s.local_tables) v_sug = View(s.schema, s.table_refs) - f_sug = Function(s.schema, s.table_refs, usage='from') + f_sug = Function(s.schema, s.table_refs, usage="from") return ( self.get_table_matches(t_sug, word_before_cursor, alias) + self.get_view_matches(v_sug, word_before_cursor, alias) @@ -712,43 +773,43 @@ class PGCompleter(Completer): """ template = { - 'call': self.call_arg_style, - 'call_display': self.call_arg_display_style, - 'signature': self.signature_arg_style + "call": self.call_arg_style, + "call_display": self.call_arg_display_style, + "signature": self.signature_arg_style, }[usage] args = func.args() if not template: - return '()' - elif usage == 'call' and len(args) < 2: - return '()' - elif usage == 'call' and func.has_variadic(): - return '()' - multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max + return "()" + elif usage == "call" and len(args) < 2: + return "()" + elif usage == "call" and func.has_variadic(): + return "()" + multiline = usage == "call" and len(args) > self.call_arg_oneliner_max max_arg_len = max(len(a.name) for a in args) if multiline else 0 args = ( self._format_arg(template, arg, arg_num + 1, max_arg_len) for arg_num, arg in enumerate(args) ) if multiline: - return '(' + ','.join('\n ' + a for a in args if a) + '\n)' + return "(" + ",".join("\n " + a for a in args if a) + "\n)" else: - return '(' + ', '.join(a for a in args if a) + ')' + return "(" + ", ".join(a for a in args if a) + ")" def _format_arg(self, template, arg, arg_num, max_arg_len): if not template: return None if arg.has_default: - arg_default = 'NULL' if arg.default is None else arg.default + arg_default = "NULL" if arg.default is None else arg.default # Remove trailing ::(schema.)type - arg_default = arg_default_type_strip_regex.sub('', arg_default) + arg_default = arg_default_type_strip_regex.sub("", arg_default) else: - arg_default = '' + arg_default = "" return template.format( max_arg_len=max_arg_len, arg_name=self.case(arg.name), arg_num=arg_num, arg_type=arg.datatype, - arg_default=arg_default + arg_default=arg_default, ) def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): @@ -763,53 +824,49 @@ class PGCompleter(Completer): if do_alias: alias = self.alias(cased_tbl, suggestion.table_refs) synonyms = (cased_tbl, generate_alias(cased_tbl)) - maybe_alias = (' ' + alias) if do_alias else '' - maybe_schema = (self.case(tbl.schema) + '.') if tbl.schema else '' - suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else '' - if arg_mode == 'call': - display_suffix = self._arg_list_cache['call_display'][tbl.meta] - elif arg_mode == 'signature': - display_suffix = self._arg_list_cache['signature'][tbl.meta] + maybe_alias = (" " + alias) if do_alias else "" + maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else "" + suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else "" + if arg_mode == "call": + display_suffix = self._arg_list_cache["call_display"][tbl.meta] + elif arg_mode == "signature": + display_suffix = self._arg_list_cache["signature"][tbl.meta] else: - display_suffix = '' + display_suffix = "" item = maybe_schema + cased_tbl + suffix + maybe_alias display = maybe_schema + cased_tbl + display_suffix + maybe_alias prio2 = 0 if tbl.schema else 1 return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) def get_table_matches(self, suggestion, word_before_cursor, alias=False): - tables = self.populate_schema_objects(suggestion.schema, 'tables') + tables = self.populate_schema_objects(suggestion.schema, "tables") tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables) # Unless we're sure the user really wants them, don't suggest the # pg_catalog tables that are implicitly on the search path - if not suggestion.schema and ( - not word_before_cursor.startswith('pg_')): - tables = [t for t in tables if not t.name.startswith('pg_')] + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + tables = [t for t in tables if not t.name.startswith("pg_")] tables = [self._make_cand(t, alias, suggestion) for t in tables] - return self.find_matches(word_before_cursor, tables, meta='table') + return self.find_matches(word_before_cursor, tables, meta="table") def get_table_formats(self, _, word_before_cursor): formats = TabularOutputFormatter().supported_formats - return self.find_matches(word_before_cursor, formats, meta='table format') + return self.find_matches(word_before_cursor, formats, meta="table format") def get_view_matches(self, suggestion, word_before_cursor, alias=False): - views = self.populate_schema_objects(suggestion.schema, 'views') + views = self.populate_schema_objects(suggestion.schema, "views") - if not suggestion.schema and ( - not word_before_cursor.startswith('pg_')): - views = [v for v in views if not v.name.startswith('pg_')] + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + views = [v for v in views if not v.name.startswith("pg_")] views = [self._make_cand(v, alias, suggestion) for v in views] - return self.find_matches(word_before_cursor, views, meta='view') + return self.find_matches(word_before_cursor, views, meta="view") def get_alias_matches(self, suggestion, word_before_cursor): aliases = suggestion.aliases - return self.find_matches(word_before_cursor, aliases, - meta='table alias') + return self.find_matches(word_before_cursor, aliases, meta="table alias") def get_database_matches(self, _, word_before_cursor): - return self.find_matches(word_before_cursor, self.databases, - meta='database') + return self.find_matches(word_before_cursor, self.databases, meta="database") def get_keyword_matches(self, suggestion, word_before_cursor): keywords = self.keywords_tree.keys() @@ -820,24 +877,26 @@ class PGCompleter(Completer): keywords = next_keywords casing = self.keyword_casing - if casing == 'auto': + if casing == "auto": if word_before_cursor and word_before_cursor[-1].islower(): - casing = 'lower' + casing = "lower" else: - casing = 'upper' + casing = "upper" - if casing == 'upper': + if casing == "upper": keywords = [k.upper() for k in keywords] else: keywords = [k.lower() for k in keywords] - return self.find_matches(word_before_cursor, keywords, - mode='strict', meta='keyword') + return self.find_matches( + word_before_cursor, keywords, mode="strict", meta="keyword" + ) def get_path_matches(self, _, word_before_cursor): completer = PathCompleter(expanduser=True) - document = Document(text=word_before_cursor, - cursor_position=len(word_before_cursor)) + document = Document( + text=word_before_cursor, cursor_position=len(word_before_cursor) + ) for c in completer.get_completions(document, None): yield Match(completion=c, priority=(0,)) @@ -848,24 +907,28 @@ class PGCompleter(Completer): commands = self.pgspecial.commands cmds = commands.keys() cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds] - return self.find_matches(word_before_cursor, cmds, mode='strict') + return self.find_matches(word_before_cursor, cmds, mode="strict") def get_datatype_matches(self, suggestion, word_before_cursor): # suggest custom datatypes - types = self.populate_schema_objects(suggestion.schema, 'datatypes') + types = self.populate_schema_objects(suggestion.schema, "datatypes") types = [self._make_cand(t, False, suggestion) for t in types] - matches = self.find_matches(word_before_cursor, types, meta='datatype') + matches = self.find_matches(word_before_cursor, types, meta="datatype") if not suggestion.schema: # Also suggest hardcoded types - matches.extend(self.find_matches(word_before_cursor, self.datatypes, - mode='strict', meta='datatype')) + matches.extend( + self.find_matches( + word_before_cursor, self.datatypes, mode="strict", meta="datatype" + ) + ) return matches def get_namedquery_matches(self, _, word_before_cursor): return self.find_matches( - word_before_cursor, NamedQueries.instance.list(), meta='named query') + word_before_cursor, NamedQueries.instance.list(), meta="named query" + ) suggestion_matchers = { FromClauseItem: get_from_clause_item_matches, @@ -899,7 +962,7 @@ class PGCompleter(Completer): meta = self.dbmetadata def addcols(schema, rel, alias, reltype, cols): - tbl = TableReference(schema, rel, alias, reltype == 'functions') + tbl = TableReference(schema, rel, alias, reltype == "functions") if tbl not in columns: columns[tbl] = [] columns[tbl].extend(cols) @@ -908,22 +971,22 @@ class PGCompleter(Completer): # Local tables should shadow database tables if tbl.schema is None and normalize_ref(tbl.name) in ctes: cols = ctes[normalize_ref(tbl.name)] - addcols(None, tbl.name, 'CTE', tbl.alias, cols) + addcols(None, tbl.name, "CTE", tbl.alias, cols) continue schemas = [tbl.schema] if tbl.schema else self.search_path for schema in schemas: relname = self.escape_name(tbl.name) schema = self.escape_name(schema) if tbl.is_function: - # Return column names from a set-returning function - # Get an array of FunctionMetadata objects - functions = meta['functions'].get(schema, {}).get(relname) - for func in (functions or []): + # Return column names from a set-returning function + # Get an array of FunctionMetadata objects + functions = meta["functions"].get(schema, {}).get(relname) + for func in functions or []: # func is a FunctionMetadata object cols = func.fields() - addcols(schema, relname, tbl.alias, 'functions', cols) + addcols(schema, relname, tbl.alias, "functions", cols) else: - for reltype in ('tables', 'views'): + for reltype in ("tables", "views"): cols = meta[reltype].get(schema, {}).get(relname) if cols: cols = cols.values() @@ -956,8 +1019,7 @@ class PGCompleter(Completer): return [ SchemaObject( - name=obj, - schema=(self._maybe_schema(schema=sch, parent=schema)) + name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) ) for sch in self._get_schemas(obj_type, schema) for obj in self.dbmetadata[obj_type][sch].keys() @@ -979,10 +1041,10 @@ class PGCompleter(Completer): SchemaObject( name=func, schema=(self._maybe_schema(schema=sch, parent=schema)), - meta=meta + meta=meta, ) - for sch in self._get_schemas('functions', schema) - for (func, metas) in self.dbmetadata['functions'][sch].items() + for sch in self._get_schemas("functions", schema) + for (func, metas) in self.dbmetadata["functions"][sch].items() for meta in metas if filter_func(meta) ] diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 337be145..ad5ed4a3 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -24,7 +24,7 @@ ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE)) # Cast bytea fields to text. By default, this will render as hex strings with # Postgres 9+ and as escaped binary in earlier versions. -ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING)) +ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) # TODO: Get default timeout from pgclirc? _WAIT_SELECT_TIMEOUT = 1 @@ -55,6 +55,7 @@ def _wait_select(conn): if errno != 4: raise + # When running a query, make pressing CTRL+C raise a KeyboardInterrupt # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ # See also https://github.com/psycopg/psycopg2/issues/468 @@ -66,17 +67,19 @@ def register_date_typecasters(connection): Casts date and timestamp values to string, resolves issues with out of range dates (e.g. BC) which psycopg2 can't handle """ + def cast_date(value, cursor): return value + cursor = connection.cursor() - cursor.execute('SELECT NULL::date') + cursor.execute("SELECT NULL::date") date_oid = cursor.description[0][1] - cursor.execute('SELECT NULL::timestamp') + cursor.execute("SELECT NULL::timestamp") timestamp_oid = cursor.description[0][1] - cursor.execute('SELECT NULL::timestamp with time zone') + cursor.execute("SELECT NULL::timestamp with time zone") timestamptz_oid = cursor.description[0][1] oids = (date_oid, timestamp_oid, timestamptz_oid) - new_type = psycopg2.extensions.new_type(oids, 'DATE', cast_date) + new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date) psycopg2.extensions.register_type(new_type) @@ -97,7 +100,7 @@ def register_json_typecasters(conn, loads_fn): """ available = set() - for name in ['json', 'jsonb']: + for name in ["json", "jsonb"]: try: psycopg2.extras.register_json(conn, loads=loads_fn, name=name) available.add(name) @@ -117,7 +120,8 @@ def register_hstore_typecaster(conn): with conn.cursor() as cur: try: cur.execute( - "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined") + "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined" + ) oid = cur.fetchone()[0] ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) except Exception: @@ -128,29 +132,29 @@ class PGExecute(object): # The boolean argument to the current_schemas function indicates whether # implicit schemas, e.g. pg_catalog - search_path_query = ''' - SELECT * FROM unnest(current_schemas(true))''' + search_path_query = """ + SELECT * FROM unnest(current_schemas(true))""" - schemata_query = ''' + schemata_query = """ SELECT nspname FROM pg_catalog.pg_namespace - ORDER BY 1 ''' + ORDER BY 1 """ - tables_query = ''' + tables_query = """ SELECT n.nspname schema_name, c.relname table_name FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind = ANY(%s) - ORDER BY 1,2;''' + ORDER BY 1,2;""" - databases_query = ''' + databases_query = """ SELECT d.datname FROM pg_catalog.pg_database d - ORDER BY 1''' + ORDER BY 1""" - full_databases_query = ''' + full_databases_query = """ SELECT d.datname as "Name", pg_catalog.pg_get_userbyid(d.datdba) as "Owner", pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding", @@ -158,15 +162,15 @@ class PGExecute(object): d.datctype as "Ctype", pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges" FROM pg_catalog.pg_database d - ORDER BY 1''' + ORDER BY 1""" - socket_directory_query = ''' + socket_directory_query = """ SELECT setting FROM pg_settings WHERE name = 'unix_socket_directories' - ''' + """ - view_definition_query = ''' + view_definition_query = """ WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid) SELECT nspname, relname, relkind, pg_catalog.pg_get_viewdef(c.oid, true), @@ -179,18 +183,26 @@ class PGExecute(object): END AS checkoption FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid) - JOIN v ON (c.oid = v.v_oid)''' + JOIN v ON (c.oid = v.v_oid)""" - function_definition_query = ''' + function_definition_query = """ WITH f AS (SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid) SELECT pg_catalog.pg_get_functiondef(f.f_oid) - FROM f''' + FROM f""" version_query = "SELECT version();" - def __init__(self, database=None, user=None, password=None, host=None, - port=None, dsn=None, **kwargs): + def __init__( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs + ): self._conn_params = {} self.conn = None self.dbname = None @@ -207,10 +219,10 @@ class PGExecute(object): return self.__class__(**self._conn_params) def get_server_version(self, cursor): - _logger.debug('Version Query. sql: %r', self.version_query) + _logger.debug("Version Query. sql: %r", self.version_query) cursor.execute(self.version_query) result = cursor.fetchone() - server_version = '' + server_version = "" if result: # full version string looks like this: # PostgreSQL 10.3 on x86_64-apple-darwin17.3.0, compiled by Apple LLVM version 9.0.0 (clang-900.0.39.2), 64-bit # noqa @@ -219,38 +231,42 @@ class PGExecute(object): server_version = version_parts[1] return server_version - def connect(self, database=None, user=None, password=None, host=None, - port=None, dsn=None, **kwargs): + def connect( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs + ): conn_params = self._conn_params.copy() new_params = { - 'database': database, - 'user': user, - 'password': password, - 'host': host, - 'port': port, - 'dsn': dsn, + "database": database, + "user": user, + "password": password, + "host": host, + "port": port, + "dsn": dsn, } new_params.update(kwargs) - if new_params['dsn']: - new_params = { - 'dsn': new_params['dsn'], - 'password': new_params['password'] - } + if new_params["dsn"]: + new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} - if new_params['password']: - new_params['dsn'] = make_dsn( - new_params['dsn'], password=new_params.pop('password')) + if new_params["password"]: + new_params["dsn"] = make_dsn( + new_params["dsn"], password=new_params.pop("password") + ) - conn_params.update({ - k: unicode2utf8(v) for k, v in new_params.items() if v - }) + conn_params.update({k: unicode2utf8(v) for k, v in new_params.items() if v}) conn = psycopg2.connect(**conn_params) cursor = conn.cursor() - conn.set_client_encoding('utf8') + conn.set_client_encoding("utf8") self._conn_params = conn_params if self.conn: @@ -264,11 +280,11 @@ class PGExecute(object): # TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa dsn_parameters = conn.get_dsn_parameters() - self.dbname = dsn_parameters.get('dbname') - self.user = dsn_parameters.get('user') + self.dbname = dsn_parameters.get("dbname") + self.user = dsn_parameters.get("user") self.password = password - self.host = dsn_parameters.get('host') - self.port = dsn_parameters.get('port') + self.host = dsn_parameters.get("host") + self.port = dsn_parameters.get("port") self.extra_args = kwargs if not self.host: @@ -277,9 +293,9 @@ class PGExecute(object): cursor.execute("SHOW ALL") db_parameters = dict(name_val_desc[:2] for name_val_desc in cursor.fetchall()) - pid = self._select_one(cursor, 'select pg_backend_pid()')[0] + pid = self._select_one(cursor, "select pg_backend_pid()")[0] self.pid = pid - self.superuser = db_parameters.get('is_superuser') == '1' + self.superuser = db_parameters.get("is_superuser") == "1" self.server_version = self.get_server_version(cursor) @@ -289,11 +305,11 @@ class PGExecute(object): @property def short_host(self): - if ',' in self.host: - host, _, _ = self.host.partition(',') + if "," in self.host: + host, _, _ = self.host.partition(",") else: host = self.host - short_host, _, _ = host.partition('.') + short_host, _, _ = host.partition(".") return short_host def _select_one(self, cur, sql): @@ -326,11 +342,14 @@ class PGExecute(object): def valid_transaction(self): status = self.conn.get_transaction_status() - return (status == ext.TRANSACTION_STATUS_ACTIVE or - status == ext.TRANSACTION_STATUS_INTRANS) + return ( + status == ext.TRANSACTION_STATUS_ACTIVE + or status == ext.TRANSACTION_STATUS_INTRANS + ) - def run(self, statement, pgspecial=None, exception_formatter=None, - on_error_resume=False): + def run( + self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False + ): """Execute the sql in the database and return the results. :param statement: A string containing one or more sql statements @@ -355,12 +374,12 @@ class PGExecute(object): # Split the sql into separate queries and run each one. for sql in sqlparse.split(statement): # Remove spaces, eol and semi-colons. - sql = sql.rstrip(';') + sql = sql.rstrip(";") try: if pgspecial: # First try to run each query as special - _logger.debug('Trying a pgspecial command. sql: %r', sql) + _logger.debug("Trying a pgspecial command. sql: %r", sql) try: cur = self.conn.cursor() except psycopg2.InterfaceError: @@ -385,8 +404,7 @@ class PGExecute(object): _logger.error("sql: %r, error: %r", sql, e) _logger.error("traceback: %r", traceback.format_exc()) - if (self._must_raise(e) - or not exception_formatter): + if self._must_raise(e) or not exception_formatter: raise yield None, None, None, exception_formatter(e), sql, False, False @@ -410,12 +428,12 @@ class PGExecute(object): def execute_normal_sql(self, split_sql): """Returns tuple (title, rows, headers, status)""" - _logger.debug('Regular sql statement. sql: %r', split_sql) + _logger.debug("Regular sql statement. sql: %r", split_sql) cur = self.conn.cursor() cur.execute(split_sql) # conn.notices persist between queies, we use pop to clear out the list - title = '' + title = "" while len(self.conn.notices) > 0: title = utf8tounicode(self.conn.notices.pop()) + title @@ -425,7 +443,7 @@ class PGExecute(object): headers = [x[0] for x in cur.description] return title, cur, headers, cur.statusmessage else: - _logger.debug('No rows in result.') + _logger.debug("No rows in result.") return title, None, None, cur.statusmessage def search_path(self): @@ -433,33 +451,32 @@ class PGExecute(object): try: with self.conn.cursor() as cur: - _logger.debug('Search path query. sql: %r', self.search_path_query) + _logger.debug("Search path query. sql: %r", self.search_path_query) cur.execute(self.search_path_query) return [x[0] for x in cur.fetchall()] except psycopg2.ProgrammingError: - fallback = 'SELECT * FROM current_schemas(true)' + fallback = "SELECT * FROM current_schemas(true)" with self.conn.cursor() as cur: - _logger.debug('Search path query. sql: %r', fallback) + _logger.debug("Search path query. sql: %r", fallback) cur.execute(fallback) return cur.fetchone()[0] def view_definition(self, spec): """Returns the SQL defining views described by `spec`""" - template = 'CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}' + template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}" # 2: relkind, v or m (materialized) # 4: reloptions, null # 5: checkoption: local or cascaded with self.conn.cursor() as cur: sql = self.view_definition_query - _logger.debug('View Definition Query. sql: %r\nspec: %r', - sql, spec) + _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec) try: - cur.execute(sql, (spec, )) + cur.execute(sql, (spec,)) except psycopg2.ProgrammingError: - raise RuntimeError('View {} does not exist.'.format(spec)) + raise RuntimeError("View {} does not exist.".format(spec)) result = cur.fetchone() - view_type = 'MATERIALIZED' if result[2] == 'm' else '' + view_type = "MATERIALIZED" if result[2] == "m" else "" return template.format(*result + (view_type,)) def function_definition(self, spec): @@ -467,24 +484,23 @@ class PGExecute(object): with self.conn.cursor() as cur: sql = self.function_definition_query - _logger.debug('Function Definition Query. sql: %r\nspec: %r', - sql, spec) + _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec) try: cur.execute(sql, (spec,)) result = cur.fetchone() return result[0] except psycopg2.ProgrammingError: - raise RuntimeError('Function {} does not exist.'.format(spec)) + raise RuntimeError("Function {} does not exist.".format(spec)) def schemata(self): """Returns a list of schema names in the database""" with self.conn.cursor() as cur: - _logger.debug('Schemata Query. sql: %r', self.schemata_query) + _logger.debug("Schemata Query. sql: %r", self.schemata_query) cur.execute(self.schemata_query) return [x[0] for x in cur.fetchall()] - def _relations(self, kinds=('r', 'v', 'm')): + def _relations(self, kinds=("r", "v", "m")): """Get table or view name metadata :param kinds: list of postgres relkind filters: @@ -496,14 +512,14 @@ class PGExecute(object): with self.conn.cursor() as cur: sql = cur.mogrify(self.tables_query, [kinds]) - _logger.debug('Tables Query. sql: %r', sql) + _logger.debug("Tables Query. sql: %r", sql) cur.execute(sql) for row in cur: yield row def tables(self): """Yields (schema_name, table_name) tuples""" - for row in self._relations(kinds=['r']): + for row in self._relations(kinds=["r"]): yield row def views(self): @@ -511,10 +527,10 @@ class PGExecute(object): Includes both views and and materialized views """ - for row in self._relations(kinds=['v', 'm']): + for row in self._relations(kinds=["v", "m"]): yield row - def _columns(self, kinds=('r', 'v', 'm')): + def _columns(self, kinds=("r", "v", "m")): """Get column metadata for tables and views :param kinds: kinds: list of postgres relkind filters: @@ -525,7 +541,7 @@ class PGExecute(object): """ if self.conn.server_version >= 80400: - columns_query = ''' + columns_query = """ SELECT nsp.nspname schema_name, cls.relname table_name, att.attname column_name, @@ -543,9 +559,9 @@ class PGExecute(object): WHERE cls.relkind = ANY(%s) AND NOT att.attisdropped AND att.attnum > 0 - ORDER BY 1, 2, att.attnum''' + ORDER BY 1, 2, att.attnum""" else: - columns_query = ''' + columns_query = """ SELECT nsp.nspname schema_name, cls.relname table_name, att.attname column_name, @@ -562,44 +578,44 @@ class PGExecute(object): WHERE cls.relkind = ANY(%s) AND NOT att.attisdropped AND att.attnum > 0 - ORDER BY 1, 2, att.attnum''' + ORDER BY 1, 2, att.attnum""" with self.conn.cursor() as cur: sql = cur.mogrify(columns_query, [kinds]) - _logger.debug('Columns Query. sql: %r', sql) + _logger.debug("Columns Query. sql: %r", sql) cur.execute(sql) for row in cur: yield row def table_columns(self): - for row in self._columns(kinds=['r']): + for row in self._columns(kinds=["r"]): yield row def view_columns(self): - for row in self._columns(kinds=['v', 'm']): + for row in self._columns(kinds=["v", "m"]): yield row def databases(self): with self.conn.cursor() as cur: - _logger.debug('Databases Query. sql: %r', self.databases_query) + _logger.debug("Databases Query. sql: %r", self.databases_query) cur.execute(self.databases_query) return [x[0] for x in cur.fetchall()] def full_databases(self): with self.conn.cursor() as cur: - _logger.debug('Databases Query. sql: %r', - self.full_databases_query) + _logger.debug("Databases Query. sql: %r", self.full_databases_query) cur.execute(self.full_databases_query) headers = [x[0] for x in cur.description] return cur.fetchall(), headers, cur.statusmessage def get_socket_directory(self): with self.conn.cursor() as cur: - _logger.debug('Socket directory Query. sql: %r', - self.socket_directory_query) + _logger.debug( + "Socket directory Query. sql: %r", self.socket_directory_query + ) cur.execute(self.socket_directory_query) result = cur.fetchone() - return result[0] if result else '' + return result[0] if result else "" def foreignkeys(self): """Yields ForeignKey named tuples""" @@ -608,7 +624,7 @@ class PGExecute(object): return with self.conn.cursor() as cur: - query = ''' + query = """ SELECT s_p.nspname AS parentschema, t_p.relname AS parenttable, unnest(( @@ -635,8 +651,8 @@ class PGExecute(object): JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace WHERE fk.contype = 'f'; - ''' - _logger.debug('Functions Query. sql: %r', query) + """ + _logger.debug("Functions Query. sql: %r", query) cur.execute(query) for row in cur: yield ForeignKey(*row) @@ -645,7 +661,7 @@ class PGExecute(object): """Yields FunctionMetadata named tuples""" if self.conn.server_version >= 110000: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -663,9 +679,9 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ elif self.conn.server_version > 90000: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -683,9 +699,9 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ elif self.conn.server_version >= 80400: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -703,9 +719,9 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ else: - query = ''' + query = """ SELECT n.nspname schema_name, p.proname func_name, p.proargnames, @@ -723,20 +739,20 @@ class PGExecute(object): LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 - ''' + """ with self.conn.cursor() as cur: - _logger.debug('Functions Query. sql: %r', query) + _logger.debug("Functions Query. sql: %r", query) cur.execute(query) for row in cur: - yield FunctionMetadata(*row) + yield FunctionMetadata(*row) def datatypes(self): """Yields tuples of (schema_name, type_name)""" with self.conn.cursor() as cur: if self.conn.server_version > 90000: - query = ''' + query = """ SELECT n.nspname schema_name, t.typname type_name FROM pg_catalog.pg_type t @@ -757,9 +773,9 @@ class PGExecute(object): AND n.nspname <> 'pg_catalog' AND n.nspname <> 'information_schema' ORDER BY 1, 2; - ''' + """ else: - query = ''' + query = """ SELECT n.nspname schema_name, pg_catalog.format_type(t.oid, NULL) type_name FROM pg_catalog.pg_type t @@ -770,8 +786,8 @@ class PGExecute(object): AND n.nspname <> 'information_schema' AND pg_catalog.pg_type_is_visible(t.oid) ORDER BY 1, 2; - ''' - _logger.debug('Datatypes Query. sql: %r', query) + """ + _logger.debug("Datatypes Query. sql: %r", query) cur.execute(query) for row in cur: yield row @@ -779,7 +795,7 @@ class PGExecute(object): def casing(self): """Yields the most common casing for names used in db functions""" with self.conn.cursor() as cur: - query = r''' + query = r""" WITH Words AS ( SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1) FROM pg_catalog.pg_proc P @@ -819,8 +835,8 @@ class PGExecute(object): FROM OrderWords WHERE LOWER(Word) IN (SELECT Name FROM Names) AND Row_Number = 1; - ''' - _logger.debug('Casing Query. sql: %r', query) + """ + _logger.debug("Casing Query. sql: %r", query) cur.execute(query) for row in cur: yield row[0] diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py index bde464c1..0355599b 100644 --- a/pgcli/pgstyle.py +++ b/pgcli/pgstyle.py @@ -13,35 +13,33 @@ logger = logging.getLogger(__name__) # map Pygments tokens (ptk 1.0) to class names (ptk 2.0). TOKEN_TO_PROMPT_STYLE = { - Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current', - Token.Menu.Completions.Completion: 'completion-menu.completion', - Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current', - Token.Menu.Completions.Meta: 'completion-menu.meta.completion', - Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta', - Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess - Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess - Token.SelectedText: 'selected', - Token.SearchMatch: 'search', - Token.SearchMatch.Current: 'search.current', - Token.Toolbar: 'bottom-toolbar', - Token.Toolbar.Off: 'bottom-toolbar.off', - Token.Toolbar.On: 'bottom-toolbar.on', - Token.Toolbar.Search: 'search-toolbar', - Token.Toolbar.Search.Text: 'search-toolbar.text', - Token.Toolbar.System: 'system-toolbar', - Token.Toolbar.Arg: 'arg-toolbar', - Token.Toolbar.Arg.Text: 'arg-toolbar.text', - Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid', - Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed', - Token.Output.Header: 'output.header', - Token.Output.OddRow: 'output.odd-row', - Token.Output.EvenRow: 'output.even-row', + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", } # reverse dict for cli_helpers, because they still expect Pygments tokens. -PROMPT_STYLE_TO_TOKEN = { - v: k for k, v in TOKEN_TO_PROMPT_STYLE.items() -} +PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} def parse_pygments_style(token_name, style_object, style_dict): @@ -64,52 +62,48 @@ def style_factory(name, cli_style): try: style = pygments.styles.get_style_by_name(name) except ClassNotFound: - style = pygments.styles.get_style_by_name('native') + style = pygments.styles.get_style_by_name("native") prompt_styles = [] # prompt-toolkit used pygments tokens for styling before, switched to style # names in 2.0. Convert old token types to new style names, for backwards compatibility. for token in cli_style: - if token.startswith('Token.'): + if token.startswith("Token."): # treat as pygments token (1.0) - token_type, style_value = parse_pygments_style( - token, style, cli_style) + token_type, style_value = parse_pygments_style(token, style, cli_style) if token_type in TOKEN_TO_PROMPT_STYLE: prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] prompt_styles.append((prompt_style, style_value)) else: # we don't want to support tokens anymore - logger.error('Unhandled style / class name: %s', token) + logger.error("Unhandled style / class name: %s", token) else: # treat as prompt style name (2.0). See default style names here: # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py prompt_styles.append((token, cli_style[token])) - override_style = Style([('bottom-toolbar', 'noreverse')]) - return merge_styles([ - style_from_pygments_cls(style), - override_style, - Style(prompt_styles) - ]) + override_style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles( + [style_from_pygments_cls(style), override_style, Style(prompt_styles)] + ) def style_factory_output(name, cli_style): try: style = pygments.styles.get_style_by_name(name).styles except ClassNotFound: - style = pygments.styles.get_style_by_name('native').styles + style = pygments.styles.get_style_by_name("native").styles for token in cli_style: - if token.startswith('Token.'): - token_type, style_value = parse_pygments_style( - token, style, cli_style) + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) style.update({token_type: style_value}) elif token in PROMPT_STYLE_TO_TOKEN: token_type = PROMPT_STYLE_TO_TOKEN[token] style.update({token_type: cli_style[token]}) else: # TODO: cli helpers will have to switch to ptk.Style - logger.error('Unhandled style / class name: %s', token) + logger.error("Unhandled style / class name: %s", token) class OutputStyle(PygmentsStyle): default_style = "" diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py index 2fd978c1..fc3bb3b9 100644 --- a/pgcli/pgtoolbar.py +++ b/pgcli/pgtoolbar.py @@ -6,57 +6,58 @@ from prompt_toolkit.application import get_app def _get_vi_mode(): return { - InputMode.INSERT: 'I', - InputMode.NAVIGATION: 'N', - InputMode.REPLACE: 'R', - InputMode.INSERT_MULTIPLE: 'M', + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.INSERT_MULTIPLE: "M", }[get_app().vi_state.input_mode] def create_toolbar_tokens_func(pgcli): """Return a function that generates the toolbar tokens.""" + def get_toolbar_tokens(): result = [] - result.append(('class:bottom-toolbar', ' ')) + result.append(("class:bottom-toolbar", " ")) if pgcli.completer.smart_completion: - result.append(('class:bottom-toolbar.on', - '[F2] Smart Completion: ON ')) + result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON ")) else: - result.append(('class:bottom-toolbar.off', - '[F2] Smart Completion: OFF ')) + result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF ")) if pgcli.multi_line: - result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON ')) + result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) else: - result.append(('class:bottom-toolbar.off', - '[F3] Multiline: OFF ')) + result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) if pgcli.multi_line: - if pgcli.multiline_mode == 'safe': - result.append( - ('class:bottom-toolbar', ' ([Esc] [Enter] to execute]) ')) + if pgcli.multiline_mode == "safe": + result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) ")) else: result.append( - ('class:bottom-toolbar', ' (Semi-colon [;] will end the line) ')) + ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") + ) if pgcli.vi_mode: result.append( - ('class:bottom-toolbar', '[F4] Vi-mode (' + _get_vi_mode() + ')')) + ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ")") + ) else: - result.append(('class:bottom-toolbar', '[F4] Emacs-mode')) + result.append(("class:bottom-toolbar", "[F4] Emacs-mode")) if pgcli.pgexecute.failed_transaction(): - result.append(('class:bottom-toolbar.transaction.failed', - ' Failed transaction')) + result.append( + ("class:bottom-toolbar.transaction.failed", " Failed transaction") + ) if pgcli.pgexecute.valid_transaction(): result.append( - ('class:bottom-toolbar.transaction.valid', ' Transaction')) + ("class:bottom-toolbar.transaction.valid", " Transaction") + ) if pgcli.completion_refresher.is_refreshing(): - result.append( - ('class:bottom-toolbar', ' Refreshing completions...')) + result.append(("class:bottom-toolbar", " Refreshing completions...")) return result + return get_toolbar_tokens diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..4b11c266 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[tool.black] +line-length = 88 +target-version = ['py27'] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | \.cache + | \.pytest_cache + | _build + | buck-out + | build + | dist + | tests/data +)/ +''' + diff --git a/release.py b/release.py index 6526128d..9ac18d24 100644 --- a/release.py +++ b/release.py @@ -23,7 +23,7 @@ def skip_step(): global CONFIRM_STEPS if CONFIRM_STEPS: - return not click.confirm('--- Run this step?', default=True) + return not click.confirm("--- Run this step?", default=True) return False @@ -36,90 +36,100 @@ def run_step(*args): global DRY_RUN cmd = args - print(' '.join(cmd)) + print (" ".join(cmd)) if skip_step(): - print('--- Skipping...') + print ("--- Skipping...") elif DRY_RUN: - print('--- Pretending to run...') + print ("--- Pretending to run...") else: subprocess.check_output(cmd) def version(version_file): _version_re = re.compile( - r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)') + r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)' + ) - with io.open(version_file, encoding='utf-8') as f: - ver = _version_re.search(f.read()).group('version') + with io.open(version_file, encoding="utf-8") as f: + ver = _version_re.search(f.read()).group("version") return ver def commit_for_release(version_file, ver): - run_step('git', 'reset') - run_step('git', 'add', version_file) - run_step('git', 'commit', '--message', - 'Releasing version {}'.format(ver)) + run_step("git", "reset") + run_step("git", "add", version_file) + run_step("git", "commit", "--message", "Releasing version {}".format(ver)) def create_git_tag(tag_name): - run_step('git', 'tag', tag_name) + run_step("git", "tag", tag_name) def create_distribution_files(): - run_step('python', 'setup.py', 'clean', '--all', 'sdist', 'bdist_wheel') + run_step("python", "setup.py", "clean", "--all", "sdist", "bdist_wheel") def upload_distribution_files(): - run_step('twine', 'upload', 'dist/*') + run_step("twine", "upload", "dist/*") def push_to_github(): - run_step('git', 'push', 'origin', 'master') + run_step("git", "push", "origin", "master") def push_tags_to_github(): - run_step('git', 'push', '--tags', 'origin') + run_step("git", "push", "--tags", "origin") def checklist(questions): for question in questions: - if not click.confirm('--- {}'.format(question), default=False): + if not click.confirm("--- {}".format(question), default=False): sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": if DEBUG: subprocess.check_output = lambda x: x - checks = ['Have you updated the AUTHORS file?', - 'Have you updated the `Usage` section of the README?', - ] + checks = [ + "Have you updated the AUTHORS file?", + "Have you updated the `Usage` section of the README?", + ] checklist(checks) - ver = version('pgcli/__init__.py') - print('Releasing Version:', ver) + ver = version("pgcli/__init__.py") + print ("Releasing Version:", ver) parser = OptionParser() parser.add_option( - "-c", "--confirm-steps", action="store_true", dest="confirm_steps", - default=False, help=("Confirm every step. If the step is not " - "confirmed, it will be skipped.") + "-c", + "--confirm-steps", + action="store_true", + dest="confirm_steps", + default=False, + help=( + "Confirm every step. If the step is not " "confirmed, it will be skipped." + ), ) parser.add_option( - "-d", "--dry-run", action="store_true", dest="dry_run", - default=False, help="Print out, but not actually run any steps." + "-d", + "--dry-run", + action="store_true", + dest="dry_run", + default=False, + help="Print out, but not actually run any steps.", ) popts, pargs = parser.parse_args() CONFIRM_STEPS = popts.confirm_steps DRY_RUN = popts.dry_run - if not click.confirm('Are you sure?', default=False): + if not click.confirm("Are you sure?", default=False): sys.exit(1) - commit_for_release('pgcli/__init__.py', ver) - create_git_tag('v{}'.format(ver)) + commit_for_release("pgcli/__init__.py", ver) + create_git_tag("v{}".format(ver)) create_distribution_files() push_to_github() push_tags_to_github() diff --git a/requirements-dev.txt b/requirements-dev.txt index 575a7741..89b171d7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,11 +4,10 @@ mock>=1.0.1 tox>=1.9.2 behave>=1.2.4 pexpect==3.3 +pre-commit>=1.16.0 coverage==4.3.4 codecov>=1.5.1 docutils>=0.13.1 autopep8==1.3.3 -# we want the latest possible version of pep8radius -git+https://github.com/hayd/pep8radius.git click==6.7 twine==1.11.0 diff --git a/setup.py b/setup.py index 775c46ef..57d791ae 100644 --- a/setup.py +++ b/setup.py @@ -3,24 +3,25 @@ import ast import platform from setuptools import setup, find_packages -_version_re = re.compile(r'__version__\s+=\s+(.*)') +_version_re = re.compile(r"__version__\s+=\s+(.*)") -with open('pgcli/__init__.py', 'rb') as f: - version = str(ast.literal_eval(_version_re.search( - f.read().decode('utf-8')).group(1))) +with open("pgcli/__init__.py", "rb") as f: + version = str( + ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) + ) -description = 'CLI for Postgres Database. With auto-completion and syntax highlighting.' +description = "CLI for Postgres Database. With auto-completion and syntax highlighting." install_requirements = [ - 'pgspecial>=1.11.5', - 'click >= 4.1', - 'Pygments >= 2.0', # Pygments has to be Capitalcased. WTF? - 'prompt_toolkit>=2.0.6,<2.1.0', - 'psycopg2 >= 2.7.4,<2.8', - 'sqlparse >=0.3.0,<0.4', - 'configobj >= 5.0.6', - 'humanize >= 0.5.1', - 'cli_helpers[styles] >= 1.2.0', + "pgspecial>=1.11.5", + "click >= 4.1", + "Pygments >= 2.0", # Pygments has to be Capitalcased. WTF? + "prompt_toolkit>=2.0.6,<2.1.0", + "psycopg2 >= 2.7.4,<2.8", + "sqlparse >=0.3.0,<0.4", + "configobj >= 5.0.6", + "humanize >= 0.5.1", + "cli_helpers[styles] >= 1.2.0", ] @@ -28,45 +29,42 @@ install_requirements = [ # But this is not necessary in Windows since the password is never shown in the # task manager. Also setproctitle is a hard dependency to install in Windows, # so we'll only install it if we're not in Windows. -if platform.system() != 'Windows' and not platform.system().startswith("CYGWIN"): - install_requirements.append('setproctitle >= 1.1.9') +if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"): + install_requirements.append("setproctitle >= 1.1.9") setup( - name='pgcli', - author='Pgcli Core Team', - author_email='pgcli-dev@googlegroups.com', + name="pgcli", + author="Pgcli Core Team", + author_email="pgcli-dev@googlegroups.com", version=version, - license='BSD', - url='http://pgcli.com', + license="BSD", + url="http://pgcli.com", packages=find_packages(), - package_data={'pgcli': ['pgclirc', - 'packages/pgliterals/pgliterals.json']}, + package_data={"pgcli": ["pgclirc", "packages/pgliterals/pgliterals.json"]}, description=description, - long_description=open('README.rst').read(), + long_description=open("README.rst").read(), install_requires=install_requirements, - extras_require={ - 'keyring': ['keyring >= 12.2.0'], - }, - entry_points=''' + extras_require={"keyring": ["keyring >= 12.2.0"]}, + entry_points=""" [console_scripts] pgcli=pgcli.main:cli - ''', + """, classifiers=[ - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: Unix', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: SQL', - 'Topic :: Database', - 'Topic :: Database :: Front-Ends', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", ], ) diff --git a/tests/conftest.py b/tests/conftest.py index 26a40816..315e3de8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,22 @@ from __future__ import print_function import os import pytest -from utils import (POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD, create_db, db_connection, - drop_tables) +from utils import ( + POSTGRES_HOST, + POSTGRES_PORT, + POSTGRES_USER, + POSTGRES_PASSWORD, + create_db, + db_connection, + drop_tables, +) import pgcli.pgexecute @pytest.yield_fixture(scope="function") def connection(): - create_db('_test_db') - connection = db_connection('_test_db') + create_db("_test_db") + connection = db_connection("_test_db") yield connection drop_tables(connection) @@ -25,8 +32,14 @@ def cursor(connection): @pytest.fixture def executor(connection): - return pgcli.pgexecute.PGExecute(database='_test_db', user=POSTGRES_USER, host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, port=POSTGRES_PORT, dsn=None) + return pgcli.pgexecute.PGExecute( + database="_test_db", + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + dsn=None, + ) @pytest.fixture @@ -38,4 +51,4 @@ def exception_formatter(): def temp_config(tmpdir_factory): # this function runs on start of test session. # use temporary directory for config home so user config will not be used - os.environ['XDG_CONFIG_HOME'] = str(tmpdir_factory.mktemp('data')) + os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data")) diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py index e0b2035d..4eba6063 100644 --- a/tests/features/db_utils.py +++ b/tests/features/db_utils.py @@ -6,7 +6,9 @@ from psycopg2 import connect from psycopg2.extensions import AsIs -def create_db(hostname='localhost', username=None, password=None, dbname=None, port=None): +def create_db( + hostname="localhost", username=None, password=None, dbname=None, port=None +): """Create test database. :param hostname: string @@ -17,15 +19,15 @@ def create_db(hostname='localhost', username=None, password=None, dbname=None, p :return: """ - cn = create_cn(hostname, password, username, 'postgres', port) + cn = create_cn(hostname, password, username, "postgres", port) # ISOLATION_LEVEL_AUTOCOMMIT = 0 # Needed for DB creation. cn.set_isolation_level(0) with cn.cursor() as cr: - cr.execute('drop database if exists %s', (AsIs(dbname),)) - cr.execute('create database %s', (AsIs(dbname),)) + cr.execute("drop database if exists %s", (AsIs(dbname),)) + cr.execute("create database %s", (AsIs(dbname),)) cn.close() @@ -42,15 +44,15 @@ def create_cn(hostname, password, username, dbname, port): :param dbname: string :return: psycopg2.connection """ - cn = connect(host=hostname, user=username, database=dbname, - password=password, port=port) + cn = connect( + host=hostname, user=username, database=dbname, password=password, port=port + ) - print('Created connection: {0}.'.format(cn.dsn)) + print ("Created connection: {0}.".format(cn.dsn)) return cn -def drop_db(hostname='localhost', username=None, password=None, - dbname=None, port=None): +def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None): """ Drop database. :param hostname: string @@ -58,14 +60,14 @@ def drop_db(hostname='localhost', username=None, password=None, :param password: string :param dbname: string """ - cn = create_cn(hostname, password, username, 'postgres', port) + cn = create_cn(hostname, password, username, "postgres", port) # ISOLATION_LEVEL_AUTOCOMMIT = 0 # Needed for DB drop. cn.set_isolation_level(0) with cn.cursor() as cr: - cr.execute('drop database if exists %s', (AsIs(dbname),)) + cr.execute("drop database if exists %s", (AsIs(dbname),)) close_cn(cn) @@ -77,4 +79,4 @@ def close_cn(cn=None): """ if cn: cn.close() - print('Closed connection: {0}.'.format(cn.dsn)) + print ("Closed connection: {0}.".format(cn.dsn)) diff --git a/tests/features/environment.py b/tests/features/environment.py index 6b4d4241..0133ab01 100644 --- a/tests/features/environment.py +++ b/tests/features/environment.py @@ -18,113 +18,119 @@ from steps import wrappers def before_all(context): """Set env parameters.""" env_old = copy.deepcopy(dict(os.environ)) - os.environ['LINES'] = "100" - os.environ['COLUMNS'] = "100" - os.environ['PAGER'] = 'cat' - os.environ['EDITOR'] = 'ex' - os.environ['VISUAL'] = 'ex' + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["PAGER"] = "cat" + os.environ["EDITOR"] = "ex" + os.environ["VISUAL"] = "ex" context.package_root = os.path.abspath( - os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - fixture_dir = os.path.join( - context.package_root, 'tests/features/fixture_data') + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + ) + fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data") - print('package root:', context.package_root) - print('fixture dir:', fixture_dir) + print ("package root:", context.package_root) + print ("fixture dir:", fixture_dir) - os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, - '.coveragerc') + os.environ["COVERAGE_PROCESS_START"] = os.path.join( + context.package_root, ".coveragerc" + ) context.exit_sent = False - vi = '_'.join([str(x) for x in sys.version_info[:3]]) - db_name = context.config.userdata.get('pg_test_db', 'pgcli_behave_tests') - db_name_full = '{0}_{1}'.format(db_name, vi) + vi = "_".join([str(x) for x in sys.version_info[:3]]) + db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests") + db_name_full = "{0}_{1}".format(db_name, vi) # Store get params from config. context.conf = { - 'host': context.config.userdata.get( - 'pg_test_host', - os.getenv('PGHOST', 'localhost') + "host": context.config.userdata.get( + "pg_test_host", os.getenv("PGHOST", "localhost") ), - 'user': context.config.userdata.get( - 'pg_test_user', - os.getenv('PGUSER', 'postgres') + "user": context.config.userdata.get( + "pg_test_user", os.getenv("PGUSER", "postgres") ), - 'pass': context.config.userdata.get( - 'pg_test_pass', - os.getenv('PGPASSWORD', None) + "pass": context.config.userdata.get( + "pg_test_pass", os.getenv("PGPASSWORD", None) ), - 'port': context.config.userdata.get( - 'pg_test_port', - os.getenv('PGPORT', '5432') + "port": context.config.userdata.get( + "pg_test_port", os.getenv("PGPORT", "5432") ), - 'cli_command': ( - context.config.userdata.get('pg_cli_command', None) or - '{python} -c "{startup}"'.format( + "cli_command": ( + context.config.userdata.get("pg_cli_command", None) + or '{python} -c "{startup}"'.format( python=sys.executable, - startup='; '.join([ - "import coverage", - "coverage.process_startup()", - "import pgcli.main", - "pgcli.main.cli()"]))), - 'dbname': db_name_full, - 'dbname_tmp': db_name_full + '_tmp', - 'vi': vi, - 'pager_boundary': '---boundary---', + startup="; ".join( + [ + "import coverage", + "coverage.process_startup()", + "import pgcli.main", + "pgcli.main.cli()", + ] + ), + ) + ), + "dbname": db_name_full, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", } - os.environ['PAGER'] = "{0} {1} {2}".format( + os.environ["PAGER"] = "{0} {1} {2}".format( sys.executable, os.path.join(context.package_root, "tests/features/wrappager.py"), - context.conf['pager_boundary']) + context.conf["pager_boundary"], + ) # Store old env vars. context.pgenv = { - 'PGDATABASE': os.environ.get('PGDATABASE', None), - 'PGUSER': os.environ.get('PGUSER', None), - 'PGHOST': os.environ.get('PGHOST', None), - 'PGPASSWORD': os.environ.get('PGPASSWORD', None), - 'PGPORT': os.environ.get('PGPORT', None), - 'XDG_CONFIG_HOME': os.environ.get('XDG_CONFIG_HOME', None), - 'PGSERVICEFILE': os.environ.get('PGSERVICEFILE', None), + "PGDATABASE": os.environ.get("PGDATABASE", None), + "PGUSER": os.environ.get("PGUSER", None), + "PGHOST": os.environ.get("PGHOST", None), + "PGPASSWORD": os.environ.get("PGPASSWORD", None), + "PGPORT": os.environ.get("PGPORT", None), + "XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None), + "PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None), } # Set new env vars. - os.environ['PGDATABASE'] = context.conf['dbname'] - os.environ['PGUSER'] = context.conf['user'] - os.environ['PGHOST'] = context.conf['host'] - os.environ['PGPORT'] = context.conf['port'] - os.environ['PGSERVICEFILE'] = os.path.join( - fixture_dir, 'mock_pg_service.conf') + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGUSER"] = context.conf["user"] + os.environ["PGHOST"] = context.conf["host"] + os.environ["PGPORT"] = context.conf["port"] + os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf") - if context.conf['pass']: - os.environ['PGPASSWORD'] = context.conf['pass'] + if context.conf["pass"]: + os.environ["PGPASSWORD"] = context.conf["pass"] else: - if 'PGPASSWORD' in os.environ: - del os.environ['PGPASSWORD'] + if "PGPASSWORD" in os.environ: + del os.environ["PGPASSWORD"] - context.cn = dbutils.create_db(context.conf['host'], context.conf['user'], - context.conf['pass'], context.conf['dbname'], - context.conf['port']) + context.cn = dbutils.create_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) context.fixture_data = fixutils.read_fixture_files() # use temporary directory as config home - context.env_config_home = tempfile.mkdtemp(prefix='pgcli_home_') - os.environ['XDG_CONFIG_HOME'] = context.env_config_home + context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_") + os.environ["XDG_CONFIG_HOME"] = context.env_config_home show_env_changes(env_old, dict(os.environ)) def show_env_changes(env_old, env_new): """Print out all test-specific env values.""" - print('--- os.environ changed values: ---') + print ("--- os.environ changed values: ---") all_keys = set(list(env_old.keys()) + list(env_new.keys())) for k in sorted(all_keys): - old_value = env_old.get(k, '') - new_value = env_new.get(k, '') + old_value = env_old.get(k, "") + new_value = env_new.get(k, "") if new_value and old_value != new_value: - print('{}="{}"'.format(k, new_value)) - print('-' * 20) + print ('{}="{}"'.format(k, new_value)) + print ("-" * 20) def after_all(context): @@ -132,9 +138,13 @@ def after_all(context): Unset env parameters. """ dbutils.close_cn(context.cn) - dbutils.drop_db(context.conf['host'], context.conf['user'], - context.conf['pass'], context.conf['dbname'], - context.conf['port']) + dbutils.drop_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) # Remove temp config direcotry shutil.rmtree(context.env_config_home) @@ -152,7 +162,7 @@ def before_step(context, _): def before_scenario(context, scenario): - if scenario.name == 'list databases': + if scenario.name == "list databases": # not using the cli for that return wrappers.run_cli(context) @@ -161,19 +171,19 @@ def before_scenario(context, scenario): def after_scenario(context, scenario): """Cleans up after each scenario completes.""" - if hasattr(context, 'cli') and context.cli and not context.exit_sent: + if hasattr(context, "cli") and context.cli and not context.exit_sent: # Quit nicely. if not context.atprompt: dbname = context.currentdb - context.cli.expect_exact('{0}> '.format(dbname), timeout=15) - context.cli.sendcontrol('c') - context.cli.sendcontrol('d') + context.cli.expect_exact("{0}> ".format(dbname), timeout=15) + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") try: context.cli.expect_exact(pexpect.EOF, timeout=15) except pexpect.TIMEOUT: - print('--- after_scenario {}: kill cli'.format(scenario.name)) + print ("--- after_scenario {}: kill cli".format(scenario.name)) context.cli.kill(signal.SIGKILL) - if hasattr(context, 'tmpfile_sql_help') and context.tmpfile_sql_help: + if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help: context.tmpfile_sql_help.close() context.tmpfile_sql_help = None diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py index af10596a..6cb74b2d 100644 --- a/tests/features/fixture_utils.py +++ b/tests/features/fixture_utils.py @@ -13,7 +13,7 @@ def read_fixture_lines(filename): :return: list of strings """ lines = [] - for line in codecs.open(filename, 'rb', encoding='utf-8'): + for line in codecs.open(filename, "rb", encoding="utf-8"): lines.append(line.strip()) return lines @@ -21,11 +21,11 @@ def read_fixture_lines(filename): def read_fixture_files(): """Read all files inside fixture_data directory.""" current_dir = os.path.dirname(__file__) - fixture_dir = os.path.join(current_dir, 'fixture_data/') - print('reading fixture data: {}'.format(fixture_dir)) + fixture_dir = os.path.join(current_dir, "fixture_data/") + print ("reading fixture data: {}".format(fixture_dir)) fixture_dict = {} for filename in os.listdir(fixture_dir): - if filename not in ['.', '..']: + if filename not in [".", ".."]: fullname = os.path.join(fixture_dir, filename) fixture_dict[filename] = read_fixture_lines(fullname) diff --git a/tests/features/steps/auto_vertical.py b/tests/features/steps/auto_vertical.py index 18a51faf..2bb89870 100644 --- a/tests/features/steps/auto_vertical.py +++ b/tests/features/steps/auto_vertical.py @@ -6,37 +6,45 @@ from behave import then, when import wrappers -@when('we run dbcli with {arg}') +@when("we run dbcli with {arg}") def step_run_cli_with_arg(context, arg): - wrappers.run_cli(context, run_args=arg.split('=')) + wrappers.run_cli(context, run_args=arg.split("=")) -@when('we execute a small query') +@when("we execute a small query") def step_execute_small_query(context): - context.cli.sendline('select 1') + context.cli.sendline("select 1") -@when('we execute a large query') +@when("we execute a large query") def step_execute_large_query(context): - context.cli.sendline( - 'select {}'.format(','.join([str(n) for n in range(1, 50)]))) + context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)]))) -@then('we see small results in horizontal format') +@then("we see small results in horizontal format") def step_see_small_results(context): - wrappers.expect_pager(context, dedent("""\ + wrappers.expect_pager( + context, + dedent( + """\ +------------+\r | ?column? |\r |------------|\r | 1 |\r +------------+\r SELECT 1\r - """), timeout=5) + """ + ), + timeout=5, + ) -@then('we see large results in vertical format') +@then("we see large results in vertical format") def step_see_large_results(context): - wrappers.expect_pager(context, dedent("""\ + wrappers.expect_pager( + context, + dedent( + """\ -[ RECORD 1 ]-------------------------\r ?column? | 1\r ?column? | 2\r @@ -88,4 +96,7 @@ def step_see_large_results(context): ?column? | 48\r ?column? | 49\r SELECT 1\r - """), timeout=5) + """ + ), + timeout=5, + ) diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py index 71b591a0..b9ab0046 100644 --- a/tests/features/steps/basic_commands.py +++ b/tests/features/steps/basic_commands.py @@ -15,47 +15,48 @@ from textwrap import dedent import wrappers -@when('we list databases') +@when("we list databases") def step_list_databases(context): - cmd = ['pgcli', '--list'] + cmd = ["pgcli", "--list"] context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root) -@then('we see list of databases') +@then("we see list of databases") def step_see_list_databases(context): - assert b'List of databases' in context.cmd_output - assert b'postgres' in context.cmd_output + assert b"List of databases" in context.cmd_output + assert b"postgres" in context.cmd_output context.cmd_output = None -@when('we run dbcli') +@when("we run dbcli") def step_run_cli(context): wrappers.run_cli(context) -@when('we launch dbcli using {arg}') +@when("we launch dbcli using {arg}") def step_run_cli_using_arg(context, arg): prompt_check = False currentdb = None - if arg == '--username': - arg = '--username={}'.format(context.conf['user']) - if arg == '--user': - arg = '--user={}'.format(context.conf['user']) - if arg == '--port': - arg = '--port={}'.format(context.conf['port']) - if arg == '--password': - arg = '--password' + if arg == "--username": + arg = "--username={}".format(context.conf["user"]) + if arg == "--user": + arg = "--user={}".format(context.conf["user"]) + if arg == "--port": + arg = "--port={}".format(context.conf["port"]) + if arg == "--password": + arg = "--password" prompt_check = False # This uses the mock_pg_service.conf file in fixtures folder. - if arg == 'dsn_password': - arg = 'service=mock_postgres --password' + if arg == "dsn_password": + arg = "service=mock_postgres --password" prompt_check = False currentdb = "postgres" - wrappers.run_cli(context, run_args=[ - arg], prompt_check=prompt_check, currentdb=currentdb) + wrappers.run_cli( + context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb + ) -@when('we wait for prompt') +@when("we wait for prompt") def step_wait_prompt(context): wrappers.wait_prompt(context) @@ -66,9 +67,9 @@ def step_ctrl_d(context): Send Ctrl + D to hopefully exit. """ # turn off pager before exiting - context.cli.sendline('\pset pager off') + context.cli.sendline("\pset pager off") wrappers.wait_prompt(context) - context.cli.sendcontrol('d') + context.cli.sendcontrol("d") context.cli.expect_exact(pexpect.EOF, timeout=15) context.exit_sent = True @@ -78,51 +79,58 @@ def step_send_help(context): """ Send \? to see help. """ - context.cli.sendline('\?') + context.cli.sendline("\?") -@when(u'we send source command') +@when("we send source command") def step_send_source_command(context): - context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix='pgcli_') - context.tmpfile_sql_help.write(b'\?') + context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") + context.tmpfile_sql_help.write(b"\?") context.tmpfile_sql_help.flush() - context.cli.sendline('\i {0}'.format(context.tmpfile_sql_help.name)) - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name)) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) -@when(u'we run query to check application_name') +@when("we run query to check application_name") def step_check_application_name(context): context.cli.sendline( "SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;" ) -@then(u'we see found') +@then("we see found") def step_see_found(context): wrappers.expect_exact( context, - context.conf['pager_boundary'] + '\r' + dedent(''' + context.conf["pager_boundary"] + + "\r" + + dedent( + """ +------------+\r | ?column? |\r |------------|\r | found |\r +------------+\r SELECT 1\r - ''') + context.conf['pager_boundary'], - timeout=5 + """ + ) + + context.conf["pager_boundary"], + timeout=5, ) -@then(u'we confirm the destructive warning') +@then("we confirm the destructive warning") def step_confirm_destructive_command(context): """Confirm destructive command.""" wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') + context, + "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", + timeout=2, + ) + context.cli.sendline("y") -@then(u'we send password') +@then("we send password") def step_send_password(context): - wrappers.expect_exact(context, 'Password for', timeout=5) - context.cli.sendline(context.conf['pass'] or 'DOES NOT MATTER') + wrappers.expect_exact(context, "Password for", timeout=5) + context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER") diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py index 3ec27c96..9eab4f45 100644 --- a/tests/features/steps/crud_database.py +++ b/tests/features/steps/crud_database.py @@ -12,47 +12,43 @@ from behave import when, then import wrappers -@when('we create database') +@when("we create database") def step_db_create(context): """ Send create database. """ - context.cli.sendline('create database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"])) - context.response = { - 'database_name': context.conf['dbname_tmp'] - } + context.response = {"database_name": context.conf["dbname_tmp"]} -@when('we drop database') +@when("we drop database") def step_db_drop(context): """ Send drop database. """ - context.cli.sendline('drop database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"])) -@when('we connect to test database') +@when("we connect to test database") def step_db_connect_test(context): """ Send connect to database. """ - db_name = context.conf['dbname'] - context.cli.sendline('\\connect {0}'.format(db_name)) + db_name = context.conf["dbname"] + context.cli.sendline("\\connect {0}".format(db_name)) -@when('we connect to dbserver') +@when("we connect to dbserver") def step_db_connect_dbserver(context): """ Send connect to database. """ - context.cli.sendline('\\connect postgres') - context.currentdb = 'postgres' + context.cli.sendline("\\connect postgres") + context.currentdb = "postgres" -@then('dbcli exits') +@then("dbcli exits") def step_wait_exit(context): """ Make sure the cli exits. @@ -60,41 +56,41 @@ def step_wait_exit(context): wrappers.expect_exact(context, pexpect.EOF, timeout=5) -@then('we see dbcli prompt') +@then("we see dbcli prompt") def step_see_prompt(context): """ Wait to see the prompt. """ - db_name = getattr(context, "currentdb", context.conf['dbname']) - wrappers.expect_exact(context, '{0}> '.format(db_name), timeout=5) + db_name = getattr(context, "currentdb", context.conf["dbname"]) + wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5) context.atprompt = True -@then('we see help output') +@then("we see help output") def step_see_help(context): - for expected_line in context.fixture_data['help_commands.txt']: + for expected_line in context.fixture_data["help_commands.txt"]: wrappers.expect_exact(context, expected_line, timeout=2) -@then('we see database created') +@then("we see database created") def step_see_db_created(context): """ Wait to see create database output. """ - wrappers.expect_pager(context, 'CREATE DATABASE\r\n', timeout=5) + wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5) -@then('we see database dropped') +@then("we see database dropped") def step_see_db_dropped(context): """ Wait to see drop database output. """ - wrappers.expect_pager(context, 'DROP DATABASE\r\n', timeout=2) + wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2) -@then('we see database connected') +@then("we see database connected") def step_see_db_connected(context): """ Wait to see drop database output. """ - wrappers.expect_exact(context, 'You are now connected to database', timeout=2) + wrappers.expect_exact(context, "You are now connected to database", timeout=2) diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py index b8ec0d59..6d848aba 100644 --- a/tests/features/steps/crud_table.py +++ b/tests/features/steps/crud_table.py @@ -11,107 +11,110 @@ from textwrap import dedent import wrappers -@when('we create table') +@when("we create table") def step_create_table(context): """ Send create table. """ - context.cli.sendline('create table a(x text);') + context.cli.sendline("create table a(x text);") -@when('we insert into table') +@when("we insert into table") def step_insert_into_table(context): """ Send insert into table. """ - context.cli.sendline('''insert into a(x) values('xxx');''') + context.cli.sendline("""insert into a(x) values('xxx');""") -@when('we update table') +@when("we update table") def step_update_table(context): """ Send insert into table. """ - context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') + context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""") -@when('we select from table') +@when("we select from table") def step_select_from_table(context): """ Send select from table. """ - context.cli.sendline('select * from a;') + context.cli.sendline("select * from a;") -@when('we delete from table') +@when("we delete from table") def step_delete_from_table(context): """ Send deete from table. """ - context.cli.sendline('''delete from a where x = 'yyy';''') + context.cli.sendline("""delete from a where x = 'yyy';""") -@when('we drop table') +@when("we drop table") def step_drop_table(context): """ Send drop table. """ - context.cli.sendline('drop table a;') + context.cli.sendline("drop table a;") -@then('we see table created') +@then("we see table created") def step_see_table_created(context): """ Wait to see create table output. """ - wrappers.expect_pager(context, 'CREATE TABLE\r\n', timeout=2) + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) -@then('we see record inserted') +@then("we see record inserted") def step_see_record_inserted(context): """ Wait to see insert output. """ - wrappers.expect_pager(context, 'INSERT 0 1\r\n', timeout=2) + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) -@then('we see record updated') +@then("we see record updated") def step_see_record_updated(context): """ Wait to see update output. """ - wrappers.expect_pager(context, 'UPDATE 1\r\n', timeout=2) + wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2) -@then('we see data selected') +@then("we see data selected") def step_see_data_selected(context): """ Wait to see select output. """ wrappers.expect_pager( context, - dedent('''\ + dedent( + """\ +-----+\r | x |\r |-----|\r | yyy |\r +-----+\r SELECT 1\r - '''), - timeout=1) + """ + ), + timeout=1, + ) -@then('we see record deleted') +@then("we see record deleted") def step_see_data_deleted(context): """ Wait to see delete output. """ - wrappers.expect_pager(context, 'DELETE 1\r\n', timeout=2) + wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2) -@then('we see table dropped') +@then("we see table dropped") def step_see_table_dropped(context): """ Wait to see drop output. """ - wrappers.expect_pager(context, 'DROP TABLE\r\n', timeout=2) + wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2) diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py index cee34814..f79f913b 100644 --- a/tests/features/steps/expanded.py +++ b/tests/features/steps/expanded.py @@ -12,53 +12,61 @@ from textwrap import dedent import wrappers -@when('we prepare the test data') +@when("we prepare the test data") def step_prepare_data(context): """Create table, insert a record.""" - context.cli.sendline('drop table if exists a;') + context.cli.sendline("drop table if exists a;") wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') + context, + "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", + timeout=2, + ) + context.cli.sendline("y") wrappers.wait_prompt(context) - context.cli.sendline( - 'create table a(x integer, y real, z numeric(10, 4));') - wrappers.expect_pager(context, 'CREATE TABLE\r\n', timeout=2) - context.cli.sendline('''insert into a(x, y, z) values(1, 1.0, 1.0);''') - wrappers.expect_pager(context, 'INSERT 0 1\r\n', timeout=2) + context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));") + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) + context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""") + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) -@when('we set expanded {mode}') +@when("we set expanded {mode}") def step_set_expanded(context, mode): """Set expanded to mode.""" - context.cli.sendline('\\' + 'x {}'.format(mode)) - wrappers.expect_exact(context, 'Expanded display is', timeout=2) + context.cli.sendline("\\" + "x {}".format(mode)) + wrappers.expect_exact(context, "Expanded display is", timeout=2) wrappers.wait_prompt(context) -@then('we see {which} data selected') +@then("we see {which} data selected") def step_see_data(context, which): """Select data from expanded test table.""" - if which == 'expanded': + if which == "expanded": wrappers.expect_pager( context, - dedent('''\ + dedent( + """\ -[ RECORD 1 ]-------------------------\r x | 1\r y | 1.0\r z | 1.0000\r SELECT 1\r - '''), - timeout=1) + """ + ), + timeout=1, + ) else: wrappers.expect_pager( context, - dedent('''\ + dedent( + """\ +-----+-----+--------+\r | x | y | z |\r |-----+-----+--------|\r | 1 | 1.0 | 1.0000 |\r +-----+-----+--------+\r SELECT 1\r - '''), - timeout=1) + """ + ), + timeout=1, + ) diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py index 1a1a4c5c..416eac59 100644 --- a/tests/features/steps/iocommands.py +++ b/tests/features/steps/iocommands.py @@ -7,77 +7,76 @@ from behave import when, then import wrappers -@when('we start external editor providing a file name') +@when("we start external editor providing a file name") def step_edit_file(context): """Edit file with external editor.""" context.editor_file_name = os.path.join( - context.package_root, 'test_file_{0}.sql'.format(context.conf['vi'])) + context.package_root, "test_file_{0}.sql".format(context.conf["vi"]) + ) if os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) - context.cli.sendline('\e {0}'.format( - os.path.basename(context.editor_file_name))) + context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name))) wrappers.expect_exact( - context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) - wrappers.expect_exact(context, ':', timeout=2) + context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2 + ) + wrappers.expect_exact(context, ":", timeout=2) -@when('we type sql in the editor') +@when("we type sql in the editor") def step_edit_type_sql(context): - context.cli.sendline('i') - context.cli.sendline('select * from abc') - context.cli.sendline('.') - wrappers.expect_exact(context, ':', timeout=2) + context.cli.sendline("i") + context.cli.sendline("select * from abc") + context.cli.sendline(".") + wrappers.expect_exact(context, ":", timeout=2) -@when('we exit the editor') +@when("we exit the editor") def step_edit_quit(context): - context.cli.sendline('x') + context.cli.sendline("x") wrappers.expect_exact(context, "written", timeout=2) -@then('we see the sql in prompt') +@then("we see the sql in prompt") def step_edit_done_sql(context): - for match in 'select * from abc'.split(' '): + for match in "select * from abc".split(" "): wrappers.expect_exact(context, match, timeout=1) # Cleanup the command line. - context.cli.sendcontrol('c') + context.cli.sendcontrol("c") # Cleanup the edited file. if context.editor_file_name and os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) context.atprompt = True -@when(u'we tee output') +@when("we tee output") def step_tee_ouptut(context): context.tee_file_name = os.path.join( - context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi'])) + context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]) + ) if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) - context.cli.sendline('\o {0}'.format( - os.path.basename(context.tee_file_name))) - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name))) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) wrappers.expect_exact(context, "Writing to file", timeout=5) - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) wrappers.expect_exact(context, "Time", timeout=5) -@when(u'we query "select 123456"') +@when('we query "select 123456"') def step_query_select_123456(context): - context.cli.sendline('select 123456') + context.cli.sendline("select 123456") -@when(u'we stop teeing output') +@when("we stop teeing output") def step_notee_output(context): - context.cli.sendline('\o') + context.cli.sendline("\o") wrappers.expect_exact(context, "Time", timeout=5) -@then(u'we see 123456 in tee output') +@then("we see 123456 in tee output") def step_see_123456_in_ouput(context): with open(context.tee_file_name) as f: - assert '123456' in f.read() + assert "123456" in f.read() if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) context.atprompt = True diff --git a/tests/features/steps/named_queries.py b/tests/features/steps/named_queries.py index 84afb11c..289ec639 100644 --- a/tests/features/steps/named_queries.py +++ b/tests/features/steps/named_queries.py @@ -10,50 +10,50 @@ from behave import when, then import wrappers -@when('we save a named query') +@when("we save a named query") def step_save_named_query(context): """ Send \ns command """ - context.cli.sendline('\\ns foo SELECT 12345') + context.cli.sendline("\\ns foo SELECT 12345") -@when('we use a named query') +@when("we use a named query") def step_use_named_query(context): """ Send \n command """ - context.cli.sendline('\\n foo') + context.cli.sendline("\\n foo") -@when('we delete a named query') +@when("we delete a named query") def step_delete_named_query(context): """ Send \nd command """ - context.cli.sendline('\\nd foo') + context.cli.sendline("\\nd foo") -@then('we see the named query saved') +@then("we see the named query saved") def step_see_named_query_saved(context): """ Wait to see query saved. """ - wrappers.expect_exact(context, 'Saved.', timeout=2) + wrappers.expect_exact(context, "Saved.", timeout=2) -@then('we see the named query executed') +@then("we see the named query executed") def step_see_named_query_executed(context): """ Wait to see select output. """ - wrappers.expect_exact(context, '12345', timeout=1) - wrappers.expect_exact(context, 'SELECT 1', timeout=1) + wrappers.expect_exact(context, "12345", timeout=1) + wrappers.expect_exact(context, "SELECT 1", timeout=1) -@then('we see the named query deleted') +@then("we see the named query deleted") def step_see_named_query_deleted(context): """ Wait to see query deleted. """ - wrappers.expect_pager(context, 'foo: Deleted\r\n', timeout=1) + wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1) diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py index 26450514..2c77a3b1 100644 --- a/tests/features/steps/specials.py +++ b/tests/features/steps/specials.py @@ -10,18 +10,19 @@ from behave import when, then import wrappers -@when('we refresh completions') +@when("we refresh completions") def step_refresh_completions(context): """ Send refresh command. """ - context.cli.sendline('\\refresh') + context.cli.sendline("\\refresh") -@then('we see completions refresh started') +@then("we see completions refresh started") def step_see_refresh_started(context): """ Wait to see refresh output. """ wrappers.expect_pager( - context, 'Auto-completion refresh started in the background.\r\n', timeout=2) + context, "Auto-completion refresh started in the background.\r\n", timeout=2 + ) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py index 23f9cc2b..9a2db982 100644 --- a/tests/features/steps/wrappers.py +++ b/tests/features/steps/wrappers.py @@ -20,10 +20,10 @@ def expect_exact(context, expected, timeout): timedout = True if timedout: # Strip color codes out of the output. - actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', - '', context.cli.before) + actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before) raise Exception( - textwrap.dedent('''\ + textwrap.dedent( + """\ Expected: --- {0!r} @@ -36,35 +36,35 @@ def expect_exact(context, expected, timeout): --- {2!r} --- - ''').format( - expected, - actual, - context.logfile.getvalue() - ) + """ + ).format(expected, actual, context.logfile.getvalue()) ) def expect_pager(context, expected, timeout): - expect_exact(context, "{0}\r\n{1}{0}\r\n".format( - context.conf['pager_boundary'], expected), timeout=timeout) + expect_exact( + context, + "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), + timeout=timeout, + ) def run_cli(context, run_args=None, prompt_check=True, currentdb=None): """Run the process using pexpect.""" run_args = run_args or [] - cli_cmd = context.conf.get('cli_command') + cli_cmd = context.conf.get("cli_command") cmd_parts = [cli_cmd] + run_args - cmd = ' '.join(cmd_parts) + cmd = " ".join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() context.cli.logfile = context.logfile context.exit_sent = False - context.currentdb = currentdb or context.conf['dbname'] - context.cli.sendline('\pset pager always') + context.currentdb = currentdb or context.conf["dbname"] + context.cli.sendline("\pset pager always") if prompt_check: wait_prompt(context) def wait_prompt(context): """Make sure prompt is displayed.""" - expect_exact(context, '{0}> '.format(context.conf['dbname']), timeout=5) + expect_exact(context, "{0}> ".format(context.conf["dbname"]), timeout=5) diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py index 51d49095..e98ea979 100755 --- a/tests/features/wrappager.py +++ b/tests/features/wrappager.py @@ -3,13 +3,13 @@ import sys def wrappager(boundary): - print(boundary) + print (boundary) while 1: buf = sys.stdin.read(2048) if not buf: break sys.stdout.write(buf) - print(boundary) + print (boundary) if __name__ == "__main__": diff --git a/tests/metadata.py b/tests/metadata.py index 782d58d5..5cf9456d 100644 --- a/tests/metadata.py +++ b/tests/metadata.py @@ -11,12 +11,12 @@ import pytest parametrize = pytest.mark.parametrize -qual = ['if_more_than_one_table', 'always'] -no_qual = ['if_more_than_one_table', 'never'] +qual = ["if_more_than_one_table", "always"] +no_qual = ["if_more_than_one_table", "never"] def escape(name): - if not name.islower() or name in ('select', 'localtimestamp'): + if not name.islower() or name in ("select", "localtimestamp"): return '"' + name + '"' return name @@ -27,10 +27,7 @@ def completion(display_meta, text, pos=0): def function(text, pos=0, display=None): return Completion( - text, - display=display or text, - start_position=pos, - display_meta='function' + text, display=display or text, start_position=pos, display_meta="function" ) @@ -49,21 +46,20 @@ def result_set(completer, text, position=None): # def schema(text, pos=0): # return completion('schema', text, pos) # and so on -schema = partial(completion, 'schema') -table = partial(completion, 'table') -view = partial(completion, 'view') -column = partial(completion, 'column') -keyword = partial(completion, 'keyword') -datatype = partial(completion, 'datatype') -alias = partial(completion, 'table alias') -name_join = partial(completion, 'name join') -fk_join = partial(completion, 'fk join') -join = partial(completion, 'join') +schema = partial(completion, "schema") +table = partial(completion, "table") +view = partial(completion, "view") +column = partial(completion, "column") +keyword = partial(completion, "keyword") +datatype = partial(completion, "datatype") +alias = partial(completion, "table alias") +name_join = partial(completion, "name join") +fk_join = partial(completion, "fk join") +join = partial(completion, "join") def wildcard_expansion(cols, pos=-1): - return Completion( - cols, start_position=pos, display_meta='columns', display='*') + return Completion(cols, start_position=pos, display_meta="columns", display="*") class MetaData(object): @@ -80,78 +76,88 @@ class MetaData(object): return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()] def specials(self, pos=0): - return [Completion(text=k, start_position=pos, display_meta=v.description) for k, v in iteritems(self.completer.pgspecial.commands)] + return [ + Completion(text=k, start_position=pos, display_meta=v.description) + for k, v in iteritems(self.completer.pgspecial.commands) + ] - def columns(self, tbl, parent='public', typ='tables', pos=0): - if typ == 'functions': + def columns(self, tbl, parent="public", typ="tables", pos=0): + if typ == "functions": fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0] cols = fun[1] else: cols = self.metadata[typ][parent][tbl] return [column(escape(col), pos) for col in cols] - def datatypes(self, parent='public', pos=0): + def datatypes(self, parent="public", pos=0): return [ datatype(escape(x), pos) - for x in self.metadata.get('datatypes', {}).get(parent, [])] + for x in self.metadata.get("datatypes", {}).get(parent, []) + ] - def tables(self, parent='public', pos=0): + def tables(self, parent="public", pos=0): return [ table(escape(x), pos) - for x in self.metadata.get('tables', {}).get(parent, [])] + for x in self.metadata.get("tables", {}).get(parent, []) + ] - def views(self, parent='public', pos=0): + def views(self, parent="public", pos=0): return [ - view(escape(x), pos) - for x in self.metadata.get('views', {}).get(parent, [])] + view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, []) + ] - def functions(self, parent='public', pos=0): + def functions(self, parent="public", pos=0): return [ function( - escape(x[0]) + '(' + ', '.join( - arg_name + ' := ' + escape(x[0]) + + "(" + + ", ".join( + arg_name + " := " for (arg_name, arg_mode) in zip(x[1], x[3]) - if arg_mode in ('b', 'i') - ) + ')', + if arg_mode in ("b", "i") + ) + + ")", pos, - escape(x[0]) + '(' + ', '.join( + escape(x[0]) + + "(" + + ", ".join( arg_name for (arg_name, arg_mode) in zip(x[1], x[3]) - if arg_mode in ('b', 'i') - ) + ')' + if arg_mode in ("b", "i") + ) + + ")", ) - for x in self.metadata.get('functions', {}).get(parent, []) + for x in self.metadata.get("functions", {}).get(parent, []) ] def schemas(self, pos=0): schemas = set(sch for schs in self.metadata.values() for sch in schs) return [schema(escape(s), pos=pos) for s in schemas] - def functions_and_keywords(self, parent='public', pos=0): + def functions_and_keywords(self, parent="public", pos=0): return ( - self.functions(parent, pos) + self.builtin_functions(pos) + - self.keywords(pos) + self.functions(parent, pos) + + self.builtin_functions(pos) + + self.keywords(pos) ) # Note that the filtering parameters here only apply to the columns - def columns_functions_and_keywords( - self, tbl, parent='public', typ='tables', pos=0 - ): - return ( - self.functions_and_keywords(pos=pos) + - self.columns(tbl, parent, typ, pos) + def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0): + return self.functions_and_keywords(pos=pos) + self.columns( + tbl, parent, typ, pos ) - def from_clause_items(self, parent='public', pos=0): + def from_clause_items(self, parent="public", pos=0): return ( - self.functions(parent, pos) + self.views(parent, pos) + - self.tables(parent, pos) + self.functions(parent, pos) + + self.views(parent, pos) + + self.tables(parent, pos) ) - def schemas_and_from_clause_items(self, parent='public', pos=0): + def schemas_and_from_clause_items(self, parent="public", pos=0): return self.from_clause_items(parent, pos) + self.schemas(pos) - def types(self, parent='public', pos=0): + def types(self, parent="public", pos=0): return self.datatypes(parent, pos) + self.tables(parent, pos) @property @@ -170,85 +176,83 @@ class MetaData(object): casing, without `search_path` filtering of objects, with table aliasing, and without column qualification. """ + def _cfg(_casing, filtr, aliasing, qualify): - cfg = {'settings': {}} + cfg = {"settings": {}} if _casing: - cfg['casing'] = casing - cfg['settings']['search_path_filter'] = filtr - cfg['settings']['generate_aliases'] = aliasing - cfg['settings']['qualify_columns'] = qualify + cfg["casing"] = casing + cfg["settings"]["search_path_filter"] = filtr + cfg["settings"]["generate_aliases"] = aliasing + cfg["settings"]["qualify_columns"] = qualify return cfg def _cfgs(casing, filtr, aliasing, qualify): casings = [True, False] if casing is None else [casing] filtrs = [True, False] if filtr is None else [filtr] aliases = [True, False] if aliasing is None else [aliasing] - qualifys = qualify or ['always', 'if_more_than_one_table', 'never'] - return [ - _cfg(*p) for p in product(casings, filtrs, aliases, qualifys) - ] + qualifys = qualify or ["always", "if_more_than_one_table", "never"] + return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)] def completers(casing=None, filtr=None, aliasing=None, qualify=None): get_comp = self.get_completer - return [ - get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify) - ] + return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)] return completers def _make_col(self, sch, tbl, col): - defaults = self.metadata.get('defaults', {}).get(sch, {}) - return (sch, tbl, col, 'text', (tbl, col) in defaults, defaults.get((tbl, col))) - + defaults = self.metadata.get("defaults", {}).get(sch, {}) + return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col))) def get_completer(self, settings=None, casing=None): metadata = self.metadata from pgcli.pgcompleter import PGCompleter from pgspecial import PGSpecial - comp = PGCompleter(smart_completion=True, - settings=settings, pgspecial=PGSpecial()) + + comp = PGCompleter( + smart_completion=True, settings=settings, pgspecial=PGSpecial() + ) schemata, tables, tbl_cols, views, view_cols = [], [], [], [], [] - for sch, tbls in metadata['tables'].items(): + for sch, tbls in metadata["tables"].items(): schemata.append(sch) for tbl, cols in tbls.items(): tables.append((sch, tbl)) # Let all columns be text columns - tbl_cols.extend([self._make_col(sch, tbl, col) - for col in cols]) + tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols]) - for sch, tbls in metadata.get('views', {}).items(): + for sch, tbls in metadata.get("views", {}).items(): for tbl, cols in tbls.items(): views.append((sch, tbl)) # Let all columns be text columns - view_cols.extend([self._make_col(sch, tbl, col) - for col in cols]) + view_cols.extend([self._make_col(sch, tbl, col) for col in cols]) functions = [ FunctionMetadata(sch, *func_meta, arg_defaults=None) - for sch, funcs in metadata['functions'].items() - for func_meta in funcs] + for sch, funcs in metadata["functions"].items() + for func_meta in funcs + ] datatypes = [ (sch, typ) - for sch, datatypes in metadata['datatypes'].items() - for typ in datatypes] + for sch, datatypes in metadata["datatypes"].items() + for typ in datatypes + ] foreignkeys = [ - ForeignKey(*fk) for fks in metadata['foreignkeys'].values() - for fk in fks] + ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks + ] comp.extend_schemata(schemata) - comp.extend_relations(tables, kind='tables') - comp.extend_relations(views, kind='views') - comp.extend_columns(tbl_cols, kind='tables') - comp.extend_columns(view_cols, kind='views') + comp.extend_relations(tables, kind="tables") + comp.extend_relations(views, kind="views") + comp.extend_columns(tbl_cols, kind="tables") + comp.extend_columns(view_cols, kind="views") comp.extend_functions(functions) comp.extend_datatypes(datatypes) comp.extend_foreignkeys(foreignkeys) - comp.set_search_path(['public']) + comp.set_search_path(["public"]) comp.extend_casing(casing or []) return comp diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py index 4d2d050f..3e89ccaf 100644 --- a/tests/parseutils/test_ctes.py +++ b/tests/parseutils/test_ctes.py @@ -1,8 +1,10 @@ import pytest from sqlparse import parse from pgcli.packages.parseutils.ctes import ( - token_start_pos, extract_ctes, - extract_column_names as _extract_column_names) + token_start_pos, + extract_ctes, + extract_column_names as _extract_column_names, +) def extract_column_names(sql): @@ -11,109 +13,125 @@ def extract_column_names(sql): def test_token_str_pos(): - sql = 'SELECT * FROM xxx' + sql = "SELECT * FROM xxx" p = parse(sql)[0] idx = p.token_index(p.tokens[-1]) - assert token_start_pos(p.tokens, idx) == len('SELECT * FROM ') + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ") - sql = 'SELECT * FROM \nxxx' + sql = "SELECT * FROM \nxxx" p = parse(sql)[0] idx = p.token_index(p.tokens[-1]) - assert token_start_pos(p.tokens, idx) == len('SELECT * FROM \n') + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n") def test_single_column_name_extraction(): - sql = 'SELECT abc FROM xxx' - assert extract_column_names(sql) == ('abc',) + sql = "SELECT abc FROM xxx" + assert extract_column_names(sql) == ("abc",) def test_aliased_single_column_name_extraction(): - sql = 'SELECT abc def FROM xxx' - assert extract_column_names(sql) == ('def',) + sql = "SELECT abc def FROM xxx" + assert extract_column_names(sql) == ("def",) def test_aliased_expression_name_extraction(): - sql = 'SELECT 99 abc FROM xxx' - assert extract_column_names(sql) == ('abc',) + sql = "SELECT 99 abc FROM xxx" + assert extract_column_names(sql) == ("abc",) def test_multiple_column_name_extraction(): - sql = 'SELECT abc, def FROM xxx' - assert extract_column_names(sql) == ('abc', 'def') + sql = "SELECT abc, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") def test_missing_column_name_handled_gracefully(): - sql = 'SELECT abc, 99 FROM xxx' - assert extract_column_names(sql) == ('abc',) + sql = "SELECT abc, 99 FROM xxx" + assert extract_column_names(sql) == ("abc",) - sql = 'SELECT abc, 99, def FROM xxx' - assert extract_column_names(sql) == ('abc', 'def') + sql = "SELECT abc, 99, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") def test_aliased_multiple_column_name_extraction(): - sql = 'SELECT abc def, ghi jkl FROM xxx' - assert extract_column_names(sql) == ('def', 'jkl') + sql = "SELECT abc def, ghi jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") def test_table_qualified_column_name_extraction(): - sql = 'SELECT abc.def, ghi.jkl FROM xxx' - assert extract_column_names(sql) == ('def', 'jkl') + sql = "SELECT abc.def, ghi.jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") -@pytest.mark.parametrize('sql', [ - 'INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y', - 'DELETE FROM foo WHERE x > y RETURNING x, y', - 'UPDATE foo SET x = 9 RETURNING x, y', -]) +@pytest.mark.parametrize( + "sql", + [ + "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y", + "DELETE FROM foo WHERE x > y RETURNING x, y", + "UPDATE foo SET x = 9 RETURNING x, y", + ], +) def test_extract_column_names_from_returning_clause(sql): - assert extract_column_names(sql) == ('x', 'y') + assert extract_column_names(sql) == ("x", "y") def test_simple_cte_extraction(): - sql = 'WITH a AS (SELECT abc FROM xxx) SELECT * FROM a' - start_pos = len('WITH a AS ') - stop_pos = len('WITH a AS (SELECT abc FROM xxx)') + sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a" + start_pos = len("WITH a AS ") + stop_pos = len("WITH a AS (SELECT abc FROM xxx)") ctes, remainder = extract_ctes(sql) - assert tuple(ctes) == (('a', ('abc',), start_pos, stop_pos),) - assert remainder.strip() == 'SELECT * FROM a' + assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" def test_cte_extraction_around_comments(): - sql = '''--blah blah blah + sql = """--blah blah blah WITH a AS (SELECT abc def FROM x) - SELECT * FROM a''' - start_pos = len('''--blah blah blah - WITH a AS ''') - stop_pos = len('''--blah blah blah - WITH a AS (SELECT abc def FROM x)''') + SELECT * FROM a""" + start_pos = len( + """--blah blah blah + WITH a AS """ + ) + stop_pos = len( + """--blah blah blah + WITH a AS (SELECT abc def FROM x)""" + ) ctes, remainder = extract_ctes(sql) - assert tuple(ctes) == (('a', ('def',), start_pos, stop_pos),) - assert remainder.strip() == 'SELECT * FROM a' + assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" def test_multiple_cte_extraction(): - sql = '''WITH + sql = """WITH x AS (SELECT abc, def FROM x), y AS (SELECT ghi, jkl FROM y) - SELECT * FROM a, b''' + SELECT * FROM a, b""" - start1 = len('''WITH - x AS ''') + start1 = len( + """WITH + x AS """ + ) - stop1 = len('''WITH - x AS (SELECT abc, def FROM x)''') + stop1 = len( + """WITH + x AS (SELECT abc, def FROM x)""" + ) - start2 = len('''WITH + start2 = len( + """WITH x AS (SELECT abc, def FROM x), - y AS ''') + y AS """ + ) - stop2 = len('''WITH + stop2 = len( + """WITH x AS (SELECT abc, def FROM x), - y AS (SELECT ghi, jkl FROM y)''') + y AS (SELECT ghi, jkl FROM y)""" + ) ctes, remainder = extract_ctes(sql) assert tuple(ctes) == ( - ('x', ('abc', 'def'), start1, stop1), - ('y', ('ghi', 'jkl'), start2, stop2)) + ("x", ("abc", "def"), start1, stop1), + ("y", ("ghi", "jkl"), start2, stop2), + ) diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py index 1d0b72b7..0350e2a1 100644 --- a/tests/parseutils/test_function_metadata.py +++ b/tests/parseutils/test_function_metadata.py @@ -3,16 +3,13 @@ from pgcli.packages.parseutils.meta import FunctionMetadata def test_function_metadata_eq(): f1 = FunctionMetadata( - 's', 'f', ['x'], ['integer'], [ - ], 'int', False, False, False, False, None + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None ) f2 = FunctionMetadata( - 's', 'f', ['x'], ['integer'], [ - ], 'int', False, False, False, False, None + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None ) f3 = FunctionMetadata( - 's', 'g', ['x'], ['integer'], [ - ], 'int', False, False, False, False, None + "s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None ) assert f1 == f2 assert f1 != f3 diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py index 4ffc7ee3..50bc8896 100644 --- a/tests/parseutils/test_parseutils.py +++ b/tests/parseutils/test_parseutils.py @@ -4,104 +4,98 @@ from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote def test_empty_string(): - tables = extract_tables('') + tables = extract_tables("") assert tables == () def test_simple_select_single_table(): - tables = extract_tables('select * from abc') - assert tables == ((None, 'abc', None, False),) + tables = extract_tables("select * from abc") + assert tables == ((None, "abc", None, False),) -@pytest.mark.parametrize('sql', [ - 'select * from "abc"."def"', - 'select * from abc."def"', -]) +@pytest.mark.parametrize( + "sql", ['select * from "abc"."def"', 'select * from abc."def"'] +) def test_simple_select_single_table_schema_qualified_quoted_table(sql): tables = extract_tables(sql) - assert tables == (('abc', 'def', '"def"', False),) + assert tables == (("abc", "def", '"def"', False),) -@pytest.mark.parametrize('sql', [ - 'select * from abc.def', - 'select * from "abc".def', -]) +@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def']) def test_simple_select_single_table_schema_qualified(sql): tables = extract_tables(sql) - assert tables == (('abc', 'def', None, False),) + assert tables == (("abc", "def", None, False),) def test_simple_select_single_table_double_quoted(): tables = extract_tables('select * from "Abc"') - assert tables == ((None, 'Abc', None, False),) + assert tables == ((None, "Abc", None, False),) def test_simple_select_multiple_tables(): - tables = extract_tables('select * from abc, def') - assert set(tables) == set([(None, 'abc', None, False), - (None, 'def', None, False)]) + tables = extract_tables("select * from abc, def") + assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)]) def test_simple_select_multiple_tables_double_quoted(): tables = extract_tables('select * from "Abc", "Def"') - assert set(tables) == set([(None, 'Abc', None, False), - (None, 'Def', None, False)]) + assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)]) def test_simple_select_single_table_deouble_quoted_aliased(): tables = extract_tables('select * from "Abc" a') - assert tables == ((None, 'Abc', 'a', False),) + assert tables == ((None, "Abc", "a", False),) def test_simple_select_multiple_tables_deouble_quoted_aliased(): tables = extract_tables('select * from "Abc" a, "Def" d') - assert set(tables) == set([(None, 'Abc', 'a', False), - (None, 'Def', 'd', False)]) + assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)]) def test_simple_select_multiple_tables_schema_qualified(): - tables = extract_tables('select * from abc.def, ghi.jkl') - assert set(tables) == set([('abc', 'def', None, False), - ('ghi', 'jkl', None, False)]) + tables = extract_tables("select * from abc.def, ghi.jkl") + assert set(tables) == set( + [("abc", "def", None, False), ("ghi", "jkl", None, False)] + ) def test_simple_select_with_cols_single_table(): - tables = extract_tables('select a,b from abc') - assert tables == ((None, 'abc', None, False),) + tables = extract_tables("select a,b from abc") + assert tables == ((None, "abc", None, False),) def test_simple_select_with_cols_single_table_schema_qualified(): - tables = extract_tables('select a,b from abc.def') - assert tables == (('abc', 'def', None, False),) + tables = extract_tables("select a,b from abc.def") + assert tables == (("abc", "def", None, False),) def test_simple_select_with_cols_multiple_tables(): - tables = extract_tables('select a,b from abc, def') - assert set(tables) == set([(None, 'abc', None, False), - (None, 'def', None, False)]) + tables = extract_tables("select a,b from abc, def") + assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)]) def test_simple_select_with_cols_multiple_qualified_tables(): - tables = extract_tables('select a,b from abc.def, def.ghi') - assert set(tables) == set([('abc', 'def', None, False), - ('def', 'ghi', None, False)]) + tables = extract_tables("select a,b from abc.def, def.ghi") + assert set(tables) == set( + [("abc", "def", None, False), ("def", "ghi", None, False)] + ) def test_select_with_hanging_comma_single_table(): - tables = extract_tables('select a, from abc') - assert tables == ((None, 'abc', None, False),) + tables = extract_tables("select a, from abc") + assert tables == ((None, "abc", None, False),) def test_select_with_hanging_comma_multiple_tables(): - tables = extract_tables('select a, from abc, def') - assert set(tables) == set([(None, 'abc', None, False), - (None, 'def', None, False)]) + tables = extract_tables("select a, from abc, def") + assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)]) def test_select_with_hanging_period_multiple_tables(): - tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') - assert set(tables) == set([(None, 'tabl1', 't1', False), - (None, 'tabl2', 't2', False)]) + tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") + assert set(tables) == set( + [(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)] + ) def test_simple_insert_single_table(): @@ -111,157 +105,165 @@ def test_simple_insert_single_table(): # AND mistakenly identifies the field list as # assert tables == ((None, 'abc', 'abc', False),) - assert tables == ((None, 'abc', 'abc', False),) + assert tables == ((None, "abc", "abc", False),) @pytest.mark.xfail def test_simple_insert_single_table_schema_qualified(): tables = extract_tables('insert into abc.def (id, name) values (1, "def")') - assert tables == (('abc', 'def', None, False),) + assert tables == (("abc", "def", None, False),) def test_simple_update_table_no_schema(): - tables = extract_tables('update abc set id = 1') - assert tables == ((None, 'abc', None, False),) + tables = extract_tables("update abc set id = 1") + assert tables == ((None, "abc", None, False),) def test_simple_update_table_with_schema(): - tables = extract_tables('update abc.def set id = 1') - assert tables == (('abc', 'def', None, False),) + tables = extract_tables("update abc.def set id = 1") + assert tables == (("abc", "def", None, False),) -@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER']) +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) def test_join_table(join_type): - sql = 'SELECT * FROM abc a {0} JOIN def d ON a.id = d.num'.format(join_type) + sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type) tables = extract_tables(sql) - assert set(tables) == set([(None, 'abc', 'a', False), - (None, 'def', 'd', False)]) + assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)]) def test_join_table_schema_qualified(): - tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') - assert set(tables) == set([('abc', 'def', 'x', False), - ('ghi', 'jkl', 'y', False)]) + tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") + assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)]) def test_incomplete_join_clause(): - sql = '''select a.x, b.y + sql = """select a.x, b.y from abc a join bcd b - on a.id = ''' + on a.id = """ tables = extract_tables(sql) - assert tables == ((None, 'abc', 'a', False), - (None, 'bcd', 'b', False)) + assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False)) def test_join_as_table(): - tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') - assert tables == ((None, 'my_table', 'm', False),) + tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == ((None, "my_table", "m", False),) def test_multiple_joins(): - sql = '''select * from t1 + sql = """select * from t1 inner join t2 ON t1.id = t2.t1_id inner join t3 ON - t2.id = t3.''' + t2.id = t3.""" tables = extract_tables(sql) assert tables == ( - (None, 't1', None, False), - (None, 't2', None, False), - (None, 't3', None, False)) + (None, "t1", None, False), + (None, "t2", None, False), + (None, "t3", None, False), + ) def test_subselect_tables(): - sql = 'SELECT * FROM (SELECT FROM abc' + sql = "SELECT * FROM (SELECT FROM abc" tables = extract_tables(sql) - assert tables == ((None, 'abc', None, False),) + assert tables == ((None, "abc", None, False),) -@pytest.mark.parametrize('text', ['SELECT * FROM foo.', 'SELECT 123 AS foo']) +@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"]) def test_extract_no_tables(text): tables = extract_tables(text) assert tables == tuple() -@pytest.mark.parametrize('arg_list', ['', 'arg1', 'arg1, arg2, arg3']) +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) def test_simple_function_as_table(arg_list): - tables = extract_tables('SELECT * FROM foo({0})'.format(arg_list)) - assert tables == ((None, 'foo', None, True),) + tables = extract_tables("SELECT * FROM foo({0})".format(arg_list)) + assert tables == ((None, "foo", None, True),) -@pytest.mark.parametrize('arg_list', ['', 'arg1', 'arg1, arg2, arg3']) +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) def test_simple_schema_qualified_function_as_table(arg_list): - tables = extract_tables('SELECT * FROM foo.bar({0})'.format(arg_list)) - assert tables == (('foo', 'bar', None, True),) + tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list)) + assert tables == (("foo", "bar", None, True),) -@pytest.mark.parametrize('arg_list', ['', 'arg1', 'arg1, arg2, arg3']) +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) def test_simple_aliased_function_as_table(arg_list): - tables = extract_tables('SELECT * FROM foo({0}) bar'.format(arg_list)) - assert tables == ((None, 'foo', 'bar', True),) + tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list)) + assert tables == ((None, "foo", "bar", True),) def test_simple_table_and_function(): - tables = extract_tables('SELECT * FROM foo JOIN bar()') - assert set(tables) == set([(None, 'foo', None, False), - (None, 'bar', None, True)]) + tables = extract_tables("SELECT * FROM foo JOIN bar()") + assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)]) def test_complex_table_and_function(): - tables = extract_tables('''SELECT * FROM foo.bar baz - JOIN bar.qux(x, y, z) quux''') - assert set(tables) == set([('foo', 'bar', 'baz', False), - ('bar', 'qux', 'quux', True)]) + tables = extract_tables( + """SELECT * FROM foo.bar baz + JOIN bar.qux(x, y, z) quux""" + ) + assert set(tables) == set( + [("foo", "bar", "baz", False), ("bar", "qux", "quux", True)] + ) def test_find_prev_keyword_using(): - q = 'select * from tbl1 inner join tbl2 using (col1, ' + q = "select * from tbl1 inner join tbl2 using (col1, " kw, q2 = find_prev_keyword(q) - assert kw.value == '(' and q2 == 'select * from tbl1 inner join tbl2 using (' + assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using (" -@pytest.mark.parametrize('sql', [ - 'select * from foo where bar', - 'select * from foo where bar = 1 and baz or ', - 'select * from foo where bar = 1 and baz between qux and ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select * from foo where bar", + "select * from foo where bar = 1 and baz or ", + "select * from foo where bar = 1 and baz between qux and ", + ], +) def test_find_prev_keyword_where(sql): kw, stripped = find_prev_keyword(sql) - assert kw.value == 'where' and stripped == 'select * from foo where' + assert kw.value == "where" and stripped == "select * from foo where" -@pytest.mark.parametrize('sql', [ - 'create table foo (bar int, baz ', - 'select * from foo() as bar (baz ' -]) +@pytest.mark.parametrize( + "sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "] +) def test_find_prev_keyword_open_parens(sql): kw, _ = find_prev_keyword(sql) - assert kw.value == '(' + assert kw.value == "(" -@pytest.mark.parametrize('sql', [ - '', - '$$ foo $$', - "$$ 'foo' $$", - '$$ "foo" $$', - '$$ $a$ $$', - '$a$ $$ $a$', - 'foo bar $$ baz $$', -]) +@pytest.mark.parametrize( + "sql", + [ + "", + "$$ foo $$", + "$$ 'foo' $$", + '$$ "foo" $$', + "$$ $a$ $$", + "$a$ $$ $a$", + "foo bar $$ baz $$", + ], +) def test_is_open_quote__closed(sql): assert not is_open_quote(sql) -@pytest.mark.parametrize('sql', [ - '$$', - ';;;$$', - 'foo $$ bar $$; foo $$', - '$$ foo $a$', - "foo 'bar baz", - "$a$ foo ", - '$$ "foo" ', - '$$ $a$ ', - 'foo bar $$ baz', -]) +@pytest.mark.parametrize( + "sql", + [ + "$$", + ";;;$$", + "foo $$ bar $$; foo $$", + "$$ foo $a$", + "foo 'bar baz", + "$a$ foo ", + '$$ "foo" ', + "$$ $a$ ", + "foo bar $$ baz", + ], +) def test_is_open_quote__open(sql): assert is_open_quote(sql) diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py index e6a39ba4..6a916a81 100644 --- a/tests/test_completion_refresher.py +++ b/tests/test_completion_refresher.py @@ -6,6 +6,7 @@ from mock import Mock, patch @pytest.fixture def refresher(): from pgcli.completion_refresher import CompletionRefresher + return CompletionRefresher() @@ -17,8 +18,15 @@ def test_ctor(refresher): """ assert len(refresher.refreshers) > 0 actual_handlers = list(refresher.refreshers.keys()) - expected_handlers = ['schemata', 'tables', 'views', - 'types', 'databases', 'casing', 'functions'] + expected_handlers = [ + "schemata", + "tables", + "views", + "types", + "databases", + "casing", + "functions", + ] assert expected_handlers == actual_handlers @@ -32,14 +40,13 @@ def test_refresh_called_once(refresher): pgexecute = Mock() special = Mock() - with patch.object(refresher, '_bg_refresh') as bg_refresh: + with patch.object(refresher, "_bg_refresh") as bg_refresh: actual = refresher.refresh(pgexecute, special, callbacks) time.sleep(1) # Wait for the thread to work. assert len(actual) == 1 assert len(actual[0]) == 4 - assert actual[0][3] == 'Auto-completion refresh started in the background.' - bg_refresh.assert_called_with(pgexecute, special, callbacks, None, - None) + assert actual[0][3] == "Auto-completion refresh started in the background." + bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None) def test_refresh_called_twice(refresher): @@ -62,13 +69,13 @@ def test_refresh_called_twice(refresher): time.sleep(1) # Wait for the thread to work. assert len(actual1) == 1 assert len(actual1[0]) == 4 - assert actual1[0][3] == 'Auto-completion refresh started in the background.' + assert actual1[0][3] == "Auto-completion refresh started in the background." actual2 = refresher.refresh(pgexecute, special, callbacks) time.sleep(1) # Wait for the thread to work. assert len(actual2) == 1 assert len(actual2[0]) == 4 - assert actual2[0][3] == 'Auto-completion refresh restarted.' + assert actual2[0][3] == "Auto-completion refresh restarted." def test_refresh_with_callbacks(refresher): @@ -82,9 +89,9 @@ def test_refresh_with_callbacks(refresher): pgexecute.extra_args = {} special = Mock() - with patch('pgcli.completion_refresher.PGExecute', pgexecute_class): + with patch("pgcli.completion_refresher.PGExecute", pgexecute_class): # Set refreshers to 0: we're not testing refresh logic here refresher.refreshers = {} refresher.refresh(pgexecute, special, callbacks) time.sleep(1) # Wait for the thread to work. - assert (callbacks[0].call_count == 1) + assert callbacks[0].call_count == 1 diff --git a/tests/test_fuzzy_completion.py b/tests/test_fuzzy_completion.py index 67f904ba..30f9de2c 100644 --- a/tests/test_fuzzy_completion.py +++ b/tests/test_fuzzy_completion.py @@ -5,6 +5,7 @@ import pytest @pytest.fixture def completer(): import pgcli.pgcompleter as pgcompleter + return pgcompleter.PGCompleter() @@ -22,8 +23,8 @@ def test_ranking_ignores_identifier_quotes(completer): """ - text = 'user' - collection = ['user_action', '"user"'] + text = "user" + collection = ["user_action", '"user"'] matches = completer.find_matches(text, collection) assert len(matches) == 2 @@ -42,18 +43,17 @@ def test_ranking_based_on_shortest_match(completer): """ - text = 'user' - collection = ['api_user', 'user_group'] + text = "user" + collection = ["api_user", "user_group"] matches = completer.find_matches(text, collection) assert matches[1].priority > matches[0].priority -@pytest.mark.parametrize('collection', [ - ['user_action', 'user'], - ['user_group', 'user'], - ['user_group', 'user_action'], -]) +@pytest.mark.parametrize( + "collection", + [["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]], +) def test_should_break_ties_using_lexical_order(completer, collection): """Fuzzy result rank should use lexical order to break ties. @@ -68,7 +68,7 @@ def test_should_break_ties_using_lexical_order(completer, collection): """ - text = 'user' + text = "user" matches = completer.find_matches(text, collection) assert matches[1].priority > matches[0].priority @@ -81,8 +81,8 @@ def test_matching_should_be_case_insensitive(completer): are still matched. """ - text = 'foo' - collection = ['Foo', 'FOO', 'fOO'] + text = "foo" + collection = ["Foo", "FOO", "fOO"] matches = completer.find_matches(text, collection) assert len(matches) == 3 diff --git a/tests/test_main.py b/tests/test_main.py index e8e801d5..c55944b6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,24 +5,27 @@ import platform import mock import pytest + try: import setproctitle except ImportError: setproctitle = None from pgcli.main import ( - obfuscate_process_password, format_output, PGCli, OutputSettings, COLOR_CODE_REGEX + obfuscate_process_password, + format_output, + PGCli, + OutputSettings, + COLOR_CODE_REGEX, ) from pgcli.pgexecute import PGExecute -from pgspecial.main import (PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS) +from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS from utils import dbtest, run from collections import namedtuple -@pytest.mark.skipif(platform.system() == 'Windows', - reason='Not applicable in windows') -@pytest.mark.skipif(not setproctitle, - reason='setproctitle not available') +@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows") +@pytest.mark.skipif(not setproctitle, reason="setproctitle not available") def test_obfuscate_process_password(): original_title = setproctitle.getproctitle() @@ -54,24 +57,25 @@ def test_obfuscate_process_password(): def test_format_output(): - settings = OutputSettings(table_format='psql', dcmlfmt='d', floatfmt='g') - results = format_output('Title', [('abc', 'def')], ['head1', 'head2'], - 'test status', settings) + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g") + results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) expected = [ - 'Title', - '+---------+---------+', - '| head1 | head2 |', - '|---------+---------|', - '| abc | def |', - '+---------+---------+', - 'test status' + "Title", + "+---------+---------+", + "| head1 | head2 |", + "|---------+---------|", + "| abc | def |", + "+---------+---------+", + "test status", ] assert list(results) == expected @dbtest def test_format_array_output(executor): - statement = u""" + statement = """ SELECT array[1, 2, 3]::bigint[] as bigint_array, '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, @@ -81,20 +85,20 @@ def test_format_array_output(executor): """ results = run(executor, statement) expected = [ - '+----------------+------------------------+--------------+', - '| bigint_array | nested_numeric_array | 配列 |', - '|----------------+------------------------+--------------|', - '| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |', - '| {} | | {} |', - '+----------------+------------------------+--------------+', - 'SELECT 2' + "+----------------+------------------------+--------------+", + "| bigint_array | nested_numeric_array | 配列 |", + "|----------------+------------------------+--------------|", + "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |", + "| {} | | {} |", + "+----------------+------------------------+--------------+", + "SELECT 2", ] assert list(results) == expected @dbtest def test_format_array_output_expanded(executor): - statement = u""" + statement = """ SELECT array[1, 2, 3]::bigint[] as bigint_array, '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, @@ -104,110 +108,111 @@ def test_format_array_output_expanded(executor): """ results = run(executor, statement, expanded=True) expected = [ - '-[ RECORD 1 ]-------------------------', - 'bigint_array | {1,2,3}', - 'nested_numeric_array | {{1,2},{3,4}}', - '配列 | {å,魚,текст}', - '-[ RECORD 2 ]-------------------------', - 'bigint_array | {}', - 'nested_numeric_array | ', - '配列 | {}', - 'SELECT 2' + "-[ RECORD 1 ]-------------------------", + "bigint_array | {1,2,3}", + "nested_numeric_array | {{1,2},{3,4}}", + "配列 | {å,魚,текст}", + "-[ RECORD 2 ]-------------------------", + "bigint_array | {}", + "nested_numeric_array | ", + "配列 | {}", + "SELECT 2", ] - assert '\n'.join(results) == '\n'.join(expected) + assert "\n".join(results) == "\n".join(expected) def test_format_output_auto_expand(): settings = OutputSettings( - table_format='psql', dcmlfmt='d', floatfmt='g', max_width=100) - table_results = format_output('Title', [('abc', 'def')], - ['head1', 'head2'], 'test status', settings) + table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100 + ) + table_results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) table = [ - 'Title', - '+---------+---------+', - '| head1 | head2 |', - '|---------+---------|', - '| abc | def |', - '+---------+---------+', - 'test status' + "Title", + "+---------+---------+", + "| head1 | head2 |", + "|---------+---------|", + "| abc | def |", + "+---------+---------+", + "test status", ] assert list(table_results) == table expanded_results = format_output( - 'Title', - [('abc', 'def')], - ['head1', 'head2'], - 'test status', - settings._replace(max_width=1) + "Title", + [("abc", "def")], + ["head1", "head2"], + "test status", + settings._replace(max_width=1), ) expanded = [ - 'Title', - '-[ RECORD 1 ]-------------------------', - 'head1 | abc', - 'head2 | def', - 'test status' + "Title", + "-[ RECORD 1 ]-------------------------", + "head1 | abc", + "head2 | def", + "test status", ] - assert '\n'.join(expanded_results) == '\n'.join(expanded) + assert "\n".join(expanded_results) == "\n".join(expanded) -termsize = namedtuple('termsize', ['rows', 'columns']) -test_line = '-' * 10 +termsize = namedtuple("termsize", ["rows", "columns"]) +test_line = "-" * 10 test_data = [ - (10, 10, '\n'.join([test_line] * 7)), - (10, 10, '\n'.join([test_line] * 6)), - (10, 10, '\n'.join([test_line] * 5)), - (10, 10, '-' * 11), - (10, 10, '-' * 10), - (10, 10, '-' * 9), + (10, 10, "\n".join([test_line] * 7)), + (10, 10, "\n".join([test_line] * 6)), + (10, 10, "\n".join([test_line] * 5)), + (10, 10, "-" * 11), + (10, 10, "-" * 10), + (10, 10, "-" * 9), ] # 4 lines are reserved at the bottom of the terminal for pgcli's prompt -use_pager_when_on = [True, - True, - False, - True, - False, - False] +use_pager_when_on = [True, True, False, True, False, False] # Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL -test_ids = ["Output longer than terminal height", - "Output equal to terminal height", - "Output shorter than terminal height", - "Output longer than terminal width", - "Output equal to terminal width", - "Output shorter than terminal width"] +test_ids = [ + "Output longer than terminal height", + "Output equal to terminal height", + "Output shorter than terminal height", + "Output longer than terminal width", + "Output equal to terminal width", + "Output shorter than terminal width", +] @pytest.fixture def pset_pager_mocks(): cli = PGCli() cli.watch_command = None - with mock.patch('pgcli.main.click.echo') as mock_echo, \ - mock.patch('pgcli.main.click.echo_via_pager') as mock_echo_via_pager, \ - mock.patch.object(cli, 'prompt_app') as mock_app: + with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( + "pgcli.main.click.echo_via_pager" + ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: yield cli, mock_echo, mock_echo_via_pager, mock_app -@pytest.mark.parametrize('term_height,term_width,text', test_data, ids=test_ids) +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks): cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks mock_cli.output.get_size.return_value = termsize( - rows=term_height, columns=term_width) + rows=term_height, columns=term_width + ) - with mock.patch.object(cli.pgspecial, 'pager_config', PAGER_OFF): + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF): cli.echo_via_pager(text) mock_echo.assert_called() mock_echo_via_pager.assert_not_called() -@pytest.mark.parametrize('term_height,term_width,text', test_data, ids=test_ids) +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks mock_cli.output.get_size.return_value = termsize( - rows=term_height, columns=term_width) + rows=term_height, columns=term_width + ) - with mock.patch.object(cli.pgspecial, 'pager_config', PAGER_ALWAYS): + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS): cli.echo_via_pager(text) mock_echo.assert_not_called() @@ -217,13 +222,16 @@ def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)] -@pytest.mark.parametrize('term_height,term_width,text,use_pager', pager_on_test_data, ids=test_ids) +@pytest.mark.parametrize( + "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids +) def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks): cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks mock_cli.output.get_size.return_value = termsize( - rows=term_height, columns=term_width) + rows=term_height, columns=term_width + ) - with mock.patch.object(cli.pgspecial, 'pager_config', PAGER_LONG_OUTPUT): + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT): cli.echo_via_pager(text) if use_pager: @@ -234,24 +242,28 @@ def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mock mock_echo.assert_called() -@pytest.mark.parametrize('text,expected_length', [ - (u"22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", 78), - (u"=\u001b[m=", 2), - (u"-\u001b]23\u0007-", 2), -]) +@pytest.mark.parametrize( + "text,expected_length", + [ + ( + "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", + 78, + ), + ("=\u001b[m=", 2), + ("-\u001b]23\u0007-", 2), + ], +) def test_color_pattern(text, expected_length, pset_pager_mocks): cli = pset_pager_mocks[0] - assert len(COLOR_CODE_REGEX.sub('', text)) == expected_length + assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length + @dbtest def test_i_works(tmpdir, executor): sqlfile = tmpdir.join("test.sql") sqlfile.write("SELECT NOW()") rcfile = str(tmpdir.join("rcfile")) - cli = PGCli( - pgexecute=executor, - pgclirc_file=rcfile, - ) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) statement = r"\i {0}".format(sqlfile) run(executor, statement, pgspecial=cli.pgspecial) @@ -264,58 +276,62 @@ def test_missing_rc_dir(tmpdir): def test_quoted_db_uri(tmpdir): - with mock.patch.object(PGCli, 'connect') as mock_connect: + with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri('postgres://bar%5E:%5Dfoo@baz.com/testdb%5B') - mock_connect.assert_called_with(database='testdb[', - host='baz.com', - user='bar^', - passwd=']foo') + cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B") + mock_connect.assert_called_with( + database="testdb[", host="baz.com", user="bar^", passwd="]foo" + ) def test_ssl_db_uri(tmpdir): - with mock.patch.object(PGCli, 'connect') as mock_connect: + with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri( - 'postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?' - 'sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem') - mock_connect.assert_called_with(database='testdb[', - host='baz.com', - user='bar^', - passwd=']foo', - sslmode='verify-full', - sslcert='my.pem', - sslkey='my-key.pem', - sslrootcert='ca.pem') + "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?" + "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" + ) + mock_connect.assert_called_with( + database="testdb[", + host="baz.com", + user="bar^", + passwd="]foo", + sslmode="verify-full", + sslcert="my.pem", + sslkey="my-key.pem", + sslrootcert="ca.pem", + ) def test_port_db_uri(tmpdir): - with mock.patch.object(PGCli, 'connect') as mock_connect: + with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri('postgres://bar:foo@baz.com:2543/testdb') - mock_connect.assert_called_with(database='testdb', - host='baz.com', - user='bar', - passwd='foo', - port='2543') + cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb") + mock_connect.assert_called_with( + database="testdb", host="baz.com", user="bar", passwd="foo", port="2543" + ) def test_multihost_db_uri(tmpdir): - with mock.patch.object(PGCli, 'connect') as mock_connect: + with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri( - 'postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb') - mock_connect.assert_called_with(database='testdb', - host='baz1.com,baz2.com,baz3.com', - user='bar', - passwd='foo', - port='2543,2543,2543') + "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb" + ) + mock_connect.assert_called_with( + database="testdb", + host="baz1.com,baz2.com,baz3.com", + user="bar", + passwd="foo", + port="2543,2543,2543", + ) def test_application_name_db_uri(tmpdir): - with mock.patch.object(PGExecute, '__init__') as mock_pgexecute: + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: mock_pgexecute.return_value = None cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri('postgres://bar@baz.com/?application_name=cow') - mock_pgexecute.assert_called_with('bar', 'bar', '', 'baz.com', '', '', - application_name='cow') + cli.connect_uri("postgres://bar@baz.com/?application_name=cow") + mock_pgexecute.assert_called_with( + "bar", "bar", "", "baz.com", "", "", application_name="cow" + ) diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py index 6560f82e..cf2824e4 100644 --- a/tests/test_naive_completion.py +++ b/tests/test_naive_completion.py @@ -8,77 +8,97 @@ from utils import completions_to_set @pytest.fixture def completer(): import pgcli.pgcompleter as pgcompleter + return pgcompleter.PGCompleter(smart_completion=False) @pytest.fixture def complete_event(): from mock import Mock + return Mock() + def test_empty_string_completion(completer, complete_event): - text = '' + text = "" position = 0 - result = completions_to_set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == completions_to_set( - map(Completion, completer.all_completions)) + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) def test_select_keyword_completion(completer, complete_event): - text = 'SEL' - position = len('SEL') - result = completions_to_set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == completions_to_set( - [Completion(text='SELECT', start_position=-3)]) + text = "SEL" + position = len("SEL") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set([Completion(text="SELECT", start_position=-3)]) def test_function_name_completion(completer, complete_event): - text = 'SELECT MA' - position = len('SELECT MA') - result = completions_to_set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == completions_to_set([ - Completion(text='MATERIALIZED VIEW', start_position=-2), - Completion(text='MAX', start_position=-2), - Completion(text='MAXEXTENTS', start_position=-2)]) + text = "SELECT MA" + position = len("SELECT MA") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set( + [ + Completion(text="MATERIALIZED VIEW", start_position=-2), + Completion(text="MAX", start_position=-2), + Completion(text="MAXEXTENTS", start_position=-2), + ] + ) def test_column_name_completion(completer, complete_event): - text = 'SELECT FROM users' - position = len('SELECT ') - result = completions_to_set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == completions_to_set( - map(Completion, completer.all_completions)) + text = "SELECT FROM users" + position = len("SELECT ") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) def test_alter_well_known_keywords_completion(completer, complete_event): - text = 'ALTER ' + text = "ALTER " position = len(text) - result = completions_to_set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event, - smart_completion=True)) - assert result > completions_to_set([ - Completion(text="DATABASE", display_meta='keyword'), - Completion(text="TABLE", display_meta='keyword'), - Completion(text="SYSTEM", display_meta='keyword'), - ]) - assert completions_to_set( - [Completion(text="CREATE", display_meta="keyword")]) not in result + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event, + smart_completion=True, + ) + ) + assert result > completions_to_set( + [ + Completion(text="DATABASE", display_meta="keyword"), + Completion(text="TABLE", display_meta="keyword"), + Completion(text="SYSTEM", display_meta="keyword"), + ] + ) + assert ( + completions_to_set([Completion(text="CREATE", display_meta="keyword")]) + not in result + ) def test_special_name_completion(completer, complete_event): - text = '\\' - position = len('\\') - result = completions_to_set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + text = "\\" + position = len("\\") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) # Special commands will NOT be suggested during naive completion mode. assert result == completions_to_set([]) diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index f49cb6bc..55d7f548 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -14,54 +14,77 @@ from pgcli.packages.parseutils.meta import FunctionMetadata def function_meta_data( - func_name, schema_name='public', arg_names=None, arg_types=None, - arg_modes=None, return_type=None, is_aggregate=False, is_window=False, - is_set_returning=False, is_extension=False, arg_defaults=None + func_name, + schema_name="public", + arg_names=None, + arg_types=None, + arg_modes=None, + return_type=None, + is_aggregate=False, + is_window=False, + is_set_returning=False, + is_extension=False, + arg_defaults=None, ): return FunctionMetadata( - schema_name, func_name, arg_names, arg_types, arg_modes, return_type, - is_aggregate, is_window, is_set_returning, is_extension, arg_defaults + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, ) + @dbtest def test_conn(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - assert run(executor, '''select * from test''', join=True) == dedent("""\ + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ +-----+ | a | |-----| | abc | +-----+ - SELECT 1""") + SELECT 1""" + ) @dbtest def test_copy(executor): executor_copy = executor.copy() - run(executor_copy, '''create table test(a text)''') - run(executor_copy, '''insert into test values('abc')''') - assert run(executor_copy, '''select * from test''', join=True) == dedent("""\ + run(executor_copy, """create table test(a text)""") + run(executor_copy, """insert into test values('abc')""") + assert run(executor_copy, """select * from test""", join=True) == dedent( + """\ +-----+ | a | |-----| | abc | +-----+ - SELECT 1""") - + SELECT 1""" + ) @dbtest def test_bools_are_treated_as_strings(executor): - run(executor, '''create table test(a boolean)''') - run(executor, '''insert into test values(True)''') - assert run(executor, '''select * from test''', join=True) == dedent("""\ + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ +------+ | a | |------| | True | +------+ - SELECT 1""") + SELECT 1""" + ) @dbtest @@ -76,27 +99,31 @@ def test_schemata_table_views_and_columns_query(executor): # schemata # don't enforce all members of the schemas since they may include postgres # temporary schemas - assert set(executor.schemata()) >= set([ - 'public', 'pg_catalog', 'information_schema', 'schema1', 'schema2']) - assert executor.search_path() == ['pg_catalog', 'public'] + assert set(executor.schemata()) >= set( + ["public", "pg_catalog", "information_schema", "schema1", "schema2"] + ) + assert executor.search_path() == ["pg_catalog", "public"] # tables - assert set(executor.tables()) >= set([ - ('public', 'a'), ('public', 'b'), ('schema1', 'c')]) + assert set(executor.tables()) >= set( + [("public", "a"), ("public", "b"), ("schema1", "c")] + ) - assert set(executor.table_columns()) >= set([ - ('public', 'a', 'x', 'text', False, None), - ('public', 'a', 'y', 'text', False, None), - ('public', 'b', 'z', 'text', False, None), - ('schema1', 'c', 'w', 'text', True, "'meow'::text"), - ]) + assert set(executor.table_columns()) >= set( + [ + ("public", "a", "x", "text", False, None), + ("public", "a", "y", "text", False, None), + ("public", "b", "z", "text", False, None), + ("schema1", "c", "w", "text", True, "'meow'::text"), + ] + ) # views - assert set(executor.views()) >= set([ - ('public', 'd')]) + assert set(executor.views()) >= set([("public", "d")]) - assert set(executor.view_columns()) >= set([ - ('public', 'd', 'e', 'integer', False, None)]) + assert set(executor.view_columns()) >= set( + [("public", "d", "e", "integer", False, None)] + ) @dbtest @@ -104,82 +131,95 @@ def test_foreign_key_query(executor): run(executor, "create schema schema1") run(executor, "create schema schema2") run(executor, "create table schema1.parent(parentid int PRIMARY KEY)") - run(executor, "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)") + run( + executor, + "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)", + ) - assert set(executor.foreignkeys()) >= set([ - ('schema1', 'parent', 'parentid', 'schema2', 'child', 'motherid')]) + assert set(executor.foreignkeys()) >= set( + [("schema1", "parent", "parentid", "schema2", "child", "motherid")] + ) @dbtest def test_functions_query(executor): - run(executor, '''create function func1() returns int - language sql as $$select 1$$''') - run(executor, 'create schema schema1') - run(executor, '''create function schema1.func2() returns int - language sql as $$select 2$$''') + run( + executor, + """create function func1() returns int + language sql as $$select 1$$""", + ) + run(executor, "create schema schema1") + run( + executor, + """create function schema1.func2() returns int + language sql as $$select 2$$""", + ) - run(executor, '''create function func3() + run( + executor, + """create function func3() returns table(x int, y int) language sql - as $$select 1, 2 from generate_series(1,5)$$;''') + as $$select 1, 2 from generate_series(1,5)$$;""", + ) - run(executor, '''create function func4(x int) returns setof int language sql - as $$select generate_series(1,5)$$;''') + run( + executor, + """create function func4(x int) returns setof int language sql + as $$select generate_series(1,5)$$;""", + ) funcs = set(executor.functions()) - assert funcs >= set([ - function_meta_data( - func_name='func1', - return_type='integer' - ), - function_meta_data( - func_name='func3', - arg_names=['x', 'y'], - arg_types=['integer', 'integer'], - arg_modes=['t', 't'], - return_type='record', - is_set_returning=True - ), - function_meta_data( - schema_name='public', - func_name='func4', - arg_names=('x',), - arg_types=('integer',), - return_type='integer', - is_set_returning=True - ), - function_meta_data( - schema_name='schema1', - func_name='func2', - return_type='integer' - ), - ]) + assert funcs >= set( + [ + function_meta_data(func_name="func1", return_type="integer"), + function_meta_data( + func_name="func3", + arg_names=["x", "y"], + arg_types=["integer", "integer"], + arg_modes=["t", "t"], + return_type="record", + is_set_returning=True, + ), + function_meta_data( + schema_name="public", + func_name="func4", + arg_names=("x",), + arg_types=("integer",), + return_type="integer", + is_set_returning=True, + ), + function_meta_data( + schema_name="schema1", func_name="func2", return_type="integer" + ), + ] + ) @dbtest def test_datatypes_query(executor): - run(executor, 'create type foo AS (a int, b text)') + run(executor, "create type foo AS (a int, b text)") types = list(executor.datatypes()) - assert types == [('public', 'foo')] + assert types == [("public", "foo")] @dbtest def test_database_list(executor): databases = executor.databases() - assert '_test_db' in databases + assert "_test_db" in databases @dbtest def test_invalid_syntax(executor, exception_formatter): - result = run(executor, 'invalid syntax!', - exception_formatter=exception_formatter) + result = run(executor, "invalid syntax!", exception_formatter=exception_formatter) assert 'syntax error at or near "invalid"' in result[0] @dbtest def test_invalid_column_name(executor, exception_formatter): - result = run(executor, 'select invalid command', - exception_formatter=exception_formatter) + result = run( + executor, "select invalid command", exception_formatter=exception_formatter + ) assert 'column "invalid" does not exist' in result[0] @@ -194,14 +234,15 @@ def test_unicode_support_in_output(executor, expanded): run(executor, u"insert into unicodechars (t) values ('é')") # See issue #24, this raises an exception without proper handling - assert u'é' in run(executor, "select * from unicodechars", - join=True, expanded=expanded) + assert u"é" in run( + executor, "select * from unicodechars", join=True, expanded=expanded + ) @dbtest def test_not_is_special(executor, pgspecial): """is_special is set to false for database queries.""" - query = 'select 1' + query = "select 1" result = list(executor.run(query, pgspecial=pgspecial)) success, is_special = result[0][5:] assert success == True @@ -213,22 +254,22 @@ def test_execute_from_file_no_arg(executor, pgspecial): """\i without a filename returns an error.""" result = list(executor.run("\i", pgspecial=pgspecial)) status, sql, success, is_special = result[0][3:] - assert 'missing required argument' in status + assert "missing required argument" in status assert success == False assert is_special == True @dbtest -@patch('pgcli.main.os') +@patch("pgcli.main.os") def test_execute_from_file_io_error(os, executor, pgspecial): """\i with an io_error returns an error.""" # Inject an IOError. - os.path.expanduser.side_effect = IOError('test') + os.path.expanduser.side_effect = IOError("test") # Check the result. result = list(executor.run("\i test", pgspecial=pgspecial)) status, sql, success, is_special = result[0][3:] - assert status == 'test' + assert status == "test" assert success == False assert is_special == True @@ -252,9 +293,12 @@ def test_multiple_queries_with_special_command_same_line(executor, pgspecial): @dbtest def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter): - result = run(executor, u"select 'fooé'; invalid syntax é", - exception_formatter=exception_formatter) - assert u'fooé' in result[3] + result = run( + executor, + u"select 'fooé'; invalid syntax é", + exception_formatter=exception_formatter, + ) + assert u"fooé" in result[3] assert 'syntax error at or near "invalid"' in result[-1] @@ -265,23 +309,22 @@ def pgspecial(): @dbtest def test_special_command_help(executor, pgspecial): - result = run(executor, '\\?', pgspecial=pgspecial)[1].split('|') - assert u'Command' in result[1] - assert u'Description' in result[2] + result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|") + assert u"Command" in result[1] + assert u"Description" in result[2] @dbtest def test_bytea_field_support_in_output(executor): run(executor, "create table binarydata(c bytea)") - run(executor, - "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))") + run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))") - assert u'\\xdeadbeef' in run(executor, "select * from binarydata", join=True) + assert u"\\xdeadbeef" in run(executor, "select * from binarydata", join=True) @dbtest def test_unicode_support_in_unknown_type(executor): - assert u'日本語' in run(executor, u"SELECT '日本語' AS japanese;", join=True) + assert u"日本語" in run(executor, u"SELECT '日本語' AS japanese;", join=True) @dbtest @@ -289,15 +332,16 @@ def test_unicode_support_in_enum_type(executor): run(executor, u"CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')") run(executor, u"CREATE TABLE person (name TEXT, current_mood mood)") run(executor, u"INSERT INTO person VALUES ('Moe', '日本語')") - assert u'日本語' in run(executor, u"SELECT * FROM person", join=True) + assert u"日本語" in run(executor, u"SELECT * FROM person", join=True) @requires_json def test_json_renders_without_u_prefix(executor, expanded): run(executor, u"create table jsontest(d json)") run(executor, u"""insert into jsontest (d) values ('{"name": "Éowyn"}')""") - result = run(executor, u"SELECT d FROM jsontest LIMIT 1", - join=True, expanded=expanded) + result = run( + executor, u"SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded + ) assert u'{"name": "Éowyn"}' in result @@ -306,8 +350,9 @@ def test_json_renders_without_u_prefix(executor, expanded): def test_jsonb_renders_without_u_prefix(executor, expanded): run(executor, "create table jsonbtest(d jsonb)") run(executor, u"""insert into jsonbtest (d) values ('{"name": "Éowyn"}')""") - result = run(executor, "SELECT d FROM jsonbtest LIMIT 1", - join=True, expanded=expanded) + result = run( + executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded + ) assert u'{"name": "Éowyn"}' in result @@ -315,45 +360,65 @@ def test_jsonb_renders_without_u_prefix(executor, expanded): @dbtest def test_date_time_types(executor): run(executor, "SET TIME ZONE UTC") - assert run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] \ - == "| 00:00:00 |" - assert run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split("\n")[3] \ + assert ( + run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] + == "| 00:00:00 |" + ) + assert ( + run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split( + "\n" + )[3] == "| 00:00:00+14:59 |" - assert run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[3] \ - == "| 4713-01-01 BC |" - assert run(executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True).split("\n")[3] \ - == "| 4713-01-01 00:00:00 BC |" - assert run(executor, "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))", join=True).split("\n")[3] \ - == "| 4713-01-01 00:00:00+00 BC |" - assert run(executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True).split("\n")[3] \ - == "| -123456789 days, 12:23:56 |" + ) + assert ( + run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[ + 3 + ] + == "| 4713-01-01 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True + ).split("\n")[3] + == "| 4713-01-01 00:00:00 BC |" + ) + assert ( + run( + executor, + "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))", + join=True, + ).split("\n")[3] + == "| 4713-01-01 00:00:00+00 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True + ).split("\n")[3] + == "| -123456789 days, 12:23:56 |" + ) @dbtest -@pytest.mark.parametrize('value', ['10000000', '10000000.0', '10000000000000']) +@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"]) def test_large_numbers_render_directly(executor, value): run(executor, "create table numbertest(a numeric)") - run(executor, - "insert into numbertest (a) values ({0})".format(value)) + run(executor, "insert into numbertest (a) values ({0})".format(value)) assert value in run(executor, "select * from numbertest", join=True) @dbtest -@pytest.mark.parametrize('command', ['di', 'dv', 'ds', 'df', 'dT']) -@pytest.mark.parametrize('verbose', ['', '+']) -@pytest.mark.parametrize('pattern', ['', 'x', '*.*', 'x.y', 'x.*', '*.y']) +@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"]) +@pytest.mark.parametrize("verbose", ["", "+"]) +@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"]) def test_describe_special(executor, command, verbose, pattern, pgspecial): # We don't have any tests for the output of any of the special commands, # but we can at least make sure they run without error - sql = r'\{command}{verbose} {pattern}'.format(**locals()) + sql = r"\{command}{verbose} {pattern}".format(**locals()) list(executor.run(sql, pgspecial=pgspecial)) @dbtest -@pytest.mark.parametrize('sql', [ - 'invalid sql', - 'SELECT 1; select error;', -]) +@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"]) def test_raises_with_no_formatter(executor, sql): with pytest.raises(psycopg2.ProgrammingError): list(executor.run(sql)) @@ -361,19 +426,24 @@ def test_raises_with_no_formatter(executor, sql): @dbtest def test_on_error_resume(executor, exception_formatter): - sql = 'select 1; error; select 1;' - result = list(executor.run(sql, on_error_resume=True, - exception_formatter=exception_formatter)) + sql = "select 1; error; select 1;" + result = list( + executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter) + ) assert len(result) == 3 @dbtest def test_on_error_stop(executor, exception_formatter): - sql = 'select 1; error; select 1;' - result = list(executor.run(sql, on_error_resume=False, - exception_formatter=exception_formatter)) + sql = "select 1; error; select 1;" + result = list( + executor.run( + sql, on_error_resume=False, exception_formatter=exception_formatter + ) + ) assert len(result) == 2 + # @dbtest # def test_unicode_notices(executor): # sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;" @@ -384,50 +454,55 @@ def test_on_error_stop(executor, exception_formatter): @dbtest def test_nonexistent_function_definition(executor): with pytest.raises(RuntimeError): - result = executor.view_definition('there_is_no_such_function') + result = executor.view_definition("there_is_no_such_function") @dbtest def test_function_definition(executor): - run(executor, ''' + run( + executor, + """ CREATE OR REPLACE FUNCTION public.the_number_three() RETURNS int LANGUAGE sql AS $function$ select 3; $function$ - ''') - result = executor.function_definition('the_number_three') + """, + ) + result = executor.function_definition("the_number_three") @dbtest def test_view_definition(executor): - run(executor, 'create table tbl1 (a text, b numeric)') - run(executor, 'create view vw1 AS SELECT * FROM tbl1') - run(executor, 'create materialized view mvw1 AS SELECT * FROM tbl1') - result = executor.view_definition('vw1') - assert 'FROM tbl1' in result + run(executor, "create table tbl1 (a text, b numeric)") + run(executor, "create view vw1 AS SELECT * FROM tbl1") + run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1") + result = executor.view_definition("vw1") + assert "FROM tbl1" in result # import pytest; pytest.set_trace() - result = executor.view_definition('mvw1') - assert 'MATERIALIZED VIEW' in result + result = executor.view_definition("mvw1") + assert "MATERIALIZED VIEW" in result @dbtest def test_nonexistent_view_definition(executor): with pytest.raises(RuntimeError): - result = executor.view_definition('there_is_no_such_view') + result = executor.view_definition("there_is_no_such_view") with pytest.raises(RuntimeError): - result = executor.view_definition('mvw1') + result = executor.view_definition("mvw1") @dbtest def test_short_host(executor): - with patch.object(executor, 'host', 'localhost'): - assert executor.short_host == 'localhost' - with patch.object(executor, 'host', 'localhost.example.org'): - assert executor.short_host == 'localhost' - with patch.object(executor, 'host', 'localhost1.example.org,localhost2.example.org'): - assert executor.short_host == 'localhost1' + with patch.object(executor, "host", "localhost"): + assert executor.short_host == "localhost" + with patch.object(executor, "host", "localhost.example.org"): + assert executor.short_host == "localhost" + with patch.object( + executor, "host", "localhost1.example.org,localhost2.example.org" + ): + assert executor.short_host == "localhost1" class BrokenConnection(object): @@ -441,8 +516,15 @@ class BrokenConnection(object): def test_exit_without_active_connection(executor): quit_handler = MagicMock() pgspecial = PGSpecial() - pgspecial.register(quit_handler, '\\q', '\\q', 'Quit pgcli.', - arg_type=NO_QUERY, case_sensitive=True, aliases=(':q',)) + pgspecial.register( + quit_handler, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) with patch.object(executor, "conn", BrokenConnection()): # we should be able to quit the app, even without active connection diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py index b08fa029..eaeaf12f 100644 --- a/tests/test_pgspecial.py +++ b/tests/test_pgspecial.py @@ -1,92 +1,78 @@ import pytest from pgcli.packages.sqlcompletion import ( - suggest_type, Special, Database, Schema, Table, View, Function, Datatype) + suggest_type, + Special, + Database, + Schema, + Table, + View, + Function, + Datatype, +) def test_slash_suggests_special(): - suggestions = suggest_type('\\', '\\') - assert set(suggestions) == set( - [Special()]) + suggestions = suggest_type("\\", "\\") + assert set(suggestions) == set([Special()]) def test_slash_d_suggests_special(): - suggestions = suggest_type('\\d', '\\d') - assert set(suggestions) == set( - [Special()]) + suggestions = suggest_type("\\d", "\\d") + assert set(suggestions) == set([Special()]) def test_dn_suggests_schemata(): - suggestions = suggest_type('\\dn ', '\\dn ') + suggestions = suggest_type("\\dn ", "\\dn ") assert suggestions == (Schema(),) - suggestions = suggest_type('\\dn xxx', '\\dn xxx') + suggestions = suggest_type("\\dn xxx", "\\dn xxx") assert suggestions == (Schema(),) def test_d_suggests_tables_views_and_schemas(): - suggestions = suggest_type('\d ', '\d ') - assert set(suggestions) == set([ - Schema(), - Table(schema=None), - View(schema=None), - ]) + suggestions = suggest_type("\d ", "\d ") + assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)]) - suggestions = suggest_type('\d xxx', '\d xxx') - assert set(suggestions) == set([ - Schema(), - Table(schema=None), - View(schema=None), - ]) + suggestions = suggest_type("\d xxx", "\d xxx") + assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)]) def test_d_dot_suggests_schema_qualified_tables_or_views(): - suggestions = suggest_type('\d myschema.', '\d myschema.') - assert set(suggestions) == set([ - Table(schema='myschema'), - View(schema='myschema'), - ]) + suggestions = suggest_type("\d myschema.", "\d myschema.") + assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")]) - suggestions = suggest_type('\d myschema.xxx', '\d myschema.xxx') - assert set(suggestions) == set([ - Table(schema='myschema'), - View(schema='myschema'), - ]) + suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx") + assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")]) def test_df_suggests_schema_or_function(): - suggestions = suggest_type('\\df xxx', '\\df xxx') - assert set(suggestions) == set([ - Function(schema=None, usage='special'), - Schema(), - ]) + suggestions = suggest_type("\\df xxx", "\\df xxx") + assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()]) - suggestions = suggest_type('\\df myschema.xxx', '\\df myschema.xxx') - assert suggestions == (Function(schema='myschema', usage='special'),) + suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx") + assert suggestions == (Function(schema="myschema", usage="special"),) def test_leading_whitespace_ok(): - cmd = '\\dn ' - whitespace = ' ' + cmd = "\\dn " + whitespace = " " suggestions = suggest_type(whitespace + cmd, whitespace + cmd) assert suggestions == suggest_type(cmd, cmd) def test_dT_suggests_schema_or_datatypes(): - text = '\\dT ' + text = "\\dT " suggestions = suggest_type(text, text) - assert set(suggestions) == set([ - Schema(), - Datatype(schema=None), - ]) + assert set(suggestions) == set([Schema(), Datatype(schema=None)]) def test_schema_qualified_dT_suggests_datatypes(): - text = '\\dT foo.' + text = "\\dT foo." suggestions = suggest_type(text, text) - assert suggestions == (Datatype(schema='foo'),) + assert suggestions == (Datatype(schema="foo"),) -@pytest.mark.parametrize('command', ['\\c ', '\\connect ']) +@pytest.mark.parametrize("command", ["\\c ", "\\connect "]) def test_c_suggests_databases(command): suggestions = suggest_type(command, command) assert suggestions == (Database(),) diff --git a/tests/test_prioritization.py b/tests/test_prioritization.py index 3046456e..f5b67006 100644 --- a/tests/test_prioritization.py +++ b/tests/test_prioritization.py @@ -3,18 +3,18 @@ from pgcli.packages.prioritization import PrevalenceCounter def test_prevalence_counter(): counter = PrevalenceCounter() - sql = '''SELECT * FROM foo WHERE bar GROUP BY baz; + sql = """SELECT * FROM foo WHERE bar GROUP BY baz; select * from foo; SELECT * FROM foo WHERE bar GROUP - BY baz''' + BY baz""" counter.update(sql) - keywords = ['SELECT', 'FROM', 'GROUP BY'] + keywords = ["SELECT", "FROM", "GROUP BY"] expected = [3, 3, 2] kw_counts = [counter.keyword_count(x) for x in keywords] assert kw_counts == expected - assert counter.keyword_count('NOSUCHKEYWORD') == 0 + assert counter.keyword_count("NOSUCHKEYWORD") == 0 - names = ['foo', 'bar', 'baz'] + names = ["foo", "bar", "baz"] name_counts = [counter.name_count(x) for x in names] assert name_counts == [3, 2, 2] diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py index 7fb7482e..4986dc57 100644 --- a/tests/test_prompt_utils.py +++ b/tests/test_prompt_utils.py @@ -7,7 +7,7 @@ from pgcli.packages.prompt_utils import confirm_destructive_query def test_confirm_destructive_query_notty(): - stdin = click.get_text_stream('stdin') + stdin = click.get_text_stream("stdin") if not stdin.isatty(): - sql = 'drop database foo;' + sql = "drop database foo;" assert confirm_destructive_query(sql) is None diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py index 91389a15..bc66f1e7 100644 --- a/tests/test_rowlimit.py +++ b/tests/test_rowlimit.py @@ -6,6 +6,7 @@ import pytest # We need this fixtures beacause we need PGCli object to be created # after test collection so it has config loaded from temp directory + @pytest.fixture(scope="module") def default_pgcli_obj(): return PGCli() @@ -24,9 +25,7 @@ def LIMIT(DEFAULT): @pytest.fixture(scope="module") def over_default(DEFAULT): over_default_cursor = Mock() - over_default_cursor.configure_mock( - rowcount=DEFAULT + 10 - ) + over_default_cursor.configure_mock(rowcount=DEFAULT + 10) return over_default_cursor diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index bfadd1d0..9125f124 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -1,180 +1,244 @@ from __future__ import unicode_literals, print_function import itertools -from metadata import (MetaData, alias, name_join, fk_join, join, - schema, table, function, wildcard_expansion, column, - get_result, result_set, qual, no_qual, parametrize) +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + schema, + table, + function, + wildcard_expansion, + column, + get_result, + result_set, + qual, + no_qual, + parametrize, +) from utils import completions_to_set metadata = { - 'tables': { - 'public': { - 'users': ['id', 'email', 'first_name', 'last_name'], - 'orders': ['id', 'ordered_date', 'status', 'datestamp'], - 'select': ['id', 'localtime', 'ABC'] + "tables": { + "public": { + "users": ["id", "email", "first_name", "last_name"], + "orders": ["id", "ordered_date", "status", "datestamp"], + "select": ["id", "localtime", "ABC"], }, - 'custom': { - 'users': ['id', 'phone_number'], - 'Users': ['userid', 'username'], - 'products': ['id', 'product_name', 'price'], - 'shipments': ['id', 'address', 'user_id'] + "custom": { + "users": ["id", "phone_number"], + "Users": ["userid", "username"], + "products": ["id", "product_name", "price"], + "shipments": ["id", "address", "user_id"], }, - 'Custom': { - 'projects': ['projectid', 'name'] + "Custom": {"projects": ["projectid", "name"]}, + "blog": { + "entries": ["entryid", "entrytitle", "entrytext"], + "tags": ["tagid", "name"], + "entrytags": ["entryid", "tagid"], + "entacclog": ["entryid", "username", "datestamp"], }, - 'blog': { - 'entries': ['entryid', 'entrytitle', 'entrytext'], - 'tags': ['tagid', 'name'], - 'entrytags': ['entryid', 'tagid'], - 'entacclog': ['entryid', 'username', 'datestamp'], - } }, - 'functions': { - 'public': [ - ['func1', [], [], [], '', False, False, False, False], - ['func2', [], [], [], '', False, False, False, False]], - 'custom': [ - ['func3', [], [], [], '', False, False, False, False], - ['set_returning_func', ['x'], ['integer'], ['o'], - 'integer', False, False, True, False]], - 'Custom': [ - ['func4', [], [], [], '', False, False, False, False]], - 'blog': [ - ['extract_entry_symbols', ['_entryid', 'symbol'], - ['integer', 'text'], ['i', 'o'], '', False, False, True, False], - ['enter_entry', ['_title', '_text', 'entryid'], - ['text', 'text', 'integer'], ['i', 'i', 'o'], - '', False, False, False, False]], - }, - 'datatypes': { - 'public': ['typ1', 'typ2'], - 'custom': ['typ3', 'typ4'], - }, - 'foreignkeys': { - 'custom': [ - ('public', 'users', 'id', 'custom', 'shipments', 'user_id') + "functions": { + "public": [ + ["func1", [], [], [], "", False, False, False, False], + ["func2", [], [], [], "", False, False, False, False], ], - 'blog': [ - ('blog', 'entries', 'entryid', 'blog', 'entacclog', 'entryid'), - ('blog', 'entries', 'entryid', 'blog', 'entrytags', 'entryid'), - ('blog', 'tags', 'tagid', 'blog', 'entrytags', 'tagid'), + "custom": [ + ["func3", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x"], + ["integer"], + ["o"], + "integer", + False, + False, + True, + False, + ], + ], + "Custom": [["func4", [], [], [], "", False, False, False, False]], + "blog": [ + [ + "extract_entry_symbols", + ["_entryid", "symbol"], + ["integer", "text"], + ["i", "o"], + "", + False, + False, + True, + False, + ], + [ + "enter_entry", + ["_title", "_text", "entryid"], + ["text", "text", "integer"], + ["i", "i", "o"], + "", + False, + False, + False, + False, + ], ], }, - 'defaults': { - 'public': { - ('orders', 'id'): "nextval('orders_id_seq'::regclass)", - ('orders', 'datestamp'): "now()", - ('orders', 'status'): "'PENDING'::text", + "datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]}, + "foreignkeys": { + "custom": [("public", "users", "id", "custom", "shipments", "user_id")], + "blog": [ + ("blog", "entries", "entryid", "blog", "entacclog", "entryid"), + ("blog", "entries", "entryid", "blog", "entrytags", "entryid"), + ("blog", "tags", "tagid", "blog", "entrytags", "tagid"), + ], + }, + "defaults": { + "public": { + ("orders", "id"): "nextval('orders_id_seq'::regclass)", + ("orders", "datestamp"): "now()", + ("orders", "status"): "'PENDING'::text", } }, } testdata = MetaData(metadata) -cased_schemas = [schema(x) for x in ('public', 'blog', 'CUSTOM', '"Custom"')] -casing = ('SELECT', 'Orders', 'User_Emails', 'CUSTOM', 'Func1', 'Entries', - 'Tags', 'EntryTags', 'EntAccLog', - 'EntryID', 'EntryTitle', 'EntryText') +cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')] +casing = ( + "SELECT", + "Orders", + "User_Emails", + "CUSTOM", + "Func1", + "Entries", + "Tags", + "EntryTags", + "EntAccLog", + "EntryID", + "EntryTitle", + "EntryText", +) completers = testdata.get_completers(casing) -@parametrize('completer', completers(filtr=True, casing=False, qualify=no_qual)) -@parametrize('table', ['users', '"users"']) -def test_suggested_column_names_from_shadowed_visible_table(completer, table) : - result = get_result(completer, 'SELECT FROM ' + table, len('SELECT ')) +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("table", ["users", '"users"']) +def test_suggested_column_names_from_shadowed_visible_table(completer, table): + result = get_result(completer, "SELECT FROM " + table, len("SELECT ")) assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords('users')) + testdata.columns_functions_and_keywords("users") + ) -@parametrize('completer', completers(filtr=True, casing=False, qualify=no_qual)) -@parametrize('text', [ - 'SELECT from custom.users', - 'WITH users as (SELECT 1 AS foo) SELECT from custom.users', -]) +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT from custom.users", + "WITH users as (SELECT 1 AS foo) SELECT from custom.users", + ], +) def test_suggested_column_names_from_qualified_shadowed_table(completer, text): - result = get_result(completer, text, position=text.find(' ') + 1) - assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords( - 'users', 'custom' - )) + result = get_result(completer, text, position=text.find(" ") + 1) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users", "custom") + ) -@parametrize('completer', completers(filtr=True, casing=False, qualify=no_qual)) -@parametrize('text', ['WITH users as (SELECT 1 AS foo) SELECT from users',]) +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"]) def test_suggested_column_names_from_cte(completer, text): - result = completions_to_set(get_result( - completer, text, text.find(' ') + 1)) + result = completions_to_set(get_result(completer, text, text.find(" ") + 1)) assert result == completions_to_set( - [column('foo')] + testdata.functions_and_keywords()) + [column("foo")] + testdata.functions_and_keywords() + ) -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT * FROM users JOIN custom.shipments ON ', - '''SELECT * +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users JOIN custom.shipments ON ", + """SELECT * FROM public.users - JOIN custom.shipments ON ''' -]) + JOIN custom.shipments ON """, + ], +) def test_suggested_join_conditions(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set([ - alias('users'), - alias('shipments'), - name_join('shipments.id = users.id'), - fk_join('shipments.user_id = users.id')]) + assert completions_to_set(result) == completions_to_set( + [ + alias("users"), + alias("shipments"), + name_join("shipments.id = users.id"), + fk_join("shipments.user_id = users.id"), + ] + ) -@parametrize('completer', completers(filtr=True, casing=False, aliasing=False)) -@parametrize(('query', 'tbl'), itertools.product(( - 'SELECT * FROM public.{0} RIGHT OUTER JOIN ', - '''SELECT * +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize( + ("query", "tbl"), + itertools.product( + ( + "SELECT * FROM public.{0} RIGHT OUTER JOIN ", + """SELECT * FROM {0} - JOIN ''' -), ('users', '"users"', 'Users'))) + JOIN """, + ), + ("users", '"users"', "Users"), + ), +) def test_suggested_joins(completer, query, tbl): result = get_result(completer, query.format(tbl)) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() + - [join('custom.shipments ON shipments.user_id = {0}.id'.format(tbl))] + testdata.schemas_and_from_clause_items() + + [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))] ) -@parametrize('completer', completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) def test_suggested_column_names_from_schema_qualifed_table(completer): - result = get_result( - completer, 'SELECT from custom.products', len('SELECT ') + result = get_result(completer, "SELECT from custom.products", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") ) - assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords( - 'products', 'custom' - )) -@parametrize('text', [ - 'INSERT INTO orders(', - 'INSERT INTO orders (', - 'INSERT INTO public.orders(', - 'INSERT INTO public.orders (' -]) -@parametrize('completer', completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "INSERT INTO orders(", + "INSERT INTO orders (", + "INSERT INTO public.orders(", + "INSERT INTO public.orders (", + ], +) +@parametrize("completer", completers(filtr=True, casing=False)) def test_suggested_columns_with_insert(completer, text): assert completions_to_set(get_result(completer, text)) == completions_to_set( - testdata.columns('orders')) + testdata.columns("orders") + ) -@parametrize('completer', completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) def test_suggested_column_names_in_function(completer): result = get_result( - completer, 'SELECT MAX( from custom.products', len('SELECT MAX(') + completer, "SELECT MAX( from custom.products", len("SELECT MAX(") ) assert completions_to_set(result) == completions_to_set( - testdata.columns('products', 'custom')) + testdata.columns("products", "custom") + ) -@parametrize('completer', completers(casing=False, aliasing=False)) -@parametrize('text', [ - 'SELECT * FROM Custom.', - 'SELECT * FROM custom.', - 'SELECT * FROM "custom".', -]) -@parametrize('use_leading_double_quote', [False, True]) +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + ["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'], +) +@parametrize("use_leading_double_quote", [False, True]) def test_suggested_table_names_with_schema_dot( completer, text, use_leading_double_quote ): @@ -186,14 +250,13 @@ def test_suggested_table_names_with_schema_dot( result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.from_clause_items('custom', start_position)) + testdata.from_clause_items("custom", start_position) + ) -@parametrize('completer', completers(casing=False, aliasing=False)) -@parametrize('text', [ - 'SELECT * FROM "Custom".', -]) -@parametrize('use_leading_double_quote', [False, True]) +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ['SELECT * FROM "Custom".']) +@parametrize("use_leading_double_quote", [False, True]) def test_suggested_table_names_with_schema_dot2( completer, text, use_leading_double_quote ): @@ -205,137 +268,153 @@ def test_suggested_table_names_with_schema_dot2( result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.from_clause_items('Custom', start_position)) - - -@parametrize('completer', completers(filtr=True, casing=False)) -def test_suggested_column_names_with_qualified_alias(completer): - result = get_result( - completer, 'SELECT p. from custom.products p', len('SELECT p.') + testdata.from_clause_items("Custom", start_position) ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_column_names_with_qualified_alias(completer): + result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p.")) assert completions_to_set(result) == completions_to_set( - testdata.columns('products', 'custom')) + testdata.columns("products", "custom") + ) -@parametrize('completer', completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) def test_suggested_multiple_column_names(completer): result = get_result( - completer, 'SELECT id, from custom.products', len('SELECT id, ') + completer, "SELECT id, from custom.products", len("SELECT id, ") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") ) - assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords( - 'products', 'custom' - )) -@parametrize('completer', completers(filtr=True, casing=False)) +@parametrize("completer", completers(filtr=True, casing=False)) def test_suggested_multiple_column_names_with_alias(completer): result = get_result( - completer, - 'SELECT p.id, p. from custom.products p', - len('SELECT u.id, u.') + completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.") ) assert completions_to_set(result) == completions_to_set( - testdata.columns('products', 'custom')) + testdata.columns("products", "custom") + ) -@parametrize('completer', completers(filtr=True, casing=False)) -@parametrize('text', [ - 'SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ', - 'SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id' -]) +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ", + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id", + ], +) def test_suggestions_after_on(completer, text): - position = len('SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ') + position = len( + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON " + ) result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set([ - alias('x'), - alias('y'), - name_join('y.price = x.price'), - name_join('y.product_name = x.product_name'), - name_join('y.id = x.id')]) + assert completions_to_set(result) == completions_to_set( + [ + alias("x"), + alias("y"), + name_join("y.price = x.price"), + name_join("y.product_name = x.product_name"), + name_join("y.id = x.id"), + ] + ) -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_suggested_aliases_after_on_right_side(completer): - text = 'SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = ' + text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = " result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [alias('x'), alias('y')]) + assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")]) -@parametrize('completer', completers(filtr=True, casing=False, aliasing=False)) +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) def test_table_names_after_from(completer): - text = 'SELECT * FROM ' + text = "SELECT * FROM " result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items()) + testdata.schemas_and_from_clause_items() + ) -@parametrize('completer', completers(filtr=True, casing=False)) +@parametrize("completer", completers(filtr=True, casing=False)) def test_schema_qualified_function_name(completer): - text = 'SELECT custom.func' + text = "SELECT custom.func" result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set([ - function('func3()', -len('func')), - function('set_returning_func()', -len('func'))]) + assert completions_to_set(result) == completions_to_set( + [ + function("func3()", -len("func")), + function("set_returning_func()", -len("func")), + ] + ) -@parametrize('completer', completers(filtr=True, casing=False)) -@parametrize('text', [ - 'SELECT 1::custom.', - 'CREATE TABLE foo (bar custom.', - 'CREATE FUNCTION foo (bar INT, baz custom.', - 'ALTER TABLE foo ALTER COLUMN bar TYPE custom.', -]) +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT 1::custom.", + "CREATE TABLE foo (bar custom.", + "CREATE FUNCTION foo (bar INT, baz custom.", + "ALTER TABLE foo ALTER COLUMN bar TYPE custom.", + ], +) def test_schema_qualified_type_name(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.types('custom')) + assert completions_to_set(result) == completions_to_set(testdata.types("custom")) -@parametrize('completer', completers(filtr=True, casing=False)) +@parametrize("completer", completers(filtr=True, casing=False)) def test_suggest_columns_from_aliased_set_returning_function(completer): result = get_result( - completer, - 'select f. from custom.set_returning_func() f', - len('select f.') + completer, "select f. from custom.set_returning_func() f", len("select f.") ) assert completions_to_set(result) == completions_to_set( - testdata.columns('set_returning_func', 'custom', 'functions')) + testdata.columns("set_returning_func", "custom", "functions") + ) -@parametrize('completer',completers(filtr=True, casing=False, qualify=no_qual)) -@parametrize('text', [ - 'SELECT * FROM custom.set_returning_func()', - 'SELECT * FROM Custom.set_returning_func()', - 'SELECT * FROM Custom.Set_Returning_Func()' -]) +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM custom.set_returning_func()", + "SELECT * FROM Custom.set_returning_func()", + "SELECT * FROM Custom.Set_Returning_Func()", + ], +) def test_wildcard_column_expansion_with_function(completer, text): - position = len('SELECT *') + position = len("SELECT *") completions = get_result(completer, text, position) - col_list = 'x' + col_list = "x" expected = [wildcard_expansion(col_list)] assert expected == completions -@parametrize('completer', completers(filtr=True, casing=False)) +@parametrize("completer", completers(filtr=True, casing=False)) def test_wildcard_column_expansion_with_alias_qualifier(completer): - text = 'SELECT p.* FROM custom.products p' - position = len('SELECT p.*') + text = "SELECT p.* FROM custom.products p" + position = len("SELECT p.*") completions = get_result(completer, text, position) - col_list = 'id, p.product_name, p.price' + col_list = "id, p.product_name, p.price" expected = [wildcard_expansion(col_list)] assert expected == completions -@parametrize('completer', completers(filtr=True, casing=False)) -@parametrize('text', [ - ''' +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + """ SELECT count(1) FROM users; CREATE FUNCTION foo(custom.products _products) returns custom.shipments LANGUAGE SQL @@ -345,33 +424,34 @@ def test_wildcard_column_expansion_with_alias_qualifier(completer): SELECT 2 FROM custom.users; $foo$; SELECT count(1) FROM custom.shipments; - ''', - 'INSERT INTO public.orders(*', - 'INSERT INTO public.Orders(*', - 'INSERT INTO public.orders (*', - 'INSERT INTO public.Orders (*', - 'INSERT INTO orders(*', - 'INSERT INTO Orders(*', - 'INSERT INTO orders (*', - 'INSERT INTO Orders (*', - 'INSERT INTO public.orders(*)', - 'INSERT INTO public.Orders(*)', - 'INSERT INTO public.orders (*)', - 'INSERT INTO public.Orders (*)', - 'INSERT INTO orders(*)', - 'INSERT INTO Orders(*)', - 'INSERT INTO orders (*)', - 'INSERT INTO Orders (*)' -]) + """, + "INSERT INTO public.orders(*", + "INSERT INTO public.Orders(*", + "INSERT INTO public.orders (*", + "INSERT INTO public.Orders (*", + "INSERT INTO orders(*", + "INSERT INTO Orders(*", + "INSERT INTO orders (*", + "INSERT INTO Orders (*", + "INSERT INTO public.orders(*)", + "INSERT INTO public.Orders(*)", + "INSERT INTO public.orders (*)", + "INSERT INTO public.Orders (*)", + "INSERT INTO orders(*)", + "INSERT INTO Orders(*)", + "INSERT INTO orders (*)", + "INSERT INTO Orders (*)", + ], +) def test_wildcard_column_expansion_with_insert(completer, text): - position = text.index('*') + 1 + position = text.index("*") + 1 completions = get_result(completer, text, position) - expected = [wildcard_expansion('ordered_date, status')] + expected = [wildcard_expansion("ordered_date, status")] assert expected == completions -@parametrize('completer', completers(filtr=True, casing=False)) +@parametrize("completer", completers(filtr=True, casing=False)) def test_wildcard_column_expansion_with_table_qualifier(completer): text = 'SELECT "select".* FROM public."select"' position = len('SELECT "select".*') @@ -384,20 +464,22 @@ def test_wildcard_column_expansion_with_table_qualifier(completer): assert expected == completions -@parametrize('completer',completers(filtr=True, casing=False, qualify=qual)) +@parametrize("completer", completers(filtr=True, casing=False, qualify=qual)) def test_wildcard_column_expansion_with_two_tables(completer): text = 'SELECT * FROM public."select" JOIN custom.users ON true' - position = len('SELECT *') + position = len("SELECT *") completions = get_result(completer, text, position) - cols = ('"select".id, "select"."localtime", "select"."ABC", ' - 'users.id, users.phone_number') + cols = ( + '"select".id, "select"."localtime", "select"."ABC", ' + "users.id, users.phone_number" + ) expected = [wildcard_expansion(cols)] assert completions == expected -@parametrize('completer', completers(filtr=True, casing=False)) +@parametrize("completer", completers(filtr=True, casing=False)) def test_wildcard_column_expansion_with_two_tables_and_parent(completer): text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true' position = len('SELECT "select".*') @@ -410,221 +492,238 @@ def test_wildcard_column_expansion_with_two_tables_and_parent(completer): assert expected == completions -@parametrize('completer', completers(filtr=True, casing=False)) -@parametrize('text', [ - 'SELECT U. FROM custom.Users U', - 'SELECT U. FROM custom.USERS U', - 'SELECT U. FROM custom.users U', - 'SELECT U. FROM "custom".Users U', - 'SELECT U. FROM "custom".USERS U', - 'SELECT U. FROM "custom".users U' -]) +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT U. FROM custom.Users U", + "SELECT U. FROM custom.USERS U", + "SELECT U. FROM custom.users U", + 'SELECT U. FROM "custom".Users U', + 'SELECT U. FROM "custom".USERS U', + 'SELECT U. FROM "custom".users U', + ], +) def test_suggest_columns_from_unquoted_table(completer, text): - position = len('SELECT U.') + position = len("SELECT U.") result = get_result(completer, text, position) assert completions_to_set(result) == completions_to_set( - testdata.columns('users', 'custom')) + testdata.columns("users", "custom") + ) -@parametrize('completer', completers(filtr=True, casing=False)) -@parametrize('text', [ - 'SELECT U. FROM custom."Users" U', - 'SELECT U. FROM "custom"."Users" U' -]) +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U'] +) def test_suggest_columns_from_quoted_table(completer, text): - position = len('SELECT U.') + position = len("SELECT U.") result = get_result(completer, text, position) assert completions_to_set(result) == completions_to_set( - testdata.columns('Users', 'custom')) - -texts = ['SELECT * FROM ', 'SELECT * FROM public.Orders O CROSS JOIN '] + testdata.columns("Users", "custom") + ) -@parametrize('completer', completers(filtr=True, casing=False, aliasing=False)) -@parametrize('text', texts) +texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "] + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize("text", texts) def test_schema_or_visible_table_completion(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items()) + testdata.schemas_and_from_clause_items() + ) -@parametrize('completer', completers(aliasing=True, casing=False, filtr=True)) -@parametrize('text', texts) +@parametrize("completer", completers(aliasing=True, casing=False, filtr=True)) +@parametrize("text", texts) def test_table_aliases(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set(testdata.schemas() + [ - table('users u'), - table('orders o' if text == 'SELECT * FROM ' else 'orders o2'), - table('"select" s'), - function('func1() f'), - function('func2() f')]) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("users u"), + table("orders o" if text == "SELECT * FROM " else "orders o2"), + table('"select" s'), + function("func1() f"), + function("func2() f"), + ] + ) -@parametrize('completer', completers(aliasing=True, casing=True, filtr=True)) -@parametrize('text', texts) +@parametrize("completer", completers(aliasing=True, casing=True, filtr=True)) +@parametrize("text", texts) def test_aliases_with_casing(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set(cased_schemas + [ - table('users u'), - table('Orders O' if text == 'SELECT * FROM ' else 'Orders O2'), - table('"select" s'), - function('Func1() F'), - function('func2() f')]) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users u"), + table("Orders O" if text == "SELECT * FROM " else "Orders O2"), + table('"select" s'), + function("Func1() F"), + function("func2() f"), + ] + ) -@parametrize('completer', completers(aliasing=False, casing=True, filtr=True)) -@parametrize('text', texts) +@parametrize("completer", completers(aliasing=False, casing=True, filtr=True)) +@parametrize("text", texts) def test_table_casing(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set(cased_schemas + [ - table('users'), - table('Orders'), - table('"select"'), - function('Func1()'), - function('func2()')]) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users"), + table("Orders"), + table('"select"'), + function("Func1()"), + function("func2()"), + ] + ) -@parametrize('completer', completers(aliasing=False, casing=True)) +@parametrize("completer", completers(aliasing=False, casing=True)) def test_alias_search_without_aliases2(completer): - text = 'SELECT * FROM blog.et' + text = "SELECT * FROM blog.et" result = get_result(completer, text) - assert result[0] == table('EntryTags', -2) + assert result[0] == table("EntryTags", -2) -@parametrize('completer', completers(aliasing=False, casing=True)) +@parametrize("completer", completers(aliasing=False, casing=True)) def test_alias_search_without_aliases1(completer): - text = 'SELECT * FROM blog.e' + text = "SELECT * FROM blog.e" result = get_result(completer, text) - assert result[0] == table('Entries', -1) + assert result[0] == table("Entries", -1) -@parametrize('completer', completers(aliasing=True, casing=True)) +@parametrize("completer", completers(aliasing=True, casing=True)) def test_alias_search_with_aliases2(completer): - text = 'SELECT * FROM blog.et' + text = "SELECT * FROM blog.et" result = get_result(completer, text) - assert result[0] == table('EntryTags ET', -2) + assert result[0] == table("EntryTags ET", -2) -@parametrize('completer', completers(aliasing=True, casing=True)) +@parametrize("completer", completers(aliasing=True, casing=True)) def test_alias_search_with_aliases1(completer): - text = 'SELECT * FROM blog.e' + text = "SELECT * FROM blog.e" result = get_result(completer, text) - assert result[0] == table('Entries E', -1) + assert result[0] == table("Entries E", -1) -@parametrize('completer', completers(aliasing=True, casing=True)) +@parametrize("completer", completers(aliasing=True, casing=True)) def test_join_alias_search_with_aliases1(completer): - text = 'SELECT * FROM blog.Entries E JOIN blog.e' + text = "SELECT * FROM blog.Entries E JOIN blog.e" result = get_result(completer, text) - assert result[:2] == [table('Entries E2', -1), join( - 'EntAccLog EAL ON EAL.EntryID = E.EntryID', -1)] - - -@parametrize('completer', completers(aliasing=False, casing=True)) -def test_join_alias_search_without_aliases1(completer): - text = 'SELECT * FROM blog.Entries JOIN blog.e' - result = get_result(completer, text) - assert result[:2] == [table('Entries', -1), join( - 'EntAccLog ON EntAccLog.EntryID = Entries.EntryID', -1)] - - -@parametrize('completer', completers(aliasing=True, casing=True)) -def test_join_alias_search_with_aliases2(completer): - text = 'SELECT * FROM blog.Entries E JOIN blog.et' - result = get_result(completer, text) - assert result[0] == join('EntryTags ET ON ET.EntryID = E.EntryID', -2) - - -@parametrize('completer', completers(aliasing=False, casing=True)) -def test_join_alias_search_without_aliases2(completer): - text = 'SELECT * FROM blog.Entries JOIN blog.et' - result = get_result(completer, text) - assert result[0] == join( - 'EntryTags ON EntryTags.EntryID = Entries.EntryID', -2) - - -@parametrize('completer', completers()) -def test_function_alias_search_without_aliases(completer): - text = 'SELECT blog.ees' - result = get_result(completer, text) - first = result[0] - assert first.start_position == -3 - assert first.text == 'extract_entry_symbols()' - assert first.display_text == 'extract_entry_symbols(_entryid)' - - -@parametrize('completer', completers()) -def test_function_alias_search_with_aliases(completer): - text = 'SELECT blog.ee' - result = get_result(completer, text) - first = result[0] - assert first.start_position == -2 - assert first.text == 'enter_entry(_title := , _text := )' - assert first.display_text == 'enter_entry(_title, _text)' - - -@parametrize('completer',completers(filtr=True, casing=True, qualify=no_qual)) -def test_column_alias_search(completer): - result = get_result( - completer, 'SELECT et FROM blog.Entries E', len('SELECT et') - ) - cols = ('EntryText', 'EntryTitle', 'EntryID') - assert result[:3] == [column(c, -2) for c in cols] - - -@parametrize('completer', completers(casing=True)) -def test_column_alias_search_qualified(completer): - result = get_result( - completer, 'SELECT E.ei FROM blog.Entries E', len('SELECT E.ei') - ) - cols = ('EntryID', 'EntryTitle') - assert result[:3] == [column(c, -2) for c in cols] - - -@parametrize('completer', completers(casing=False, filtr=False, aliasing=False)) -def test_schema_object_order(completer): - result = get_result(completer, 'SELECT * FROM u') - assert result[:3] == [ - table(t, pos=-1) for t in ('users', 'custom."Users"', 'custom.users') + assert result[:2] == [ + table("Entries E2", -1), + join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1), ] -@parametrize('completer', completers(casing=False, filtr=False, aliasing=False)) +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases1(completer): + text = "SELECT * FROM blog.Entries JOIN blog.e" + result = get_result(completer, text) + assert result[:2] == [ + table("Entries", -1), + join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1), + ] + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_join_alias_search_with_aliases2(completer): + text = "SELECT * FROM blog.Entries E JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases2(completer): + text = "SELECT * FROM blog.Entries JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2) + + +@parametrize("completer", completers()) +def test_function_alias_search_without_aliases(completer): + text = "SELECT blog.ees" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -3 + assert first.text == "extract_entry_symbols()" + assert first.display_text == "extract_entry_symbols(_entryid)" + + +@parametrize("completer", completers()) +def test_function_alias_search_with_aliases(completer): + text = "SELECT blog.ee" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -2 + assert first.text == "enter_entry(_title := , _text := )" + assert first.display_text == "enter_entry(_title, _text)" + + +@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual)) +def test_column_alias_search(completer): + result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et")) + cols = ("EntryText", "EntryTitle", "EntryID") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=True)) +def test_column_alias_search_qualified(completer): + result = get_result( + completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei") + ) + cols = ("EntryID", "EntryTitle") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) +def test_schema_object_order(completer): + result = get_result(completer, "SELECT * FROM u") + assert result[:3] == [ + table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users") + ] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) def test_all_schema_objects(completer): - text = ('SELECT * FROM ') + text = "SELECT * FROM " result = get_result(completer, text) assert completions_to_set(result) >= completions_to_set( - [table(x) for x in ('orders', '"select"', 'custom.shipments')] - + [function(x + '()') for x in ('func2',)] + [table(x) for x in ("orders", '"select"', "custom.shipments")] + + [function(x + "()") for x in ("func2",)] ) -@parametrize('completer', completers(filtr=False, aliasing=False, casing=True)) +@parametrize("completer", completers(filtr=False, aliasing=False, casing=True)) def test_all_schema_objects_with_casing(completer): - text = 'SELECT * FROM ' + text = "SELECT * FROM " result = get_result(completer, text) assert completions_to_set(result) >= completions_to_set( - [table(x) for x in ('Orders', '"select"', 'CUSTOM.shipments')] - + [function(x + '()') for x in ('func2',)] + [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")] + + [function(x + "()") for x in ("func2",)] ) -@parametrize('completer', completers(casing=False, filtr=False, aliasing=True)) +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) def test_all_schema_objects_with_aliases(completer): - text = ('SELECT * FROM ') + text = "SELECT * FROM " result = get_result(completer, text) assert completions_to_set(result) >= completions_to_set( - [table(x) for x in ('orders o', '"select" s', 'custom.shipments s')] - + [function(x) for x in ('func2() f',)] + [table(x) for x in ("orders o", '"select" s', "custom.shipments s")] + + [function(x) for x in ("func2() f",)] ) -@parametrize('completer', completers(casing=False, filtr=False, aliasing=True)) +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) def test_set_schema(completer): - text = ('SET SCHEMA ') + text = "SET SCHEMA " result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set([ - schema(u"'blog'"), - schema(u"'Custom'"), - schema(u"'custom'"), - schema(u"'public'")]) + assert completions_to_set(result) == completions_to_set( + [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")] + ) diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 9d503245..63f1e80e 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -1,304 +1,354 @@ from __future__ import unicode_literals, print_function -from metadata import (MetaData, alias, name_join, fk_join, join, keyword, - schema, table, view, function, column, wildcard_expansion, - get_result, result_set, qual, no_qual, parametrize) +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + keyword, + schema, + table, + view, + function, + column, + wildcard_expansion, + get_result, + result_set, + qual, + no_qual, + parametrize, +) from prompt_toolkit.completion import Completion from utils import completions_to_set - metadata = { - 'tables': { - 'users': ['id', 'parentid', 'email', 'first_name', 'last_name'], - 'Users': ['userid', 'username'], - 'orders': ['id', 'ordered_date', 'status', 'email'], - 'select': ['id', 'insert', 'ABC']}, - 'views': { - 'user_emails': ['id', 'email'], - 'functions': ['function'], + "tables": { + "users": ["id", "parentid", "email", "first_name", "last_name"], + "Users": ["userid", "username"], + "orders": ["id", "ordered_date", "status", "email"], + "select": ["id", "insert", "ABC"], }, - 'functions': [ - ['custom_fun', [], [], [], '', False, False, False, False], - ['_custom_fun', [], [], [], '', False, False, False, False], - ['custom_func1', [], [], [], '', False, False, False, False], - ['custom_func2', [], [], [], '', False, False, False, False], - ['set_returning_func', ['x', 'y'], ['integer', 'integer'], - ['b', 'b'], '', False, False, True, False]], - 'datatypes': ['custom_type1', 'custom_type2'], - 'foreignkeys': [ - ('public', 'users', 'id', 'public', 'users', 'parentid'), - ('public', 'users', 'id', 'public', 'Users', 'userid') + "views": {"user_emails": ["id", "email"], "functions": ["function"]}, + "functions": [ + ["custom_fun", [], [], [], "", False, False, False, False], + ["_custom_fun", [], [], [], "", False, False, False, False], + ["custom_func1", [], [], [], "", False, False, False, False], + ["custom_func2", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x", "y"], + ["integer", "integer"], + ["b", "b"], + "", + False, + False, + True, + False, + ], + ], + "datatypes": ["custom_type1", "custom_type2"], + "foreignkeys": [ + ("public", "users", "id", "public", "users", "parentid"), + ("public", "users", "id", "public", "Users", "userid"), ], } -metadata = dict((k, {'public': v}) for k, v in metadata.items()) +metadata = dict((k, {"public": v}) for k, v in metadata.items()) testdata = MetaData(metadata) -cased_users_col_names = ['ID', 'PARENTID', 'Email', 'First_Name', 'last_name'] -cased_users2_col_names = ['UserID', 'UserName'] +cased_users_col_names = ["ID", "PARENTID", "Email", "First_Name", "last_name"] +cased_users2_col_names = ["UserID", "UserName"] cased_func_names = [ - 'Custom_Fun', '_custom_fun', 'Custom_Func1', 'custom_func2', 'set_returning_func' + "Custom_Fun", + "_custom_fun", + "Custom_Func1", + "custom_func2", + "set_returning_func", ] -cased_tbls = ['Users', 'Orders'] -cased_views = ['User_Emails', 'Functions'] +cased_tbls = ["Users", "Orders"] +cased_views = ["User_Emails", "Functions"] casing = ( - ['SELECT', 'PUBLIC'] + cased_func_names + cased_tbls + cased_views - + cased_users_col_names + cased_users2_col_names + ["SELECT", "PUBLIC"] + + cased_func_names + + cased_tbls + + cased_views + + cased_users_col_names + + cased_users2_col_names ) # Lists for use in assertions cased_funcs = [ - function(f) for f in ('Custom_Fun()', '_custom_fun()', 'Custom_Func1()', 'custom_func2()') -] + [function('set_returning_func(x := , y := )', display='set_returning_func(x, y)')] + function(f) + for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()") +] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")] cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])] cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls cased_users_cols = [column(c) for c in cased_users_col_names] -aliased_rels = [ - table(t) for t in ('users u', '"Users" U', 'orders o', '"select" s') -] + [view('user_emails ue'), view('functions f')] + [ - function(f) for f in ( - '_custom_fun() cf', 'custom_fun() cf', 'custom_func1() cf', - 'custom_func2() cf' - ) -] + [function( - 'set_returning_func(x := , y := ) srf', - display='set_returning_func(x, y) srf' -)] -cased_aliased_rels = [ - table(t) for t in ('Users U', '"Users" U', 'Orders O', '"select" s') -] + [view('User_Emails UE'), view('Functions F')] + [ - function(f) for f in ( - '_custom_fun() cf', 'Custom_Fun() CF', 'Custom_Func1() CF', 'custom_func2() cf' - ) -] + [function( - 'set_returning_func(x := , y := ) srf', - display='set_returning_func(x, y) srf' -)] +aliased_rels = ( + [table(t) for t in ("users u", '"Users" U', "orders o", '"select" s')] + + [view("user_emails ue"), view("functions f")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "custom_fun() cf", + "custom_func1() cf", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) +cased_aliased_rels = ( + [table(t) for t in ("Users U", '"Users" U', "Orders O", '"select" s')] + + [view("User_Emails UE"), view("Functions F")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "Custom_Fun() CF", + "Custom_Func1() CF", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) completers = testdata.get_completers(casing) # Just to make sure that this doesn't crash -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_function_column_name(completer): for l in range( - len('SELECT * FROM Functions WHERE function:'), - len('SELECT * FROM Functions WHERE function:text') + 1 + len("SELECT * FROM Functions WHERE function:"), + len("SELECT * FROM Functions WHERE function:text") + 1, ): assert [] == get_result( - completer, 'SELECT * FROM Functions WHERE function:text'[:l] + completer, "SELECT * FROM Functions WHERE function:text"[:l] ) -@parametrize('action', ['ALTER', 'DROP', 'CREATE', 'CREATE OR REPLACE']) -@parametrize('completer', completers()) +@parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"]) +@parametrize("completer", completers()) def test_drop_alter_function(completer, action): - assert get_result(completer, action + ' FUNCTION set_ret') == [ - function('set_returning_func(x integer, y integer)', -len('set_ret')) + assert get_result(completer, action + " FUNCTION set_ret") == [ + function("set_returning_func(x integer, y integer)", -len("set_ret")) ] -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_empty_string_completion(completer): - result = get_result(completer, '') + result = get_result(completer, "") assert completions_to_set( - testdata.keywords() + testdata.specials()) == completions_to_set(result) + testdata.keywords() + testdata.specials() + ) == completions_to_set(result) -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_select_keyword_completion(completer): - result = get_result(completer, 'SEL') - assert completions_to_set(result) == completions_to_set( - [keyword('SELECT', -3)]) + result = get_result(completer, "SEL") + assert completions_to_set(result) == completions_to_set([keyword("SELECT", -3)]) -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_builtin_function_name_completion(completer): - result = get_result(completer, 'SELECT MA') - assert completions_to_set(result) == completions_to_set([ - function('MAX', -2), - keyword('MAXEXTENTS', -2), keyword('MATERIALIZED VIEW', -2) - ]) - - -@parametrize('completer', completers()) -def test_builtin_function_matches_only_at_start(completer): - text = 'SELECT IN' - - result = [c.text for c in get_result(completer, text)] - - assert 'MIN' not in result - - -@parametrize('completer', completers(casing=False, aliasing=False)) -def test_user_function_name_completion(completer): - result = get_result(completer, 'SELECT cu') - assert completions_to_set(result) == completions_to_set([ - function('custom_fun()', -2), - function('_custom_fun()', -2), - function('custom_func1()', -2), - function('custom_func2()', -2), - keyword('CURRENT', -2), - ]) - - -@parametrize('completer', completers(casing=False, aliasing=False)) -def test_user_function_name_completion_matches_anywhere(completer): - result = get_result(completer, 'SELECT om') - assert completions_to_set(result) == completions_to_set([ - function('custom_fun()', -2), - function('_custom_fun()', -2), - function('custom_func1()', -2), - function('custom_func2()', -2)]) - - -@parametrize('completer', completers(casing=True)) -def test_list_functions_for_special(completer): - result = get_result(completer, r'\df ') + result = get_result(completer, "SELECT MA") assert completions_to_set(result) == completions_to_set( - [schema('PUBLIC')] + [function(f) for f in cased_func_names] + [ + function("MAX", -2), + keyword("MAXEXTENTS", -2), + keyword("MATERIALIZED VIEW", -2), + ] ) -@parametrize('completer', completers(casing=False, qualify=no_qual)) -def test_suggested_column_names_from_visible_table(completer): - result = get_result(completer, 'SELECT from users', len('SELECT ')) +@parametrize("completer", completers()) +def test_builtin_function_matches_only_at_start(completer): + text = "SELECT IN" + + result = [c.text for c in get_result(completer, text)] + + assert "MIN" not in result + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion(completer): + result = get_result(completer, "SELECT cu") assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords('users')) + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + keyword("CURRENT", -2), + ] + ) -@parametrize('completer', completers(casing=True, qualify=no_qual)) +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion_matches_anywhere(completer): + result = get_result(completer, "SELECT om") + assert completions_to_set(result) == completions_to_set( + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + ] + ) + + +@parametrize("completer", completers(casing=True)) +def test_list_functions_for_special(completer): + result = get_result(completer, r"\df ") + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + [function(f) for f in cased_func_names] + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_column_names_from_visible_table(completer): + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=True, qualify=no_qual)) def test_suggested_cased_column_names(completer): - result = get_result(completer, 'SELECT from users', len('SELECT ')) - assert completions_to_set(result) == completions_to_set(cased_funcs + cased_users_cols - + testdata.builtin_functions() + testdata.keywords()) + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + cased_funcs + + cased_users_cols + + testdata.builtin_functions() + + testdata.keywords() + ) -@parametrize('completer', completers(casing=False, qualify=no_qual)) -@parametrize('text', [ - 'SELECT from users', - 'INSERT INTO Orders SELECT from users', -]) +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize("text", ["SELECT from users", "INSERT INTO Orders SELECT from users"]) def test_suggested_auto_qualified_column_names(text, completer): - position = text.index(' ') + 1 + position = text.index(" ") + 1 cols = [column(c.lower()) for c in cased_users_col_names] result = get_result(completer, text, position) assert completions_to_set(result) == completions_to_set( - cols + testdata.functions_and_keywords()) + cols + testdata.functions_and_keywords() + ) -@parametrize('completer', completers(casing=False, qualify=qual)) -@parametrize('text', [ - 'SELECT from users U NATURAL JOIN "Users"', - 'INSERT INTO Orders SELECT from users U NATURAL JOIN "Users"', -]) +@parametrize("completer", completers(casing=False, qualify=qual)) +@parametrize( + "text", + [ + 'SELECT from users U NATURAL JOIN "Users"', + 'INSERT INTO Orders SELECT from users U NATURAL JOIN "Users"', + ], +) def test_suggested_auto_qualified_column_names_two_tables(text, completer): - position = text.index(' ') + 1 - cols = [column('U.' + c.lower()) for c in cased_users_col_names] + position = text.index(" ") + 1 + cols = [column("U." + c.lower()) for c in cased_users_col_names] cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names] result = get_result(completer, text, position) assert completions_to_set(result) == completions_to_set( - cols + testdata.functions_and_keywords()) + cols + testdata.functions_and_keywords() + ) -@parametrize('completer', completers(casing=True, qualify=['always'])) -@parametrize('text', [ - 'UPDATE users SET ', - 'INSERT INTO users(', -]) +@parametrize("completer", completers(casing=True, qualify=["always"])) +@parametrize("text", ["UPDATE users SET ", "INSERT INTO users("]) def test_no_column_qualification(text, completer): cols = [column(c) for c in cased_users_col_names] result = get_result(completer, text) assert completions_to_set(result) == completions_to_set(cols) -@parametrize('completer', completers(casing=True, qualify=['always'])) -def test_suggested_cased_always_qualified_column_names( - completer -): - text = 'SELECT from users' - position = len('SELECT ') - cols = [column('users.' + c) for c in cased_users_col_names] +@parametrize("completer", completers(casing=True, qualify=["always"])) +def test_suggested_cased_always_qualified_column_names(completer): + text = "SELECT from users" + position = len("SELECT ") + cols = [column("users." + c) for c in cased_users_col_names] result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set(cased_funcs + cols - + testdata.builtin_functions() + testdata.keywords()) + assert completions_to_set(result) == completions_to_set( + cased_funcs + cols + testdata.builtin_functions() + testdata.keywords() + ) -@parametrize('completer', completers(casing=False, qualify=no_qual)) +@parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggested_column_names_in_function(completer): - result = get_result( - completer, 'SELECT MAX( from users', len('SELECT MAX(') - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns('users')) + result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX(")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggested_column_names_with_table_dot(completer): - result = get_result( - completer, 'SELECT users. from users', len('SELECT users.') - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns('users')) + result = get_result(completer, "SELECT users. from users", len("SELECT users.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggested_column_names_with_alias(completer): - result = get_result(completer, 'SELECT u. from users u', len('SELECT u.')) - assert completions_to_set(result) == completions_to_set( - testdata.columns('users')) + result = get_result(completer, "SELECT u. from users u", len("SELECT u.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) -@parametrize('completer', completers(casing=False, qualify=no_qual)) +@parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggested_multiple_column_names(completer): - result = get_result( - completer, 'SELECT id, from users u', len('SELECT id, ') - ) + result = get_result(completer, "SELECT id, from users u", len("SELECT id, ")) assert completions_to_set(result) == completions_to_set( - (testdata.columns_functions_and_keywords('users'))) + (testdata.columns_functions_and_keywords("users")) + ) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggested_multiple_column_names_with_alias(completer): result = get_result( - completer, 'SELECT u.id, u. from users u', len('SELECT u.id, u.') + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") ) - assert completions_to_set(result) == completions_to_set( - testdata.columns('users')) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) -@parametrize('completer', completers(casing=True)) +@parametrize("completer", completers(casing=True)) def test_suggested_cased_column_names_with_alias(completer): result = get_result( - completer, 'SELECT u.id, u. from users u', len('SELECT u.id, u.') + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") ) assert completions_to_set(result) == completions_to_set(cased_users_cols) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggested_multiple_column_names_with_dot(completer): result = get_result( completer, - 'SELECT users.id, users. from users u', - len('SELECT users.id, users.') + "SELECT users.id, users. from users u", + len("SELECT users.id, users."), ) - assert completions_to_set(result) == completions_to_set( - testdata.columns('users')) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggest_columns_after_three_way_join(completer): - text = '''SELECT * FROM users u1 + text = """SELECT * FROM users u1 INNER JOIN users u2 ON u1.id = u2.id - INNER JOIN users u3 ON u2.id = u3.''' + INNER JOIN users u3 ON u2.id = u3.""" result = get_result(completer, text) - assert (column('id') in result) + assert column("id") in result join_condition_texts = [ 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ', - '''INSERT INTO public.orders(orderid) - SELECT * FROM users U JOIN "Users" U2 ON ''', + """INSERT INTO public.orders(orderid) + SELECT * FROM users U JOIN "Users" U2 ON """, 'SELECT * FROM users U JOIN "Users" U2 ON ', 'SELECT * FROM users U INNER join "Users" U2 ON ', 'SELECT * FROM USERS U right JOIN "Users" U2 ON ', @@ -307,282 +357,336 @@ join_condition_texts = [ 'SELECT * FROM users U right outer join "Users" U2 ON ', 'SELECT * FROM Users U LEFT OUTER JOIN "Users" U2 ON ', 'SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON ', - '''SELECT * + """SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON - ''' + """, ] -@parametrize('completer', completers(casing=False)) -@parametrize('text', join_condition_texts) +@parametrize("completer", completers(casing=False)) +@parametrize("text", join_condition_texts) def test_suggested_join_conditions(completer, text): - result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set([ - alias('U'), alias('U2'), fk_join('U2.userid = U.id') - ]) - - -@parametrize('completer', completers(casing=True)) -@parametrize('text', join_condition_texts) -def test_cased_join_conditions(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [alias('U'), alias('U2'), fk_join('U2.UserID = U.ID')] + [alias("U"), alias("U2"), fk_join("U2.userid = U.id")] ) -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - '''SELECT * +@parametrize("completer", completers(casing=True)) +@parametrize("text", join_condition_texts) +def test_cased_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + """SELECT * FROM users CROSS JOIN "Users" NATURAL JOIN users u JOIN "Users" u2 ON - ''' -]) + """ + ], +) def test_suggested_join_conditions_with_same_table_twice(completer, text): result = get_result(completer, text) assert result == [ - fk_join('u2.userid = u.id'), - fk_join('u2.userid = users.id'), + fk_join("u2.userid = u.id"), + fk_join("u2.userid = users.id"), name_join('u2.userid = "Users".userid'), name_join('u2.username = "Users".username'), - alias('u'), - alias('u2'), - alias('users'), - alias('"Users"') + alias("u"), + alias("u2"), + alias("users"), + alias('"Users"'), ] -@parametrize('completer', completers()) -@parametrize('text', [ - 'SELECT * FROM users JOIN users u2 on foo.' -]) +@parametrize("completer", completers()) +@parametrize("text", ["SELECT * FROM users JOIN users u2 on foo."]) def test_suggested_join_conditions_with_invalid_qualifier(completer, text): result = get_result(completer, text) assert result == [] -@parametrize('completer', completers(casing=False)) -@parametrize(('text', 'ref'), [ - ('SELECT * FROM users JOIN NonTable on ', 'NonTable'), - ('SELECT * FROM users JOIN nontable nt on ', 'nt') -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + ("text", "ref"), + [ + ("SELECT * FROM users JOIN NonTable on ", "NonTable"), + ("SELECT * FROM users JOIN nontable nt on ", "nt"), + ], +) def test_suggested_join_conditions_with_invalid_table(completer, text, ref): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [alias('users'), alias(ref)]) + [alias("users"), alias(ref)] + ) -@parametrize('completer', completers(casing=False, aliasing=False)) -@parametrize('text', [ - 'SELECT * FROM "Users" u JOIN u', - 'SELECT * FROM "Users" u JOIN uid', - 'SELECT * FROM "Users" u JOIN userid', - 'SELECT * FROM "Users" u JOIN id', -]) +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM "Users" u JOIN u', + 'SELECT * FROM "Users" u JOIN uid', + 'SELECT * FROM "Users" u JOIN userid', + 'SELECT * FROM "Users" u JOIN id', + ], +) def test_suggested_joins_fuzzy(completer, text): result = get_result(completer, text) last_word = text.split()[-1] - expected = join('users ON users.id = u.userid', -len(last_word)) + expected = join("users ON users.id = u.userid", -len(last_word)) assert expected in result join_texts = [ - 'SELECT * FROM Users JOIN ', - '''INSERT INTO "Users" + "SELECT * FROM Users JOIN ", + """INSERT INTO "Users" SELECT * FROM Users - INNER JOIN ''', - '''INSERT INTO public."Users"(username) + INNER JOIN """, + """INSERT INTO public."Users"(username) SELECT * FROM Users - INNER JOIN ''', - '''SELECT * + INNER JOIN """, + """SELECT * FROM Users - INNER JOIN ''' + INNER JOIN """, ] -@parametrize('completer', completers(casing=False, aliasing=False)) -@parametrize('text', join_texts) +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", join_texts) def test_suggested_joins(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() + [ - join('"Users" ON "Users".userid = Users.id'), - join('users users2 ON users2.id = Users.parentid'), - join('users users2 ON users2.parentid = Users.id'), + testdata.schemas_and_from_clause_items() + + [ + join('"Users" ON "Users".userid = Users.id'), + join("users users2 ON users2.id = Users.parentid"), + join("users users2 ON users2.parentid = Users.id"), ] ) -@parametrize('completer', completers(casing=True, aliasing=False)) -@parametrize('text', join_texts) +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", join_texts) def test_cased_joins(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set([schema('PUBLIC')] + cased_rels + [ - join('"Users" ON "Users".UserID = Users.ID'), - join('Users Users2 ON Users2.ID = Users.PARENTID'), - join('Users Users2 ON Users2.PARENTID = Users.ID'), - ]) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + + cased_rels + + [ + join('"Users" ON "Users".UserID = Users.ID'), + join("Users Users2 ON Users2.ID = Users.PARENTID"), + join("Users Users2 ON Users2.PARENTID = Users.ID"), + ] + ) -@parametrize('completer', completers(casing=False, aliasing=True)) -@parametrize('text', join_texts) +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", join_texts) def test_aliased_joins(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set(testdata.schemas() + aliased_rels + [ - join('"Users" U ON U.userid = Users.id'), - join('users u ON u.id = Users.parentid'), - join('users u ON u.parentid = Users.id'), - ]) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + aliased_rels + + [ + join('"Users" U ON U.userid = Users.id'), + join("users u ON u.id = Users.parentid"), + join("users u ON u.parentid = Users.id"), + ] + ) -@parametrize('completer', completers(casing=False, aliasing=False)) -@parametrize('text', [ - 'SELECT * FROM public."Users" JOIN ', - 'SELECT * FROM public."Users" RIGHT OUTER JOIN ', - '''SELECT * +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM public."Users" JOIN ', + 'SELECT * FROM public."Users" RIGHT OUTER JOIN ', + """SELECT * FROM public."Users" - LEFT JOIN ''' -]) + LEFT JOIN """, + ], +) def test_suggested_joins_quoted_schema_qualified_table(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() + - [join('public.users ON users.id = "Users".userid')] + testdata.schemas_and_from_clause_items() + + [join('public.users ON users.id = "Users".userid')] ) -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT u.name, o.id FROM users u JOIN orders o ON ', - 'SELECT u.name, o.id FROM users u JOIN orders o ON JOIN orders o2 ON' -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON ", + "SELECT u.name, o.id FROM users u JOIN orders o ON JOIN orders o2 ON", + ], +) def test_suggested_aliases_after_on(completer, text): - position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ') + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set([ - alias('u'), - name_join('o.id = u.id'), - name_join('o.email = u.email'), - alias('o')]) + assert completions_to_set(result) == completions_to_set( + [ + alias("u"), + name_join("o.id = u.id"), + name_join("o.email = u.email"), + alias("o"), + ] + ) -@parametrize('completer', completers()) -@parametrize('text', [ - 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ', - 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = JOIN orders o2 ON' -]) +@parametrize("completer", completers()) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ", + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = JOIN orders o2 ON", + ], +) def test_suggested_aliases_after_on_right_side(completer, text): + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set([alias("u"), alias("o")]) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON ", + "SELECT users.name, orders.id FROM users JOIN orders ON JOIN orders orders2 ON", + ], +) +def test_suggested_tables_after_on(completer, text): + position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + name_join("orders.id = users.id"), + name_join("orders.email = users.email"), + alias("users"), + alias("orders"), + ] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = JOIN orders orders2 ON", + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ", + ], +) +def test_suggested_tables_after_on_right_side(completer, text): position = len( - 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ' + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " ) result = get_result(completer, text, position) assert completions_to_set(result) == completions_to_set( - [alias('u'), alias('o')]) + [alias("users"), alias("orders")] + ) -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT users.name, orders.id FROM users JOIN orders ON ', - 'SELECT users.name, orders.id FROM users JOIN orders ON JOIN orders orders2 ON' -]) -def test_suggested_tables_after_on(completer, text): - position = len('SELECT users.name, orders.id FROM users JOIN orders ON ') - result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set([ - name_join('orders.id = users.id'), - name_join('orders.email = users.email'), - alias('users'), - alias('orders') - ]) - - -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = JOIN orders orders2 ON', - 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ' -]) -def test_suggested_tables_after_on_right_side(completer, text): - position = len('SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ') - result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - [alias('users'), alias('orders')]) - - -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT * FROM users INNER JOIN orders USING (', - 'SELECT * FROM users INNER JOIN orders USING(', -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (", + "SELECT * FROM users INNER JOIN orders USING(", + ], +) def test_join_using_suggests_common_columns(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [column('id'), column('email')]) + [column("id"), column("email")] + ) -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT * FROM users u1 JOIN users u2 USING (email) JOIN user_emails ue USING()', - 'SELECT * FROM users u1 JOIN users u2 USING(email) JOIN user_emails ue USING ()', - 'SELECT * FROM users u1 JOIN user_emails ue USING () JOIN users u2 ue USING(first_name, last_name)', - 'SELECT * FROM users u1 JOIN user_emails ue USING() JOIN users u2 ue USING (first_name, last_name)', -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users u1 JOIN users u2 USING (email) JOIN user_emails ue USING()", + "SELECT * FROM users u1 JOIN users u2 USING(email) JOIN user_emails ue USING ()", + "SELECT * FROM users u1 JOIN user_emails ue USING () JOIN users u2 ue USING(first_name, last_name)", + "SELECT * FROM users u1 JOIN user_emails ue USING() JOIN users u2 ue USING (first_name, last_name)", + ], +) def test_join_using_suggests_from_last_table(completer, text): - position = text.index('()') + 1 + position = text.index("()") + 1 result = get_result(completer, text, position) assert completions_to_set(result) == completions_to_set( - [column('id'), column('email')]) + [column("id"), column("email")] + ) -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT * FROM users INNER JOIN orders USING (id,', - 'SELECT * FROM users INNER JOIN orders USING(id,', -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (id,", + "SELECT * FROM users INNER JOIN orders USING(id,", + ], +) def test_join_using_suggests_columns_after_first_column(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [column('id'), column('email')]) + [column("id"), column("email")] + ) -@parametrize('completer', completers(casing=False, aliasing=False)) -@parametrize('text', [ - 'SELECT * FROM ', - 'SELECT * FROM users CROSS JOIN ', - 'SELECT * FROM users natural join ' -]) +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + "SELECT * FROM ", + "SELECT * FROM users CROSS JOIN ", + "SELECT * FROM users natural join ", + ], +) def test_table_names_after_from(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items()) + testdata.schemas_and_from_clause_items() + ) assert [c.text for c in result] == [ - 'public', - 'orders', + "public", + "orders", '"select"', - 'users', + "users", '"Users"', - 'functions', - 'user_emails', - '_custom_fun()', - 'custom_fun()', - 'custom_func1()', - 'custom_func2()', - 'set_returning_func(x := , y := )', + "functions", + "user_emails", + "_custom_fun()", + "custom_fun()", + "custom_func1()", + "custom_func2()", + "set_returning_func(x := , y := )", ] -@parametrize('completer', completers(casing=False, qualify=no_qual)) +@parametrize("completer", completers(casing=False, qualify=no_qual)) def test_auto_escaped_col_names(completer): - result = get_result(completer, 'SELECT from "select"', len('SELECT ')) + result = get_result(completer, 'SELECT from "select"', len("SELECT ")) assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords('select')) + testdata.columns_functions_and_keywords("select") + ) -@parametrize('completer', completers(aliasing=False)) +@parametrize("completer", completers(aliasing=False)) def test_allow_leading_double_quote_in_last_word(completer): result = get_result(completer, 'SELECT * from "sele') @@ -591,13 +695,16 @@ def test_allow_leading_double_quote_in_last_word(completer): assert expected in result -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT 1::', - 'CREATE TABLE foo (bar ', - 'CREATE FUNCTION foo (bar INT, baz ', - 'ALTER TABLE foo ALTER COLUMN bar TYPE ', -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT 1::", + "CREATE TABLE foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "ALTER TABLE foo ALTER COLUMN bar TYPE ", + ], +) def test_suggest_datatype(text, completer): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( @@ -605,141 +712,149 @@ def test_suggest_datatype(text, completer): ) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggest_columns_from_escaped_table_alias(completer): result = get_result(completer, 'select * from "select" s where s.') - assert completions_to_set(result) == completions_to_set( - testdata.columns('select')) + assert completions_to_set(result) == completions_to_set(testdata.columns("select")) -@parametrize('completer', completers(casing=False, qualify=no_qual)) +@parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggest_columns_from_set_returning_function(completer): - result = get_result( - completer, 'select from set_returning_func()', len('select ') + result = get_result(completer, "select from set_returning_func()", len("select ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("set_returning_func", typ="functions") ) - assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords( - 'set_returning_func', typ='functions' - )) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggest_columns_from_aliased_set_returning_function(completer): result = get_result( - completer, 'select f. from set_returning_func() f', len('select f.') + completer, "select f. from set_returning_func() f", len("select f.") ) assert completions_to_set(result) == completions_to_set( - testdata.columns('set_returning_func', typ='functions') + testdata.columns("set_returning_func", typ="functions") ) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_join_functions_using_suggests_common_columns(completer): - text = '''SELECT * FROM set_returning_func() f1 - INNER JOIN set_returning_func() f2 USING (''' + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 USING (""" result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.columns('set_returning_func', typ='functions') + testdata.columns("set_returning_func", typ="functions") ) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_join_functions_on_suggests_columns_and_join_conditions(completer): - text = '''SELECT * FROM set_returning_func() f1 - INNER JOIN set_returning_func() f2 ON f1.''' + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 ON f1.""" result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [name_join('y = f2.y'), name_join('x = f2.x')] + - testdata.columns('set_returning_func', typ='functions') + [name_join("y = f2.y"), name_join("x = f2.x")] + + testdata.columns("set_returning_func", typ="functions") ) -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_learn_keywords(completer): - history = 'CREATE VIEW v AS SELECT 1' + history = "CREATE VIEW v AS SELECT 1" completer.extend_query_history(history) # Now that we've used `VIEW` once, it should be suggested ahead of other # keywords starting with v. - text = 'create v' + text = "create v" completions = get_result(completer, text) - assert completions[0].text == 'VIEW' + assert completions[0].text == "VIEW" -@parametrize('completer', completers(casing=False, aliasing=False)) +@parametrize("completer", completers(casing=False, aliasing=False)) def test_learn_table_names(completer): - history = 'SELECT * FROM users; SELECT * FROM orders; SELECT * FROM users' + history = "SELECT * FROM users; SELECT * FROM orders; SELECT * FROM users" completer.extend_query_history(history) - text = 'SELECT * FROM ' + text = "SELECT * FROM " completions = get_result(completer, text) # `users` should be higher priority than `orders` (used more often) - users = table('users') - orders = table('orders') + users = table("users") + orders = table("orders") assert completions.index(users) < completions.index(orders) -@parametrize('completer', completers(casing=False, qualify=no_qual)) +@parametrize("completer", completers(casing=False, qualify=no_qual)) def test_columns_before_keywords(completer): - text = 'SELECT * FROM orders WHERE s' + text = "SELECT * FROM orders WHERE s" completions = get_result(completer, text) - col = column('status', -1) - kw = keyword('SELECT', -1) + col = column("status", -1) + kw = keyword("SELECT", -1) assert completions.index(col) < completions.index(kw) -@parametrize('completer', completers(casing=False, qualify=no_qual)) -@parametrize('text', [ - 'SELECT * FROM users', - 'INSERT INTO users SELECT * FROM users u', - '''INSERT INTO users(id, parentid, email, first_name, last_name) +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM users", + "INSERT INTO users SELECT * FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) SELECT * - FROM users u''', - ]) + FROM users u""", + ], +) def test_wildcard_column_expansion(completer, text): - position = text.find('*') + 1 + position = text.find("*") + 1 completions = get_result(completer, text, position) - col_list = 'id, parentid, email, first_name, last_name' + col_list = "id, parentid, email, first_name, last_name" expected = [wildcard_expansion(col_list)] assert expected == completions -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT u.* FROM users u', - 'INSERT INTO public.users SELECT u.* FROM users u', - '''INSERT INTO users(id, parentid, email, first_name, last_name) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.* FROM users u", + "INSERT INTO public.users SELECT u.* FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) SELECT u.* - FROM users u''', - ]) + FROM users u""", + ], +) def test_wildcard_column_expansion_with_alias(completer, text): - position = text.find('*') + 1 + position = text.find("*") + 1 completions = get_result(completer, text, position) - col_list = 'id, u.parentid, u.email, u.first_name, u.last_name' + col_list = "id, u.parentid, u.email, u.first_name, u.last_name" expected = [wildcard_expansion(col_list)] assert expected == completions -@parametrize('completer', completers(casing=False)) -@parametrize('text,expected', [ - ('SELECT users.* FROM users', - 'id, users.parentid, users.email, users.first_name, users.last_name'), - ('SELECT Users.* FROM Users', - 'id, Users.parentid, Users.email, Users.first_name, Users.last_name'), -]) -def test_wildcard_column_expansion_with_table_qualifier( - completer, text, expected -): - position = len('SELECT users.*') +@parametrize("completer", completers(casing=False)) +@parametrize( + "text,expected", + [ + ( + "SELECT users.* FROM users", + "id, users.parentid, users.email, users.first_name, users.last_name", + ), + ( + "SELECT Users.* FROM Users", + "id, Users.parentid, Users.email, Users.first_name, Users.last_name", + ), + ], +) +def test_wildcard_column_expansion_with_table_qualifier(completer, text, expected): + position = len("SELECT users.*") completions = get_result(completer, text, position) @@ -748,20 +863,22 @@ def test_wildcard_column_expansion_with_table_qualifier( assert expected == completions -@parametrize('completer', completers(casing=False, qualify=qual)) +@parametrize("completer", completers(casing=False, qualify=qual)) def test_wildcard_column_expansion_with_two_tables(completer): text = 'SELECT * FROM "select" JOIN users u ON true' - position = len('SELECT *') + position = len("SELECT *") completions = get_result(completer, text, position) - cols = ('"select".id, "select".insert, "select"."ABC", ' - 'u.id, u.parentid, u.email, u.first_name, u.last_name') + cols = ( + '"select".id, "select".insert, "select"."ABC", ' + "u.id, u.parentid, u.email, u.first_name, u.last_name" + ) expected = [wildcard_expansion(cols)] assert completions == expected -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_wildcard_column_expansion_with_two_tables_and_parent(completer): text = 'SELECT "select".* FROM "select" JOIN users u ON true' position = len('SELECT "select".*') @@ -774,194 +891,212 @@ def test_wildcard_column_expansion_with_two_tables_and_parent(completer): assert expected == completions -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'SELECT U. FROM Users U', - 'SELECT U. FROM USERS U', - 'SELECT U. FROM users U' -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + ["SELECT U. FROM Users U", "SELECT U. FROM USERS U", "SELECT U. FROM users U"], +) def test_suggest_columns_from_unquoted_table(completer, text): - position = len('SELECT U.') + position = len("SELECT U.") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - testdata.columns('users')) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) -@parametrize('completer', completers(casing=False)) +@parametrize("completer", completers(casing=False)) def test_suggest_columns_from_quoted_table(completer): - result = get_result( - completer, 'SELECT U. FROM "Users" U', len('SELECT U.') - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns('Users')) + result = get_result(completer, 'SELECT U. FROM "Users" U', len("SELECT U.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("Users")) -@parametrize('completer', completers(casing=False, aliasing=False)) -@parametrize('text', ['SELECT * FROM ', - 'SELECT * FROM Orders o CROSS JOIN ']) +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "]) def test_schema_or_visible_table_completion(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items()) + testdata.schemas_and_from_clause_items() + ) -@parametrize('completer', completers(casing=False, aliasing=True)) -@parametrize('text', ['SELECT * FROM ']) +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) def test_table_aliases(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas() + aliased_rels) + testdata.schemas() + aliased_rels + ) -@parametrize('completer', completers(casing=False, aliasing=True)) -@parametrize('text', ['SELECT * FROM Orders o CROSS JOIN ']) +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) def test_duplicate_table_aliases(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set(testdata.schemas() + [ - table('orders o2'), - table('users u'), - table('"Users" U'), - table('"select" s'), - view('user_emails ue'), - view('functions f'), - function('_custom_fun() cf'), - function('custom_fun() cf'), - function('custom_func1() cf'), - function('custom_func2() cf'), - function( - 'set_returning_func(x := , y := ) srf', - display='set_returning_func(x, y) srf' - ), - ]) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("orders o2"), + table("users u"), + table('"Users" U'), + table('"select" s'), + view("user_emails ue"), + view("functions f"), + function("_custom_fun() cf"), + function("custom_fun() cf"), + function("custom_func1() cf"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) -@parametrize('completer', completers(casing=True, aliasing=True)) -@parametrize('text', ['SELECT * FROM Orders o CROSS JOIN ']) +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) def test_duplicate_aliases_with_casing(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set([ - schema('PUBLIC'), - table('Orders O2'), - table('Users U'), - table('"Users" U'), - table('"select" s'), - view('User_Emails UE'), - view('Functions F'), - function('_custom_fun() cf'), - function('Custom_Fun() CF'), - function('Custom_Func1() CF'), - function('custom_func2() cf'), - function( - 'set_returning_func(x := , y := ) srf', - display='set_returning_func(x, y) srf' - ), - ]) + assert completions_to_set(result) == completions_to_set( + [ + schema("PUBLIC"), + table("Orders O2"), + table("Users U"), + table('"Users" U'), + table('"select" s'), + view("User_Emails UE"), + view("Functions F"), + function("_custom_fun() cf"), + function("Custom_Fun() CF"), + function("Custom_Func1() CF"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) -@parametrize('completer', completers(casing=True, aliasing=True)) -@parametrize('text', ['SELECT * FROM ']) +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) def test_aliases_with_casing(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [schema('PUBLIC')] + cased_aliased_rels) + [schema("PUBLIC")] + cased_aliased_rels + ) -@parametrize('completer', completers(casing=True, aliasing=False)) -@parametrize('text', ['SELECT * FROM ']) +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", ["SELECT * FROM "]) def test_table_casing(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [schema('PUBLIC')] + cased_rels) + [schema("PUBLIC")] + cased_rels + ) -@parametrize('completer', completers(casing=False)) -@parametrize('text', [ - 'INSERT INTO users ()', - 'INSERT INTO users()', - 'INSERT INTO users () SELECT * FROM orders;', - 'INSERT INTO users() SELECT * FROM users u cross join orders o', -]) +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "INSERT INTO users ()", + "INSERT INTO users()", + "INSERT INTO users () SELECT * FROM orders;", + "INSERT INTO users() SELECT * FROM users u cross join orders o", + ], +) def test_insert(completer, text): - position = text.find('(') + 1 + position = text.find("(") + 1 result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - testdata.columns('users')) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) -@parametrize('completer', completers(casing=False, aliasing=False)) +@parametrize("completer", completers(casing=False, aliasing=False)) def test_suggest_cte_names(completer): - text = ''' + text = """ WITH cte1 AS (SELECT a, b, c FROM foo), cte2 AS (SELECT d, e, f FROM bar) SELECT * FROM - ''' + """ result = get_result(completer, text) - expected = completions_to_set([ - Completion('cte1', 0, display_meta='table'), - Completion('cte2', 0, display_meta='table'), - ]) + expected = completions_to_set( + [ + Completion("cte1", 0, display_meta="table"), + Completion("cte2", 0, display_meta="table"), + ] + ) assert expected <= completions_to_set(result) -@parametrize('completer', completers(casing=False, qualify=no_qual)) +@parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggest_columns_from_cte(completer): result = get_result( completer, - 'WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte', - len('WITH cte AS (SELECT foo, bar FROM baz) SELECT ') - ) - expected = ( - [ - Completion('foo', 0, display_meta='column'), - Completion('bar', 0, display_meta='column'), - ] + testdata.functions_and_keywords() + "WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte", + len("WITH cte AS (SELECT foo, bar FROM baz) SELECT "), ) + expected = [ + Completion("foo", 0, display_meta="column"), + Completion("bar", 0, display_meta="column"), + ] + testdata.functions_and_keywords() assert completions_to_set(expected) == completions_to_set(result) -@parametrize('completer', completers(casing=False, qualify=no_qual)) -@parametrize('text', [ - 'WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.', - 'WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.', -]) +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.", + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.", + ], +) def test_cte_qualified_columns(completer, text): result = get_result(completer, text) - expected = [Completion('foo', 0, display_meta='column')] + expected = [Completion("foo", 0, display_meta="column")] assert completions_to_set(expected) == completions_to_set(result) -@parametrize('keyword_casing,expected,texts', [ - ('upper', 'SELECT', ('', 's', 'S', 'Sel')), - ('lower', 'select', ('', 's', 'S', 'Sel')), - ('auto', 'SELECT', ('', 'S', 'SEL', 'seL')), - ('auto', 'select', ('s', 'sel', 'SEl')), -]) +@parametrize( + "keyword_casing,expected,texts", + [ + ("upper", "SELECT", ("", "s", "S", "Sel")), + ("lower", "select", ("", "s", "S", "Sel")), + ("auto", "SELECT", ("", "S", "SEL", "seL")), + ("auto", "select", ("s", "sel", "SEl")), + ], +) def test_keyword_casing_upper(keyword_casing, expected, texts): for text in texts: - completer = testdata.get_completer({'keyword_casing': keyword_casing}) + completer = testdata.get_completer({"keyword_casing": keyword_casing}) completions = get_result(completer, text) assert expected in [cpl.text for cpl in completions] -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_keyword_after_alter(completer): - text = 'ALTER TABLE users ALTER ' - expected = Completion('COLUMN', start_position=0, display_meta='keyword') + text = "ALTER TABLE users ALTER " + expected = Completion("COLUMN", start_position=0, display_meta="keyword") completions = get_result(completer, text) assert expected in completions -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_set_schema(completer): - text = ('SET SCHEMA ') + text = "SET SCHEMA " result = get_result(completer, text) - expected = completions_to_set([schema(u"'public'")]) + expected = completions_to_set([schema("'public'")]) assert completions_to_set(result) == expected -@parametrize('completer', completers()) +@parametrize("completer", completers()) def test_special_name_completion(completer): - result = get_result(completer, '\\t') - assert completions_to_set(result) == completions_to_set([Completion( - text='\\timing', start_position=-2, display_meta='Toggle timing of commands.')]) + result = get_result(completer, "\\t") + assert completions_to_set(result) == completions_to_set( + [ + Completion( + text="\\timing", + start_position=-2, + display_meta="Toggle timing of commands.", + ) + ] + ) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py index 7e6a8225..ed647fc6 100644 --- a/tests/test_sqlcompletion.py +++ b/tests/test_sqlcompletion.py @@ -1,685 +1,721 @@ from pgcli.packages.sqlcompletion import ( - suggest_type, Special, Database, Schema, Table, Column, View, Keyword, - FromClauseItem, Function, Datatype, Alias, JoinCondition, Join) + suggest_type, + Special, + Database, + Schema, + Table, + Column, + View, + Keyword, + FromClauseItem, + Function, + Datatype, + Alias, + JoinCondition, + Join, +) from pgcli.packages.parseutils.tables import TableReference import pytest -def cols_etc(table, schema=None, alias=None, is_function=False, parent=None, - last_keyword=None): +def cols_etc( + table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None +): """Returns the expected select-clause suggestions for a single-table select.""" - return set([ - Column(table_refs=(TableReference(schema, table, alias, is_function),), - qualifiable=True), - Function(schema=parent), - Keyword(last_keyword)]) + return set( + [ + Column( + table_refs=(TableReference(schema, table, alias, is_function),), + qualifiable=True, + ), + Function(schema=parent), + Keyword(last_keyword), + ] + ) def test_select_suggests_cols_with_visible_table_scope(): - suggestions = suggest_type('SELECT FROM tabl', 'SELECT ') - assert set(suggestions) == cols_etc('tabl', last_keyword='SELECT') + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", last_keyword="SELECT") def test_select_suggests_cols_with_qualified_table_scope(): - suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ') - assert set(suggestions) == cols_etc('tabl', 'sch', last_keyword='SELECT') + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", "sch", last_keyword="SELECT") def test_cte_does_not_crash(): - sql = 'WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;' + sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;" for i in range(len(sql)): - suggestions = suggest_type(sql[:i+1], sql[:i+1]) + suggestions = suggest_type(sql[: i + 1], sql[: i + 1]) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM "tabl" WHERE ', -]) +@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE ']) def test_where_suggests_columns_functions_quoted_table(expression): - expected = cols_etc('tabl', alias='"tabl"', last_keyword='WHERE') + expected = cols_etc("tabl", alias='"tabl"', last_keyword="WHERE") suggestions = suggest_type(expression, expression) assert expected == set(suggestions) -@pytest.mark.parametrize('expression', [ - 'INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ', - 'INSERT INTO OtherTabl SELECT * FROM tabl WHERE ', - 'SELECT * FROM tabl WHERE ', - 'SELECT * FROM tabl WHERE (', - 'SELECT * FROM tabl WHERE foo = ', - 'SELECT * FROM tabl WHERE bar OR ', - 'SELECT * FROM tabl WHERE foo = 1 AND ', - 'SELECT * FROM tabl WHERE (bar > 10 AND ', - 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (', - 'SELECT * FROM tabl WHERE 10 < ', - 'SELECT * FROM tabl WHERE foo BETWEEN ', - 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ', -]) +@pytest.mark.parametrize( + "expression", + [ + "INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ", + "INSERT INTO OtherTabl SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE foo = ", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) def test_where_suggests_columns_functions(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == cols_etc('tabl', last_keyword='WHERE') + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE foo IN (', - 'SELECT * FROM tabl WHERE foo IN (bar, ', -]) +@pytest.mark.parametrize( + "expression", + ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "], +) def test_where_in_suggests_columns(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == cols_etc('tabl', last_keyword='WHERE') + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") -@pytest.mark.parametrize('expression', [ - 'SELECT 1 AS ', - 'SELECT 1 FROM tabl AS ', -]) +@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "]) def test_after_as(expression): suggestions = suggest_type(expression, expression) assert set(suggestions) == set() def test_where_equals_any_suggests_columns_or_keywords(): - text = 'SELECT * FROM tabl WHERE foo = ANY(' + text = "SELECT * FROM tabl WHERE foo = ANY(" suggestions = suggest_type(text, text) - assert set(suggestions) == cols_etc('tabl', last_keyword='WHERE') + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") def test_lparen_suggests_cols(): - suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(') - assert set(suggestion) == set([ - Column(table_refs=((None, 'tbl', None, False),), qualifiable=True)]) + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert set(suggestion) == set( + [Column(table_refs=((None, "tbl", None, False),), qualifiable=True)] + ) def test_select_suggests_cols_and_funcs(): - suggestions = suggest_type('SELECT ', 'SELECT ') - assert set(suggestions) == set([ - Column(table_refs=(), qualifiable=True), - Function(schema=None), - Keyword('SELECT'), - ]) + suggestions = suggest_type("SELECT ", "SELECT ") + assert set(suggestions) == set( + [ + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ] + ) -@pytest.mark.parametrize('expression', [ - 'INSERT INTO ', - 'COPY ', - 'UPDATE ', - 'DESCRIBE ', -]) +@pytest.mark.parametrize( + "expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "] +) def test_suggests_tables_views_and_schemas(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([ - Table(schema=None), - View(schema=None), - Schema(), - ]) + assert set(suggestions) == set([Table(schema=None), View(schema=None), Schema()]) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM ', -]) +@pytest.mark.parametrize("expression", ["SELECT * FROM "]) def test_suggest_tables_views_schemas_and_functions(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([ - FromClauseItem(schema=None), - Schema() - ]) + assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ', - 'SELECT * FROM foo JOIN bar USING (barid) JOIN ' -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ", + "SELECT * FROM foo JOIN bar USING (barid) JOIN ", + ], +) def test_suggest_after_join_with_two_tables(expression): suggestions = suggest_type(expression, expression) - tables = tuple([(None, 'foo', None, False), (None, 'bar', None, False)]) - assert set(suggestions) == set([ - FromClauseItem(schema=None, table_refs=tables), - Join(tables, None), - Schema(), - ]) + tables = tuple([(None, "foo", None, False), (None, "bar", None, False)]) + assert set(suggestions) == set( + [FromClauseItem(schema=None, table_refs=tables), Join(tables, None), Schema()] + ) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM foo JOIN ', - 'SELECT * FROM foo JOIN bar' -]) +@pytest.mark.parametrize( + "expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"] +) def test_suggest_after_join_with_one_table(expression): suggestions = suggest_type(expression, expression) - tables = ((None, 'foo', None, False),) - assert set(suggestions) == set([ - FromClauseItem(schema=None, table_refs=tables), - Join(((None, 'foo', None, False),), None), - Schema(), - ]) + tables = ((None, "foo", None, False),) + assert set(suggestions) == set( + [ + FromClauseItem(schema=None, table_refs=tables), + Join(((None, "foo", None, False),), None), + Schema(), + ] + ) -@pytest.mark.parametrize('expression', [ - 'INSERT INTO sch.', - 'COPY sch.', - 'DESCRIBE sch.', -]) +@pytest.mark.parametrize( + "expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."] +) def test_suggest_qualified_tables_and_views(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([ - Table(schema='sch'), - View(schema='sch'), - ]) + assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")]) -@pytest.mark.parametrize('expression', [ - 'UPDATE sch.', -]) +@pytest.mark.parametrize("expression", ["UPDATE sch."]) def test_suggest_qualified_aliasable_tables_and_views(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([ - Table(schema='sch'), - View(schema='sch'), - ]) + assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")]) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM sch.', - 'SELECT * FROM sch."', - 'SELECT * FROM sch."foo', - 'SELECT * FROM "sch".', - 'SELECT * FROM "sch"."', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + 'SELECT * FROM sch."', + 'SELECT * FROM sch."foo', + 'SELECT * FROM "sch".', + 'SELECT * FROM "sch"."', + ], +) def test_suggest_qualified_tables_views_and_functions(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([FromClauseItem(schema='sch')]) + assert set(suggestions) == set([FromClauseItem(schema="sch")]) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM foo JOIN sch.', -]) +@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."]) def test_suggest_qualified_tables_views_functions_and_joins(expression): suggestions = suggest_type(expression, expression) - tbls = tuple([(None, 'foo', None, False)]) - assert set(suggestions) == set([ - FromClauseItem(schema='sch', table_refs=tbls), - Join(tbls, 'sch'), - ]) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestions) == set( + [FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch")] + ) def test_truncate_suggests_tables_and_schemas(): - suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ') - assert set(suggestions) == set([ - Table(schema=None), - Schema()]) + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert set(suggestions) == set([Table(schema=None), Schema()]) def test_truncate_suggests_qualified_tables(): - suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.') - assert set(suggestions) == set([ - Table(schema='sch')]) + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert set(suggestions) == set([Table(schema="sch")]) -@pytest.mark.parametrize('text', [ - 'SELECT DISTINCT ', - 'INSERT INTO foo SELECT DISTINCT ' -]) +@pytest.mark.parametrize( + "text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "] +) def test_distinct_suggests_cols(text): suggestions = suggest_type(text, text) - assert set(suggestions) == set([ - Column(table_refs=(), local_tables=(), qualifiable=True), - Function(schema=None), - Keyword('DISTINCT') - ]) - - -@pytest.mark.parametrize('text, text_before, last_keyword', [ - ( - 'SELECT DISTINCT FROM tbl x JOIN tbl1 y', - 'SELECT DISTINCT', - 'SELECT', - ), - ( - 'SELECT * FROM tbl x JOIN tbl1 y ORDER BY ', - 'SELECT * FROM tbl x JOIN tbl1 y ORDER BY ', - 'ORDER BY', + assert set(suggestions) == set( + [ + Column(table_refs=(), local_tables=(), qualifiable=True), + Function(schema=None), + Keyword("DISTINCT"), + ] ) -]) -def test_distinct_and_order_by_suggestions_with_aliases(text, text_before, - last_keyword): - suggestions = suggest_type(text, text_before) - assert set(suggestions) == set([ - Column( - table_refs=( - TableReference(None, 'tbl', 'x', False), - TableReference(None, 'tbl1', 'y', False), - ), - local_tables=(), - qualifiable=True + + +@pytest.mark.parametrize( + "text, text_before, last_keyword", + [ + ("SELECT DISTINCT FROM tbl x JOIN tbl1 y", "SELECT DISTINCT", "SELECT"), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "ORDER BY", ), - Function(schema=None), - Keyword(last_keyword) - ]) - - -@pytest.mark.parametrize('text, text_before', [ - ( - 'SELECT DISTINCT x. FROM tbl x JOIN tbl1 y', - 'SELECT DISTINCT x.' - ), - ( - 'SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.', - 'SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.' + ], +) +def test_distinct_and_order_by_suggestions_with_aliases( + text, text_before, last_keyword +): + suggestions = suggest_type(text, text_before) + assert set(suggestions) == set( + [ + Column( + table_refs=( + TableReference(None, "tbl", "x", False), + TableReference(None, "tbl1", "y", False), + ), + local_tables=(), + qualifiable=True, + ), + Function(schema=None), + Keyword(last_keyword), + ] ) -]) + + +@pytest.mark.parametrize( + "text, text_before", + [ + ("SELECT DISTINCT x. FROM tbl x JOIN tbl1 y", "SELECT DISTINCT x."), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + ), + ], +) def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before): suggestions = suggest_type(text, text_before) - assert set(suggestions) == set([ - Column( - table_refs=(TableReference(None, 'tbl', 'x', False),), - local_tables=(), - qualifiable=False - ), - Table(schema='x'), - View(schema='x'), - Function(schema='x'), - ]) + assert set(suggestions) == set( + [ + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + ] + ) def test_col_comma_suggests_cols(): - suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,') - assert set(suggestions) == set([ - Column(table_refs=((None, 'tbl', None, False),), qualifiable=True), - Function(schema=None), - Keyword('SELECT'), - ]) + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert set(suggestions) == set( + [ + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ] + ) def test_table_comma_suggests_tables_and_schemas(): - suggestions = suggest_type('SELECT a, b FROM tbl1, ', - 'SELECT a, b FROM tbl1, ') - assert set(suggestions) == set([ - FromClauseItem(schema=None), - Schema(), - ]) + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) def test_into_suggests_tables_and_schemas(): - suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ') - assert set(suggestion) == set([ - Table(schema=None), - View(schema=None), - Schema(), - ]) + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert set(suggestion) == set([Table(schema=None), View(schema=None), Schema()]) -@pytest.mark.parametrize('text', [ - 'INSERT INTO abc (', - 'INSERT INTO abc () SELECT * FROM hij;', -]) +@pytest.mark.parametrize( + "text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"] +) def test_insert_into_lparen_suggests_cols(text): - suggestions = suggest_type(text, 'INSERT INTO abc (') + suggestions = suggest_type(text, "INSERT INTO abc (") assert suggestions == ( - Column( - table_refs=((None, 'abc', None, False),), - context='insert' - ), + Column(table_refs=((None, "abc", None, False),), context="insert"), ) def test_insert_into_lparen_partial_text_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i') + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") assert suggestions == ( - Column( - table_refs=((None, 'abc', None, False),), - context='insert' - ), + Column(table_refs=((None, "abc", None, False),), context="insert"), ) def test_insert_into_lparen_comma_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,') + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") assert suggestions == ( - Column( - table_refs=((None, 'abc', None, False),), - context='insert' - ), + Column(table_refs=((None, "abc", None, False),), context="insert"), ) def test_partially_typed_col_name_suggests_col_names(): - suggestions = suggest_type('SELECT * FROM tabl WHERE col_n', - 'SELECT * FROM tabl WHERE col_n') - assert set(suggestions) == cols_etc('tabl', last_keyword='WHERE') + suggestions = suggest_type( + "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n" + ) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): - suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.') - assert set(suggestions) == set([ - Column(table_refs=((None, 'tabl', None, False),)), - Table(schema='tabl'), - View(schema='tabl'), - Function(schema='tabl'), - ]) + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert set(suggestions) == set( + [ + Column(table_refs=((None, "tabl", None, False),)), + Table(schema="tabl"), + View(schema="tabl"), + Function(schema="tabl"), + ] + ) -@pytest.mark.parametrize('sql', [ - 'SELECT t1. FROM tabl1 t1', - 'SELECT t1. FROM tabl1 t1, tabl2 t2', - 'SELECT t1. FROM "tabl1" t1', - 'SELECT t1. FROM "tabl1" t1, "tabl2" t2', - ]) +@pytest.mark.parametrize( + "sql", + [ + "SELECT t1. FROM tabl1 t1", + "SELECT t1. FROM tabl1 t1, tabl2 t2", + 'SELECT t1. FROM "tabl1" t1', + 'SELECT t1. FROM "tabl1" t1, "tabl2" t2', + ], +) def test_dot_suggests_cols_of_an_alias(sql): - suggestions = suggest_type(sql, 'SELECT t1.') - assert set(suggestions) == set([ - Table(schema='t1'), - View(schema='t1'), - Column(table_refs=((None, 'tabl1', 't1', False),)), - Function(schema='t1'), - ]) + suggestions = suggest_type(sql, "SELECT t1.") + assert set(suggestions) == set( + [ + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + ] + ) -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM tabl1 t1 WHERE t1.', - 'SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.', - 'SELECT * FROM "tabl1" t1 WHERE t1.', - 'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.', - ]) +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM tabl1 t1 WHERE t1.", + "SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.", + 'SELECT * FROM "tabl1" t1 WHERE t1.', + 'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.', + ], +) def test_dot_suggests_cols_of_an_alias_where(sql): suggestions = suggest_type(sql, sql) - assert set(suggestions) == set([ - Table(schema='t1'), - View(schema='t1'), - Column(table_refs=((None, 'tabl1', 't1', False),)), - Function(schema='t1'), - ]) + assert set(suggestions) == set( + [ + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + ] + ) def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): - suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.a, t2.') - assert set(suggestions) == set([ - Column(table_refs=((None, 'tabl2', 't2', False),)), - Table(schema='t2'), - View(schema='t2'), - Function(schema='t2'), - ]) + suggestions = suggest_type( + "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2." + ) + assert set(suggestions) == set( + [ + Column(table_refs=((None, "tabl2", "t2", False),)), + Table(schema="t2"), + View(schema="t2"), + Function(schema="t2"), + ] + ) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (', - 'SELECT * FROM foo WHERE EXISTS (', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + ], +) def test_sub_select_suggests_keyword(expression): suggestion = suggest_type(expression, expression) assert suggestion == (Keyword(),) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (S', - 'SELECT * FROM foo WHERE EXISTS (S', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) def test_sub_select_partial_text_suggests_keyword(expression): suggestion = suggest_type(expression, expression) - assert suggestion ==(Keyword(),) + assert suggestion == (Keyword(),) def test_outer_table_reference_in_exists_subquery_suggests_columns(): - q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.' + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." suggestions = suggest_type(q, q) - assert set(suggestions) == set([ - Column(table_refs=((None, 'foo', 'f', False),)), - Table(schema='f'), - View(schema='f'), - Function(schema='f'), - ]) + assert set(suggestions) == set( + [ + Column(table_refs=((None, "foo", "f", False),)), + Table(schema="f"), + View(schema="f"), + Function(schema="f"), + ] + ) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (SELECT * FROM ', -]) +@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "]) def test_sub_select_table_name_completion(expression): suggestion = suggest_type(expression, expression) - assert set(suggestion) == set([ - FromClauseItem(schema=None), - Schema(), - ]) + assert set(suggestion) == set([FromClauseItem(schema=None), Schema()]) -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) def test_sub_select_table_name_completion_with_outer_table(expression): suggestion = suggest_type(expression, expression) - tbls = tuple([(None, 'foo', None, False)]) - assert set(suggestion) == set([ - FromClauseItem(schema=None, table_refs=tbls), - Schema(), - ]) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestion) == set( + [FromClauseItem(schema=None, table_refs=tbls), Schema()] + ) def test_sub_select_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT FROM abc', - 'SELECT * FROM (SELECT ') - assert set(suggestions) == set([ - Column(table_refs=((None, 'abc', None, False),), qualifiable=True), - Function(schema=None), - Keyword('SELECT'), - ]) + suggestions = suggest_type( + "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT " + ) + assert set(suggestions) == set( + [ + Column(table_refs=((None, "abc", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ] + ) @pytest.mark.xfail def test_sub_select_multiple_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc', - 'SELECT * FROM (SELECT a, ') - assert set(suggestions) == cols_etc('abc') + suggestions = suggest_type( + "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, " + ) + assert set(suggestions) == cols_etc("abc") def test_sub_select_dot_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t', - 'SELECT * FROM (SELECT t.') - assert set(suggestions) == set([ - Column(table_refs=((None, 'tabl', 't', False),)), - Table(schema='t'), - View(schema='t'), - Function(schema='t'), - ]) + suggestions = suggest_type( + "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t." + ) + assert set(suggestions) == set( + [ + Column(table_refs=((None, "tabl", "t", False),)), + Table(schema="t"), + View(schema="t"), + Function(schema="t"), + ] + ) -@pytest.mark.parametrize('join_type',('', 'INNER', 'LEFT', 'RIGHT OUTER',)) -@pytest.mark.parametrize('tbl_alias',('', 'foo',)) +@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER")) +@pytest.mark.parametrize("tbl_alias", ("", "foo")) def test_join_suggests_tables_and_schemas(tbl_alias, join_type): - text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type) + text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type) suggestion = suggest_type(text, text) - tbls = tuple([(None, 'abc', tbl_alias or None, False)]) - assert set(suggestion) == set([ - FromClauseItem(schema=None, table_refs=tbls), - Schema(), - Join(tbls, None), - ]) + tbls = tuple([(None, "abc", tbl_alias or None, False)]) + assert set(suggestion) == set( + [FromClauseItem(schema=None, table_refs=tbls), Schema(), Join(tbls, None)] + ) def test_left_join_with_comma(): - text = 'select * from foo f left join bar b,' + text = "select * from foo f left join bar b," suggestions = suggest_type(text, text) # tbls should also include (None, 'bar', 'b', False) # but there's a bug with commas - tbls = tuple([(None, 'foo', 'f', False)]) - assert set(suggestions) == set([ - FromClauseItem(schema=None, table_refs=tbls), - Schema(), - ]) + tbls = tuple([(None, "foo", "f", False)]) + assert set(suggestions) == set( + [FromClauseItem(schema=None, table_refs=tbls), Schema()] + ) -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.', -]) +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) def test_join_alias_dot_suggests_cols1(sql): suggestions = suggest_type(sql, sql) - tables = ((None, 'abc', 'a', False), (None, 'def', 'd', False)) - assert set(suggestions) == set([ - Column(table_refs=((None, 'abc', 'a', False),)), - Table(schema='a'), - View(schema='a'), - Function(schema='a'), - JoinCondition(table_refs=tables, parent=(None, 'abc', 'a', False)) - ]) + tables = ((None, "abc", "a", False), (None, "def", "d", False)) + assert set(suggestions) == set( + [ + Column(table_refs=((None, "abc", "a", False),)), + Table(schema="a"), + View(schema="a"), + Function(schema="a"), + JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)), + ] + ) -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.id = d.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.', -]) +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) def test_join_alias_dot_suggests_cols2(sql): suggestion = suggest_type(sql, sql) - assert set(suggestion) == set([ - Column(table_refs=((None, 'def', 'd', False),)), - Table(schema='d'), - View(schema='d'), - Function(schema='d'), - ]) + assert set(suggestion) == set( + [ + Column(table_refs=((None, "def", "d", False),)), + Table(schema="d"), + View(schema="d"), + Function(schema="d"), + ] + ) -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on ', - '''select a.x, b.y +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + """select a.x, b.y from abc a join bcd b on -''', - '''select a.x, b.y +""", + """select a.x, b.y from abc a join bcd b -on ''', - 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ', -]) +on """, + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + ], +) def test_on_suggests_aliases_and_join_conditions(sql): suggestions = suggest_type(sql, sql) - tables = ((None, 'abc', 'a', False), (None, 'bcd', 'b', False)) - assert set(suggestions) == set((JoinCondition(table_refs=tables, parent=None), - Alias(aliases=('a', 'b',)),)) + tables = ((None, "abc", "a", False), (None, "bcd", "b", False)) + assert set(suggestions) == set( + (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("a", "b"))) + ) -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ', - 'select abc.x, bcd.y from abc join bcd on ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) def test_on_suggests_tables_and_join_conditions(sql): suggestions = suggest_type(sql, sql) - tables = ((None, 'abc', None, False), (None, 'bcd', None, False)) - assert set(suggestions) == set((JoinCondition(table_refs=tables, parent=None), - Alias(aliases=('abc', 'bcd',)),)) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == set( + (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd"))) + ) -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on a.id = ', - 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) def test_on_suggests_aliases_right_side(sql): suggestions = suggest_type(sql, sql) - assert suggestions == (Alias(aliases=('a', 'b',)),) + assert suggestions == (Alias(aliases=("a", "b")),) -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ', - 'select abc.x, bcd.y from abc join bcd on ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) def test_on_suggests_tables_and_join_conditions_right_side(sql): suggestions = suggest_type(sql, sql) - tables = ((None, 'abc', None, False), (None, 'bcd', None, False)) - assert set(suggestions) == set((JoinCondition(table_refs=tables, parent=None), - Alias(aliases=('abc', 'bcd',)),)) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == set( + (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd"))) + ) -@pytest.mark.parametrize('text', ( - 'select * from abc inner join def using (', - 'select * from abc inner join def using (col1, ', - 'insert into hij select * from abc inner join def using (', - '''insert into hij(x, y, z) - select * from abc inner join def using (col1, ''', - '''insert into hij (a,b,c) - select * from abc inner join def using (col1, ''', -)) +@pytest.mark.parametrize( + "text", + ( + "select * from abc inner join def using (", + "select * from abc inner join def using (col1, ", + "insert into hij select * from abc inner join def using (", + """insert into hij(x, y, z) + select * from abc inner join def using (col1, """, + """insert into hij (a,b,c) + select * from abc inner join def using (col1, """, + ), +) def test_join_using_suggests_common_columns(text): - tables = ((None, 'abc', None, False), (None, 'def', None, False)) - assert set(suggest_type(text, text)) == set([ - Column(table_refs=tables, require_last_table=True),]) + tables = ((None, "abc", None, False), (None, "def", None, False)) + assert set(suggest_type(text, text)) == set( + [Column(table_refs=tables, require_last_table=True)] + ) def test_suggest_columns_after_multiple_joins(): - sql = '''select * from t1 + sql = """select * from t1 inner join t2 ON t1.id = t2.t1_id inner join t3 ON - t2.id = t3.''' + t2.id = t3.""" suggestions = suggest_type(sql, sql) - assert Column(table_refs=((None, 't3', None, False),)) in set(suggestions) + assert Column(table_refs=((None, "t3", None, False),)) in set(suggestions) def test_2_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ', - 'select * from a; select * from ') - assert set(suggestions) == set([ - FromClauseItem(schema=None), - Schema(), - ]) + suggestions = suggest_type( + "select * from a; select * from ", "select * from a; select * from " + ) + assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) - suggestions = suggest_type('select * from a; select from b', - 'select * from a; select ') - assert set(suggestions) == set([ - Column(table_refs=((None, 'b', None, False),), qualifiable=True), - Function(schema=None), - Keyword('SELECT') - ]) + suggestions = suggest_type( + "select * from a; select from b", "select * from a; select " + ) + assert set(suggestions) == set( + [ + Column(table_refs=((None, "b", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ] + ) # Should work even if first statement is invalid - suggestions = suggest_type('select * from; select * from ', - 'select * from; select * from ') - assert set(suggestions) == set([ - FromClauseItem(schema=None), - Schema(), - ]) + suggestions = suggest_type( + "select * from; select * from ", "select * from; select * from " + ) + assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) def test_2_statements_1st_current(): - suggestions = suggest_type('select * from ; select * from b', - 'select * from ') - assert set(suggestions) == set([ - FromClauseItem(schema=None), - Schema(), - ]) + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) - suggestions = suggest_type('select from a; select * from b', - 'select ') - assert set(suggestions) == cols_etc('a', last_keyword='SELECT') + suggestions = suggest_type("select from a; select * from b", "select ") + assert set(suggestions) == cols_etc("a", last_keyword="SELECT") def test_3_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ; select * from c', - 'select * from a; select * from ') - assert set(suggestions) == set([ - FromClauseItem(schema=None), - Schema(), - ]) + suggestions = suggest_type( + "select * from a; select * from ; select * from c", + "select * from a; select * from ", + ) + assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) - suggestions = suggest_type('select * from a; select from b; select * from c', - 'select * from a; select ') - assert set(suggestions) == cols_etc('b', last_keyword='SELECT') + suggestions = suggest_type( + "select * from a; select from b; select * from c", "select * from a; select " + ) + assert set(suggestions) == cols_etc("b", last_keyword="SELECT") -@pytest.mark.parametrize('text', [ -''' +@pytest.mark.parametrize( + "text", + [ + """ CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ SELECT FROM foo; SELECT 2 FROM bar; $$ language sql; - ''', - '''create function func2(int, varchar) + """, + """create function func2(int, varchar) RETURNS text language sql AS $func$ SELECT 2 FROM bar; SELECT FROM foo; $func$ - ''', -''' + """, + """ CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ SELECT 3 FROM foo; SELECT 2 FROM bar; @@ -691,8 +727,8 @@ $func$ SELECT 2 FROM bar; SELECT FROM foo; $func$ - ''', -''' + """, + """ SELECT * FROM baz; CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ SELECT FROM foo; @@ -706,25 +742,28 @@ SELECT 3 FROM bar; SELECT FROM foo; $func$ SELECT * FROM qux; - ''' -]) + """, + ], +) def test_statements_in_function_body(text): - suggestions = suggest_type(text, text[:text.find(' ') + 1]) - assert set(suggestions) == set([ - Column(table_refs=((None, 'foo', None, False),), qualifiable=True), - Function(schema=None), - Keyword('SELECT'), - ]) + suggestions = suggest_type(text, text[: text.find(" ") + 1]) + assert set(suggestions) == set( + [ + Column(table_refs=((None, "foo", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ] + ) functions = [ -''' + """ CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ SELECT 1 FROM foo; SELECT 2 FROM bar; $$ language sql; - ''', - ''' + """, + """ create function func2(int, varchar) RETURNS text language sql AS @@ -732,44 +771,44 @@ language sql AS SELECT 2 FROM bar; SELECT 1 FROM foo; '; - ''' + """, ] -@pytest.mark.parametrize('text', functions) +@pytest.mark.parametrize("text", functions) def test_statements_with_cursor_after_function_body(text): - suggestions = suggest_type(text, text[:text.find('; ') + 1]) + suggestions = suggest_type(text, text[: text.find("; ") + 1]) assert set(suggestions) == set([Keyword(), Special()]) -@pytest.mark.parametrize('text', functions) +@pytest.mark.parametrize("text", functions) def test_statements_with_cursor_before_function_body(text): - suggestions = suggest_type(text, '') + suggestions = suggest_type(text, "") assert set(suggestions) == set([Keyword(), Special()]) def test_create_db_with_template(): - suggestions = suggest_type('create database foo with template ', - 'create database foo with template ') + suggestions = suggest_type( + "create database foo with template ", "create database foo with template " + ) assert set(suggestions) == set((Database(),)) -@pytest.mark.parametrize('initial_text',('', ' ', '\t \t',)) +@pytest.mark.parametrize("initial_text", ("", " ", "\t \t")) def test_specials_included_for_initial_completion(initial_text): suggestions = suggest_type(initial_text, initial_text) - assert set(suggestions) == \ - set([Keyword(), Special()]) + assert set(suggestions) == set([Keyword(), Special()]) def test_drop_schema_qualified_table_suggests_only_tables(): - text = 'DROP TABLE schema_name.table_name' + text = "DROP TABLE schema_name.table_name" suggestions = suggest_type(text, text) - assert suggestions ==(Table(schema='schema_name'),) + assert suggestions == (Table(schema="schema_name"),) -@pytest.mark.parametrize('text',(',', ' ,', 'sel ,',)) +@pytest.mark.parametrize("text", (",", " ,", "sel ,")) def test_handle_pre_completion_comma_gracefully(text): suggestions = suggest_type(text, text) @@ -777,159 +816,157 @@ def test_handle_pre_completion_comma_gracefully(text): def test_drop_schema_suggests_schemas(): - sql = 'DROP SCHEMA ' - assert suggest_type(sql, sql) ==(Schema(),) + sql = "DROP SCHEMA " + assert suggest_type(sql, sql) == (Schema(),) -@pytest.mark.parametrize('text', [ - 'SELECT x::', - 'SELECT x::y', - 'SELECT (x + y)::', -]) +@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"]) def test_cast_operator_suggests_types(text): - assert set(suggest_type(text, text)) == set([ - Datatype(schema=None), - Table(schema=None), - Schema()]) + assert set(suggest_type(text, text)) == set( + [Datatype(schema=None), Table(schema=None), Schema()] + ) -@pytest.mark.parametrize('text', [ - 'SELECT foo::bar.', - 'SELECT foo::bar.baz', - 'SELECT (x + y)::bar.', -]) +@pytest.mark.parametrize( + "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."] +) def test_cast_operator_suggests_schema_qualified_types(text): - assert set(suggest_type(text, text)) == set([ - Datatype(schema='bar'), - Table(schema='bar')]) + assert set(suggest_type(text, text)) == set( + [Datatype(schema="bar"), Table(schema="bar")] + ) def test_alter_column_type_suggests_types(): - q = 'ALTER TABLE foo ALTER COLUMN bar TYPE ' - assert set(suggest_type(q, q)) == set([ - Datatype(schema=None), - Table(schema=None), - Schema()]) + q = "ALTER TABLE foo ALTER COLUMN bar TYPE " + assert set(suggest_type(q, q)) == set( + [Datatype(schema=None), Table(schema=None), Schema()] + ) -@pytest.mark.parametrize('text', [ - 'CREATE TABLE foo (bar ', - 'CREATE TABLE foo (bar DOU', - 'CREATE TABLE foo (bar INT, baz ', - 'CREATE TABLE foo (bar INT, baz TEXT, qux ', - 'CREATE FUNCTION foo (bar ', - 'CREATE FUNCTION foo (bar INT, baz ', - 'SELECT * FROM foo() AS bar (baz ', - 'SELECT * FROM foo() AS bar (baz INT, qux ', - - # make sure this doesnt trigger special completion - 'CREATE TABLE foo (dt d', -]) +@pytest.mark.parametrize( + "text", + [ + "CREATE TABLE foo (bar ", + "CREATE TABLE foo (bar DOU", + "CREATE TABLE foo (bar INT, baz ", + "CREATE TABLE foo (bar INT, baz TEXT, qux ", + "CREATE FUNCTION foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "SELECT * FROM foo() AS bar (baz ", + "SELECT * FROM foo() AS bar (baz INT, qux ", + # make sure this doesnt trigger special completion + "CREATE TABLE foo (dt d", + ], +) def test_identifier_suggests_types_in_parentheses(text): - assert set(suggest_type(text, text)) == set([ - Datatype(schema=None), - Table(schema=None), - Schema()]) + assert set(suggest_type(text, text)) == set( + [Datatype(schema=None), Table(schema=None), Schema()] + ) -@pytest.mark.parametrize('text', [ - 'SELECT foo ', - 'SELECT foo FROM bar ', - 'SELECT foo AS bar ', - 'SELECT foo bar ', - 'SELECT * FROM foo AS bar ', - 'SELECT * FROM foo bar ', - 'SELECT foo FROM (SELECT bar ' -]) +@pytest.mark.parametrize( + "text", + [ + "SELECT foo ", + "SELECT foo FROM bar ", + "SELECT foo AS bar ", + "SELECT foo bar ", + "SELECT * FROM foo AS bar ", + "SELECT * FROM foo bar ", + "SELECT foo FROM (SELECT bar ", + ], +) def test_alias_suggests_keywords(text): suggestions = suggest_type(text, text) - assert suggestions ==(Keyword(),) + assert suggestions == (Keyword(),) def test_invalid_sql(): # issue 317 - text = 'selt *' + text = "selt *" suggestions = suggest_type(text, text) - assert suggestions ==(Keyword(),) + assert suggestions == (Keyword(),) -@pytest.mark.parametrize('text', [ - 'SELECT * FROM foo where created > now() - ', - 'select * from foo where bar ', -]) +@pytest.mark.parametrize( + "text", + ["SELECT * FROM foo where created > now() - ", "select * from foo where bar "], +) def test_suggest_where_keyword(text): # https://github.com/dbcli/mycli/issues/135 suggestions = suggest_type(text, text) - assert set(suggestions) == cols_etc('foo', last_keyword='WHERE') + assert set(suggestions) == cols_etc("foo", last_keyword="WHERE") -@pytest.mark.parametrize('text, before, expected', [ - ('\\ns abc SELECT ', 'SELECT ', [ - Column(table_refs=(), qualifiable=True), - Function(schema=None), - Keyword('SELECT') - ]), - ('\\ns abc SELECT foo ', 'SELECT foo ', (Keyword(),)), - ('\\ns abc SELECT t1. FROM tabl1 t1', 'SELECT t1.', [ - Table(schema='t1'), - View(schema='t1'), - Column(table_refs=((None, 'tabl1', 't1', False),)), - Function(schema='t1') - ]) -]) +@pytest.mark.parametrize( + "text, before, expected", + [ + ( + "\\ns abc SELECT ", + "SELECT ", + [ + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ], + ), + ("\\ns abc SELECT foo ", "SELECT foo ", (Keyword(),)), + ( + "\\ns abc SELECT t1. FROM tabl1 t1", + "SELECT t1.", + [ + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + ], + ), + ], +) def test_named_query_completion(text, before, expected): suggestions = suggest_type(text, before) assert set(expected) == set(suggestions) def test_select_suggests_fields_from_function(): - suggestions = suggest_type('SELECT FROM func()', 'SELECT ') - assert set(suggestions) == cols_etc( - 'func', is_function=True, last_keyword='SELECT') + suggestions = suggest_type("SELECT FROM func()", "SELECT ") + assert set(suggestions) == cols_etc("func", is_function=True, last_keyword="SELECT") -@pytest.mark.parametrize('sql', [ - '(', -]) +@pytest.mark.parametrize("sql", ["("]) def test_leading_parenthesis(sql): # No assertion for now; just make sure it doesn't crash suggest_type(sql, sql) -@pytest.mark.parametrize('sql', [ - 'select * from "', - 'select * from "foo', -]) +@pytest.mark.parametrize("sql", ['select * from "', 'select * from "foo']) def test_ignore_leading_double_quotes(sql): suggestions = suggest_type(sql, sql) assert FromClauseItem(schema=None) in set(suggestions) -@pytest.mark.parametrize('sql', [ - 'ALTER TABLE foo ALTER COLUMN ', - 'ALTER TABLE foo ALTER COLUMN bar', - 'ALTER TABLE foo DROP COLUMN ', - 'ALTER TABLE foo DROP COLUMN bar', -]) +@pytest.mark.parametrize( + "sql", + [ + "ALTER TABLE foo ALTER COLUMN ", + "ALTER TABLE foo ALTER COLUMN bar", + "ALTER TABLE foo DROP COLUMN ", + "ALTER TABLE foo DROP COLUMN bar", + ], +) def test_column_keyword_suggests_columns(sql): suggestions = suggest_type(sql, sql) - assert set(suggestions) == set([ - Column(table_refs=((None, 'foo', None, False),)), - ]) + assert set(suggestions) == set([Column(table_refs=((None, "foo", None, False),))]) def test_handle_unrecognized_kw_generously(): - sql = 'SELECT * FROM sessions WHERE session = 1 AND ' + sql = "SELECT * FROM sessions WHERE session = 1 AND " suggestions = suggest_type(sql, sql) - expected = Column(table_refs=((None, 'sessions', None, False),), - qualifiable=True) + expected = Column(table_refs=((None, "sessions", None, False),), qualifiable=True) assert expected in set(suggestions) -@pytest.mark.parametrize('sql', [ - 'ALTER ', - 'ALTER TABLE foo ALTER ', -]) +@pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "]) def test_keyword_after_alter(sql): - assert Keyword('ALTER') in set(suggest_type(sql, sql)) + assert Keyword("ALTER") in set(suggest_type(sql, sql)) diff --git a/tests/utils.py b/tests/utils.py index 2ef7aa18..2427c30d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,18 +5,20 @@ from pgcli.main import format_output, OutputSettings from pgcli.pgexecute import register_json_typecasters from os import getenv -POSTGRES_USER = getenv('PGUSER', 'postgres') -POSTGRES_HOST = getenv('PGHOST', 'localhost') -POSTGRES_PORT = getenv('PGPORT', 5432) -POSTGRES_PASSWORD = getenv('PGPASSWORD', '') +POSTGRES_USER = getenv("PGUSER", "postgres") +POSTGRES_HOST = getenv("PGHOST", "localhost") +POSTGRES_PORT = getenv("PGPORT", 5432) +POSTGRES_PASSWORD = getenv("PGPASSWORD", "") def db_connection(dbname=None): - conn = psycopg2.connect(user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - port=POSTGRES_PORT, - database=dbname) + conn = psycopg2.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + database=dbname, + ) conn.autocommit = True return conn @@ -26,8 +28,8 @@ try: CAN_CONNECT_TO_DB = True SERVER_VERSION = conn.server_version json_types = register_json_typecasters(conn, lambda x: x) - JSON_AVAILABLE = 'json' in json_types - JSONB_AVAILABLE = 'jsonb' in json_types + JSON_AVAILABLE = "json" in json_types + JSONB_AVAILABLE = "jsonb" in json_types except: CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False SERVER_VERSION = 0 @@ -35,51 +37,59 @@ except: dbtest = pytest.mark.skipif( not CAN_CONNECT_TO_DB, - reason="Need a postgres instance at localhost accessible by user 'postgres'") + reason="Need a postgres instance at localhost accessible by user 'postgres'", +) requires_json = pytest.mark.skipif( - not JSON_AVAILABLE, - reason='Postgres server unavailable or json type not defined') + not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined" +) requires_jsonb = pytest.mark.skipif( - not JSONB_AVAILABLE, - reason='Postgres server unavailable or jsonb type not defined') + not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined" +) def create_db(dbname): with db_connection().cursor() as cur: try: - cur.execute('''CREATE DATABASE _test_db''') + cur.execute("""CREATE DATABASE _test_db""") except: pass def drop_tables(conn): with conn.cursor() as cur: - cur.execute(''' + cur.execute( + """ DROP SCHEMA public CASCADE; CREATE SCHEMA public; DROP SCHEMA IF EXISTS schema1 CASCADE; - DROP SCHEMA IF EXISTS schema2 CASCADE''') + DROP SCHEMA IF EXISTS schema2 CASCADE""" + ) -def run(executor, sql, join=False, expanded=False, pgspecial=None, - exception_formatter=None): +def run( + executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None +): " Return string output for the sql to be run " results = executor.run(sql, pgspecial, exception_formatter) formatted = [] - settings = OutputSettings(table_format='psql', dcmlfmt='d', floatfmt='g', - expanded=expanded) + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded + ) for title, rows, headers, status, sql, success, is_special in results: formatted.extend(format_output(title, rows, headers, status, settings)) if join: - formatted = '\n'.join(formatted) + formatted = "\n".join(formatted) return formatted def completions_to_set(completions): - return set((completion.display_text, completion.display_meta_text) for completion in completions) + return set( + (completion.display_text, completion.display_meta_text) + for completion in completions + )