From d50ad9433fae7302816641990b59fab49e4cb1f9 Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 3 Jan 2023 23:05:42 +0800 Subject: [PATCH] feat: persistent session name (#902) * feat: persistent session name * chore: update --- api/system_setting.go | 4 ++ bin/server/main.go | 19 ++-------- go.mod | 2 +- server/server.go | 70 +++++++++++++++++------------------ server/system.go | 46 +++++++++++++++++++++++ store/db/db.go | 10 ++--- store/db/migration_history.go | 4 +- 7 files changed, 96 insertions(+), 59 deletions(-) diff --git a/api/system_setting.go b/api/system_setting.go index 7e4f69f6..e4136841 100644 --- a/api/system_setting.go +++ b/api/system_setting.go @@ -13,6 +13,8 @@ type SystemSettingName string const ( // SystemSettingServerID is the key type of server id. SystemSettingServerID SystemSettingName = "serverId" + // SystemSettingSecretSessionName is the key type of secret session name. + SystemSettingSecretSessionName SystemSettingName = "secretSessionName" // SystemSettingAllowSignUpName is the key type of allow signup setting. SystemSettingAllowSignUpName SystemSettingName = "allowSignUp" // SystemSettingAdditionalStyleName is the key type of additional style. @@ -43,6 +45,8 @@ func (key SystemSettingName) String() string { switch key { case SystemSettingServerID: return "serverId" + case SystemSettingSecretSessionName: + return "secretSessionName" case SystemSettingAllowSignUpName: return "allowSignUp" case SystemSettingAdditionalStyleName: diff --git a/bin/server/main.go b/bin/server/main.go index ccae898c..0bfb1755 100644 --- a/bin/server/main.go +++ b/bin/server/main.go @@ -4,15 +4,13 @@ import ( "os" _ "github.com/mattn/go-sqlite3" + "github.com/pkg/errors" "context" "fmt" "github.com/usememos/memos/server" "github.com/usememos/memos/server/profile" - "github.com/usememos/memos/store" - - DB "github.com/usememos/memos/store/db" ) const ( @@ -40,20 +38,11 @@ func run() error { println("version:", profile.Version) println("---") - db := DB.NewDB(profile) - if err := db.Open(ctx); err != nil { - return fmt.Errorf("cannot open db: %w", err) + serverInstance, err := server.NewServer(ctx, profile) + if err != nil { + return errors.Wrap(err, "failed to start server") } - serverInstance := server.NewServer(profile) - storeInstance := store.New(db.Db, profile) - serverInstance.Store = storeInstance - - metricCollector := server.NewMetricCollector(profile, storeInstance) - // Disable metrics collector. - metricCollector.Enabled = false - serverInstance.Collector = &metricCollector - println(greetingBanner) fmt.Printf("Version %s has started at :%d\n", profile.Version, profile.Port) return serverInstance.Run(ctx) diff --git a/go.mod b/go.mod index 121944b1..674c26c3 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require github.com/labstack/echo/v4 v4.9.0 require ( github.com/VictoriaMetrics/fastcache v1.10.0 github.com/gorilla/feeds v1.1.1 - github.com/gorilla/securecookie v1.1.1 + github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 github.com/labstack/echo-contrib v0.13.0 github.com/stretchr/testify v1.8.1 diff --git a/server/server.go b/server/server.go index 2b83a93c..0028f521 100644 --- a/server/server.go +++ b/server/server.go @@ -6,14 +6,12 @@ import ( "fmt" "time" - "github.com/google/uuid" "github.com/pkg/errors" "github.com/usememos/memos/api" - "github.com/usememos/memos/common" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" + "github.com/usememos/memos/store/db" - "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" @@ -23,16 +21,13 @@ import ( type Server struct { e *echo.Echo - ID string - + ID string + Profile *profile.Profile + Store *store.Store Collector *MetricCollector - - Profile *profile.Profile - - Store *store.Store } -func NewServer(profile *profile.Profile) *Server { +func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { e := echo.New() e.Debug = true e.HideBanner = true @@ -43,6 +38,19 @@ func NewServer(profile *profile.Profile) *Server { Profile: profile, } + db := db.NewDB(profile) + if err := db.Open(ctx); err != nil { + return nil, errors.Wrap(err, "cannot open db") + } + + storeInstance := store.New(db.DBInstance, profile) + s.Store = storeInstance + + metricCollector := NewMetricCollector(profile, storeInstance) + // Disable metrics collector. + metricCollector.Enabled = false + s.Collector = &metricCollector + e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ Format: `{"time":"${time_rfc3339}",` + `"method":"${method}","uri":"${uri}",` + @@ -68,14 +76,22 @@ func NewServer(profile *profile.Profile) *Server { Timeout: 30 * time.Second, })) - embedFrontend(e) - - // In dev mode, set the const secret key to make signin session persistence. - secret := []byte("usememos") - if profile.Mode == "prod" { - secret = securecookie.GenerateRandomKey(16) + serverID, err := s.getSystemServerID(ctx) + if err != nil { + return nil, err } - e.Use(session.Middleware(sessions.NewCookieStore(secret))) + s.ID = serverID + + secretSessionName := "usememos" + if profile.Mode == "prod" { + secretSessionName, err = s.getSystemSecretSessionName(ctx) + if err != nil { + return nil, err + } + } + e.Use(session.Middleware(sessions.NewCookieStore([]byte(secretSessionName)))) + + embedFrontend(e) rootGroup := e.Group("") s.registerRSSRoutes(rootGroup) @@ -99,28 +115,10 @@ func NewServer(profile *profile.Profile) *Server { s.registerResourceRoutes(apiGroup) s.registerTagRoutes(apiGroup) - return s + return s, nil } func (s *Server) Run(ctx context.Context) error { - serverIDKey := api.SystemSettingServerID - serverIDValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ - Name: &serverIDKey, - }) - if err != nil && common.ErrorCode(err) != common.NotFound { - return err - } - if serverIDValue == nil || serverIDValue.Value == "" { - serverIDValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ - Name: serverIDKey, - Value: uuid.NewString(), - }) - if err != nil { - return err - } - } - s.ID = serverIDValue.Value - if err := s.createServerStartActivity(ctx); err != nil { return errors.Wrap(err, "failed to create activity") } diff --git a/server/system.go b/server/system.go index cc75c3a0..7cbb8ec4 100644 --- a/server/system.go +++ b/server/system.go @@ -1,10 +1,12 @@ package server import ( + "context" "encoding/json" "net/http" "os" + "github.com/google/uuid" "github.com/usememos/memos/api" "github.com/usememos/memos/common" @@ -61,6 +63,10 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err) } for _, systemSetting := range systemSettingList { + if systemSetting.Name == api.SystemSettingServerID || systemSetting.Name == api.SystemSettingSecretSessionName { + continue + } + var value interface{} err := json.Unmarshal([]byte(systemSetting.Value), &value) if err != nil { @@ -195,3 +201,43 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { return nil }) } + +func (s *Server) getSystemServerID(ctx context.Context) (string, error) { + serverIDKey := api.SystemSettingServerID + serverIDValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ + Name: &serverIDKey, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return "", err + } + if serverIDValue == nil || serverIDValue.Value == "" { + serverIDValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ + Name: serverIDKey, + Value: uuid.NewString(), + }) + if err != nil { + return "", err + } + } + return serverIDValue.Value, nil +} + +func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error) { + secretSessionNameKey := api.SystemSettingSecretSessionName + secretSessionNameValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ + Name: &secretSessionNameKey, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return "", err + } + if secretSessionNameValue == nil || secretSessionNameValue.Value == "" { + secretSessionNameValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ + Name: secretSessionNameKey, + Value: uuid.NewString(), + }) + if err != nil { + return "", err + } + } + return secretSessionNameValue.Value, nil +} diff --git a/store/db/db.go b/store/db/db.go index 5ab544b7..e6c47593 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -24,8 +24,8 @@ var seedFS embed.FS type DB struct { // sqlite db connection instance - Db *sql.DB - profile *profile.Profile + DBInstance *sql.DB + profile *profile.Profile } // NewDB returns a new instance of DB associated with the given datasource name. @@ -47,7 +47,7 @@ func (db *DB) Open(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("failed to open db with dsn: %s, err: %w", db.profile.DSN, err) } - db.Db = sqliteDB + db.DBInstance = sqliteDB if db.profile.Mode == "dev" { // In dev mode, we should migrate and seed the database. @@ -156,7 +156,7 @@ func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion st } } - tx, err := db.Db.Begin() + tx, err := db.DBInstance.Begin() if err != nil { return err } @@ -197,7 +197,7 @@ func (db *DB) seed(ctx context.Context) error { // execute runs a single SQL statement within a transaction. func (db *DB) execute(ctx context.Context, stmt string) error { - tx, err := db.Db.Begin() + tx, err := db.DBInstance.Begin() if err != nil { return err } diff --git a/store/db/migration_history.go b/store/db/migration_history.go index b8207613..33e95c5e 100644 --- a/store/db/migration_history.go +++ b/store/db/migration_history.go @@ -20,7 +20,7 @@ type MigrationHistoryFind struct { } func (db *DB) FindMigrationHistory(ctx context.Context, find *MigrationHistoryFind) (*MigrationHistory, error) { - tx, err := db.Db.BeginTx(ctx, nil) + tx, err := db.DBInstance.BeginTx(ctx, nil) if err != nil { return nil, err } @@ -40,7 +40,7 @@ func (db *DB) FindMigrationHistory(ctx context.Context, find *MigrationHistoryFi } func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - tx, err := db.Db.BeginTx(ctx, nil) + tx, err := db.DBInstance.BeginTx(ctx, nil) if err != nil { return nil, err }