From 861eeb7b0fcda967e99293c5fc26e55c0009f524 Mon Sep 17 00:00:00 2001 From: boojack Date: Sun, 1 Jan 2023 21:32:17 +0800 Subject: [PATCH] chore: add skipper in CSRF (#885) --- server/server.go | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/server/server.go b/server/server.go index b66e5198..3be61e46 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,8 @@ import ( "fmt" "time" + "github.com/usememos/memos/api" + "github.com/usememos/memos/common" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" @@ -30,6 +32,11 @@ func NewServer(profile *profile.Profile) *Server { e.HideBanner = true e.HidePort = true + s := &Server{ + e: e, + Profile: profile, + } + e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ Format: `{"time":"${time_rfc3339}",` + `"method":"${method}","uri":"${uri}",` + @@ -37,6 +44,7 @@ func NewServer(profile *profile.Profile) *Server { })) e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ + Skipper: s.OpenAPISkipper, TokenLookup: "cookie:_csrf", })) @@ -59,11 +67,6 @@ func NewServer(profile *profile.Profile) *Server { } e.Use(session.Middleware(sessions.NewCookieStore(secret))) - s := &Server{ - e: e, - Profile: profile, - } - rootGroup := e.Group("") s.registerRSSRoutes(rootGroup) @@ -92,3 +95,32 @@ func NewServer(profile *profile.Profile) *Server { func (server *Server) Run() error { return server.e.Start(fmt.Sprintf(":%d", server.Profile.Port)) } + +func (server *Server) OpenAPISkipper(c echo.Context) bool { + ctx := c.Request().Context() + path := c.Path() + + // Skip auth. + if common.HasPrefixes(path, "/api/auth") { + return true + } + + // If there is openId in query string and related user is found, then skip auth. + openID := c.QueryParam("openId") + if openID != "" { + userFind := &api.UserFind{ + OpenID: &openID, + } + user, err := server.Store.FindUser(ctx, userFind) + if err != nil && common.ErrorCode(err) != common.NotFound { + return false + } + if user != nil { + // Stores userID into context. + c.Set(getUserIDContextKey(), user.ID) + return true + } + } + + return false +}