diff --git a/store/driver.go b/store/driver.go index 30f082e6..6c58da19 100644 --- a/store/driver.go +++ b/store/driver.go @@ -32,4 +32,8 @@ type Driver interface { GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error + + UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) + ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) + DeleteTag(ctx context.Context, delete *DeleteTag) error } diff --git a/store/sqlite/tag.go b/store/sqlite/tag.go new file mode 100644 index 00000000..73400e91 --- /dev/null +++ b/store/sqlite/tag.go @@ -0,0 +1,75 @@ +package sqlite + +import ( + "context" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertTag(ctx context.Context, upsert *store.Tag) (*store.Tag, error) { + stmt := ` + INSERT INTO tag ( + name, creator_id + ) + VALUES (?, ?) + ON CONFLICT(name, creator_id) DO UPDATE + SET + name = EXCLUDED.name + ` + if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.CreatorID); err != nil { + return nil, err + } + + tag := upsert + return tag, nil +} + +func (d *Driver) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, error) { + where, args := []string{"creator_id = ?"}, []any{find.CreatorID} + query := ` + SELECT + name, + creator_id + FROM tag + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY name ASC + ` + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Tag{} + for rows.Next() { + tag := &store.Tag{} + if err := rows.Scan( + &tag.Name, + &tag.CreatorID, + ); err != nil { + return nil, err + } + + list = append(list, tag) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) DeleteTag(ctx context.Context, delete *store.DeleteTag) error { + where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID} + stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +} diff --git a/store/tag.go b/store/tag.go index 713a076e..afea0739 100644 --- a/store/tag.go +++ b/store/tag.go @@ -3,7 +3,6 @@ package store import ( "context" "database/sql" - "strings" ) type Tag struct { @@ -21,70 +20,15 @@ type DeleteTag struct { } func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { - stmt := ` - INSERT INTO tag ( - name, creator_id - ) - VALUES (?, ?) - ON CONFLICT(name, creator_id) DO UPDATE - SET - name = EXCLUDED.name - ` - if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.CreatorID); err != nil { - return nil, err - } - - tag := upsert - return tag, nil + return s.driver.UpsertTag(ctx, upsert) } func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { - where, args := []string{"creator_id = ?"}, []any{find.CreatorID} - query := ` - SELECT - name, - creator_id - FROM tag - WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY name ASC - ` - rows, err := s.db.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - list := []*Tag{} - for rows.Next() { - tag := &Tag{} - if err := rows.Scan( - &tag.Name, - &tag.CreatorID, - ); err != nil { - return nil, err - } - - list = append(list, tag) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return list, nil + return s.driver.ListTags(ctx, find) } func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error { - where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID} - stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") - result, err := s.db.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - if _, err = result.RowsAffected(); err != nil { - return err - } - return nil + return s.driver.DeleteTag(ctx, delete) } func vacuumTag(ctx context.Context, tx *sql.Tx) error { diff --git a/test/store/tag_test.go b/test/store/tag_test.go new file mode 100644 index 00000000..0b0624cd --- /dev/null +++ b/test/store/tag_test.go @@ -0,0 +1,38 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/store" +) + +func TestTagStore(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + tag, err := ts.UpsertTag(ctx, &store.Tag{ + CreatorID: user.ID, + Name: "test_tag", + }) + require.NoError(t, err) + require.Equal(t, "test_tag", tag.Name) + require.Equal(t, user.ID, tag.CreatorID) + tags, err := ts.ListTags(ctx, &store.FindTag{ + CreatorID: user.ID, + }) + require.NoError(t, err) + require.Equal(t, 1, len(tags)) + require.Equal(t, tag, tags[0]) + err = ts.DeleteTag(ctx, &store.DeleteTag{ + Name: "test_tag", + CreatorID: user.ID, + }) + require.NoError(t, err) + tags, err = ts.ListTags(ctx, &store.FindTag{}) + require.NoError(t, err) + require.Equal(t, 0, len(tags)) +}