chore: move sql code of User into driver (#2281)

Move SQL code of User into Driver
This commit is contained in:
Athurg Gooth 2023-09-26 18:23:45 +08:00 committed by GitHub
parent 41eba71f0f
commit fcba3ffa26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 184 additions and 137 deletions

View File

@ -7,4 +7,9 @@ type Driver interface {
UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error)
ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error)
CreateUser(ctx context.Context, create *User) (*User, error)
UpdateUser(ctx context.Context, update *UpdateUser) (*User, error)
ListUsers(ctx context.Context, find *FindUser) ([]*User, error)
DeleteUser(ctx context.Context, delete *DeleteUser) error
} }

172
store/sqlite3/user.go Normal file
View File

@ -0,0 +1,172 @@
package sqlite3
import (
"context"
"strings"
"github.com/usememos/memos/store"
)
func (d *Driver) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
stmt := `
INSERT INTO user (
username,
role,
email,
nickname,
password_hash
)
VALUES (?, ?, ?, ?, ?)
RETURNING id, avatar_url, created_ts, updated_ts, row_status
`
if err := d.db.QueryRowContext(
ctx,
stmt,
create.Username,
create.Role,
create.Email,
create.Nickname,
create.PasswordHash,
).Scan(
&create.ID,
&create.AvatarURL,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
return create, nil
}
func (d *Driver) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Username; v != nil {
set, args = append(set, "username = ?"), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = ?"), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = ?"), append(args, *v)
}
if v := update.AvatarURL; v != nil {
set, args = append(set, "avatar_url = ?"), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = ?"), append(args, *v)
}
args = append(args, update.ID)
query := `
UPDATE user
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status
`
user := &store.User{}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
return user, nil
}
func (d *Driver) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.Username; v != nil {
where, args = append(where, "username = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
query := `
SELECT
id,
username,
role,
email,
nickname,
password_hash,
avatar_url,
created_ts,
updated_ts,
row_status
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.User, 0)
for rows.Next() {
var user store.User
if err := rows.Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
list = append(list, &user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *Driver) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
result, err := d.db.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@ -2,7 +2,6 @@ package store
import ( import (
"context" "context"
"strings"
) )
// Role is the type of a role. // Role is the type of a role.
@ -74,84 +73,18 @@ type DeleteUser struct {
} }
func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
stmt := ` user, err := s.driver.CreateUser(ctx, create)
INSERT INTO user ( if err != nil {
username,
role,
email,
nickname,
password_hash
)
VALUES (?, ?, ?, ?, ?)
RETURNING id, avatar_url, created_ts, updated_ts, row_status
`
if err := s.db.QueryRowContext(
ctx,
stmt,
create.Username,
create.Role,
create.Email,
create.Nickname,
create.PasswordHash,
).Scan(
&create.ID,
&create.AvatarURL,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err return nil, err
} }
user := create
s.userCache.Store(user.ID, user) s.userCache.Store(user.ID, user)
return user, nil return user, nil
} }
func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
set, args := []string{}, []any{} user, err := s.driver.UpdateUser(ctx, update)
if v := update.UpdatedTs; v != nil { if err != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Username; v != nil {
set, args = append(set, "username = ?"), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = ?"), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = ?"), append(args, *v)
}
if v := update.AvatarURL; v != nil {
set, args = append(set, "avatar_url = ?"), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = ?"), append(args, *v)
}
args = append(args, update.ID)
query := `
UPDATE user
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status
`
user := &User{}
if err := s.db.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err return nil, err
} }
@ -160,69 +93,10 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro
} }
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
where, args := []string{"1 = 1"}, []any{} list, err := s.driver.ListUsers(ctx, find)
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.Username; v != nil {
where, args = append(where, "username = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
query := `
SELECT
id,
username,
role,
email,
nickname,
password_hash,
avatar_url,
created_ts,
updated_ts,
row_status
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
list := make([]*User, 0)
for rows.Next() {
var user User
if err := rows.Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
list = append(list, &user)
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, user := range list { for _, user := range list {
s.userCache.Store(user.ID, user) s.userCache.Store(user.ID, user)
@ -251,15 +125,11 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
} }
func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error {
result, err := s.db.ExecContext(ctx, ` err := s.driver.DeleteUser(ctx, delete)
DELETE FROM user WHERE id = ?
`, delete.ID)
if err != nil { if err != nil {
return err return err
} }
if _, err := result.RowsAffected(); err != nil {
return err
}
if err := s.Vacuum(ctx); err != nil { if err := s.Vacuum(ctx); err != nil {
// Prevent linter warning. // Prevent linter warning.
return err return err