1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-10-06 02:07:53 +03:00

Add function to check for unclosed quotations

This commit is contained in:
Darik Gamble 2015-08-17 19:04:31 -04:00 committed by Darik Gamble
parent 27c2eabc08
commit 9be09b8bf4
2 changed files with 67 additions and 2 deletions

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import re import re
import sqlparse import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Function from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation from sqlparse.tokens import Keyword, DML, Punctuation, Token
cleanup_regex = { cleanup_regex = {
# This matches only alphanumerics and underscores. # This matches only alphanumerics and underscores.
@ -188,6 +188,43 @@ def find_prev_keyword(sql):
return None, '' return None, ''
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote"""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
def _parsed_is_open_quote(parsed):
tokens = list(parsed.flatten())
i = 0
while i < len(tokens):
tok = tokens[i]
if tok.match(Token.Error, "'"):
# An unmatched single quote
return True
elif (tok.ttype in Token.Name.Builtin
and dollar_quote_regex.match(tok.value)):
# Find the matching closing dollar quote sign
for (i, tok2) in enumerate(tokens[i+1:], i+1):
if tok2.match(Token.Name.Builtin, tok.value):
break
else:
# No matching dollar sign quote
return True
i += 1
return False
if __name__ == '__main__': if __name__ == '__main__':
sql = 'select * from (select t. from tabl t' sql = 'select * from (select t. from tabl t'
print (extract_tables(sql)) print (extract_tables(sql))

View File

@ -1,6 +1,6 @@
import pytest import pytest
from pgcli.packages.parseutils import extract_tables from pgcli.packages.parseutils import extract_tables
from pgcli.packages.parseutils import find_prev_keyword from pgcli.packages.parseutils import find_prev_keyword, is_open_quote
def test_empty_string(): def test_empty_string():
tables = extract_tables('') tables = extract_tables('')
@ -122,3 +122,31 @@ def test_find_prev_keyword_where(sql):
def test_find_prev_keyword_open_parens(sql): def test_find_prev_keyword_open_parens(sql):
kw, _ = find_prev_keyword(sql) kw, _ = find_prev_keyword(sql)
assert kw.value == '(' assert kw.value == '('
@pytest.mark.parametrize('sql', [
'',
'$$ foo $$',
"$$ 'foo' $$",
'$$ "foo" $$',
'$$ $a$ $$',
'$a$ $$ $a$',
'foo bar $$ baz $$',
])
def test_is_open_quote__closed(sql):
assert not is_open_quote(sql)
@pytest.mark.parametrize('sql', [
'$$',
';;;$$',
'foo $$ bar $$; foo $$',
'$$ foo $a$',
"foo 'bar baz",
"$a$ foo ",
'$$ "foo" ',
'$$ $a$ ',
'foo bar $$ baz',
])
def test_is_open_quote__open(sql):
assert is_open_quote(sql)