package sqlmodel_test import ( "strings" "testing" "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/testh/proj" ) func TestSplitInput(t *testing.T) { const sel, other = sqlmodel.StmtSelect, sqlmodel.StmtOther // convenience func to return a slice of n core.sqlmodel.StmtType. nTypes := func(n int, typ sqlmodel.StmtType) []sqlmodel.StmtType { types := make([]sqlmodel.StmtType, n) for i := range types { types[i] = typ } return types } // testCases key is a group of tests, e.g. "postgres" or "sqlserver" testCases := map[string][]struct { name string delim string moreDelims []string input string wantCount int wantContains []string wantTypes []sqlmodel.StmtType }{ "simple": { { name: "select_1", delim: ";", input: "select * from my_table; ", wantCount: 1, wantContains: []string{"select * from my_table"}, wantTypes: []sqlmodel.StmtType{sel}, }, { name: "select_1_no_delim", delim: ";", input: "select * from my_table", wantCount: 1, wantContains: []string{"select * from my_table"}, wantTypes: []sqlmodel.StmtType{sel}, }, { name: "select_1_many_delim", delim: ";", input: "select * from my_table;\n;;", wantCount: 1, wantContains: []string{"select * from my_table"}, wantTypes: []sqlmodel.StmtType{sel}, }, { name: "select_2", delim: ";", input: "select * from my_table;\ndrop table my_table;", wantCount: 2, wantContains: []string{"select * from my_table", "drop table my_table"}, wantTypes: []sqlmodel.StmtType{sel, other}, }, { name: "select_2_no_trailing_delim", delim: ";", input: "select * from my_table;\ndrop table my_table", wantCount: 2, wantContains: []string{"select * from my_table", "drop table my_table"}, wantTypes: []sqlmodel.StmtType{sel, other}, }, }, "sqlite3": { { name: "sqlite3_type_test.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/sqlite3_type_test.sql")), wantCount: 4, wantTypes: nTypes(4, other), }, { name: "sqlite3_address.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/sqlite3_address.sql")), wantCount: 4, wantTypes: nTypes(4, other), }, { name: "sqlite3_person.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/sqlite3_person.sql")), wantCount: 10, wantTypes: nTypes(10, other), }, }, "postgres": { { name: "sqtype_public_type_test.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/postgres_public_type_test.sql")), wantCount: 5, wantTypes: nTypes(5, other), }, { name: "sqtest_public_address.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/postgres_public_address.sql")), wantCount: 5, wantTypes: nTypes(5, other), }, { name: "sqtest_public_person.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/postgres_public_person.sql")), wantCount: 11, wantTypes: nTypes(11, other), }, }, "sqlserver": { { name: "select_1_semi_go", delim: ";", moreDelims: []string{"go"}, input: "select * from my_table;\ngo", wantCount: 1, wantContains: []string{"select * from my_table"}, wantTypes: []sqlmodel.StmtType{sel}, }, { name: "select_2", delim: ";", moreDelims: []string{"go"}, input: "select * from my_table;\ndrop table my_table;", wantCount: 2, wantContains: []string{"select * from my_table", "drop table my_table"}, wantTypes: []sqlmodel.StmtType{sel, other}, }, { name: "sqtype_dbo_type_test.sql", delim: ";", moreDelims: []string{"go"}, input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/sqlserver_dbo_type_test.sql")), wantCount: 5, wantTypes: nTypes(5, other), }, { name: "sqtest_dbo_address.sql", delim: ";", moreDelims: []string{"go"}, input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/sqlserver_dbo_address.sql")), wantCount: 4, wantTypes: nTypes(5, other), }, { name: "sqtest_dbo_person.sql", delim: ";", moreDelims: []string{"go"}, input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/sqlserver_dbo_person.sql")), wantCount: 10, wantTypes: nTypes(10, other), }, }, "mysql": { { name: "mysql_type_test.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/mysql_type_test.sql")), wantCount: 5, wantTypes: nTypes(5, other), }, { name: "mysql_address.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/mysql_address.sql")), wantCount: 4, wantTypes: nTypes(5, other), }, { name: "mysql_person.sql", delim: ";", input: string(proj.ReadFile("libsq/core/sqlmodel/testdata/mysql_person.sql")), wantCount: 9, wantTypes: nTypes(9, other), }, }, } for groupName, testGroup := range testCases { testGroup := testGroup t.Run(groupName, func(t *testing.T) { for _, tc := range testGroup { tc := tc t.Run(tc.name, func(t *testing.T) { stmts, stmtTypes, err := sqlmodel.SplitSQL(strings.NewReader(tc.input), tc.delim, tc.moreDelims...) require.NoError(t, err) require.Equal(t, tc.wantCount, len(stmts)) require.Equal(t, len(stmts), len(stmtTypes)) for i, stmtType := range stmtTypes { require.Equal(t, tc.wantTypes[i], stmtType) } for i, wantContains := range tc.wantContains { require.Contains(t, stmts[i], wantContains) } // Sanity check to verify that we've stripped trailing delims allDelims := append([]string{tc.delim}, tc.moreDelims...) for _, sep := range allDelims { for _, stmt := range stmts { require.False(t, strings.HasSuffix(stmt, sep)) } } }) } }) } }