1
1
mirror of https://github.com/dbcli/pgcli.git synced 2024-10-06 02:07:53 +03:00

Add function to check for unclosed quotations

This commit is contained in:
Darik Gamble 2015-08-17 19:04:31 -04:00 committed by Darik Gamble
parent 27c2eabc08
commit 9be09b8bf4
2 changed files with 67 additions and 2 deletions

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import re
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation
from sqlparse.tokens import Keyword, DML, Punctuation, Token
cleanup_regex = {
# This matches only alphanumerics and underscores.
@ -188,6 +188,43 @@ def find_prev_keyword(sql):
return None, ''
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote"""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
def _parsed_is_open_quote(parsed):
tokens = list(parsed.flatten())
i = 0
while i < len(tokens):
tok = tokens[i]
if tok.match(Token.Error, "'"):
# An unmatched single quote
return True
elif (tok.ttype in Token.Name.Builtin
and dollar_quote_regex.match(tok.value)):
# Find the matching closing dollar quote sign
for (i, tok2) in enumerate(tokens[i+1:], i+1):
if tok2.match(Token.Name.Builtin, tok.value):
break
else:
# No matching dollar sign quote
return True
i += 1
return False
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print (extract_tables(sql))

View File

@ -1,6 +1,6 @@
import pytest
from pgcli.packages.parseutils import extract_tables
from pgcli.packages.parseutils import find_prev_keyword
from pgcli.packages.parseutils import find_prev_keyword, is_open_quote
def test_empty_string():
tables = extract_tables('')
@ -122,3 +122,31 @@ def test_find_prev_keyword_where(sql):
def test_find_prev_keyword_open_parens(sql):
kw, _ = find_prev_keyword(sql)
assert kw.value == '('
@pytest.mark.parametrize('sql', [
'',
'$$ foo $$',
"$$ 'foo' $$",
'$$ "foo" $$',
'$$ $a$ $$',
'$a$ $$ $a$',
'foo bar $$ baz $$',
])
def test_is_open_quote__closed(sql):
assert not is_open_quote(sql)
@pytest.mark.parametrize('sql', [
'$$',
';;;$$',
'foo $$ bar $$; foo $$',
'$$ foo $a$',
"foo 'bar baz",
"$a$ foo ",
'$$ "foo" ',
'$$ $a$ ',
'foo bar $$ baz',
])
def test_is_open_quote__open(sql):
assert is_open_quote(sql)