chore: add origin flag to config cors

This commit is contained in:
Steven 2024-04-07 22:15:15 +08:00
parent b5893aa60b
commit 8101a5e0b1
3 changed files with 32 additions and 12 deletions

View File

@ -31,18 +31,19 @@ const (
)
var (
profile *_profile.Profile
mode string
addr string
port int
data string
driver string
dsn string
serveFrontend bool
profile *_profile.Profile
mode string
addr string
port int
data string
driver string
dsn string
serveFrontend bool
allowedOrigins []string
rootCmd = &cobra.Command{
Use: "memos",
Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`,
Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
Run: func(_cmd *cobra.Command, _args []string) {
ctx, cancel := context.WithCancel(context.Background())
dbDriver, err := db.NewDBDriver(profile)
@ -114,6 +115,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files")
rootCmd.PersistentFlags().StringArrayVarP(&allowedOrigins, "origins", "", []string{}, "CORS allowed domain origins")
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
if err != nil {
@ -143,12 +145,17 @@ func init() {
if err != nil {
panic(err)
}
err = viper.BindPFlag("origins", rootCmd.PersistentFlags().Lookup("origins"))
if err != nil {
panic(err)
}
viper.SetDefault("mode", "demo")
viper.SetDefault("driver", "sqlite")
viper.SetDefault("addr", "")
viper.SetDefault("port", 8081)
viper.SetDefault("frontend", true)
viper.SetDefault("origins", []string{})
viper.SetEnvPrefix("memos")
}

View File

@ -32,6 +32,8 @@ type Profile struct {
Version string `json:"version"`
// Frontend indicate the frontend is enabled or not
Frontend bool `json:"-"`
// Origins is the list of allowed origins
Origins []string `json:"-"`
}
func (p *Profile) IsDev() bool {

View File

@ -49,7 +49,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
}
// Register CORS middleware.
e.Use(CORSMiddleware())
e.Use(CORSMiddleware(s.Profile.Origins))
serverID, err := s.getSystemServerID(ctx)
if err != nil {
@ -160,7 +160,7 @@ func grpcRequestSkipper(c echo.Context) bool {
return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.")
}
func CORSMiddleware() echo.MiddlewareFunc {
func CORSMiddleware(origins []string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if grpcRequestSkipper(c) {
@ -170,7 +170,18 @@ func CORSMiddleware() echo.MiddlewareFunc {
r := c.Request()
w := c.Response().Writer
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
requestOrigin := r.Header.Get("Origin")
if len(origins) == 0 {
w.Header().Set("Access-Control-Allow-Origin", requestOrigin)
} else {
for _, origin := range origins {
if origin == requestOrigin {
w.Header().Set("Access-Control-Allow-Origin", origin)
break
}
}
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true")