mirror of
https://github.com/dbcli/pgcli.git
synced 2024-10-06 02:07:53 +03:00
Add function to check for unclosed quotations
This commit is contained in:
parent
27c2eabc08
commit
9be09b8bf4
@ -2,7 +2,7 @@ from __future__ import print_function
|
||||
import re
|
||||
import sqlparse
|
||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||
from sqlparse.tokens import Keyword, DML, Punctuation
|
||||
from sqlparse.tokens import Keyword, DML, Punctuation, Token
|
||||
|
||||
cleanup_regex = {
|
||||
# This matches only alphanumerics and underscores.
|
||||
@ -188,6 +188,43 @@ def find_prev_keyword(sql):
|
||||
|
||||
return None, ''
|
||||
|
||||
|
||||
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
||||
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
"""Returns true if the query contains an unclosed quote"""
|
||||
|
||||
# parsed can contain one or more semi-colon separated commands
|
||||
parsed = sqlparse.parse(sql)
|
||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||
|
||||
|
||||
def _parsed_is_open_quote(parsed):
|
||||
tokens = list(parsed.flatten())
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
tok = tokens[i]
|
||||
if tok.match(Token.Error, "'"):
|
||||
# An unmatched single quote
|
||||
return True
|
||||
elif (tok.ttype in Token.Name.Builtin
|
||||
and dollar_quote_regex.match(tok.value)):
|
||||
# Find the matching closing dollar quote sign
|
||||
for (i, tok2) in enumerate(tokens[i+1:], i+1):
|
||||
if tok2.match(Token.Name.Builtin, tok.value):
|
||||
break
|
||||
else:
|
||||
# No matching dollar sign quote
|
||||
return True
|
||||
|
||||
i += 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sql = 'select * from (select t. from tabl t'
|
||||
print (extract_tables(sql))
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from pgcli.packages.parseutils import extract_tables
|
||||
from pgcli.packages.parseutils import find_prev_keyword
|
||||
from pgcli.packages.parseutils import find_prev_keyword, is_open_quote
|
||||
|
||||
def test_empty_string():
|
||||
tables = extract_tables('')
|
||||
@ -122,3 +122,31 @@ def test_find_prev_keyword_where(sql):
|
||||
def test_find_prev_keyword_open_parens(sql):
|
||||
kw, _ = find_prev_keyword(sql)
|
||||
assert kw.value == '('
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'',
|
||||
'$$ foo $$',
|
||||
"$$ 'foo' $$",
|
||||
'$$ "foo" $$',
|
||||
'$$ $a$ $$',
|
||||
'$a$ $$ $a$',
|
||||
'foo bar $$ baz $$',
|
||||
])
|
||||
def test_is_open_quote__closed(sql):
|
||||
assert not is_open_quote(sql)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'$$',
|
||||
';;;$$',
|
||||
'foo $$ bar $$; foo $$',
|
||||
'$$ foo $a$',
|
||||
"foo 'bar baz",
|
||||
"$a$ foo ",
|
||||
'$$ "foo" ',
|
||||
'$$ $a$ ',
|
||||
'foo bar $$ baz',
|
||||
])
|
||||
def test_is_open_quote__open(sql):
|
||||
assert is_open_quote(sql)
|
||||
|
Loading…
Reference in New Issue
Block a user