memos/api/v2/user_service.go

504 lines
17 KiB
Go
Raw Normal View History

package v2
import (
"context"
2023-11-05 18:03:43 +03:00
"fmt"
2023-09-10 13:56:24 +03:00
"net/http"
2023-09-18 17:34:31 +03:00
"regexp"
2023-09-18 17:37:13 +03:00
"strings"
2023-09-10 13:56:24 +03:00
"time"
2023-09-14 15:16:17 +03:00
"github.com/golang-jwt/jwt/v4"
2023-09-10 13:56:24 +03:00
"github.com/labstack/echo/v4"
2023-09-14 15:16:17 +03:00
"github.com/pkg/errors"
2023-09-10 13:56:24 +03:00
"golang.org/x/crypto/bcrypt"
2023-09-14 15:16:17 +03:00
"golang.org/x/exp/slices"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
2023-09-10 13:56:24 +03:00
"google.golang.org/protobuf/types/known/timestamppb"
2023-09-17 17:55:13 +03:00
"github.com/usememos/memos/api/auth"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
2023-09-18 17:34:31 +03:00
var (
2023-10-14 08:42:27 +03:00
usernameMatcher = regexp.MustCompile("^[a-z0-9]([a-z0-9-]{1,30}[a-z0-9])$")
2023-09-18 17:34:31 +03:00
)
2023-10-27 04:07:35 +03:00
func (s *APIV2Service) GetUser(ctx context.Context, request *apiv2pb.GetUserRequest) (*apiv2pb.GetUserResponse, error) {
2023-11-05 18:28:09 +03:00
username, err := ExtractUsernameFromName(request.Name)
2023-11-05 18:03:43 +03:00
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
2023-11-05 18:03:43 +03:00
Username: &username,
})
if err != nil {
2023-09-10 13:56:24 +03:00
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
userMessage := convertUserFromStore(user)
response := &apiv2pb.GetUserResponse{
User: userMessage,
}
return response, nil
}
2023-10-27 04:07:35 +03:00
func (s *APIV2Service) CreateUser(ctx context.Context, request *apiv2pb.CreateUserRequest) (*apiv2pb.CreateUserResponse, error) {
2023-10-21 07:19:06 +03:00
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
2023-11-05 18:28:09 +03:00
username, err := ExtractUsernameFromName(request.User.Name)
2023-11-05 18:03:43 +03:00
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
if !usernameMatcher.MatchString(strings.ToLower(username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", username)
2023-10-21 07:19:06 +03:00
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
user, err := s.Store.CreateUser(ctx, &store.User{
2023-11-05 18:03:43 +03:00
Username: username,
2023-10-21 07:19:06 +03:00
Role: convertUserRoleToStore(request.User.Role),
Email: request.User.Email,
Nickname: request.User.Nickname,
PasswordHash: string(passwordHash),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
}
response := &apiv2pb.CreateUserResponse{
User: convertUserFromStore(user),
}
return response, nil
}
2023-10-27 04:07:35 +03:00
func (s *APIV2Service) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUserRequest) (*apiv2pb.UpdateUserResponse, error) {
2023-11-05 18:28:09 +03:00
username, err := ExtractUsernameFromName(request.User.Name)
2023-11-05 18:03:43 +03:00
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
2023-09-14 15:16:17 +03:00
currentUser, err := getCurrentUser(ctx, s.Store)
2023-09-10 13:56:24 +03:00
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Username != username && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
2023-09-10 13:56:24 +03:00
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
2023-10-21 07:19:06 +03:00
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
2023-09-10 13:56:24 +03:00
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
2023-11-18 07:37:24 +03:00
user, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
2023-09-10 13:56:24 +03:00
currentTs := time.Now().Unix()
update := &store.UpdateUser{
2023-11-18 07:37:24 +03:00
ID: user.ID,
2023-09-10 13:56:24 +03:00
UpdatedTs: &currentTs,
}
2023-10-21 07:19:06 +03:00
for _, field := range request.UpdateMask.Paths {
2023-10-03 18:44:14 +03:00
if field == "username" {
2023-11-05 18:03:43 +03:00
if !usernameMatcher.MatchString(strings.ToLower(username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", username)
2023-09-18 17:34:31 +03:00
}
2023-11-05 18:03:43 +03:00
update.Username = &username
2023-10-03 18:44:14 +03:00
} else if field == "nickname" {
2023-09-10 13:56:24 +03:00
update.Nickname = &request.User.Nickname
2023-10-03 18:44:14 +03:00
} else if field == "email" {
2023-09-10 13:56:24 +03:00
update.Email = &request.User.Email
2023-10-03 18:44:14 +03:00
} else if field == "avatar_url" {
2023-09-10 13:56:24 +03:00
update.AvatarURL = &request.User.AvatarUrl
2023-10-03 18:44:14 +03:00
} else if field == "role" {
2023-09-10 13:56:24 +03:00
role := convertUserRoleToStore(request.User.Role)
update.Role = &role
2023-10-03 18:44:14 +03:00
} else if field == "password" {
2023-09-10 13:56:24 +03:00
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
passwordHashStr := string(passwordHash)
update.PasswordHash = &passwordHashStr
2023-10-03 18:44:14 +03:00
} else if field == "row_status" {
2023-09-10 13:56:24 +03:00
rowStatus := convertRowStatusToStore(request.User.RowStatus)
update.RowStatus = &rowStatus
} else {
2023-10-03 18:44:14 +03:00
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
2023-09-10 13:56:24 +03:00
}
}
2023-11-18 07:37:24 +03:00
updatedUser, err := s.Store.UpdateUser(ctx, update)
2023-09-10 13:56:24 +03:00
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
}
response := &apiv2pb.UpdateUserResponse{
2023-11-18 07:37:24 +03:00
User: convertUserFromStore(updatedUser),
2023-09-10 13:56:24 +03:00
}
return response, nil
}
2023-11-22 17:52:19 +03:00
func (s *APIV2Service) DeleteUser(ctx context.Context, request *apiv2pb.DeleteUserRequest) (*apiv2pb.DeleteUserResponse, error) {
username, err := ExtractUsernameFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Username != username && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err)
}
return &apiv2pb.DeleteUserResponse{}, nil
}
2023-12-01 04:03:30 +03:00
func (s *APIV2Service) GetUserSetting(ctx context.Context, _ *apiv2pb.GetUserSettingRequest) (*apiv2pb.GetUserSettingResponse, error) {
2023-11-30 18:08:54 +03:00
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
userSettings, err := s.Store.ListUserSettingsV1(ctx, &store.FindUserSettingV1{
UserID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err)
}
2023-12-01 04:03:30 +03:00
userSettingMessage := &apiv2pb.UserSetting{}
2023-11-30 18:08:54 +03:00
for _, setting := range userSettings {
if setting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
2023-12-01 04:03:30 +03:00
userSettingMessage.Locale = setting.GetLocale()
2023-11-30 18:08:54 +03:00
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_APPEARANCE {
2023-12-01 04:03:30 +03:00
userSettingMessage.Appearance = setting.GetAppearance()
2023-11-30 18:08:54 +03:00
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY {
2023-12-01 04:03:30 +03:00
userSettingMessage.MemoVisibility = setting.GetMemoVisibility()
2023-11-30 18:08:54 +03:00
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_TELEGRAM_USER_ID {
2023-12-01 04:03:30 +03:00
userSettingMessage.TelegramUserId = setting.GetTelegramUserId()
2023-11-30 18:08:54 +03:00
}
}
2023-12-01 04:03:30 +03:00
return &apiv2pb.GetUserSettingResponse{
Setting: userSettingMessage,
2023-11-30 18:08:54 +03:00
}, nil
}
2023-12-01 04:03:30 +03:00
func (s *APIV2Service) UpdateUserSetting(ctx context.Context, request *apiv2pb.UpdateUserSettingRequest) (*apiv2pb.UpdateUserSettingResponse, error) {
2023-11-30 18:08:54 +03:00
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
for _, field := range request.UpdateMask.Paths {
if field == "locale" {
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_LOCALE,
Value: &storepb.UserSetting_Locale{
2023-12-01 04:03:30 +03:00
Locale: request.Setting.Locale,
2023-11-30 18:08:54 +03:00
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "appearance" {
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_APPEARANCE,
Value: &storepb.UserSetting_Appearance{
2023-12-01 04:03:30 +03:00
Appearance: request.Setting.Appearance,
2023-11-30 18:08:54 +03:00
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "memo_visibility" {
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY,
Value: &storepb.UserSetting_MemoVisibility{
2023-12-01 04:03:30 +03:00
MemoVisibility: request.Setting.MemoVisibility,
2023-11-30 18:08:54 +03:00
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "telegram_user_id" {
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_TELEGRAM_USER_ID,
Value: &storepb.UserSetting_TelegramUserId{
2023-12-01 04:03:30 +03:00
TelegramUserId: request.Setting.TelegramUserId,
2023-11-30 18:08:54 +03:00
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else {
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
2023-12-01 04:03:30 +03:00
userSettingResponse, err := s.GetUserSetting(ctx, &apiv2pb.GetUserSettingRequest{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user setting: %v", err)
}
return &apiv2pb.UpdateUserSettingResponse{
Setting: userSettingResponse.Setting,
}, nil
2023-11-30 18:08:54 +03:00
}
2023-10-27 04:07:35 +03:00
func (s *APIV2Service) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) {
2023-09-14 15:16:17 +03:00
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userID := user.ID
2023-11-05 18:28:09 +03:00
username, err := ExtractUsernameFromName(request.Name)
2023-11-05 18:03:43 +03:00
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
// List access token for other users need to be verified.
2023-11-05 18:03:43 +03:00
if user.Username != username {
// Normal users can only list their access tokens.
if user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// The request user must be exist.
2023-11-05 18:03:43 +03:00
requestUser, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
if requestUser == nil || err != nil {
2023-11-05 18:03:43 +03:00
return nil, status.Errorf(codes.NotFound, "fail to find user %s", username)
}
userID = requestUser.ID
2023-09-14 15:16:17 +03:00
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
2023-09-14 15:16:17 +03:00
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
accessTokens := []*apiv2pb.UserAccessToken{}
for _, userAccessToken := range userAccessTokens {
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
// If the access token is invalid or expired, just ignore it.
continue
}
userAccessToken := &apiv2pb.UserAccessToken{
AccessToken: userAccessToken.AccessToken,
Description: userAccessToken.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
accessTokens = append(accessTokens, userAccessToken)
}
// Sort by issued time in descending order.
2023-10-12 19:13:13 +03:00
slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) int {
return int(i.IssuedAt.Seconds - j.IssuedAt.Seconds)
2023-09-14 15:16:17 +03:00
})
response := &apiv2pb.ListUserAccessTokensResponse{
AccessTokens: accessTokens,
}
return response, nil
}
2023-10-27 04:07:35 +03:00
func (s *APIV2Service) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) {
2023-09-14 15:16:17 +03:00
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
2023-09-20 15:48:34 +03:00
expiresAt := time.Time{}
if request.ExpiresAt != nil {
expiresAt = request.ExpiresAt.AsTime()
}
2023-11-05 18:03:43 +03:00
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret))
2023-09-14 15:16:17 +03:00
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
claims := &auth.ClaimsMessage{}
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
}
// Upsert the access token to user setting store.
2023-09-20 15:48:34 +03:00
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.Description); err != nil {
2023-09-14 15:16:17 +03:00
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
}
userAccessToken := &apiv2pb.UserAccessToken{
AccessToken: accessToken,
2023-09-20 15:48:34 +03:00
Description: request.Description,
2023-09-14 15:16:17 +03:00
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
response := &apiv2pb.CreateUserAccessTokenResponse{
AccessToken: userAccessToken,
}
return response, nil
}
2023-10-27 04:07:35 +03:00
func (s *APIV2Service) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) {
2023-09-14 15:16:17 +03:00
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
for _, userAccessToken := range userAccessTokens {
if userAccessToken.AccessToken == request.AccessToken {
continue
}
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
}
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: updatedUserAccessTokens,
},
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &apiv2pb.DeleteUserAccessTokenResponse{}, nil
}
2023-10-27 04:07:35 +03:00
func (s *APIV2Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
2023-09-14 15:16:17 +03:00
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: description,
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return errors.Wrap(err, "failed to upsert user setting")
}
return nil
}
func convertUserFromStore(user *store.User) *apiv2pb.User {
return &apiv2pb.User{
2023-11-05 18:03:43 +03:00
Name: fmt.Sprintf("%s%s", UserNamePrefix, user.Username),
Id: user.ID,
2023-09-10 13:56:24 +03:00
RowStatus: convertRowStatusFromStore(user.RowStatus),
CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
Role: convertUserRoleFromStore(user.Role),
Email: user.Email,
Nickname: user.Nickname,
AvatarUrl: user.AvatarURL,
}
}
2023-09-10 13:56:24 +03:00
func convertUserRoleFromStore(role store.Role) apiv2pb.User_Role {
switch role {
case store.RoleHost:
2023-09-10 13:56:24 +03:00
return apiv2pb.User_HOST
case store.RoleAdmin:
2023-09-10 13:56:24 +03:00
return apiv2pb.User_ADMIN
case store.RoleUser:
2023-09-10 13:56:24 +03:00
return apiv2pb.User_USER
default:
return apiv2pb.User_ROLE_UNSPECIFIED
}
}
func convertUserRoleToStore(role apiv2pb.User_Role) store.Role {
switch role {
case apiv2pb.User_HOST:
return store.RoleHost
case apiv2pb.User_ADMIN:
return store.RoleAdmin
case apiv2pb.User_USER:
return store.RoleUser
default:
2023-09-10 13:56:24 +03:00
return store.RoleUser
}
}