def tableFromCreateStatement(schema, stmt):
"""
Add a table from a CREATE TABLE sqlparse statement object.
@param schema: The schema to add the table statement to.
@type schema: L{Schema}
@param stmt: The C{CREATE TABLE} statement object.
@type stmt: L{Statement}
"""
i = iterSignificant(stmt)
expect(i, ttype=Keyword.DDL, value="CREATE")
expect(i, ttype=Keyword, value="TABLE")
function = expect(i, cls=Function)
i = iterSignificant(function)
name = expect(i, cls=Identifier).get_name().encode("utf-8")
self = Table(schema, name)
parens = expect(i, cls=Parenthesis)
cp = _ColumnParser(self, iterSignificant(parens), parens)
cp.parse()
return self
python类parse()的实例源码
def parse_sql_tables(sql):
tables = []
parsed = sqlparse.parse(sql)
stmt = parsed[0]
from_seen = False
for token in stmt.tokens:
if from_seen:
if token.ttype is Keyword:
continue
else:
if isinstance(token, IdentifierList):
for identifier in token.get_identifiers():
tables.append(SQLParser.get_table_name(identifier))
elif isinstance(token, Identifier):
tables.append(SQLParser.get_table_name(token))
else:
pass
if token.ttype is Keyword and token.value.upper() == "FROM":
from_seen = True
return tables
def test_placeholder(self):
def _get_tokens(sql):
return sqlparse.parse(sql)[0].tokens[-1].tokens
t = _get_tokens('select * from foo where user = ?')
self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
self.assertEqual(t[-1].value, '?')
t = _get_tokens('select * from foo where user = :1')
self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
self.assertEqual(t[-1].value, ':1')
t = _get_tokens('select * from foo where user = :name')
self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
self.assertEqual(t[-1].value, ':name')
t = _get_tokens('select * from foo where user = %s')
self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
self.assertEqual(t[-1].value, '%s')
t = _get_tokens('select * from foo where user = $a')
self.assert_(t[-1].ttype is sqlparse.tokens.Name.Placeholder)
self.assertEqual(t[-1].value, '$a')
def test_issue26(self):
# parse stand-alone comments
p = sqlparse.parse('--hello')[0]
self.assertEqual(len(p.tokens), 1)
self.assert_(p.tokens[0].ttype is T.Comment.Single)
p = sqlparse.parse('-- hello')[0]
self.assertEqual(len(p.tokens), 1)
self.assert_(p.tokens[0].ttype is T.Comment.Single)
p = sqlparse.parse('--hello\n')[0]
self.assertEqual(len(p.tokens), 1)
self.assert_(p.tokens[0].ttype is T.Comment.Single)
p = sqlparse.parse('--')[0]
self.assertEqual(len(p.tokens), 1)
self.assert_(p.tokens[0].ttype is T.Comment.Single)
p = sqlparse.parse('--\n')[0]
self.assertEqual(len(p.tokens), 1)
self.assert_(p.tokens[0].ttype is T.Comment.Single)
def extract_tables():
stream = extract_from_part(sqlparse.parse(sql)[0])
return list(extract_table_identifiers(stream))
def run(self, connection):
statements = sqlparse.parse(self.content)
content = "".join((six.text_type(stmt) for stmt in statements))
if content != self.content:
raise SQLRunnerException("sqlparse failed to properly split input")
rows = 0
with connection.cursor() as cursor:
for statement in statements:
if clean_sql_code(str(statement)).strip() in ("", ";"):
# Sometimes sqlparse keeps the empty lines here,
# this could negatively affect libpq
continue
logger.debug("Running one statement... <<%s>>", str(statement))
cursor.execute(str(statement).replace("\\timing\n", ""))
logger.debug("Affected %s rows", cursor.rowcount)
rows += cursor.rowcount
return rows
def can_apply_to_multiple_schemata(sql, cursor):
parsed = sqlparse.parse(sql)
data = [
get_table_and_schema(statement, cursor)
for statement in parsed
]
if any(match is False for match in data):
return False
data = [part for part in data if part is not None]
# If all of our references to tables are shared, or schema-qualified, we
# can't apply multiple times.
if all(
is_shared_table(table_name) or schema_name
for (table_name, schema_name) in data
):
return False
return True
def extract_tables(sql):
"""
Extract the table names from an SQL statment.
Returns a list of TableReference namedtuples.
"""
parsed = sqlparse.parse(sql)
if not parsed:
return ()
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as
# function calls due to the parenthesized column list, e.g. interprets
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(stream,
allow_functions=not insert_stmt)
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
return tuple(i for i in identifiers if i.name)
def _gather_sql_placeholders(body):
p = sqlparse.parse(body)
assert len(p) == 1
p = p[0]
tokens = list(p.flatten())
names = [x.value[1:] for x in tokens if x.ttype == Token.Name.Placeholder]
return sorted(set(names))
def parse(self):
if self.identifier_qualified:
identifiers = self._handle_qualified_identifier()
else:
identifiers = self._handle_identifier()
if self.index != len(self.identifier):
log.error(
"ParseError: {} failed to parse. Qualified: {}. "
"Identifiers: {}".format(
self.identifier,
self.identifier_qualified,
self.identifier,
)
)
raise ParseError()
return identifiers
def __init__(self, statement):
super(CreateTableStatement, self).__init__(statement)
if (
self.token_matcher.matches(
Optional([Compound(['if', 'not', 'exists'])]),
) and
self.token_matcher.has_next()
):
self.database_name = None
if self.token_matcher.has_matches(Compound([Any(), '.', Any()])):
db = self.token_matcher.pop().value
self.token_matcher.pop()
self.database_name = MysqlQualifiedIdentifierParser(
db,
identifier_qualified=False
).parse()
self.table = MysqlQualifiedIdentifierParser(
self.token_matcher.pop().value,
identifier_qualified=False
).parse()
else:
raise IncompatibleStatementError()
def get_parameters(cls):
""" obtiene los parametros de self.query """
placeholders = set([p.value[1:] # sacar el ':' inicial
for p in filter(lambda t: t.ttype == sqlparse.tokens.Token.Name.Placeholder,
sqlparse.parse(cls.get_query())[0].flatten())])
return placeholders
def __init__(self, sql: str, initial_offset: int) -> None:
self._initial_offset = initial_offset
self._tokens = [] # type: Tuple[sqlparse.sql.Token, int]
depth = 0
for statement in sqlparse.parse(sql):
for token in statement.tokens:
if token.is_group:
self._tokens.extend(_flatten_group(token, depth))
else:
self._tokens.append((token, depth))
def parse_sql_columns(sql):
columns = []
parsed = sqlparse.parse(sql)
stmt = parsed[0]
for token in stmt.tokens:
if isinstance(token, IdentifierList):
for identifier in token.get_identifiers():
columns.append(identifier.get_real_name())
if isinstance(token, Identifier):
columns.append(token.get_real_name())
if token.ttype is Keyword: # from
break
return columns
def to_hierarchical_list(self):
parsed = sqlparse.parse(self._sql)
stmt = parsed[0]
return self._handle_level(stmt)
def test_split_semicolon(self):
sql2 = 'select * from foo where bar = \'foo;bar\';'
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
self.ndiffAssertEqual(unicode(stmts[1]), sql2)
def test_split_backslash(self):
stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
self.assertEqual(len(stmts), 3)
def test_create_function(self):
sql = load_file('function.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql(self):
sql = load_file('function_psql.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql2(self):
sql = load_file('function_psql2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_dashcomments(self):
sql = load_file('dashcomment.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_dashcomments_eol(self):
stmts = sqlparse.parse('select foo; -- comment\n')
self.assertEqual(len(stmts), 1)
stmts = sqlparse.parse('select foo; -- comment\r')
self.assertEqual(len(stmts), 1)
stmts = sqlparse.parse('select foo; -- comment\r\n')
self.assertEqual(len(stmts), 1)
stmts = sqlparse.parse('select foo; -- comment')
self.assertEqual(len(stmts), 1)
def test_begintag(self):
sql = load_file('begintag.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_begintag_2(self):
sql = load_file('begintag_2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_comment_with_umlaut(self):
sql = (u'select * from foo;\n'
u'-- Testing an umlaut: ä\n'
u'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_comment_end_of_line(self):
sql = ('select * from foo; -- foo\n'
'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
# make sure the comment belongs to first query
self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
def test_parenthesis(self):
s = 'select (select (x3) x2) and (y2) bar'
parsed = sqlparse.parse(s)[0]
self.ndiffAssertEqual(s, str(parsed))
self.assertEqual(len(parsed.tokens), 7)
self.assert_(isinstance(parsed.tokens[2], sql.Parenthesis))
self.assert_(isinstance(parsed.tokens[-1], sql.Identifier))
self.assertEqual(len(parsed.tokens[2].tokens), 5)
self.assert_(isinstance(parsed.tokens[2].tokens[3], sql.Identifier))
self.assert_(isinstance(parsed.tokens[2].tokens[3].tokens[0], sql.Parenthesis))
self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3)
def test_comments(self):
s = '/*\n * foo\n */ \n bar'
parsed = sqlparse.parse(s)[0]
self.ndiffAssertEqual(s, unicode(parsed))
self.assertEqual(len(parsed.tokens), 2)
def test_assignment(self):
s = 'foo := 1;'
parsed = sqlparse.parse(s)[0]
self.assertEqual(len(parsed.tokens), 1)
self.assert_(isinstance(parsed.tokens[0], sql.Assignment))
s = 'foo := 1'
parsed = sqlparse.parse(s)[0]
self.assertEqual(len(parsed.tokens), 1)
self.assert_(isinstance(parsed.tokens[0], sql.Assignment))
def test_identifier_wildcard(self):
p = sqlparse.parse('a.*, b.id')[0]
self.assert_(isinstance(p.tokens[0], sql.IdentifierList))
self.assert_(isinstance(p.tokens[0].tokens[0], sql.Identifier))
self.assert_(isinstance(p.tokens[0].tokens[-1], sql.Identifier))