mirror of
https://github.com/usememos/memos.git
synced 2024-12-26 04:42:54 +03:00
chore: update user store names (#1877)
* chore: update user store names * chore: update
This commit is contained in:
parent
ca770c87d6
commit
9a8d43bf88
@ -40,7 +40,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
Username: &signin.Username,
|
Username: &signin.Username,
|
||||||
})
|
})
|
||||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||||
@ -111,14 +111,14 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
Username: &userInfo.Identifier,
|
Username: &userInfo.Identifier,
|
||||||
})
|
})
|
||||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
|
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
|
||||||
}
|
}
|
||||||
if user == nil {
|
if user == nil {
|
||||||
userCreate := &store.UserMessage{
|
userCreate := &store.User{
|
||||||
Username: userInfo.Identifier,
|
Username: userInfo.Identifier,
|
||||||
// The new signup user should be normal user by default.
|
// The new signup user should be normal user by default.
|
||||||
Role: store.NormalUser,
|
Role: store.NormalUser,
|
||||||
@ -161,14 +161,14 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostUserType := store.Host
|
hostUserType := store.Host
|
||||||
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUserMessage{
|
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
|
||||||
Role: &hostUserType,
|
Role: &hostUserType,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userCreate := &store.UserMessage{
|
userCreate := &store.User{
|
||||||
Username: signup.Username,
|
Username: signup.Username,
|
||||||
// The new signup user should be normal user by default.
|
// The new signup user should be normal user by default.
|
||||||
Role: store.NormalUser,
|
Role: store.NormalUser,
|
||||||
@ -224,7 +224,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.UserMessage) error {
|
func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User) error {
|
||||||
ctx := c.Request().Context()
|
ctx := c.Request().Context()
|
||||||
payload := ActivityUserAuthSignInPayload{
|
payload := ActivityUserAuthSignInPayload{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
@ -246,7 +246,7 @@ func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.UserMessage) error {
|
func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User) error {
|
||||||
ctx := c.Request().Context()
|
ctx := c.Request().Context()
|
||||||
payload := ActivityUserAuthSignUpPayload{
|
payload := ActivityUserAuthSignUpPayload{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
|
@ -68,7 +68,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -102,7 +102,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -147,7 +147,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||||
isHostUser := false
|
isHostUser := false
|
||||||
if ok {
|
if ok {
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -177,7 +177,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -211,7 +211,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -141,7 +141,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Even if there is no error, we still need to make sure the user still exists.
|
// Even if there is no error, we still need to make sure the user still exists.
|
||||||
user, err := server.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := server.Store.GetUser(ctx, &store.FindUser{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -222,7 +222,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool {
|
|||||||
// If there is openId in query string and related user is found, then skip auth.
|
// If there is openId in query string and related user is found, then skip auth.
|
||||||
openID := c.QueryParam("openId")
|
openID := c.QueryParam("openId")
|
||||||
if openID != "" {
|
if openID != "" {
|
||||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||||
OpenID: &openID,
|
OpenID: &openID,
|
||||||
})
|
})
|
||||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||||
|
1
go.mod
1
go.mod
@ -78,7 +78,6 @@ require (
|
|||||||
github.com/spf13/cast v1.5.0 // indirect
|
github.com/spf13/cast v1.5.0 // indirect
|
||||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/stretchr/objx v0.5.0 // indirect
|
|
||||||
github.com/subosito/gotenv v1.4.2 // indirect
|
github.com/subosito/gotenv v1.4.2 // indirect
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
github.com/valyala/fasttemplate v1.2.1 // indirect
|
github.com/valyala/fasttemplate v1.2.1 // indirect
|
||||||
|
1
go.sum
1
go.sum
@ -248,7 +248,6 @@ github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU=
|
|||||||
github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA=
|
github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
@ -64,7 +64,7 @@ func GenerateRefreshToken(userName string, userID int, secret string) (string, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
|
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
|
||||||
func GenerateTokensAndSetCookies(c echo.Context, user *store.UserMessage, secret string) error {
|
func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error {
|
||||||
accessToken, err := GenerateAccessToken(user.Username, user.ID, secret)
|
accessToken, err := GenerateAccessToken(user.Username, user.ID, secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to generate access token")
|
return errors.Wrap(err, "failed to generate access token")
|
||||||
|
@ -141,7 +141,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Even if there is no error, we still need to make sure the user still exists.
|
// Even if there is no error, we still need to make sure the user still exists.
|
||||||
user, err := server.Store.GetUser(ctx, &store.FindUserMessage{
|
user, err := server.Store.GetUser(ctx, &store.FindUser{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -9,30 +9,19 @@ import (
|
|||||||
|
|
||||||
"github.com/usememos/memos/api"
|
"github.com/usememos/memos/api"
|
||||||
"github.com/usememos/memos/common"
|
"github.com/usememos/memos/common"
|
||||||
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Execute(
|
func Execute(ctx context.Context, store *store.Store, hostUsername, hostPassword string) error {
|
||||||
ctx context.Context,
|
|
||||||
store store,
|
|
||||||
hostUsername, hostPassword string,
|
|
||||||
) error {
|
|
||||||
s := setupService{store: store}
|
s := setupService{store: store}
|
||||||
return s.Setup(ctx, hostUsername, hostPassword)
|
return s.Setup(ctx, hostUsername, hostPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
type store interface {
|
|
||||||
FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error)
|
|
||||||
CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type setupService struct {
|
type setupService struct {
|
||||||
store store
|
store *store.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s setupService) Setup(
|
func (s setupService) Setup(ctx context.Context, hostUsername, hostPassword string) error {
|
||||||
ctx context.Context,
|
|
||||||
hostUsername, hostPassword string,
|
|
||||||
) error {
|
|
||||||
if err := s.makeSureHostUserNotExists(ctx); err != nil {
|
if err := s.makeSureHostUserNotExists(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -59,31 +48,47 @@ func (s setupService) makeSureHostUserNotExists(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s setupService) createUser(
|
func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword string) error {
|
||||||
ctx context.Context,
|
userCreate := &store.User{
|
||||||
hostUsername, hostPassword string,
|
|
||||||
) error {
|
|
||||||
userCreate := &api.UserCreate{
|
|
||||||
Username: hostUsername,
|
Username: hostUsername,
|
||||||
// The new signup user should be normal user by default.
|
// The new signup user should be normal user by default.
|
||||||
Role: api.Host,
|
Role: store.Host,
|
||||||
Nickname: hostUsername,
|
Nickname: hostUsername,
|
||||||
Password: hostPassword,
|
|
||||||
OpenID: common.GenUUID(),
|
OpenID: common.GenUUID(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := userCreate.Validate(); err != nil {
|
if len(userCreate.Username) < 3 {
|
||||||
return fmt.Errorf("validate: %w", err)
|
return fmt.Errorf("username is too short, minimum length is 3")
|
||||||
|
}
|
||||||
|
if len(userCreate.Username) > 32 {
|
||||||
|
return fmt.Errorf("username is too long, maximum length is 32")
|
||||||
|
}
|
||||||
|
if len(hostPassword) < 3 {
|
||||||
|
return fmt.Errorf("password is too short, minimum length is 3")
|
||||||
|
}
|
||||||
|
if len(hostPassword) > 512 {
|
||||||
|
return fmt.Errorf("password is too long, maximum length is 512")
|
||||||
|
}
|
||||||
|
if len(userCreate.Nickname) > 64 {
|
||||||
|
return fmt.Errorf("nickname is too long, maximum length is 64")
|
||||||
|
}
|
||||||
|
if userCreate.Email != "" {
|
||||||
|
if len(userCreate.Email) > 256 {
|
||||||
|
return fmt.Errorf("email is too long, maximum length is 256")
|
||||||
|
}
|
||||||
|
if !common.ValidateEmail(userCreate.Email) {
|
||||||
|
return fmt.Errorf("invalid email format")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost)
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("hash password: %w", err)
|
return fmt.Errorf("failed to hash password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userCreate.PasswordHash = string(passwordHash)
|
userCreate.PasswordHash = string(passwordHash)
|
||||||
if _, err := s.store.CreateUser(ctx, userCreate); err != nil {
|
if _, err := s.store.CreateUserV1(ctx, userCreate); err != nil {
|
||||||
return fmt.Errorf("create user: %w", err)
|
return fmt.Errorf("failed to create user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,181 +0,0 @@
|
|||||||
package setup
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
|
|
||||||
"github.com/usememos/memos/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSetupServiceMakeSureHostUserNotExists(t *testing.T) {
|
|
||||||
cc := map[string]struct {
|
|
||||||
setupStore func(*storeMock)
|
|
||||||
expectedErr string
|
|
||||||
}{
|
|
||||||
"failed to get list": {
|
|
||||||
setupStore: func(m *storeMock) {
|
|
||||||
hostUserType := api.Host
|
|
||||||
m.
|
|
||||||
On("FindUserList", mock.Anything, &api.UserFind{
|
|
||||||
Role: &hostUserType,
|
|
||||||
}).
|
|
||||||
Return(nil, errors.New("fake error"))
|
|
||||||
},
|
|
||||||
expectedErr: "find user list: fake error",
|
|
||||||
},
|
|
||||||
"success, not empty": {
|
|
||||||
setupStore: func(m *storeMock) {
|
|
||||||
hostUserType := api.Host
|
|
||||||
m.
|
|
||||||
On("FindUserList", mock.Anything, &api.UserFind{
|
|
||||||
Role: &hostUserType,
|
|
||||||
}).
|
|
||||||
Return([]*api.User{
|
|
||||||
{},
|
|
||||||
}, nil)
|
|
||||||
},
|
|
||||||
expectedErr: "host user already exists",
|
|
||||||
},
|
|
||||||
"success, empty": {
|
|
||||||
setupStore: func(m *storeMock) {
|
|
||||||
hostUserType := api.Host
|
|
||||||
m.
|
|
||||||
On("FindUserList", mock.Anything, &api.UserFind{
|
|
||||||
Role: &hostUserType,
|
|
||||||
}).
|
|
||||||
Return(nil, nil)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for n, c := range cc {
|
|
||||||
c := c
|
|
||||||
t.Run(n, func(t *testing.T) {
|
|
||||||
sm := newStoreMock(t)
|
|
||||||
if c.setupStore != nil {
|
|
||||||
c.setupStore(sm)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := setupService{store: sm}
|
|
||||||
err := srv.makeSureHostUserNotExists(context.Background())
|
|
||||||
if c.expectedErr == "" {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
} else {
|
|
||||||
assert.EqualError(t, err, c.expectedErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetupServiceCreateUser(t *testing.T) {
|
|
||||||
expectedCreated := &api.UserCreate{
|
|
||||||
Username: "demohero",
|
|
||||||
Role: api.Host,
|
|
||||||
Nickname: "demohero",
|
|
||||||
Password: "123456",
|
|
||||||
}
|
|
||||||
|
|
||||||
userCreateMatcher := mock.MatchedBy(func(arg *api.UserCreate) bool {
|
|
||||||
return arg.Username == expectedCreated.Username &&
|
|
||||||
arg.Role == expectedCreated.Role &&
|
|
||||||
arg.Nickname == expectedCreated.Nickname &&
|
|
||||||
arg.Password == expectedCreated.Password &&
|
|
||||||
arg.PasswordHash != ""
|
|
||||||
})
|
|
||||||
|
|
||||||
cc := map[string]struct {
|
|
||||||
setupStore func(*storeMock)
|
|
||||||
hostUsername, hostPassword string
|
|
||||||
expectedErr string
|
|
||||||
}{
|
|
||||||
`username == "", password == ""`: {
|
|
||||||
expectedErr: "validate: username is too short, minimum length is 3",
|
|
||||||
},
|
|
||||||
`username == "", password != ""`: {
|
|
||||||
hostPassword: expectedCreated.Password,
|
|
||||||
expectedErr: "validate: username is too short, minimum length is 3",
|
|
||||||
},
|
|
||||||
`username != "", password == ""`: {
|
|
||||||
hostUsername: expectedCreated.Username,
|
|
||||||
expectedErr: "validate: password is too short, minimum length is 3",
|
|
||||||
},
|
|
||||||
"failed to create": {
|
|
||||||
setupStore: func(m *storeMock) {
|
|
||||||
m.
|
|
||||||
On("CreateUser", mock.Anything, userCreateMatcher).
|
|
||||||
Return(nil, errors.New("fake error"))
|
|
||||||
},
|
|
||||||
hostUsername: expectedCreated.Username,
|
|
||||||
hostPassword: expectedCreated.Password,
|
|
||||||
expectedErr: "create user: fake error",
|
|
||||||
},
|
|
||||||
"success": {
|
|
||||||
setupStore: func(m *storeMock) {
|
|
||||||
m.
|
|
||||||
On("CreateUser", mock.Anything, userCreateMatcher).
|
|
||||||
Return(nil, nil)
|
|
||||||
},
|
|
||||||
hostUsername: expectedCreated.Username,
|
|
||||||
hostPassword: expectedCreated.Password,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for n, c := range cc {
|
|
||||||
c := c
|
|
||||||
t.Run(n, func(t *testing.T) {
|
|
||||||
sm := newStoreMock(t)
|
|
||||||
if c.setupStore != nil {
|
|
||||||
c.setupStore(sm)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := setupService{store: sm}
|
|
||||||
err := srv.createUser(context.Background(), c.hostUsername, c.hostPassword)
|
|
||||||
if c.expectedErr == "" {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
} else {
|
|
||||||
assert.EqualError(t, err, c.expectedErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type storeMock struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeMock) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
|
|
||||||
ret := m.Called(ctx, find)
|
|
||||||
|
|
||||||
var u []*api.User
|
|
||||||
ret1 := ret.Get(0)
|
|
||||||
if ret1 != nil {
|
|
||||||
u = ret1.([]*api.User)
|
|
||||||
}
|
|
||||||
|
|
||||||
return u, ret.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeMock) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
|
|
||||||
ret := m.Called(ctx, create)
|
|
||||||
|
|
||||||
var u *api.User
|
|
||||||
ret1 := ret.Get(0)
|
|
||||||
if ret1 != nil {
|
|
||||||
u = ret1.(*api.User)
|
|
||||||
}
|
|
||||||
|
|
||||||
return u, ret.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStoreMock(t *testing.T) *storeMock {
|
|
||||||
m := &storeMock{}
|
|
||||||
m.Mock.Test(t)
|
|
||||||
|
|
||||||
t.Cleanup(func() { m.AssertExpectations(t) })
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
@ -14,6 +14,7 @@ type Store struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
systemSettingCache sync.Map // map[string]*systemSettingRaw
|
systemSettingCache sync.Map // map[string]*systemSettingRaw
|
||||||
userCache sync.Map // map[int]*userRaw
|
userCache sync.Map // map[int]*userRaw
|
||||||
|
userV1Cache sync.Map // map[string]*User
|
||||||
userSettingCache sync.Map // map[string]*UserSetting
|
userSettingCache sync.Map // map[string]*UserSetting
|
||||||
shortcutCache sync.Map // map[int]*shortcutRaw
|
shortcutCache sync.Map // map[int]*shortcutRaw
|
||||||
idpCache sync.Map // map[int]*IdentityProvider
|
idpCache sync.Map // map[int]*IdentityProvider
|
||||||
|
160
store/user.go
160
store/user.go
@ -34,7 +34,7 @@ func (e Role) String() string {
|
|||||||
return "USER"
|
return "USER"
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserMessage struct {
|
type User struct {
|
||||||
ID int
|
ID int
|
||||||
|
|
||||||
// Standard fields
|
// Standard fields
|
||||||
@ -52,7 +52,22 @@ type UserMessage struct {
|
|||||||
AvatarURL string
|
AvatarURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
type FindUserMessage struct {
|
type UpdateUser struct {
|
||||||
|
ID int
|
||||||
|
|
||||||
|
UpdatedTs *int64
|
||||||
|
RowStatus *RowStatus
|
||||||
|
Username *string `json:"username"`
|
||||||
|
Email *string `json:"email"`
|
||||||
|
Nickname *string `json:"nickname"`
|
||||||
|
Password *string `json:"password"`
|
||||||
|
ResetOpenID *bool `json:"resetOpenId"`
|
||||||
|
AvatarURL *string `json:"avatarUrl"`
|
||||||
|
PasswordHash *string
|
||||||
|
OpenID *string
|
||||||
|
}
|
||||||
|
|
||||||
|
type FindUser struct {
|
||||||
ID *int
|
ID *int
|
||||||
|
|
||||||
// Standard fields
|
// Standard fields
|
||||||
@ -66,10 +81,10 @@ type FindUserMessage struct {
|
|||||||
OpenID *string
|
OpenID *string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) {
|
func (s *Store) CreateUserV1(ctx context.Context, create *User) (*User, error) {
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@ -99,19 +114,85 @@ func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMes
|
|||||||
&create.UpdatedTs,
|
&create.UpdatedTs,
|
||||||
&create.RowStatus,
|
&create.RowStatus,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
userMessage := create
|
user := create
|
||||||
return userMessage, nil
|
s.userV1Cache.Store(user.ID, user)
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) {
|
func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
set, args := []string{}, []any{}
|
||||||
|
if v := update.UpdatedTs; v != nil {
|
||||||
|
set, args = append(set, "updated_ts = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := update.RowStatus; v != nil {
|
||||||
|
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := update.Username; v != nil {
|
||||||
|
set, args = append(set, "username = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := update.Email; v != nil {
|
||||||
|
set, args = append(set, "email = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := update.Nickname; v != nil {
|
||||||
|
set, args = append(set, "nickname = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := update.AvatarURL; v != nil {
|
||||||
|
set, args = append(set, "avatar_url = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := update.PasswordHash; v != nil {
|
||||||
|
set, args = append(set, "password_hash = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := update.OpenID; v != nil {
|
||||||
|
set, args = append(set, "open_id = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
args = append(args, update.ID)
|
||||||
|
|
||||||
|
query := `
|
||||||
|
UPDATE user
|
||||||
|
SET ` + strings.Join(set, ", ") + `
|
||||||
|
WHERE id = ?
|
||||||
|
RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
|
||||||
|
`
|
||||||
|
user := &User{}
|
||||||
|
if err := tx.QueryRowContext(ctx, query, args...).Scan(
|
||||||
|
&user.ID,
|
||||||
|
&user.Username,
|
||||||
|
&user.Role,
|
||||||
|
&user.Email,
|
||||||
|
&user.Nickname,
|
||||||
|
&user.PasswordHash,
|
||||||
|
&user.OpenID,
|
||||||
|
&user.AvatarURL,
|
||||||
|
&user.CreatedTs,
|
||||||
|
&user.UpdatedTs,
|
||||||
|
&user.RowStatus,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.userV1Cache.Store(user.ID, user)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
|
||||||
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@ -120,13 +201,22 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, user := range list {
|
||||||
|
s.userV1Cache.Store(user.ID, user)
|
||||||
|
}
|
||||||
return list, nil
|
return list, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) {
|
func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
|
||||||
|
if find.ID != nil {
|
||||||
|
if user, ok := s.userV1Cache.Load(*find.ID); ok {
|
||||||
|
return user.(*User), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@ -135,14 +225,14 @@ func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessag
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(list) == 0 {
|
if len(list) == 0 {
|
||||||
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
user := list[0]
|
||||||
memoMessage := list[0]
|
s.userV1Cache.Store(user.ID, user)
|
||||||
return memoMessage, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) {
|
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) {
|
||||||
where, args := []string{"1 = 1"}, []any{}
|
where, args := []string{"1 = 1"}, []any{}
|
||||||
|
|
||||||
if v := find.ID; v != nil {
|
if v := find.ID; v != nil {
|
||||||
@ -183,36 +273,36 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserM
|
|||||||
`
|
`
|
||||||
rows, err := tx.QueryContext(ctx, query, args...)
|
rows, err := tx.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
userMessageList := make([]*UserMessage, 0)
|
list := make([]*User, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var userMessage UserMessage
|
var user User
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&userMessage.ID,
|
&user.ID,
|
||||||
&userMessage.Username,
|
&user.Username,
|
||||||
&userMessage.Role,
|
&user.Role,
|
||||||
&userMessage.Email,
|
&user.Email,
|
||||||
&userMessage.Nickname,
|
&user.Nickname,
|
||||||
&userMessage.PasswordHash,
|
&user.PasswordHash,
|
||||||
&userMessage.OpenID,
|
&user.OpenID,
|
||||||
&userMessage.AvatarURL,
|
&user.AvatarURL,
|
||||||
&userMessage.CreatedTs,
|
&user.CreatedTs,
|
||||||
&userMessage.UpdatedTs,
|
&user.UpdatedTs,
|
||||||
&userMessage.RowStatus,
|
&user.RowStatus,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
userMessageList = append(userMessageList, &userMessage)
|
list = append(list, &user)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return userMessageList, nil
|
return list, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// userRaw is the store model for an User.
|
// userRaw is the store model for an User.
|
||||||
|
@ -12,45 +12,44 @@ import (
|
|||||||
|
|
||||||
func TestUserStore(t *testing.T) {
|
func TestUserStore(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
store := NewTestingStore(ctx, t)
|
ts := NewTestingStore(ctx, t)
|
||||||
user, err := createTestingHostUser(ctx, store)
|
user, err := createTestingHostUser(ctx, ts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
users, err := store.FindUserList(ctx, &api.UserFind{})
|
users, err := ts.ListUsers(ctx, &store.FindUser{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, len(users))
|
require.Equal(t, 1, len(users))
|
||||||
require.Equal(t, api.Host, users[0].Role)
|
require.Equal(t, store.Host, users[0].Role)
|
||||||
require.Equal(t, user, users[0])
|
require.Equal(t, user, users[0])
|
||||||
userPatchNickname := "test_nickname_2"
|
userPatchNickname := "test_nickname_2"
|
||||||
userPatch := &api.UserPatch{
|
userPatch := &store.UpdateUser{
|
||||||
ID: user.ID,
|
ID: user.ID,
|
||||||
Nickname: &userPatchNickname,
|
Nickname: &userPatchNickname,
|
||||||
}
|
}
|
||||||
user, err = store.PatchUser(ctx, userPatch)
|
user, err = ts.UpdateUser(ctx, userPatch)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, userPatchNickname, user.Nickname)
|
require.Equal(t, userPatchNickname, user.Nickname)
|
||||||
err = store.DeleteUser(ctx, &api.UserDelete{
|
err = ts.DeleteUser(ctx, &api.UserDelete{
|
||||||
ID: user.ID,
|
ID: user.ID,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
users, err = store.FindUserList(ctx, &api.UserFind{})
|
users, err = ts.ListUsers(ctx, &store.FindUser{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 0, len(users))
|
require.Equal(t, 0, len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTestingHostUser(ctx context.Context, store *store.Store) (*api.User, error) {
|
func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) {
|
||||||
userCreate := &api.UserCreate{
|
userCreate := &store.User{
|
||||||
Username: "test",
|
Username: "test",
|
||||||
Role: api.Host,
|
Role: store.Host,
|
||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Nickname: "test_nickname",
|
Nickname: "test_nickname",
|
||||||
Password: "test_password",
|
|
||||||
OpenID: "test_open_id",
|
OpenID: "test_open_id",
|
||||||
}
|
}
|
||||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userCreate.Password), bcrypt.DefaultCost)
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte("test_password"), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userCreate.PasswordHash = string(passwordHash)
|
userCreate.PasswordHash = string(passwordHash)
|
||||||
user, err := store.CreateUser(ctx, userCreate)
|
user, err := ts.CreateUserV1(ctx, userCreate)
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user