chore: update user store names (#1877)

* chore: update user store names

* chore: update
This commit is contained in:
boojack 2023-07-02 14:27:23 +08:00 committed by GitHub
parent ca770c87d6
commit 9a8d43bf88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 187 additions and 275 deletions

View File

@ -40,7 +40,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
} }
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &signin.Username, Username: &signin.Username,
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil && common.ErrorCode(err) != common.NotFound {
@ -111,14 +111,14 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
} }
} }
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier, Username: &userInfo.Identifier,
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again") return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
} }
if user == nil { if user == nil {
userCreate := &store.UserMessage{ userCreate := &store.User{
Username: userInfo.Identifier, Username: userInfo.Identifier,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.NormalUser, Role: store.NormalUser,
@ -161,14 +161,14 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
} }
hostUserType := store.Host hostUserType := store.Host
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUserMessage{ existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
Role: &hostUserType, Role: &hostUserType,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err)
} }
userCreate := &store.UserMessage{ userCreate := &store.User{
Username: signup.Username, Username: signup.Username,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.NormalUser, Role: store.NormalUser,
@ -224,7 +224,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
}) })
} }
func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.UserMessage) error { func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User) error {
ctx := c.Request().Context() ctx := c.Request().Context()
payload := ActivityUserAuthSignInPayload{ payload := ActivityUserAuthSignInPayload{
UserID: user.ID, UserID: user.ID,
@ -246,7 +246,7 @@ func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User
return err return err
} }
func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.UserMessage) error { func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User) error {
ctx := c.Request().Context() ctx := c.Request().Context()
payload := ActivityUserAuthSignUpPayload{ payload := ActivityUserAuthSignUpPayload{
Username: user.Username, Username: user.Username,

View File

@ -68,7 +68,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
@ -102,7 +102,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
@ -147,7 +147,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
isHostUser := false isHostUser := false
if ok { if ok {
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
@ -177,7 +177,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
@ -211,7 +211,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {

View File

@ -141,7 +141,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
} }
// Even if there is no error, we still need to make sure the user still exists. // Even if there is no error, we still need to make sure the user still exists.
user, err := server.Store.GetUser(ctx, &store.FindUserMessage{ user, err := server.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
@ -222,7 +222,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool {
// If there is openId in query string and related user is found, then skip auth. // If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId") openID := c.QueryParam("openId")
if openID != "" { if openID != "" {
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ user, err := s.Store.GetUser(ctx, &store.FindUser{
OpenID: &openID, OpenID: &openID,
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil && common.ErrorCode(err) != common.NotFound {

1
go.mod
View File

@ -78,7 +78,6 @@ require (
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/subosito/gotenv v1.4.2 // indirect github.com/subosito/gotenv v1.4.2 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.1 // indirect github.com/valyala/fasttemplate v1.2.1 // indirect

1
go.sum
View File

@ -248,7 +248,6 @@ github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU=
github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=

View File

@ -64,7 +64,7 @@ func GenerateRefreshToken(userName string, userID int, secret string) (string, e
} }
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. // GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
func GenerateTokensAndSetCookies(c echo.Context, user *store.UserMessage, secret string) error { func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error {
accessToken, err := GenerateAccessToken(user.Username, user.ID, secret) accessToken, err := GenerateAccessToken(user.Username, user.ID, secret)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to generate access token") return errors.Wrap(err, "failed to generate access token")

View File

@ -141,7 +141,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
} }
// Even if there is no error, we still need to make sure the user still exists. // Even if there is no error, we still need to make sure the user still exists.
user, err := server.Store.GetUser(ctx, &store.FindUserMessage{ user, err := server.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {

View File

@ -9,30 +9,19 @@ import (
"github.com/usememos/memos/api" "github.com/usememos/memos/api"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store"
) )
func Execute( func Execute(ctx context.Context, store *store.Store, hostUsername, hostPassword string) error {
ctx context.Context,
store store,
hostUsername, hostPassword string,
) error {
s := setupService{store: store} s := setupService{store: store}
return s.Setup(ctx, hostUsername, hostPassword) return s.Setup(ctx, hostUsername, hostPassword)
} }
type store interface {
FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error)
CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error)
}
type setupService struct { type setupService struct {
store store store *store.Store
} }
func (s setupService) Setup( func (s setupService) Setup(ctx context.Context, hostUsername, hostPassword string) error {
ctx context.Context,
hostUsername, hostPassword string,
) error {
if err := s.makeSureHostUserNotExists(ctx); err != nil { if err := s.makeSureHostUserNotExists(ctx); err != nil {
return err return err
} }
@ -59,31 +48,47 @@ func (s setupService) makeSureHostUserNotExists(ctx context.Context) error {
return nil return nil
} }
func (s setupService) createUser( func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword string) error {
ctx context.Context, userCreate := &store.User{
hostUsername, hostPassword string,
) error {
userCreate := &api.UserCreate{
Username: hostUsername, Username: hostUsername,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: api.Host, Role: store.Host,
Nickname: hostUsername, Nickname: hostUsername,
Password: hostPassword,
OpenID: common.GenUUID(), OpenID: common.GenUUID(),
} }
if err := userCreate.Validate(); err != nil { if len(userCreate.Username) < 3 {
return fmt.Errorf("validate: %w", err) return fmt.Errorf("username is too short, minimum length is 3")
}
if len(userCreate.Username) > 32 {
return fmt.Errorf("username is too long, maximum length is 32")
}
if len(hostPassword) < 3 {
return fmt.Errorf("password is too short, minimum length is 3")
}
if len(hostPassword) > 512 {
return fmt.Errorf("password is too long, maximum length is 512")
}
if len(userCreate.Nickname) > 64 {
return fmt.Errorf("nickname is too long, maximum length is 64")
}
if userCreate.Email != "" {
if len(userCreate.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(userCreate.Email) {
return fmt.Errorf("invalid email format")
}
} }
passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost) passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost)
if err != nil { if err != nil {
return fmt.Errorf("hash password: %w", err) return fmt.Errorf("failed to hash password: %w", err)
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
if _, err := s.store.CreateUser(ctx, userCreate); err != nil { if _, err := s.store.CreateUserV1(ctx, userCreate); err != nil {
return fmt.Errorf("create user: %w", err) return fmt.Errorf("failed to create user: %w", err)
} }
return nil return nil

View File

@ -1,181 +0,0 @@
package setup
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/usememos/memos/api"
)
func TestSetupServiceMakeSureHostUserNotExists(t *testing.T) {
cc := map[string]struct {
setupStore func(*storeMock)
expectedErr string
}{
"failed to get list": {
setupStore: func(m *storeMock) {
hostUserType := api.Host
m.
On("FindUserList", mock.Anything, &api.UserFind{
Role: &hostUserType,
}).
Return(nil, errors.New("fake error"))
},
expectedErr: "find user list: fake error",
},
"success, not empty": {
setupStore: func(m *storeMock) {
hostUserType := api.Host
m.
On("FindUserList", mock.Anything, &api.UserFind{
Role: &hostUserType,
}).
Return([]*api.User{
{},
}, nil)
},
expectedErr: "host user already exists",
},
"success, empty": {
setupStore: func(m *storeMock) {
hostUserType := api.Host
m.
On("FindUserList", mock.Anything, &api.UserFind{
Role: &hostUserType,
}).
Return(nil, nil)
},
},
}
for n, c := range cc {
c := c
t.Run(n, func(t *testing.T) {
sm := newStoreMock(t)
if c.setupStore != nil {
c.setupStore(sm)
}
srv := setupService{store: sm}
err := srv.makeSureHostUserNotExists(context.Background())
if c.expectedErr == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, c.expectedErr)
}
})
}
}
func TestSetupServiceCreateUser(t *testing.T) {
expectedCreated := &api.UserCreate{
Username: "demohero",
Role: api.Host,
Nickname: "demohero",
Password: "123456",
}
userCreateMatcher := mock.MatchedBy(func(arg *api.UserCreate) bool {
return arg.Username == expectedCreated.Username &&
arg.Role == expectedCreated.Role &&
arg.Nickname == expectedCreated.Nickname &&
arg.Password == expectedCreated.Password &&
arg.PasswordHash != ""
})
cc := map[string]struct {
setupStore func(*storeMock)
hostUsername, hostPassword string
expectedErr string
}{
`username == "", password == ""`: {
expectedErr: "validate: username is too short, minimum length is 3",
},
`username == "", password != ""`: {
hostPassword: expectedCreated.Password,
expectedErr: "validate: username is too short, minimum length is 3",
},
`username != "", password == ""`: {
hostUsername: expectedCreated.Username,
expectedErr: "validate: password is too short, minimum length is 3",
},
"failed to create": {
setupStore: func(m *storeMock) {
m.
On("CreateUser", mock.Anything, userCreateMatcher).
Return(nil, errors.New("fake error"))
},
hostUsername: expectedCreated.Username,
hostPassword: expectedCreated.Password,
expectedErr: "create user: fake error",
},
"success": {
setupStore: func(m *storeMock) {
m.
On("CreateUser", mock.Anything, userCreateMatcher).
Return(nil, nil)
},
hostUsername: expectedCreated.Username,
hostPassword: expectedCreated.Password,
},
}
for n, c := range cc {
c := c
t.Run(n, func(t *testing.T) {
sm := newStoreMock(t)
if c.setupStore != nil {
c.setupStore(sm)
}
srv := setupService{store: sm}
err := srv.createUser(context.Background(), c.hostUsername, c.hostPassword)
if c.expectedErr == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, c.expectedErr)
}
})
}
}
type storeMock struct {
mock.Mock
}
func (m *storeMock) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
ret := m.Called(ctx, find)
var u []*api.User
ret1 := ret.Get(0)
if ret1 != nil {
u = ret1.([]*api.User)
}
return u, ret.Error(1)
}
func (m *storeMock) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
ret := m.Called(ctx, create)
var u *api.User
ret1 := ret.Get(0)
if ret1 != nil {
u = ret1.(*api.User)
}
return u, ret.Error(1)
}
func newStoreMock(t *testing.T) *storeMock {
m := &storeMock{}
m.Mock.Test(t)
t.Cleanup(func() { m.AssertExpectations(t) })
return m
}

View File

@ -14,6 +14,7 @@ type Store struct {
db *sql.DB db *sql.DB
systemSettingCache sync.Map // map[string]*systemSettingRaw systemSettingCache sync.Map // map[string]*systemSettingRaw
userCache sync.Map // map[int]*userRaw userCache sync.Map // map[int]*userRaw
userV1Cache sync.Map // map[string]*User
userSettingCache sync.Map // map[string]*UserSetting userSettingCache sync.Map // map[string]*UserSetting
shortcutCache sync.Map // map[int]*shortcutRaw shortcutCache sync.Map // map[int]*shortcutRaw
idpCache sync.Map // map[int]*IdentityProvider idpCache sync.Map // map[int]*IdentityProvider

View File

@ -34,7 +34,7 @@ func (e Role) String() string {
return "USER" return "USER"
} }
type UserMessage struct { type User struct {
ID int ID int
// Standard fields // Standard fields
@ -52,7 +52,22 @@ type UserMessage struct {
AvatarURL string AvatarURL string
} }
type FindUserMessage struct { type UpdateUser struct {
ID int
UpdatedTs *int64
RowStatus *RowStatus
Username *string `json:"username"`
Email *string `json:"email"`
Nickname *string `json:"nickname"`
Password *string `json:"password"`
ResetOpenID *bool `json:"resetOpenId"`
AvatarURL *string `json:"avatarUrl"`
PasswordHash *string
OpenID *string
}
type FindUser struct {
ID *int ID *int
// Standard fields // Standard fields
@ -66,10 +81,10 @@ type FindUserMessage struct {
OpenID *string OpenID *string
} }
func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) { func (s *Store) CreateUserV1(ctx context.Context, create *User) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -99,19 +114,85 @@ func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMes
&create.UpdatedTs, &create.UpdatedTs,
&create.RowStatus, &create.RowStatus,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return nil, FormatError(err) return nil, err
} }
userMessage := create user := create
return userMessage, nil s.userV1Cache.Store(user.ID, user)
return user, nil
} }
func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) { func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Username; v != nil {
set, args = append(set, "username = ?"), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = ?"), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = ?"), append(args, *v)
}
if v := update.AvatarURL; v != nil {
set, args = append(set, "avatar_url = ?"), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = ?"), append(args, *v)
}
if v := update.OpenID; v != nil {
set, args = append(set, "open_id = ?"), append(args, *v)
}
args = append(args, update.ID)
query := `
UPDATE user
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
`
user := &User{}
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.OpenID,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
s.userV1Cache.Store(user.ID, user)
return user, nil
}
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -120,13 +201,22 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMe
return nil, err return nil, err
} }
for _, user := range list {
s.userV1Cache.Store(user.ID, user)
}
return list, nil return list, nil
} }
func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) { func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
if find.ID != nil {
if user, ok := s.userV1Cache.Load(*find.ID); ok {
return user.(*User), nil
}
}
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -135,14 +225,14 @@ func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessag
return nil, err return nil, err
} }
if len(list) == 0 { if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")} return nil, nil
} }
user := list[0]
memoMessage := list[0] s.userV1Cache.Store(user.ID, user)
return memoMessage, nil return user, nil
} }
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) { func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
@ -183,36 +273,36 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserM
` `
rows, err := tx.QueryContext(ctx, query, args...) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer rows.Close() defer rows.Close()
userMessageList := make([]*UserMessage, 0) list := make([]*User, 0)
for rows.Next() { for rows.Next() {
var userMessage UserMessage var user User
if err := rows.Scan( if err := rows.Scan(
&userMessage.ID, &user.ID,
&userMessage.Username, &user.Username,
&userMessage.Role, &user.Role,
&userMessage.Email, &user.Email,
&userMessage.Nickname, &user.Nickname,
&userMessage.PasswordHash, &user.PasswordHash,
&userMessage.OpenID, &user.OpenID,
&userMessage.AvatarURL, &user.AvatarURL,
&userMessage.CreatedTs, &user.CreatedTs,
&userMessage.UpdatedTs, &user.UpdatedTs,
&userMessage.RowStatus, &user.RowStatus,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
userMessageList = append(userMessageList, &userMessage) list = append(list, &user)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, FormatError(err) return nil, err
} }
return userMessageList, nil return list, nil
} }
// userRaw is the store model for an User. // userRaw is the store model for an User.

View File

@ -12,45 +12,44 @@ import (
func TestUserStore(t *testing.T) { func TestUserStore(t *testing.T) {
ctx := context.Background() ctx := context.Background()
store := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, store) user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err) require.NoError(t, err)
users, err := store.FindUserList(ctx, &api.UserFind{}) users, err := ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(users)) require.Equal(t, 1, len(users))
require.Equal(t, api.Host, users[0].Role) require.Equal(t, store.Host, users[0].Role)
require.Equal(t, user, users[0]) require.Equal(t, user, users[0])
userPatchNickname := "test_nickname_2" userPatchNickname := "test_nickname_2"
userPatch := &api.UserPatch{ userPatch := &store.UpdateUser{
ID: user.ID, ID: user.ID,
Nickname: &userPatchNickname, Nickname: &userPatchNickname,
} }
user, err = store.PatchUser(ctx, userPatch) user, err = ts.UpdateUser(ctx, userPatch)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, userPatchNickname, user.Nickname) require.Equal(t, userPatchNickname, user.Nickname)
err = store.DeleteUser(ctx, &api.UserDelete{ err = ts.DeleteUser(ctx, &api.UserDelete{
ID: user.ID, ID: user.ID,
}) })
require.NoError(t, err) require.NoError(t, err)
users, err = store.FindUserList(ctx, &api.UserFind{}) users, err = ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(users)) require.Equal(t, 0, len(users))
} }
func createTestingHostUser(ctx context.Context, store *store.Store) (*api.User, error) { func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) {
userCreate := &api.UserCreate{ userCreate := &store.User{
Username: "test", Username: "test",
Role: api.Host, Role: store.Host,
Email: "test@test.com", Email: "test@test.com",
Nickname: "test_nickname", Nickname: "test_nickname",
Password: "test_password",
OpenID: "test_open_id", OpenID: "test_open_id",
} }
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userCreate.Password), bcrypt.DefaultCost) passwordHash, err := bcrypt.GenerateFromPassword([]byte("test_password"), bcrypt.DefaultCost)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
user, err := store.CreateUser(ctx, userCreate) user, err := ts.CreateUserV1(ctx, userCreate)
return user, err return user, err
} }