diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 00000000..e7ca7803 --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,27 @@ +package auth + +import ( + "time" +) + +const ( + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + UserIDContextKey = "user-id" + // issuer is the issuer of the jwt token. + Issuer = "memos" + // Signing key section. For now, this is only used for signing, not for verifying since we only + // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. + KeyID = "v1" + // AccessTokenAudienceName is the audience name of the access token. + AccessTokenAudienceName = "user.access-token" + AccessTokenDuration = 7 * 24 * time.Hour + + // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user + // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. + // Suppose we have a valid refresh token, we will refresh the token in cases: + // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. + CookieExpDuration = AccessTokenDuration - 1*time.Minute + // AccessTokenCookieName is the cookie name of access token. + AccessTokenCookieName = "memos.access-token" +) diff --git a/api/v1/auth.go b/api/v1/auth.go index d5633d0f..cc2030a9 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -8,7 +8,6 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp/oauth2" @@ -77,7 +76,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again") } - if err := auth.GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { + if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createAuthSignInActivity(c, user); err != nil { @@ -165,7 +164,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier)) } - if err := auth.GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { + if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createAuthSignInActivity(c, user); err != nil { @@ -231,7 +230,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } - if err := auth.GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { + if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createAuthSignUpActivity(c, user); err != nil { @@ -244,7 +243,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { // POST /auth/signout - Sign out. g.POST("/auth/signout", func(c echo.Context) error { - auth.RemoveTokensAndCookies(c) + RemoveTokensAndCookies(c) return c.JSON(http.StatusOK, true) }) } diff --git a/api/v1/auth/auth.go b/api/v1/auth/auth.go deleted file mode 100644 index 47c4178a..00000000 --- a/api/v1/auth/auth.go +++ /dev/null @@ -1,132 +0,0 @@ -package auth - -import ( - "net/http" - "strconv" - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/labstack/echo/v4" - "github.com/pkg/errors" - "github.com/usememos/memos/store" -) - -const ( - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - UserIDContextKey = "user-id" - // issuer is the issuer of the jwt token. - issuer = "memos" - // Signing key section. For now, this is only used for signing, not for verifying since we only - // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. - keyID = "v1" - // AccessTokenAudienceName is the audience name of the access token. - AccessTokenAudienceName = "user.access-token" - // RefreshTokenAudienceName is the audience name of the refresh token. - RefreshTokenAudienceName = "user.refresh-token" - apiTokenDuration = 2 * time.Hour - accessTokenDuration = 24 * time.Hour - refreshTokenDuration = 7 * 24 * time.Hour - - // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user - // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. - // Suppose we have a valid refresh token, we will refresh the token in cases: - // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. - CookieExpDuration = refreshTokenDuration - 1*time.Minute - // AccessTokenCookieName is the cookie name of access token. - AccessTokenCookieName = "memos.access-token" - // RefreshTokenCookieName is the cookie name of refresh token. - RefreshTokenCookieName = "memos.refresh-token" -) - -type claimsMessage struct { - Name string `json:"name"` - jwt.RegisteredClaims -} - -// GenerateAPIToken generates an API token. -func GenerateAPIToken(userName string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(apiTokenDuration) - return generateToken(userName, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateAccessToken generates an access token for web. -func GenerateAccessToken(userName string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(accessTokenDuration) - return generateToken(userName, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateRefreshToken generates a refresh token for web. -func GenerateRefreshToken(userName string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(refreshTokenDuration) - return generateToken(userName, userID, RefreshTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. -func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { - accessToken, err := GenerateAccessToken(user.Username, user.ID, secret) - if err != nil { - return errors.Wrap(err, "failed to generate access token") - } - - cookieExp := time.Now().Add(CookieExpDuration) - setTokenCookie(c, AccessTokenCookieName, accessToken, cookieExp) - - // We generate here a new refresh token and saving it to the cookie. - refreshToken, err := GenerateRefreshToken(user.Username, user.ID, secret) - if err != nil { - return errors.Wrap(err, "failed to generate refresh token") - } - setTokenCookie(c, RefreshTokenCookieName, refreshToken, cookieExp) - - return nil -} - -// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. -func RemoveTokensAndCookies(c echo.Context) { - // We set the expiration time to the past, so that the cookie will be removed. - cookieExp := time.Now().Add(-1 * time.Hour) - setTokenCookie(c, AccessTokenCookieName, "", cookieExp) - setTokenCookie(c, RefreshTokenCookieName, "", cookieExp) -} - -// setTokenCookie sets the token to the cookie. -func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { - cookie := new(http.Cookie) - cookie.Name = name - cookie.Value = token - cookie.Expires = expiration - cookie.Path = "/" - // Http-only helps mitigate the risk of client side script accessing the protected cookie. - cookie.HttpOnly = true - cookie.SameSite = http.SameSiteStrictMode - c.SetCookie(cookie) -} - -// generateToken generates a jwt token. -func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { - // Create the JWT claims, which includes the username and expiry time. - claims := &claimsMessage{ - Name: username, - RegisteredClaims: jwt.RegisteredClaims{ - Audience: jwt.ClaimStrings{aud}, - // In JWT, the expiry time is expressed as unix milliseconds. - ExpiresAt: jwt.NewNumericDate(expirationTime), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: issuer, - Subject: strconv.Itoa(userID), - }, - } - - // Declare the token with the HS256 algorithm used for signing, and the claims. - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - token.Header["kid"] = keyID - - // Create the JWT string. - tokenString, err := token.SignedString(secret) - if err != nil { - return "", err - } - - return tokenString, nil -} diff --git a/api/v1/idp.go b/api/v1/idp.go index 5889c499..2580c2e6 100644 --- a/api/v1/idp.go +++ b/api/v1/idp.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 0e3f646b..a3e1696c 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -5,22 +5,86 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" ) -// Claims creates a struct that will be encoded to a JWT. -// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. -type Claims struct { +type claimsMessage struct { Name string `json:"name"` jwt.RegisteredClaims } +// GenerateAccessToken generates an access token for web. +func GenerateAccessToken(username string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(auth.AccessTokenDuration) + return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. +func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { + accessToken, err := GenerateAccessToken(user.Username, user.ID, secret) + if err != nil { + return errors.Wrap(err, "failed to generate access token") + } + + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) + return nil +} + +// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. +func RemoveTokensAndCookies(c echo.Context) { + cookieExp := time.Now().Add(-1 * time.Hour) + setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) +} + +// setTokenCookie sets the token to the cookie. +func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = token + cookie.Expires = expiration + cookie.Path = "/" + // Http-only helps mitigate the risk of client side script accessing the protected cookie. + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteStrictMode + c.SetCookie(cookie) +} + +// generateToken generates a jwt token. +func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { + // Create the JWT claims, which includes the username and expiry time. + claims := &claimsMessage{ + Name: username, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{aud}, + // In JWT, the expiry time is expressed as unix milliseconds. + ExpiresAt: jwt.NewNumericDate(expirationTime), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: auth.Issuer, + Subject: strconv.Itoa(userID), + }, + } + + // Declare the token with the HS256 algorithm used for signing, and the claims. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = auth.KeyID + + // Create the JWT string. + tokenString, err := token.SignedString(secret) + if err != nil { + return "", err + } + + return tokenString, nil +} + func extractTokenFromHeader(c echo.Context) (string, error) { authHeader := c.Request().Header.Get("Authorization") if authHeader == "" { @@ -62,7 +126,8 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool { // will try to generate new access token and refresh token. func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc { return func(c echo.Context) error { - path := c.Path() + ctx := c.Request().Context() + path := c.Request().URL.Path method := c.Request().Method if server.defaultAuthSkipper(c) { @@ -87,8 +152,8 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") } - claims := &Claims{} - accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { + claims := &claimsMessage{} + _, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) } @@ -100,27 +165,15 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) }) - generateToken := false if err != nil { - var ve *jwt.ValidationError - if errors.As(err, &ve) { - // If expiration error is the only error, we will clear the err - // and generate new access token and refresh token - if ve.Errors == jwt.ValidationErrorExpired { - generateToken = true - } - } else { - auth.RemoveTokensAndCookies(c) - return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) - } + RemoveTokensAndCookies(c) + return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) } - if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) } // We either have a valid access token or we will attempt to generate new access token and refresh token - ctx := c.Request().Context() userID, err := strconv.Atoi(claims.Subject) if err != nil { return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") @@ -137,59 +190,6 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) } - if generateToken { - generateTokenFunc := func() error { - rc, err := c.Cookie(auth.RefreshTokenCookieName) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") - } - - // Parses token and checks if it's valid. - refreshTokenClaims := &Claims{} - refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) - } - - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(secret), nil - } - } - return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) - }) - if err != nil { - if err == jwt.ErrSignatureInvalid { - return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.") - } - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) - } - - if !audienceContains(refreshTokenClaims.Audience, auth.RefreshTokenAudienceName) { - return echo.NewHTTPError(http.StatusUnauthorized, - fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", - refreshTokenClaims.Audience, - auth.RefreshTokenAudienceName, - )) - } - - // If we have a valid refresh token, we will generate new access token and refresh token - if refreshToken != nil && refreshToken.Valid { - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) - } - } - - return nil - } - - // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token - // In such case, we won't return the error. - if err := generateTokenFunc(); err != nil && !accessToken.Valid { - return err - } - } - // Stores userID into context. c.Set(auth.UserIDContextKey, userID) return next(c) diff --git a/api/v1/memo.go b/api/v1/memo.go index b70f561e..1ce53cb2 100644 --- a/api/v1/memo.go +++ b/api/v1/memo.go @@ -10,7 +10,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) diff --git a/api/v1/memo_organizer.go b/api/v1/memo_organizer.go index f88d4712..1e02912c 100644 --- a/api/v1/memo_organizer.go +++ b/api/v1/memo_organizer.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) diff --git a/api/v1/memo_resource.go b/api/v1/memo_resource.go index 2a197c23..197cacff 100644 --- a/api/v1/memo_resource.go +++ b/api/v1/memo_resource.go @@ -8,7 +8,7 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) diff --git a/api/v1/resource.go b/api/v1/resource.go index f004caf5..b65fb286 100644 --- a/api/v1/resource.go +++ b/api/v1/resource.go @@ -21,7 +21,7 @@ import ( "github.com/disintegration/imaging" "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/log" "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/storage/s3" diff --git a/api/v1/shortcut.go b/api/v1/shortcut.go index f804db58..96734cfb 100644 --- a/api/v1/shortcut.go +++ b/api/v1/shortcut.go @@ -9,7 +9,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) diff --git a/api/v1/storage.go b/api/v1/storage.go index 31500721..f5b181fc 100644 --- a/api/v1/storage.go +++ b/api/v1/storage.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) diff --git a/api/v1/system.go b/api/v1/system.go index 861384b9..1f48f1fb 100644 --- a/api/v1/system.go +++ b/api/v1/system.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/log" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" diff --git a/api/v1/system_setting.go b/api/v1/system_setting.go index 188c7241..cb5f28fd 100644 --- a/api/v1/system_setting.go +++ b/api/v1/system_setting.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) diff --git a/api/v1/tag.go b/api/v1/tag.go index 8c3eb6ef..1c4488a4 100644 --- a/api/v1/tag.go +++ b/api/v1/tag.go @@ -9,7 +9,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" "golang.org/x/exp/slices" ) diff --git a/api/v1/user.go b/api/v1/user.go index dbd89a17..7fa672e3 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -9,7 +9,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" "golang.org/x/crypto/bcrypt" diff --git a/api/v1/user_setting.go b/api/v1/user_setting.go index dedf04b7..cab0f53b 100644 --- a/api/v1/user_setting.go +++ b/api/v1/user_setting.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" "golang.org/x/exp/slices" ) diff --git a/api/v2/auth/auth.go b/api/v2/auth/auth.go deleted file mode 100644 index b7bc1ef4..00000000 --- a/api/v2/auth/auth.go +++ /dev/null @@ -1,297 +0,0 @@ -// Package auth handles the auth of gRPC server. -package auth - -import ( - "context" - "errors" - "net/http" - "strconv" - "strings" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" - - "github.com/golang-jwt/jwt/v4" - errs "github.com/pkg/errors" - - "github.com/usememos/memos/store" -) - -// ContextKey is the key type of context value. -type ContextKey int - -const ( - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - UserIDContextKey ContextKey = iota - // issuer is the issuer of the jwt token. - issuer = "memos" - // Signing key section. For now, this is only used for signing, not for verifying since we only - // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. - keyID = "v1" - // AccessTokenAudienceName is the audience name of the access token. - AccessTokenAudienceName = "user.access-token" - // RefreshTokenAudienceName is the audience name of the refresh token. - RefreshTokenAudienceName = "user.refresh-token" - apiTokenDuration = 2 * time.Hour - accessTokenDuration = 24 * time.Hour - refreshTokenDuration = 7 * 24 * time.Hour - - // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user - // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. - // Suppose we have a valid refresh token, we will refresh the token in cases: - // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. - CookieExpDuration = refreshTokenDuration - 1*time.Minute - // AccessTokenCookieName is the cookie name of access token. - AccessTokenCookieName = "memos.access-token" - // RefreshTokenCookieName is the cookie name of refresh token. - RefreshTokenCookieName = "memos.refresh-token" -) - -// GRPCAuthInterceptor is the auth interceptor for gRPC server. -type GRPCAuthInterceptor struct { - store *store.Store - secret string -} - -// NewGRPCAuthInterceptor returns a new API auth interceptor. -func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor { - return &GRPCAuthInterceptor{ - store: store, - secret: secret, - } -} - -// AuthenticationInterceptor is the unary interceptor for gRPC API. -func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") - } - accessTokenStr, refreshTokenStr, err := getTokenFromMetadata(md) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, err.Error()) - } - - userID, err := in.authenticate(ctx, accessTokenStr, refreshTokenStr) - if err != nil { - if IsAuthenticationAllowed(serverInfo.FullMethod) { - return handler(ctx, request) - } - return nil, err - } - - // Stores userID into context. - childCtx := context.WithValue(ctx, UserIDContextKey, userID) - return handler(childCtx, request) -} - -func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr, refreshTokenStr string) (int, error) { - if accessTokenStr == "" { - return 0, status.Errorf(codes.Unauthenticated, "access token not found") - } - claims := &claimsMessage{} - generateToken := false - accessToken, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) - } - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(in.secret), nil - } - } - return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"]) - }) - if err != nil { - var ve *jwt.ValidationError - if errors.As(err, &ve) && ve.Errors == jwt.ValidationErrorExpired { - // If expiration error is the only error, we will clear the err - // and generate new access token and refresh token - if refreshTokenStr == "" { - return 0, status.Errorf(codes.Unauthenticated, "access token is expired") - } - generateToken = true - } else { - return 0, status.Errorf(codes.Unauthenticated, "failed to parse claim") - } - } - if !audienceContains(claims.Audience, AccessTokenAudienceName) { - return 0, status.Errorf(codes.Unauthenticated, - "invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", - claims.Audience, - AccessTokenAudienceName, - ) - } - - userID, err := strconv.Atoi(claims.Subject) - if err != nil { - return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject) - } - user, err := in.store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return 0, status.Errorf(codes.Unauthenticated, "failed to find user ID %q in the access token", userID) - } - if user == nil { - return 0, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID) - } - if user.RowStatus == store.Archived { - return 0, status.Errorf(codes.Unauthenticated, "user ID %q has been deactivated by administrators", userID) - } - - if generateToken { - generateTokenFunc := func() error { - // Parses token and checks if it's valid. - refreshTokenClaims := &claimsMessage{} - refreshToken, err := jwt.ParseWithClaims(refreshTokenStr, refreshTokenClaims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, status.Errorf(codes.Unauthenticated, "unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) - } - - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(in.secret), nil - } - } - return nil, errs.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) - }) - if err != nil { - if err == jwt.ErrSignatureInvalid { - return errs.Errorf("failed to generate access token: invalid refresh token signature") - } - return errs.Errorf("Server error to refresh expired token, user ID %d", userID) - } - - if !audienceContains(refreshTokenClaims.Audience, RefreshTokenAudienceName) { - return errs.Errorf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", - refreshTokenClaims.Audience, - RefreshTokenAudienceName, - ) - } - - // If we have a valid refresh token, we will generate new access token and refresh token - if refreshToken != nil && refreshToken.Valid { - if err := generateTokensAndSetCookies(ctx, user.Username, user.ID, in.secret); err != nil { - return errs.Wrapf(err, "failed to regenerate token") - } - } - - return nil - } - - // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token - // In such case, we won't return the error. - if err := generateTokenFunc(); err != nil && !accessToken.Valid { - return 0, status.Errorf(codes.Unauthenticated, err.Error()) - } - } - return userID, nil -} - -func getTokenFromMetadata(md metadata.MD) (string, string, error) { - authorizationHeaders := md.Get("Authorization") - if len(md.Get("Authorization")) > 0 { - authHeaderParts := strings.Fields(authorizationHeaders[0]) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", "", errs.Errorf("authorization header format must be Bearer {token}") - } - return authHeaderParts[1], "", nil - } - // check the HTTP cookie - var accessToken, refreshToken string - for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) { - header := http.Header{} - header.Add("Cookie", t) - request := http.Request{Header: header} - if v, _ := request.Cookie(AccessTokenCookieName); v != nil { - accessToken = v.Value - } - if v, _ := request.Cookie(RefreshTokenCookieName); v != nil { - refreshToken = v.Value - } - } - if accessToken != "" && refreshToken != "" { - return accessToken, refreshToken, nil - } - return "", "", nil -} - -func audienceContains(audience jwt.ClaimStrings, token string) bool { - for _, v := range audience { - if v == token { - return true - } - } - return false -} - -type claimsMessage struct { - Name string `json:"name"` - jwt.RegisteredClaims -} - -// generateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. -func generateTokensAndSetCookies(ctx context.Context, username string, userID int, secret string) error { - accessToken, err := GenerateAccessToken(username, userID, secret) - if err != nil { - return errs.Wrap(err, "failed to generate access token") - } - // We generate here a new refresh token and saving it to the cookie. - refreshToken, err := GenerateRefreshToken(username, userID, secret) - if err != nil { - return errs.Wrap(err, "failed to generate refresh token") - } - - if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ - AccessTokenCookieName: accessToken, - RefreshTokenCookieName: refreshToken, - })); err != nil { - return errs.Wrapf(err, "failed to set grpc header") - } - return nil -} - -// GenerateAccessToken generates an access token for web. -func GenerateAccessToken(username string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(accessTokenDuration) - return generateToken(username, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateRefreshToken generates a refresh token for web. -func GenerateRefreshToken(username string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(refreshTokenDuration) - return generateToken(username, userID, RefreshTokenAudienceName, expirationTime, []byte(secret)) -} - -// Pay attention to this function. It holds the main JWT token generation logic. -func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { - // Create the JWT claims, which includes the username and expiry time. - claims := &claimsMessage{ - Name: username, - RegisteredClaims: jwt.RegisteredClaims{ - Audience: jwt.ClaimStrings{aud}, - // In JWT, the expiry time is expressed as unix milliseconds. - ExpiresAt: jwt.NewNumericDate(expirationTime), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: issuer, - Subject: strconv.Itoa(userID), - }, - } - - // Declare the token with the HS256 algorithm used for signing, and the claims. - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - token.Header["kid"] = keyID - - // Create the JWT string. - tokenString, err := token.SignedString(secret) - if err != nil { - return "", err - } - - return tokenString, nil -} diff --git a/api/v2/auth/config.go b/api/v2/auth/config.go deleted file mode 100644 index 6f52b1b4..00000000 --- a/api/v2/auth/config.go +++ /dev/null @@ -1,15 +0,0 @@ -package auth - -import "strings" - -var authenticationAllowlistMethods = map[string]bool{ - "/memos.api.v2.UserService/GetUser": true, -} - -// IsAuthenticationAllowed returns whether the method is exempted from authentication. -func IsAuthenticationAllowed(fullMethodName string) bool { - if strings.HasPrefix(fullMethodName, "/grpc.reflection") { - return true - } - return authenticationAllowlistMethods[fullMethodName] -} diff --git a/api/v2/jwt.go b/api/v2/jwt.go new file mode 100644 index 00000000..7fdf3cad --- /dev/null +++ b/api/v2/jwt.go @@ -0,0 +1,193 @@ +package v2 + +import ( + "context" + "net/http" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" + "github.com/usememos/memos/api/auth" + "github.com/usememos/memos/store" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +var authenticationAllowlistMethods = map[string]bool{ + "/memos.api.v2.UserService/GetUser": true, +} + +// IsAuthenticationAllowed returns whether the method is exempted from authentication. +func IsAuthenticationAllowed(fullMethodName string) bool { + if strings.HasPrefix(fullMethodName, "/grpc.reflection") { + return true + } + return authenticationAllowlistMethods[fullMethodName] +} + +// ContextKey is the key type of context value. +type ContextKey int + +const ( + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + UserIDContextKey ContextKey = iota +) + +// GRPCAuthInterceptor is the auth interceptor for gRPC server. +type GRPCAuthInterceptor struct { + store *store.Store + secret string +} + +// NewGRPCAuthInterceptor returns a new API auth interceptor. +func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor { + return &GRPCAuthInterceptor{ + store: store, + secret: secret, + } +} + +// AuthenticationInterceptor is the unary interceptor for gRPC API. +func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") + } + accessTokenStr, err := getTokenFromMetadata(md) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, err.Error()) + } + + userID, err := in.authenticate(ctx, accessTokenStr) + if err != nil { + if IsAuthenticationAllowed(serverInfo.FullMethod) { + return handler(ctx, request) + } + return nil, err + } + + // Stores userID into context. + childCtx := context.WithValue(ctx, UserIDContextKey, userID) + return handler(childCtx, request) +} + +func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (int, error) { + if accessTokenStr == "" { + return 0, status.Errorf(codes.Unauthenticated, "access token not found") + } + claims := &claimsMessage{} + _, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(in.secret), nil + } + } + return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"]) + }) + if err != nil { + return 0, status.Errorf(codes.Unauthenticated, "Invalid or expired access token") + } + if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { + return 0, status.Errorf(codes.Unauthenticated, + "invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", + claims.Audience, + auth.AccessTokenAudienceName, + ) + } + + userID, err := strconv.Atoi(claims.Subject) + if err != nil { + return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject) + } + user, err := in.store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return 0, status.Errorf(codes.Unauthenticated, "failed to find user ID %q in the access token", userID) + } + if user == nil { + return 0, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID) + } + if user.RowStatus == store.Archived { + return 0, status.Errorf(codes.Unauthenticated, "user ID %q has been deactivated by administrators", userID) + } + + return userID, nil +} + +func getTokenFromMetadata(md metadata.MD) (string, error) { + authorizationHeaders := md.Get("Authorization") + if len(md.Get("Authorization")) > 0 { + authHeaderParts := strings.Fields(authorizationHeaders[0]) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.Errorf("authorization header format must be Bearer {token}") + } + return authHeaderParts[1], nil + } + // check the HTTP cookie + var accessToken string + for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) { + header := http.Header{} + header.Add("Cookie", t) + request := http.Request{Header: header} + if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil { + accessToken = v.Value + } + } + return accessToken, nil +} + +func audienceContains(audience jwt.ClaimStrings, token string) bool { + for _, v := range audience { + if v == token { + return true + } + } + return false +} + +type claimsMessage struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +// GenerateAccessToken generates an access token for web. +func GenerateAccessToken(username string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(auth.AccessTokenDuration) + return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { + // Create the JWT claims, which includes the username and expiry time. + claims := &claimsMessage{ + Name: username, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{aud}, + // In JWT, the expiry time is expressed as unix milliseconds. + ExpiresAt: jwt.NewNumericDate(expirationTime), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: auth.Issuer, + Subject: strconv.Itoa(userID), + }, + } + + // Declare the token with the HS256 algorithm used for signing, and the claims. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = auth.KeyID + + // Create the JWT string. + tokenString, err := token.SignedString(secret) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/api/v2/user_service.go b/api/v2/user_service.go index e3f2c3e8..c9b51403 100644 --- a/api/v2/user_service.go +++ b/api/v2/user_service.go @@ -3,7 +3,6 @@ package v2 import ( "context" - "github.com/usememos/memos/api/v2/auth" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" "github.com/usememos/memos/store" "google.golang.org/grpc/codes" @@ -46,7 +45,7 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err) } - userID, ok := ctx.Value(auth.UserIDContextKey).(int) + userID, ok := ctx.Value(UserIDContextKey).(int) if ok && userID == int(userMessage.Id) { for _, userSetting := range userSettings { userMessage.Settings = append(userMessage.Settings, convertUserSettingFromStore(userSetting)) diff --git a/api/v2/v2.go b/api/v2/v2.go index c5130694..b5f9aeb8 100644 --- a/api/v2/v2.go +++ b/api/v2/v2.go @@ -6,7 +6,6 @@ import ( grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/v2/auth" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" @@ -24,7 +23,7 @@ type APIV2Service struct { } func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store, grpcServerPort int) *APIV2Service { - authProvider := auth.NewGRPCAuthInterceptor(store, secret) + authProvider := NewGRPCAuthInterceptor(store, secret) grpcServer := grpc.NewServer( grpc.ChainUnaryInterceptor( authProvider.AuthenticationInterceptor, diff --git a/test/server/server.go b/test/server/server.go index ae48f7d9..79694e34 100644 --- a/test/server/server.go +++ b/test/server/server.go @@ -11,7 +11,7 @@ import ( "time" "github.com/pkg/errors" - "github.com/usememos/memos/api/v1/auth" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/server" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store"