Files
go-kite/kite.go
Timo Riegebauer ae5d1f610a feat: middleware reach, response writer interfaces, and CORS fixes
- Run global middleware for all requests, including OPTIONS preflights, NotFound, and MethodNotAllowed — previously bypassed by httprouter's internal handling
- Implement Hijacker, Flusher, and Pusher on the response writer for WebSocket, SSE, and HTTP/2 push support
- Fix CORS: echo a single matching origin, handle AllowCredentials with wildcard, append Vary: Origin
- Logger logs from a defer to capture correct status on panicked requests
- Static and StaticFS accept route middleware; add Context.AddHeader; warn on NotFound handler override
2026-05-17 17:51:56 +02:00

284 lines
6.9 KiB
Go

package kite
import (
"context"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/julienschmidt/httprouter"
)
type Kite struct {
r *httprouter.Router
errorHandler ErrorHandler
mws []Middleware
customNotFound bool
serverReadTimeout time.Duration
serverWriteTimeout time.Duration
serverIdleTimeout time.Duration
serverShutdownTimeout time.Duration
}
type kiteContextKey struct{}
func New() *Kite {
k := &Kite{
r: httprouter.New(),
serverReadTimeout: 10 * time.Second,
serverWriteTimeout: 30 * time.Second,
serverIdleTimeout: 60 * time.Second,
serverShutdownTimeout: 30 * time.Second,
}
k.SetErrorHandler(defaultErrorHandler)
k.setNotFoundHandler(defaultNotFoundHandler)
k.SetMethodNotAllowedHandler(defaultMethodNotAllowedHandler)
return k
}
func (k *Kite) SetErrorHandler(errorHandler ErrorHandler) {
k.errorHandler = errorHandler
}
func (k *Kite) setNotFoundHandler(notFoundHandler NotFoundHandler) {
k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context().Value(kiteContextKey{}).(*Context)
if err := notFoundHandler(ctx); err != nil {
if err := k.errorHandler(ctx, err); err != nil {
log.Printf("[Kite] Error handler failed: %v\n", err)
}
}
})
}
func (k *Kite) SetNotFoundHandler(notFoundHandler NotFoundHandler) {
if k.customNotFound {
log.Println("[Kite] SetNotFoundHandler is overriding a previously configured NotFound handler")
}
k.setNotFoundHandler(notFoundHandler)
k.customNotFound = true
}
func (k *Kite) SetMethodNotAllowedHandler(methodNotAllowedHandler MethodNotAllowedHandler) {
k.r.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context().Value(kiteContextKey{}).(*Context)
if err := methodNotAllowedHandler(ctx); err != nil {
if err := k.errorHandler(ctx, err); err != nil {
log.Printf("[Kite] Error handler failed: %v\n", err)
}
}
})
}
func (k *Kite) SetServerReadTimeout(d time.Duration) {
k.serverReadTimeout = d
}
func (k *Kite) SetServerWriteTimeout(d time.Duration) {
k.serverWriteTimeout = d
}
func (k *Kite) SetServerIdleTimeout(d time.Duration) {
k.serverIdleTimeout = d
}
func (k *Kite) SetServerShutdownTimeout(d time.Duration) {
k.serverShutdownTimeout = d
}
func (k *Kite) Start(listenAddr string) error {
srv := &http.Server{
Addr: listenAddr,
Handler: k.buildHandler(),
ReadTimeout: k.serverReadTimeout,
WriteTimeout: k.serverWriteTimeout,
IdleTimeout: k.serverIdleTimeout,
}
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Printf("[Kite] Server error: %v\n", err)
}
}()
log.Printf("[Kite] Server started and listening on http://%s\n", listenAddr)
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("[Kite] Shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), k.serverShutdownTimeout)
defer cancel()
return srv.Shutdown(ctx)
}
func (k *Kite) StartTLS(listenAddr, certFile, keyFile string) error {
srv := &http.Server{
Addr: listenAddr,
Handler: k.buildHandler(),
ReadTimeout: k.serverReadTimeout,
WriteTimeout: k.serverWriteTimeout,
IdleTimeout: k.serverIdleTimeout,
}
go func() {
if err := srv.ListenAndServeTLS(certFile, keyFile); err != nil && err != http.ErrServerClosed {
log.Printf("[Kite] Server error: %v\n", err)
}
}()
log.Printf("[Kite] Server started and listening on https://%s\n", listenAddr)
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("[Kite] Shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), k.serverShutdownTimeout)
defer cancel()
return srv.Shutdown(ctx)
}
func (k *Kite) SPA(root string) {
k.SPAFS(http.Dir(root))
}
func (k *Kite) SPAFS(fs http.FileSystem) {
if k.customNotFound {
log.Println("[Kite] SPAFS is overriding a previously configured NotFound handler")
}
k.customNotFound = true
fileServer := http.FileServer(fs)
k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f, err := fs.Open(r.URL.Path)
if err != nil {
r.URL.Path = "/"
fileServer.ServeHTTP(w, r)
return
}
f.Close()
fileServer.ServeHTTP(w, r)
})
}
func (k *Kite) Static(prefix, root string, mws ...Middleware) {
k.StaticFS(prefix, http.Dir(root), mws...)
}
func (k *Kite) StaticFS(prefix string, fs http.FileSystem, mws ...Middleware) {
if prefix == "" || prefix[len(prefix)-1] != '/' {
prefix += "/"
}
handler := http.StripPrefix(prefix, http.FileServer(fs))
h := func(ctx *Context) error {
handler.ServeHTTP(ctx.w, ctx.r)
return nil
}
k.handle(http.MethodGet, prefix+"*filepath", h, mws...)
k.handle(http.MethodHead, prefix+"*filepath", h, mws...)
}
func (k *Kite) Group(prefix string, mws ...Middleware) *Group {
return &Group{
k: k,
prefix: prefix,
mws: mws,
}
}
func (k *Kite) Use(mws ...Middleware) {
k.mws = append(k.mws, mws...)
}
func (k *Kite) CONNECT(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodConnect, path, h, mws...)
}
func (k *Kite) DELETE(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodDelete, path, h, mws...)
}
func (k *Kite) GET(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodGet, path, h, mws...)
}
func (k *Kite) HEAD(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodHead, path, h, mws...)
}
func (k *Kite) OPTIONS(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodOptions, path, h, mws...)
}
func (k *Kite) POST(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodPost, path, h, mws...)
}
func (k *Kite) PUT(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodPut, path, h, mws...)
}
func (k *Kite) TRACE(path string, h Handler, mws ...Middleware) {
k.handle(http.MethodTrace, path, h, mws...)
}
func (k *Kite) buildHandler() http.Handler {
inner := func(ctx *Context) error {
ctx.r = ctx.r.WithContext(context.WithValue(ctx.r.Context(), kiteContextKey{}, ctx))
k.r.ServeHTTP(ctx.w, ctx.r)
return nil
}
wrapped := inner
for i := len(k.mws) - 1; i >= 0; i-- {
wrapped = k.mws[i](wrapped)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := &Context{
w: newResponseWriter(w),
r: r,
}
if err := wrapped(ctx); err != nil {
if err := k.errorHandler(ctx, err); err != nil {
log.Printf("[Kite] Error handler failed: %v\n", err)
}
}
})
}
func (k *Kite) handle(method, path string, h Handler, mws ...Middleware) {
k.r.Handle(method, path, func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
ctx := r.Context().Value(kiteContextKey{}).(*Context)
ctx.p = p
wrappedHandler := h
for i := len(mws) - 1; i >= 0; i-- {
wrappedHandler = mws[i](wrappedHandler)
}
if err := wrappedHandler(ctx); err != nil {
if err := k.errorHandler(ctx, err); err != nil {
log.Printf("[Kite] Error handler failed: %v\n", err)
}
}
})
}