package middleware import ( "fmt" "net/http" "strings" "git.trcreatives.at/trcreatives/go-kite" ) type CORSConfig struct { AllowedOrigins []string AllowedMethods []string AllowedHeaders []string ExposedHeaders []string AllowCredentials bool MaxAge int } func CORS(allowedOrigins ...string) kite.Middleware { return CORSWithConfig(CORSConfig{AllowedOrigins: allowedOrigins}) } func CORSWithConfig(cfg CORSConfig) kite.Middleware { if len(cfg.AllowedOrigins) == 0 { cfg.AllowedOrigins = []string{"*"} } if len(cfg.AllowedMethods) == 0 { cfg.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "CONNECT", "TRACE"} } if len(cfg.AllowedHeaders) == 0 { cfg.AllowedHeaders = []string{"Content-Type", "Authorization"} } allowAll := len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*" methods := strings.Join(cfg.AllowedMethods, ", ") headers := strings.Join(cfg.AllowedHeaders, ", ") exposed := strings.Join(cfg.ExposedHeaders, ", ") originSet := make(map[string]struct{}, len(cfg.AllowedOrigins)) for _, o := range cfg.AllowedOrigins { originSet[o] = struct{}{} } return func(h kite.Handler) kite.Handler { return func(ctx *kite.Context) error { origin := ctx.GetHeader("Origin") allowOrigin := "" switch { case allowAll && !cfg.AllowCredentials: allowOrigin = "*" case allowAll && cfg.AllowCredentials: allowOrigin = origin default: if _, ok := originSet[origin]; ok { allowOrigin = origin } } if allowOrigin != "" { ctx.SetHeader("Access-Control-Allow-Origin", allowOrigin) ctx.AddHeader("Vary", "Origin") } ctx.SetHeader("Access-Control-Allow-Methods", methods) ctx.SetHeader("Access-Control-Allow-Headers", headers) if exposed != "" { ctx.SetHeader("Access-Control-Expose-Headers", exposed) } if cfg.AllowCredentials { ctx.SetHeader("Access-Control-Allow-Credentials", "true") } if cfg.MaxAge > 0 { ctx.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", cfg.MaxAge)) } if ctx.IsMethod(http.MethodOptions) { return ctx.WriteNoContent() } return h(ctx) } } }