Files
go-kite/middleware/cors.go
Timo Riegebauer ae5d1f610a feat: middleware reach, response writer interfaces, and CORS fixes
- Run global middleware for all requests, including OPTIONS preflights, NotFound, and MethodNotAllowed — previously bypassed by httprouter's internal handling
- Implement Hijacker, Flusher, and Pusher on the response writer for WebSocket, SSE, and HTTP/2 push support
- Fix CORS: echo a single matching origin, handle AllowCredentials with wildcard, append Vary: Origin
- Logger logs from a defer to capture correct status on panicked requests
- Static and StaticFS accept route middleware; add Context.AddHeader; warn on NotFound handler override
2026-05-17 17:51:56 +02:00

91 lines
2.1 KiB
Go

package middleware
import (
"fmt"
"net/http"
"strings"
"git.trcreatives.at/trcreatives/go-kite"
)
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
ExposedHeaders []string
AllowCredentials bool
MaxAge int
}
func CORS(allowedOrigins ...string) kite.Middleware {
return CORSWithConfig(CORSConfig{AllowedOrigins: allowedOrigins})
}
func CORSWithConfig(cfg CORSConfig) kite.Middleware {
if len(cfg.AllowedOrigins) == 0 {
cfg.AllowedOrigins = []string{"*"}
}
if len(cfg.AllowedMethods) == 0 {
cfg.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "CONNECT", "TRACE"}
}
if len(cfg.AllowedHeaders) == 0 {
cfg.AllowedHeaders = []string{"Content-Type", "Authorization"}
}
allowAll := len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*"
methods := strings.Join(cfg.AllowedMethods, ", ")
headers := strings.Join(cfg.AllowedHeaders, ", ")
exposed := strings.Join(cfg.ExposedHeaders, ", ")
originSet := make(map[string]struct{}, len(cfg.AllowedOrigins))
for _, o := range cfg.AllowedOrigins {
originSet[o] = struct{}{}
}
return func(h kite.Handler) kite.Handler {
return func(ctx *kite.Context) error {
origin := ctx.GetHeader("Origin")
allowOrigin := ""
switch {
case allowAll && !cfg.AllowCredentials:
allowOrigin = "*"
case allowAll && cfg.AllowCredentials:
allowOrigin = origin
default:
if _, ok := originSet[origin]; ok {
allowOrigin = origin
}
}
if allowOrigin != "" {
ctx.SetHeader("Access-Control-Allow-Origin", allowOrigin)
ctx.AddHeader("Vary", "Origin")
}
ctx.SetHeader("Access-Control-Allow-Methods", methods)
ctx.SetHeader("Access-Control-Allow-Headers", headers)
if exposed != "" {
ctx.SetHeader("Access-Control-Expose-Headers", exposed)
}
if cfg.AllowCredentials {
ctx.SetHeader("Access-Control-Allow-Credentials", "true")
}
if cfg.MaxAge > 0 {
ctx.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", cfg.MaxAge))
}
if ctx.IsMethod(http.MethodOptions) {
return ctx.WriteNoContent()
}
return h(ctx)
}
}
}