diff --git a/README.md b/README.md index 07ea1bf..1c0dcd6 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,12 @@ A fast, lightweight, and expressive HTTP framework for Go, built on top of [http - **Fast** — ~132k RPS, sitting just below raw `net/http` and above Gin, Echo, and Chi - **Middleware** — global, group, and route-level middleware with correct onion ordering - **Groups** — nestable route groups with prefix and middleware inheritance +- **Static files & SPA** — serve assets and single-page apps with optional middleware - **Graceful shutdown** — in-flight requests finish cleanly on `SIGTERM` / `SIGINT` -- **Configurable** — server timeouts, error handlers, not-found handlers all configurable +- **Configurable** — server timeouts, error/not-found/method-not-allowed handlers - **Rich context** — typed request/response helpers, value store, cookie and form support - **Built-in middleware** — logger, recovery, CORS, request ID, max body size +- **WebSocket & SSE ready** — response writer implements `Hijacker`, `Flusher`, and `Pusher` ## Installation @@ -37,8 +39,8 @@ func main() { k := kite.New() k.Use( + middleware.Logger(), // first, so it observes the final status after Recovery middleware.Recovery(), - middleware.Logger(), middleware.RequestID(), ) @@ -66,7 +68,7 @@ k.PUT("/users/:id", updateUser) k.DELETE("/users/:id", deleteUser) ``` -All standard HTTP methods are supported: `GET`, `POST`, `PUT`, `DELETE`, `HEAD`, `OPTIONS`, `CONNECT`, `TRACE`, `PATCH`. +Supported HTTP methods: `GET`, `POST`, `PUT`, `DELETE`, `HEAD`, `OPTIONS`, `CONNECT`, `TRACE`. ## Route Groups @@ -94,6 +96,46 @@ admin := api.Group("/admin", adminOnlyMiddleware) admin.GET("/stats", getStats) // chain: global → auth → adminOnly → handler ``` +## Static Files & Single-Page Apps + +### Static Assets + +`Static` serves files from a directory on disk under a URL prefix: + +```go +k.Static("/assets", "./public") +// GET /assets/css/main.css → ./public/css/main.css +``` + +Middleware can be applied per-prefix: + +```go +k.Static("/admin/assets", "./admin-dist", authMiddleware) +``` + +`StaticFS` is the same but takes an `http.FileSystem`, useful for embedded assets via `embed`: + +```go +//go:embed dist +var distFS embed.FS + +k.StaticFS("/assets", http.FS(distFS)) +``` + +Both are also available on `Group` for prefixed/scoped serving. + +### Single-Page Applications + +For Vue, React, Svelte, and similar frameworks, `SPA` serves the directory and falls back to `index.html` for unknown paths — letting client-side routing work correctly: + +```go +k.SPA("./dist") +``` + +`SPAFS` takes an `http.FileSystem` for embedded builds. + +> **Note:** `SPA` and `SPAFS` register the file server as the NotFound handler. If you've already called `SetNotFoundHandler` (or call it afterward), a warning is logged about the override. + ## Middleware Middleware follows the standard onion model — each middleware wraps the next. @@ -104,11 +146,11 @@ type Middleware func(h Handler) Handler ### Global Middleware -Applied to every route: +Applied to every request, including NotFound, MethodNotAllowed, and OPTIONS preflights: ```go -k.Use(middleware.Recovery()) k.Use(middleware.Logger()) +k.Use(middleware.Recovery()) ``` ### Route Middleware @@ -133,6 +175,20 @@ k.GET("/admin", adminHandler, authMiddleware, rateLimitMiddleware) ← Global 1 ``` +The first middleware registered is the outermost — it runs first on the way in, and last on the way out. + +### Recommended Order + +Register `Logger` **before** `Recovery`. The reverse order causes `Logger`'s deferred log line to read the status before `Recovery` writes the 500, which means panicked requests log as `200`. + +```go +k.Use( + middleware.Logger(), // outermost — sees the final status + middleware.Recovery(), // catches panics from inner handlers + middleware.RequestID(), +) +``` + ### Writing Custom Middleware ```go @@ -156,7 +212,7 @@ import "git.trcreatives.at/trcreatives/go-kite/middleware" ### Logger -Logs method, path, status code and latency for every request. +Logs method, path, status code and latency for every request. Logs from a `defer`, so panicked and errored requests are captured too (provided Logger is registered before Recovery). ```go // simple @@ -179,7 +235,7 @@ k.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ Catches panics and returns a `500` instead of crashing the server. ```go -// simple — always register this first +// simple k.Use(middleware.Recovery()) // custom — integrate with Sentry, Datadog, etc. @@ -207,7 +263,7 @@ k.Use(middleware.RequestIDWithConfig(middleware.RequestIDConfig{ })) // retrieve in handler -id := ctx.GetValue("X-Request-ID") +id := ctx.GetValue("X-Request-ID").(string) ``` ### CORS @@ -215,11 +271,11 @@ id := ctx.GetValue("X-Request-ID") Sets Cross-Origin Resource Sharing headers and handles preflight requests. ```go -// allow all origins +// allow all origins (no credentials) k.Use(middleware.CORS()) -// restrict origins -k.Use(middleware.CORS("https://myapp.com")) +// restrict to specific origins +k.Use(middleware.CORS("https://myapp.com", "https://staging.myapp.com")) // full control k.Use(middleware.CORSWithConfig(middleware.CORSConfig{ @@ -232,7 +288,9 @@ k.Use(middleware.CORSWithConfig(middleware.CORSConfig{ })) ``` -> **Note:** `AllowCredentials: true` cannot be combined with `AllowedOrigins: ["*"]` — browsers will reject it. +For multi-origin allowlists, the middleware checks the request's `Origin` against the list and echoes back only the matching one — `Access-Control-Allow-Origin` always contains a single origin, never a comma-separated list. A `Vary: Origin` header is added so caches behave correctly. + +> **Note:** When `AllowCredentials: true` is combined with `AllowedOrigins: ["*"]`, the middleware echoes the request's `Origin` instead of sending `*`, since browsers reject that combination. ### MaxBodySize @@ -292,16 +350,17 @@ Every handler receives a `*kite.Context` with the following methods: ### Response -| Method | Description | -| ----------------------------------------- | ------------------- | -| `WriteBytes(status int, v []byte) error` | Raw bytes response | -| `WriteString(status int, v string) error` | Plain text response | -| `WriteJSON(status int, v any) error` | JSON response | -| `WriteXML(status int, v any) error` | XML response | -| `WriteNoContent() error` | 204 No Content | -| `Redirect(status int, url string) error` | Redirect | -| `SetHeader(key, value string)` | Set response header | -| `SetCookie(cookie *http.Cookie)` | Set response cookie | +| Method | Description | +| ----------------------------------------- | ---------------------------------------- | +| `WriteBytes(status int, v []byte) error` | Raw bytes response | +| `WriteString(status int, v string) error` | Plain text response (UTF-8) | +| `WriteJSON(status int, v any) error` | JSON response (UTF-8) | +| `WriteXML(status int, v any) error` | XML response (UTF-8) | +| `WriteNoContent() error` | 204 No Content | +| `Redirect(status int, url string) error` | Redirect | +| `SetHeader(key, value string)` | Set response header (replaces existing) | +| `AddHeader(key, value string)` | Append response header (for multi-value) | +| `SetCookie(cookie *http.Cookie)` | Set response cookie | ## Error Handling diff --git a/context.go b/context.go index c022df8..7bdcc6b 100644 --- a/context.go +++ b/context.go @@ -16,10 +16,9 @@ import ( type contextKey string type Context struct { - ctx context.Context - w *responseWriter - r *http.Request - p httprouter.Params + w *responseWriter + r *http.Request + p httprouter.Params } func (c *Context) GetRequest() *http.Request { @@ -35,7 +34,7 @@ func (c *Context) GetStatusCode() int { } func (c *Context) GetContext() context.Context { - return c.ctx + return c.r.Context() } func (c *Context) GetPathParam(key string) string { @@ -47,11 +46,15 @@ func (c *Context) GetQueryParam(key string) string { } func (c *Context) SetValue(key string, v any) { - c.ctx = context.WithValue(c.ctx, contextKey(key), v) + c.r = c.r.WithContext(context.WithValue(c.r.Context(), contextKey(key), v)) } func (c *Context) GetValue(key string) any { - return c.ctx.Value(contextKey(key)) + return c.r.Context().Value(contextKey(key)) +} + +func (c *Context) AddHeader(key, v string) { + c.w.Header().Add(key, v) } func (c *Context) SetHeader(key, v string) { @@ -126,20 +129,20 @@ func (c *Context) WriteBytes(statusCode int, v []byte) error { } func (c *Context) WriteString(statusCode int, v string) error { - c.w.Header().Set("Content-Type", "text/plain") + c.w.Header().Set("Content-Type", "text/plain; charset=utf-8") c.w.WriteHeader(statusCode) _, err := c.w.Write([]byte(v)) return err } func (c *Context) WriteJSON(statusCode int, v any) error { - c.w.Header().Set("Content-Type", "application/json") + c.w.Header().Set("Content-Type", "application/json; charset=utf-8") c.w.WriteHeader(statusCode) return json.NewEncoder(c.w).Encode(v) } func (c *Context) WriteXML(statusCode int, v any) error { - c.w.Header().Set("Content-Type", "application/xml") + c.w.Header().Set("Content-Type", "application/xml; charset=utf-8") c.w.WriteHeader(statusCode) return xml.NewEncoder(c.w).Encode(v) } diff --git a/group.go b/group.go index f8885f3..0db50f4 100644 --- a/group.go +++ b/group.go @@ -8,19 +8,25 @@ type Group struct { mws []Middleware } -func (g *Group) Static(prefix, root string) { - g.StaticFS(prefix, http.Dir(root)) +func (g *Group) Static(prefix, root string, mws ...Middleware) { + g.StaticFS(prefix, http.Dir(root), mws...) } -func (g *Group) StaticFS(prefix string, fs http.FileSystem) { +func (g *Group) StaticFS(prefix string, fs http.FileSystem, mws ...Middleware) { if prefix == "" || prefix[len(prefix)-1] != '/' { prefix += "/" } fullPrefix := g.prefix + prefix handler := http.StripPrefix(fullPrefix, http.FileServer(fs)) - g.k.r.Handler(http.MethodGet, fullPrefix+"*filepath", handler) - g.k.r.Handler(http.MethodHead, fullPrefix+"*filepath", handler) + + h := func(ctx *Context) error { + handler.ServeHTTP(ctx.w, ctx.r) + return nil + } + + g.k.handle(http.MethodGet, fullPrefix+"*filepath", h, mws...) + g.k.handle(http.MethodHead, fullPrefix+"*filepath", h, mws...) } func (g *Group) Group(prefix string, mws ...Middleware) *Group { @@ -68,5 +74,9 @@ func (g *Group) TRACE(path string, h Handler, mws ...Middleware) { } func (g *Group) handle(method, path string, h Handler, mws ...Middleware) { - g.k.handle(method, g.prefix+path, h, append(g.mws, mws...)...) + combined := make([]Middleware, 0, len(g.mws)+len(mws)) + combined = append(combined, g.mws...) + combined = append(combined, mws...) + + g.k.handle(method, g.prefix+path, h, combined...) } diff --git a/kite.go b/kite.go index 92cab87..5438b5f 100644 --- a/kite.go +++ b/kite.go @@ -17,12 +17,16 @@ type Kite struct { errorHandler ErrorHandler mws []Middleware + customNotFound bool + serverReadTimeout time.Duration serverWriteTimeout time.Duration serverIdleTimeout time.Duration serverShutdownTimeout time.Duration } +type kiteContextKey struct{} + func New() *Kite { k := &Kite{ r: httprouter.New(), @@ -34,7 +38,7 @@ func New() *Kite { } k.SetErrorHandler(defaultErrorHandler) - k.SetNotFoundHandler(defaultNotFoundHandler) + k.setNotFoundHandler(defaultNotFoundHandler) k.SetMethodNotAllowedHandler(defaultMethodNotAllowedHandler) return k @@ -44,17 +48,33 @@ func (k *Kite) SetErrorHandler(errorHandler ErrorHandler) { k.errorHandler = errorHandler } -func (k *Kite) SetNotFoundHandler(notFoundHandler NotFoundHandler) { +func (k *Kite) setNotFoundHandler(notFoundHandler NotFoundHandler) { k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := &Context{ctx: r.Context(), w: newResponseWriter(w), r: r} - notFoundHandler(ctx) + ctx := r.Context().Value(kiteContextKey{}).(*Context) + if err := notFoundHandler(ctx); err != nil { + if err := k.errorHandler(ctx, err); err != nil { + log.Printf("[Kite] Error handler failed: %v\n", err) + } + } }) } +func (k *Kite) SetNotFoundHandler(notFoundHandler NotFoundHandler) { + if k.customNotFound { + log.Println("[Kite] SetNotFoundHandler is overriding a previously configured NotFound handler") + } + k.setNotFoundHandler(notFoundHandler) + k.customNotFound = true +} + func (k *Kite) SetMethodNotAllowedHandler(methodNotAllowedHandler MethodNotAllowedHandler) { k.r.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := &Context{ctx: r.Context(), w: newResponseWriter(w), r: r} - methodNotAllowedHandler(ctx) + ctx := r.Context().Value(kiteContextKey{}).(*Context) + if err := methodNotAllowedHandler(ctx); err != nil { + if err := k.errorHandler(ctx, err); err != nil { + log.Printf("[Kite] Error handler failed: %v\n", err) + } + } }) } @@ -77,7 +97,7 @@ func (k *Kite) SetServerShutdownTimeout(d time.Duration) { func (k *Kite) Start(listenAddr string) error { srv := &http.Server{ Addr: listenAddr, - Handler: k.r, + Handler: k.buildHandler(), ReadTimeout: k.serverReadTimeout, WriteTimeout: k.serverWriteTimeout, @@ -105,7 +125,7 @@ func (k *Kite) Start(listenAddr string) error { func (k *Kite) StartTLS(listenAddr, certFile, keyFile string) error { srv := &http.Server{ Addr: listenAddr, - Handler: k.r, + Handler: k.buildHandler(), ReadTimeout: k.serverReadTimeout, WriteTimeout: k.serverWriteTimeout, @@ -135,6 +155,11 @@ func (k *Kite) SPA(root string) { } func (k *Kite) SPAFS(fs http.FileSystem) { + if k.customNotFound { + log.Println("[Kite] SPAFS is overriding a previously configured NotFound handler") + } + k.customNotFound = true + fileServer := http.FileServer(fs) k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -150,18 +175,24 @@ func (k *Kite) SPAFS(fs http.FileSystem) { }) } -func (k *Kite) Static(prefix, root string) { - k.StaticFS(prefix, http.Dir(root)) +func (k *Kite) Static(prefix, root string, mws ...Middleware) { + k.StaticFS(prefix, http.Dir(root), mws...) } -func (k *Kite) StaticFS(prefix string, fs http.FileSystem) { +func (k *Kite) StaticFS(prefix string, fs http.FileSystem, mws ...Middleware) { if prefix == "" || prefix[len(prefix)-1] != '/' { prefix += "/" } handler := http.StripPrefix(prefix, http.FileServer(fs)) - k.r.Handler(http.MethodGet, prefix+"*filepath", handler) - k.r.Handler(http.MethodHead, prefix+"*filepath", handler) + + h := func(ctx *Context) error { + handler.ServeHTTP(ctx.w, ctx.r) + return nil + } + + k.handle(http.MethodGet, prefix+"*filepath", h, mws...) + k.handle(http.MethodHead, prefix+"*filepath", h, mws...) } func (k *Kite) Group(prefix string, mws ...Middleware) *Group { @@ -208,25 +239,41 @@ func (k *Kite) TRACE(path string, h Handler, mws ...Middleware) { k.handle(http.MethodTrace, path, h, mws...) } +func (k *Kite) buildHandler() http.Handler { + inner := func(ctx *Context) error { + ctx.r = ctx.r.WithContext(context.WithValue(ctx.r.Context(), kiteContextKey{}, ctx)) + k.r.ServeHTTP(ctx.w, ctx.r) + return nil + } + + wrapped := inner + for i := len(k.mws) - 1; i >= 0; i-- { + wrapped = k.mws[i](wrapped) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := &Context{ + w: newResponseWriter(w), + r: r, + } + if err := wrapped(ctx); err != nil { + if err := k.errorHandler(ctx, err); err != nil { + log.Printf("[Kite] Error handler failed: %v\n", err) + } + } + }) +} + func (k *Kite) handle(method, path string, h Handler, mws ...Middleware) { k.r.Handle(method, path, func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { - ctx := &Context{ - ctx: r.Context(), - w: newResponseWriter(w), - r: r, - p: p, - } + ctx := r.Context().Value(kiteContextKey{}).(*Context) + ctx.p = p wrappedHandler := h - for i := len(mws) - 1; i >= 0; i-- { wrappedHandler = mws[i](wrappedHandler) } - for i := len(k.mws) - 1; i >= 0; i-- { - wrappedHandler = k.mws[i](wrappedHandler) - } - if err := wrappedHandler(ctx); err != nil { if err := k.errorHandler(ctx, err); err != nil { log.Printf("[Kite] Error handler failed: %v\n", err) diff --git a/middleware/cors.go b/middleware/cors.go index 45b2c85..a959d23 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -34,14 +34,37 @@ func CORSWithConfig(cfg CORSConfig) kite.Middleware { cfg.AllowedHeaders = []string{"Content-Type", "Authorization"} } - origins := strings.Join(cfg.AllowedOrigins, ", ") + allowAll := len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*" methods := strings.Join(cfg.AllowedMethods, ", ") headers := strings.Join(cfg.AllowedHeaders, ", ") exposed := strings.Join(cfg.ExposedHeaders, ", ") + originSet := make(map[string]struct{}, len(cfg.AllowedOrigins)) + for _, o := range cfg.AllowedOrigins { + originSet[o] = struct{}{} + } + return func(h kite.Handler) kite.Handler { return func(ctx *kite.Context) error { - ctx.SetHeader("Access-Control-Allow-Origin", origins) + origin := ctx.GetHeader("Origin") + + allowOrigin := "" + switch { + case allowAll && !cfg.AllowCredentials: + allowOrigin = "*" + case allowAll && cfg.AllowCredentials: + allowOrigin = origin + default: + if _, ok := originSet[origin]; ok { + allowOrigin = origin + } + } + + if allowOrigin != "" { + ctx.SetHeader("Access-Control-Allow-Origin", allowOrigin) + ctx.AddHeader("Vary", "Origin") + } + ctx.SetHeader("Access-Control-Allow-Methods", methods) ctx.SetHeader("Access-Control-Allow-Headers", headers) diff --git a/middleware/logger.go b/middleware/logger.go index 2aee1d1..2fd0ff8 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -42,10 +42,10 @@ func LoggerWithConfig(cfg LoggerConfig) kite.Middleware { } start := time.Now() - err := h(ctx) - logger.Println(cfg.Format(ctx.GetMethod(), ctx.GetPath(), ctx.GetStatusCode(), time.Since(start))) - - return err + defer func() { + logger.Println(cfg.Format(ctx.GetMethod(), ctx.GetPath(), ctx.GetStatusCode(), time.Since(start))) + }() + return h(ctx) } } } diff --git a/response_writer.go b/response_writer.go index 21121a8..6a074e2 100644 --- a/response_writer.go +++ b/response_writer.go @@ -1,6 +1,11 @@ package kite -import "net/http" +import ( + "bufio" + "fmt" + "net" + "net/http" +) type responseWriter struct { http.ResponseWriter @@ -31,3 +36,24 @@ func (rw *responseWriter) Write(b []byte) (int, error) { } return rw.ResponseWriter.Write(b) } + +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker") + } + return h.Hijack() +} + +func (rw *responseWriter) Flush() { + if f, ok := rw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (rw *responseWriter) Push(target string, opts *http.PushOptions) error { + if p, ok := rw.ResponseWriter.(http.Pusher); ok { + return p.Push(target, opts) + } + return http.ErrNotSupported +}