mirror of
https://github.com/usememos/memos.git
synced 2024-12-25 04:13:07 +03:00
chore: update user store names (#1877)
* chore: update user store names * chore: update
This commit is contained in:
parent
ca770c87d6
commit
9a8d43bf88
@ -40,7 +40,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &signin.Username,
|
||||
})
|
||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||
@ -111,14 +111,14 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &userInfo.Identifier,
|
||||
})
|
||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
|
||||
}
|
||||
if user == nil {
|
||||
userCreate := &store.UserMessage{
|
||||
userCreate := &store.User{
|
||||
Username: userInfo.Identifier,
|
||||
// The new signup user should be normal user by default.
|
||||
Role: store.NormalUser,
|
||||
@ -161,14 +161,14 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
||||
}
|
||||
|
||||
hostUserType := store.Host
|
||||
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUserMessage{
|
||||
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
|
||||
Role: &hostUserType,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err)
|
||||
}
|
||||
|
||||
userCreate := &store.UserMessage{
|
||||
userCreate := &store.User{
|
||||
Username: signup.Username,
|
||||
// The new signup user should be normal user by default.
|
||||
Role: store.NormalUser,
|
||||
@ -224,7 +224,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.UserMessage) error {
|
||||
func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User) error {
|
||||
ctx := c.Request().Context()
|
||||
payload := ActivityUserAuthSignInPayload{
|
||||
UserID: user.ID,
|
||||
@ -246,7 +246,7 @@ func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.UserMessage) error {
|
||||
func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User) error {
|
||||
ctx := c.Request().Context()
|
||||
payload := ActivityUserAuthSignUpPayload{
|
||||
Username: user.Username,
|
||||
|
@ -68,7 +68,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -102,7 +102,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -147,7 +147,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||
isHostUser := false
|
||||
if ok {
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -177,7 +177,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -211,7 +211,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -141,7 +141,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
|
||||
}
|
||||
|
||||
// Even if there is no error, we still need to make sure the user still exists.
|
||||
user, err := server.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := server.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -222,7 +222,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool {
|
||||
// If there is openId in query string and related user is found, then skip auth.
|
||||
openID := c.QueryParam("openId")
|
||||
if openID != "" {
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
OpenID: &openID,
|
||||
})
|
||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||
|
1
go.mod
1
go.mod
@ -78,7 +78,6 @@ require (
|
||||
github.com/spf13/cast v1.5.0 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/stretchr/objx v0.5.0 // indirect
|
||||
github.com/subosito/gotenv v1.4.2 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.1 // indirect
|
||||
|
1
go.sum
1
go.sum
@ -248,7 +248,6 @@ github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU=
|
||||
github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
|
@ -64,7 +64,7 @@ func GenerateRefreshToken(userName string, userID int, secret string) (string, e
|
||||
}
|
||||
|
||||
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
|
||||
func GenerateTokensAndSetCookies(c echo.Context, user *store.UserMessage, secret string) error {
|
||||
func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error {
|
||||
accessToken, err := GenerateAccessToken(user.Username, user.ID, secret)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to generate access token")
|
||||
|
@ -141,7 +141,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
||||
}
|
||||
|
||||
// Even if there is no error, we still need to make sure the user still exists.
|
||||
user, err := server.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
user, err := server.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -9,30 +9,19 @@ import (
|
||||
|
||||
"github.com/usememos/memos/api"
|
||||
"github.com/usememos/memos/common"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func Execute(
|
||||
ctx context.Context,
|
||||
store store,
|
||||
hostUsername, hostPassword string,
|
||||
) error {
|
||||
func Execute(ctx context.Context, store *store.Store, hostUsername, hostPassword string) error {
|
||||
s := setupService{store: store}
|
||||
return s.Setup(ctx, hostUsername, hostPassword)
|
||||
}
|
||||
|
||||
type store interface {
|
||||
FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error)
|
||||
CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error)
|
||||
}
|
||||
|
||||
type setupService struct {
|
||||
store store
|
||||
store *store.Store
|
||||
}
|
||||
|
||||
func (s setupService) Setup(
|
||||
ctx context.Context,
|
||||
hostUsername, hostPassword string,
|
||||
) error {
|
||||
func (s setupService) Setup(ctx context.Context, hostUsername, hostPassword string) error {
|
||||
if err := s.makeSureHostUserNotExists(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -59,31 +48,47 @@ func (s setupService) makeSureHostUserNotExists(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s setupService) createUser(
|
||||
ctx context.Context,
|
||||
hostUsername, hostPassword string,
|
||||
) error {
|
||||
userCreate := &api.UserCreate{
|
||||
func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword string) error {
|
||||
userCreate := &store.User{
|
||||
Username: hostUsername,
|
||||
// The new signup user should be normal user by default.
|
||||
Role: api.Host,
|
||||
Role: store.Host,
|
||||
Nickname: hostUsername,
|
||||
Password: hostPassword,
|
||||
OpenID: common.GenUUID(),
|
||||
}
|
||||
|
||||
if err := userCreate.Validate(); err != nil {
|
||||
return fmt.Errorf("validate: %w", err)
|
||||
if len(userCreate.Username) < 3 {
|
||||
return fmt.Errorf("username is too short, minimum length is 3")
|
||||
}
|
||||
if len(userCreate.Username) > 32 {
|
||||
return fmt.Errorf("username is too long, maximum length is 32")
|
||||
}
|
||||
if len(hostPassword) < 3 {
|
||||
return fmt.Errorf("password is too short, minimum length is 3")
|
||||
}
|
||||
if len(hostPassword) > 512 {
|
||||
return fmt.Errorf("password is too long, maximum length is 512")
|
||||
}
|
||||
if len(userCreate.Nickname) > 64 {
|
||||
return fmt.Errorf("nickname is too long, maximum length is 64")
|
||||
}
|
||||
if userCreate.Email != "" {
|
||||
if len(userCreate.Email) > 256 {
|
||||
return fmt.Errorf("email is too long, maximum length is 256")
|
||||
}
|
||||
if !common.ValidateEmail(userCreate.Email) {
|
||||
return fmt.Errorf("invalid email format")
|
||||
}
|
||||
}
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
userCreate.PasswordHash = string(passwordHash)
|
||||
if _, err := s.store.CreateUser(ctx, userCreate); err != nil {
|
||||
return fmt.Errorf("create user: %w", err)
|
||||
if _, err := s.store.CreateUserV1(ctx, userCreate); err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1,181 +0,0 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/usememos/memos/api"
|
||||
)
|
||||
|
||||
func TestSetupServiceMakeSureHostUserNotExists(t *testing.T) {
|
||||
cc := map[string]struct {
|
||||
setupStore func(*storeMock)
|
||||
expectedErr string
|
||||
}{
|
||||
"failed to get list": {
|
||||
setupStore: func(m *storeMock) {
|
||||
hostUserType := api.Host
|
||||
m.
|
||||
On("FindUserList", mock.Anything, &api.UserFind{
|
||||
Role: &hostUserType,
|
||||
}).
|
||||
Return(nil, errors.New("fake error"))
|
||||
},
|
||||
expectedErr: "find user list: fake error",
|
||||
},
|
||||
"success, not empty": {
|
||||
setupStore: func(m *storeMock) {
|
||||
hostUserType := api.Host
|
||||
m.
|
||||
On("FindUserList", mock.Anything, &api.UserFind{
|
||||
Role: &hostUserType,
|
||||
}).
|
||||
Return([]*api.User{
|
||||
{},
|
||||
}, nil)
|
||||
},
|
||||
expectedErr: "host user already exists",
|
||||
},
|
||||
"success, empty": {
|
||||
setupStore: func(m *storeMock) {
|
||||
hostUserType := api.Host
|
||||
m.
|
||||
On("FindUserList", mock.Anything, &api.UserFind{
|
||||
Role: &hostUserType,
|
||||
}).
|
||||
Return(nil, nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for n, c := range cc {
|
||||
c := c
|
||||
t.Run(n, func(t *testing.T) {
|
||||
sm := newStoreMock(t)
|
||||
if c.setupStore != nil {
|
||||
c.setupStore(sm)
|
||||
}
|
||||
|
||||
srv := setupService{store: sm}
|
||||
err := srv.makeSureHostUserNotExists(context.Background())
|
||||
if c.expectedErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.EqualError(t, err, c.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupServiceCreateUser(t *testing.T) {
|
||||
expectedCreated := &api.UserCreate{
|
||||
Username: "demohero",
|
||||
Role: api.Host,
|
||||
Nickname: "demohero",
|
||||
Password: "123456",
|
||||
}
|
||||
|
||||
userCreateMatcher := mock.MatchedBy(func(arg *api.UserCreate) bool {
|
||||
return arg.Username == expectedCreated.Username &&
|
||||
arg.Role == expectedCreated.Role &&
|
||||
arg.Nickname == expectedCreated.Nickname &&
|
||||
arg.Password == expectedCreated.Password &&
|
||||
arg.PasswordHash != ""
|
||||
})
|
||||
|
||||
cc := map[string]struct {
|
||||
setupStore func(*storeMock)
|
||||
hostUsername, hostPassword string
|
||||
expectedErr string
|
||||
}{
|
||||
`username == "", password == ""`: {
|
||||
expectedErr: "validate: username is too short, minimum length is 3",
|
||||
},
|
||||
`username == "", password != ""`: {
|
||||
hostPassword: expectedCreated.Password,
|
||||
expectedErr: "validate: username is too short, minimum length is 3",
|
||||
},
|
||||
`username != "", password == ""`: {
|
||||
hostUsername: expectedCreated.Username,
|
||||
expectedErr: "validate: password is too short, minimum length is 3",
|
||||
},
|
||||
"failed to create": {
|
||||
setupStore: func(m *storeMock) {
|
||||
m.
|
||||
On("CreateUser", mock.Anything, userCreateMatcher).
|
||||
Return(nil, errors.New("fake error"))
|
||||
},
|
||||
hostUsername: expectedCreated.Username,
|
||||
hostPassword: expectedCreated.Password,
|
||||
expectedErr: "create user: fake error",
|
||||
},
|
||||
"success": {
|
||||
setupStore: func(m *storeMock) {
|
||||
m.
|
||||
On("CreateUser", mock.Anything, userCreateMatcher).
|
||||
Return(nil, nil)
|
||||
},
|
||||
hostUsername: expectedCreated.Username,
|
||||
hostPassword: expectedCreated.Password,
|
||||
},
|
||||
}
|
||||
|
||||
for n, c := range cc {
|
||||
c := c
|
||||
t.Run(n, func(t *testing.T) {
|
||||
sm := newStoreMock(t)
|
||||
if c.setupStore != nil {
|
||||
c.setupStore(sm)
|
||||
}
|
||||
|
||||
srv := setupService{store: sm}
|
||||
err := srv.createUser(context.Background(), c.hostUsername, c.hostPassword)
|
||||
if c.expectedErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.EqualError(t, err, c.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type storeMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *storeMock) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
|
||||
ret := m.Called(ctx, find)
|
||||
|
||||
var u []*api.User
|
||||
ret1 := ret.Get(0)
|
||||
if ret1 != nil {
|
||||
u = ret1.([]*api.User)
|
||||
}
|
||||
|
||||
return u, ret.Error(1)
|
||||
}
|
||||
|
||||
func (m *storeMock) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
|
||||
ret := m.Called(ctx, create)
|
||||
|
||||
var u *api.User
|
||||
ret1 := ret.Get(0)
|
||||
if ret1 != nil {
|
||||
u = ret1.(*api.User)
|
||||
}
|
||||
|
||||
return u, ret.Error(1)
|
||||
}
|
||||
|
||||
func newStoreMock(t *testing.T) *storeMock {
|
||||
m := &storeMock{}
|
||||
m.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { m.AssertExpectations(t) })
|
||||
|
||||
return m
|
||||
}
|
@ -14,6 +14,7 @@ type Store struct {
|
||||
db *sql.DB
|
||||
systemSettingCache sync.Map // map[string]*systemSettingRaw
|
||||
userCache sync.Map // map[int]*userRaw
|
||||
userV1Cache sync.Map // map[string]*User
|
||||
userSettingCache sync.Map // map[string]*UserSetting
|
||||
shortcutCache sync.Map // map[int]*shortcutRaw
|
||||
idpCache sync.Map // map[int]*IdentityProvider
|
||||
|
160
store/user.go
160
store/user.go
@ -34,7 +34,7 @@ func (e Role) String() string {
|
||||
return "USER"
|
||||
}
|
||||
|
||||
type UserMessage struct {
|
||||
type User struct {
|
||||
ID int
|
||||
|
||||
// Standard fields
|
||||
@ -52,7 +52,22 @@ type UserMessage struct {
|
||||
AvatarURL string
|
||||
}
|
||||
|
||||
type FindUserMessage struct {
|
||||
type UpdateUser struct {
|
||||
ID int
|
||||
|
||||
UpdatedTs *int64
|
||||
RowStatus *RowStatus
|
||||
Username *string `json:"username"`
|
||||
Email *string `json:"email"`
|
||||
Nickname *string `json:"nickname"`
|
||||
Password *string `json:"password"`
|
||||
ResetOpenID *bool `json:"resetOpenId"`
|
||||
AvatarURL *string `json:"avatarUrl"`
|
||||
PasswordHash *string
|
||||
OpenID *string
|
||||
}
|
||||
|
||||
type FindUser struct {
|
||||
ID *int
|
||||
|
||||
// Standard fields
|
||||
@ -66,10 +81,10 @@ type FindUserMessage struct {
|
||||
OpenID *string
|
||||
}
|
||||
|
||||
func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) {
|
||||
func (s *Store) CreateUserV1(ctx context.Context, create *User) (*User, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
@ -99,19 +114,85 @@ func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMes
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
userMessage := create
|
||||
return userMessage, nil
|
||||
user := create
|
||||
s.userV1Cache.Store(user.ID, user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) {
|
||||
func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "avatar_url = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.OpenID; v != nil {
|
||||
set, args = append(set, "open_id = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := `
|
||||
UPDATE user
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
|
||||
`
|
||||
user := &User{}
|
||||
if err := tx.QueryRowContext(ctx, query, args...).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.OpenID,
|
||||
&user.AvatarURL,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.userV1Cache.Store(user.ID, user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
@ -120,13 +201,22 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, user := range list {
|
||||
s.userV1Cache.Store(user.ID, user)
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) {
|
||||
func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
|
||||
if find.ID != nil {
|
||||
if user, ok := s.userV1Cache.Load(*find.ID); ok {
|
||||
return user.(*User), nil
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
@ -135,14 +225,14 @@ func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessag
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memoMessage := list[0]
|
||||
return memoMessage, nil
|
||||
user := list[0]
|
||||
s.userV1Cache.Store(user.ID, user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) {
|
||||
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
@ -183,36 +273,36 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserM
|
||||
`
|
||||
rows, err := tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userMessageList := make([]*UserMessage, 0)
|
||||
list := make([]*User, 0)
|
||||
for rows.Next() {
|
||||
var userMessage UserMessage
|
||||
var user User
|
||||
if err := rows.Scan(
|
||||
&userMessage.ID,
|
||||
&userMessage.Username,
|
||||
&userMessage.Role,
|
||||
&userMessage.Email,
|
||||
&userMessage.Nickname,
|
||||
&userMessage.PasswordHash,
|
||||
&userMessage.OpenID,
|
||||
&userMessage.AvatarURL,
|
||||
&userMessage.CreatedTs,
|
||||
&userMessage.UpdatedTs,
|
||||
&userMessage.RowStatus,
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.OpenID,
|
||||
&user.AvatarURL,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
userMessageList = append(userMessageList, &userMessage)
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, FormatError(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userMessageList, nil
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// userRaw is the store model for an User.
|
||||
|
@ -12,45 +12,44 @@ import (
|
||||
|
||||
func TestUserStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, store)
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
users, err := store.FindUserList(ctx, &api.UserFind{})
|
||||
users, err := ts.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(users))
|
||||
require.Equal(t, api.Host, users[0].Role)
|
||||
require.Equal(t, store.Host, users[0].Role)
|
||||
require.Equal(t, user, users[0])
|
||||
userPatchNickname := "test_nickname_2"
|
||||
userPatch := &api.UserPatch{
|
||||
userPatch := &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
Nickname: &userPatchNickname,
|
||||
}
|
||||
user, err = store.PatchUser(ctx, userPatch)
|
||||
user, err = ts.UpdateUser(ctx, userPatch)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userPatchNickname, user.Nickname)
|
||||
err = store.DeleteUser(ctx, &api.UserDelete{
|
||||
err = ts.DeleteUser(ctx, &api.UserDelete{
|
||||
ID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
users, err = store.FindUserList(ctx, &api.UserFind{})
|
||||
users, err = ts.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(users))
|
||||
}
|
||||
|
||||
func createTestingHostUser(ctx context.Context, store *store.Store) (*api.User, error) {
|
||||
userCreate := &api.UserCreate{
|
||||
func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) {
|
||||
userCreate := &store.User{
|
||||
Username: "test",
|
||||
Role: api.Host,
|
||||
Role: store.Host,
|
||||
Email: "test@test.com",
|
||||
Nickname: "test_nickname",
|
||||
Password: "test_password",
|
||||
OpenID: "test_open_id",
|
||||
}
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userCreate.Password), bcrypt.DefaultCost)
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte("test_password"), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userCreate.PasswordHash = string(passwordHash)
|
||||
user, err := store.CreateUser(ctx, userCreate)
|
||||
user, err := ts.CreateUserV1(ctx, userCreate)
|
||||
return user, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user