diff --git a/Userland/Libraries/LibSQL/AST.h b/Userland/Libraries/LibSQL/AST.h index 179d4b326aa..677d94d8740 100644 --- a/Userland/Libraries/LibSQL/AST.h +++ b/Userland/Libraries/LibSQL/AST.h @@ -97,6 +97,23 @@ private: Vector m_column_names; }; +class CommonTableExpressionList : public ASTNode { +public: + CommonTableExpressionList(bool recursive, NonnullRefPtrVector common_table_expressions) + : m_recursive(recursive) + , m_common_table_expressions(move(common_table_expressions)) + { + VERIFY(!m_common_table_expressions.is_empty()); + } + + bool recursive() const { return m_recursive; } + const NonnullRefPtrVector& common_table_expressions() const { return m_common_table_expressions; } + +private: + bool m_recursive; + NonnullRefPtrVector m_common_table_expressions; +}; + class QualifiedTableName : public ASTNode { public: QualifiedTableName(String schema_name, String table_name, String alias) @@ -533,24 +550,21 @@ private: class Delete : public Statement { public: - Delete(bool recursive, RefPtr common_table_expression, NonnullRefPtr qualified_table_name, RefPtr where_clause, RefPtr returning_clause) - : m_recursive(recursive) - , m_common_table_expression(move(common_table_expression)) + Delete(RefPtr common_table_expression_list, NonnullRefPtr qualified_table_name, RefPtr where_clause, RefPtr returning_clause) + : m_common_table_expression_list(move(common_table_expression_list)) , m_qualified_table_name(move(qualified_table_name)) , m_where_clause(move(where_clause)) , m_returning_clause(move(returning_clause)) { } - bool recursive() const { return m_recursive; } - const RefPtr& common_table_expression() const { return m_common_table_expression; } + const RefPtr& common_table_expression_list() const { return m_common_table_expression_list; } const NonnullRefPtr& qualified_table_name() const { return m_qualified_table_name; } const RefPtr& where_clause() const { return m_where_clause; } const RefPtr& returning_clause() const { return m_returning_clause; } private: - bool m_recursive; - RefPtr m_common_table_expression; + RefPtr m_common_table_expression_list; NonnullRefPtr m_qualified_table_name; RefPtr m_where_clause; RefPtr m_returning_clause; diff --git a/Userland/Libraries/LibSQL/Forward.h b/Userland/Libraries/LibSQL/Forward.h index 51a0303fcd0..d0949a5ddc6 100644 --- a/Userland/Libraries/LibSQL/Forward.h +++ b/Userland/Libraries/LibSQL/Forward.h @@ -18,6 +18,7 @@ class CollateExpression; class ColumnDefinition; class ColumnNameExpression; class CommonTableExpression; +class CommonTableExpressionList; class CreateTable; class Delete; class DropTable; diff --git a/Userland/Libraries/LibSQL/Parser.cpp b/Userland/Libraries/LibSQL/Parser.cpp index 61a226dfead..c215fcedc9f 100644 --- a/Userland/Libraries/LibSQL/Parser.cpp +++ b/Userland/Libraries/LibSQL/Parser.cpp @@ -112,11 +112,20 @@ NonnullRefPtr Parser::parse_delete_statement() { // https://sqlite.org/lang_delete.html - bool recursive = false; - RefPtr common_table_expression; + RefPtr common_table_expression_list; if (consume_if(TokenType::With)) { - recursive = consume_if(TokenType::Recursive); - common_table_expression = parse_common_table_expression(); + NonnullRefPtrVector common_table_expression; + bool recursive = consume_if(TokenType::Recursive); + + do { + common_table_expression.append(parse_common_table_expression()); + if (!match(TokenType::Comma)) + break; + + consume(TokenType::Comma); + } while (!match(TokenType::Eof)); + + common_table_expression_list = create_ast_node(recursive, move(common_table_expression)); } consume(TokenType::Delete); @@ -133,7 +142,7 @@ NonnullRefPtr Parser::parse_delete_statement() consume(TokenType::SemiColon); - return create_ast_node(recursive, move(common_table_expression), move(qualified_table_name), move(where_clause), move(returning_clause)); + return create_ast_node(move(common_table_expression_list), move(qualified_table_name), move(where_clause), move(returning_clause)); } NonnullRefPtr Parser::parse_expression() diff --git a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp index 3da71ff54df..6b45a252c39 100644 --- a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp +++ b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp @@ -153,13 +153,17 @@ TEST_CASE(delete_) EXPECT(parse("WITH RECURSIVE table DELETE FROM table;").is_error()); EXPECT(parse("WITH RECURSIVE table AS DELETE FROM table;").is_error()); - struct SelectedTable { + struct SelectedTableList { + struct SelectedTable { + StringView table_name {}; + Vector column_names {}; + }; + bool recursive { false }; - StringView table_name {}; - Vector column_names {}; + Vector selected_tables {}; }; - auto validate = [](StringView sql, SelectedTable expected_selected_table, StringView expected_schema, StringView expected_table, StringView expected_alias, bool expect_where_clause, bool expect_returning_clause, Vector expected_returned_column_aliases) { + auto validate = [](StringView sql, SelectedTableList expected_selected_tables, StringView expected_schema, StringView expected_table, StringView expected_alias, bool expect_where_clause, bool expect_returning_clause, Vector expected_returned_column_aliases) { auto result = parse(sql); EXPECT(!result.is_error()); @@ -167,15 +171,24 @@ TEST_CASE(delete_) EXPECT(is(*statement)); const auto& delete_ = static_cast(*statement); - EXPECT_EQ(delete_.recursive(), expected_selected_table.recursive); - const auto& common_table_expression = delete_.common_table_expression(); - EXPECT_EQ(common_table_expression.is_null(), expected_selected_table.table_name.is_empty()); - if (common_table_expression) { - EXPECT_EQ(common_table_expression->table_name(), expected_selected_table.table_name); - EXPECT_EQ(common_table_expression->column_names().size(), expected_selected_table.column_names.size()); - for (size_t i = 0; i < common_table_expression->column_names().size(); ++i) - EXPECT_EQ(common_table_expression->column_names()[i], expected_selected_table.column_names[i]); + const auto& common_table_expression_list = delete_.common_table_expression_list(); + EXPECT_EQ(common_table_expression_list.is_null(), expected_selected_tables.selected_tables.is_empty()); + if (common_table_expression_list) { + EXPECT_EQ(common_table_expression_list->recursive(), expected_selected_tables.recursive); + + const auto& common_table_expressions = common_table_expression_list->common_table_expressions(); + EXPECT_EQ(common_table_expressions.size(), expected_selected_tables.selected_tables.size()); + + for (size_t i = 0; i < common_table_expressions.size(); ++i) { + const auto& common_table_expression = common_table_expressions[i]; + const auto& expected_common_table_expression = expected_selected_tables.selected_tables[i]; + EXPECT_EQ(common_table_expression.table_name(), expected_common_table_expression.table_name); + EXPECT_EQ(common_table_expression.column_names().size(), expected_common_table_expression.column_names.size()); + + for (size_t j = 0; j < common_table_expression.column_names().size(); ++j) + EXPECT_EQ(common_table_expression.column_names()[j], expected_common_table_expression.column_names[j]); + } } const auto& qualified_table_name = delete_.qualified_table_name(); @@ -213,10 +226,10 @@ TEST_CASE(delete_) validate("DELETE FROM table RETURNING column1 AS alias1, column2 AS alias2;", {}, {}, "table", {}, false, true, { "alias1", "alias2" }); // FIXME: When parsing of SELECT statements are supported, the common-table-expressions below will become invalid due to the empty "AS ()" clause. - validate("WITH table AS () DELETE FROM table;", { false, "table", {} }, {}, "table", {}, false, false, {}); - validate("WITH table (column) AS () DELETE FROM table;", { false, "table", { "column" } }, {}, "table", {}, false, false, {}); - validate("WITH table (column1, column2) AS () DELETE FROM table;", { false, "table", { "column1", "column2" } }, {}, "table", {}, false, false, {}); - validate("WITH RECURSIVE table AS () DELETE FROM table;", { true, "table", {} }, {}, "table", {}, false, false, {}); + validate("WITH table AS () DELETE FROM table;", { false, { { "table" } } }, {}, "table", {}, false, false, {}); + validate("WITH table (column) AS () DELETE FROM table;", { false, { { "table", { "column" } } } }, {}, "table", {}, false, false, {}); + validate("WITH table (column1, column2) AS () DELETE FROM table;", { false, { { "table", { "column1", "column2" } } } }, {}, "table", {}, false, false, {}); + validate("WITH RECURSIVE table AS () DELETE FROM table;", { true, { { "table", {} } } }, {}, "table", {}, false, false, {}); } TEST_MAIN(SqlStatementParser)