diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index c632c009..30cc9fe4 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -306,8 +306,8 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier elif token_v in ('type', '::'): # ALTER TABLE foo SET DATA TYPE bar # SELECT foo::bar - # Note that tables are a form of composite type in postgresql, so they - # are be suggested here as well + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well schema = (identifier and identifier.get_parent_name()) or [] suggestions = [{'type': 'datatype', 'schema': schema}, {'type': 'table', 'schema': schema}] diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 30a65295..dd48ecdc 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -307,9 +307,17 @@ class PGCompleter(Completer): completions.extend(special) elif suggestion['type'] == 'datatype': - datatypes = self.find_matches(word_before_cursor, + # suggest custom datatypes + types = self.populate_schema_objects( + suggestion['schema'], 'datatypes') + types = self.find_matches(word_before_cursor, types) + completions.extend(types) + + if not suggestion['schema']: + # Also suggest hardcoded types + types = self.find_matches(word_before_cursor, self.datatypes, start_only=True) - completions.extend(datatypes) + completions.extend(types) return completions diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 658a9eee..de2e76a3 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -16,8 +16,13 @@ metadata = { 'shipments': ['id', 'address', 'user_id'] }}, 'functions': { - 'public': ['func1', 'func2'], - 'custom': ['func3', 'func4']} + 'public': ['func1', 'func2'], + 'custom': ['func3', 'func4'], + }, + 'datatypes': { + 'public': ['typ1', 'typ2'], + 'custom': ['typ3', 'typ4'], + }, } @pytest.fixture @@ -39,10 +44,15 @@ def completer(): for schema, funcs in metadata['functions'].items() for func in funcs] + datatypes = [(schema, datatype) + for schema, datatypes in metadata['datatypes'].items() + for datatype in datatypes] + comp.extend_schemata(schemata) comp.extend_relations(tables, kind='tables') comp.extend_columns(columns, kind='tables') comp.extend_functions(functions) + comp.extend_datatypes(datatypes) comp.set_search_path(['public']) return comp @@ -249,3 +259,22 @@ def test_schema_qualified_function_name(completer, complete_event): assert result == set([ Completion(text='func3', start_position=-len('func')), Completion(text='func4', start_position=-len('func'))]) + + +@pytest.mark.parametrize('text', [ + 'SELECT 1::custom.', + 'CREATE TABLE foo (bar custom.', + 'CREATE FUNCTION foo (bar INT, baz custom.', + 'ALTER TABLE foo ALTER COLUMN bar TYPE custom.', +]) +def test_schema_qualified_type_name(text, completer, complete_event): + pos = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=pos), complete_event) + assert set(result) == set([ + Completion(text='typ3'), + Completion(text='typ4'), + Completion(text='users'), + Completion(text='products'), + Completion(text='shipments'), + ]) \ No newline at end of file diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index 41a06d09..e28a9cb3 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -10,7 +10,8 @@ metadata = { 'select': ['id', 'insert', 'ABC']}, 'views': { 'user_emails': ['id', 'email']}, - 'functions': ['custom_func1', 'custom_func2'] + 'functions': ['custom_func1', 'custom_func2'], + 'datatypes': ['custom_type1', 'custom_type2'], } @pytest.fixture @@ -44,6 +45,10 @@ def completer(): functions = [('public', func) for func in metadata['functions']] comp.extend_functions(functions) + # types + datatypes = [('public', typ) for typ in metadata['datatypes']] + comp.extend_datatypes(datatypes) + comp.set_search_path(['public']) return comp @@ -357,12 +362,25 @@ def test_auto_escaped_col_names(completer, complete_event): @pytest.mark.parametrize('text', [ - 'SELECT 1::DOU', - 'CREATE TABLE foo (bar DOU', - 'CREATE FUNCTION foo (bar INT, baz DOU', + 'SELECT 1::', + 'CREATE TABLE foo (bar ', + 'CREATE FUNCTION foo (bar INT, baz ', + 'ALTER TABLE foo ALTER COLUMN bar TYPE ', ]) def test_suggest_datatype(text, completer, complete_event): pos = len(text) result = completer.get_completions( Document(text=text, cursor_position=pos), complete_event) - assert result == [Completion(text='DOUBLE PRECISION', start_position=-3)] + assert set(result) == set( + [Completion(t) for t in [ + # Custom types + 'custom_type1', 'custom_type2', + + # Tables + 'public', 'users', 'orders', '"select"', + + # Built-in datatypes + ] + completer.datatypes + ]) + +