feat: middleware reach, response writer interfaces, and CORS fixes
- Run global middleware for all requests, including OPTIONS preflights, NotFound, and MethodNotAllowed — previously bypassed by httprouter's internal handling - Implement Hijacker, Flusher, and Pusher on the response writer for WebSocket, SSE, and HTTP/2 push support - Fix CORS: echo a single matching origin, handle AllowCredentials with wildcard, append Vary: Origin - Logger logs from a defer to capture correct status on panicked requests - Static and StaticFS accept route middleware; add Context.AddHeader; warn on NotFound handler override
This commit is contained in:
93
README.md
93
README.md
@@ -9,10 +9,12 @@ A fast, lightweight, and expressive HTTP framework for Go, built on top of [http
|
|||||||
- **Fast** — ~132k RPS, sitting just below raw `net/http` and above Gin, Echo, and Chi
|
- **Fast** — ~132k RPS, sitting just below raw `net/http` and above Gin, Echo, and Chi
|
||||||
- **Middleware** — global, group, and route-level middleware with correct onion ordering
|
- **Middleware** — global, group, and route-level middleware with correct onion ordering
|
||||||
- **Groups** — nestable route groups with prefix and middleware inheritance
|
- **Groups** — nestable route groups with prefix and middleware inheritance
|
||||||
|
- **Static files & SPA** — serve assets and single-page apps with optional middleware
|
||||||
- **Graceful shutdown** — in-flight requests finish cleanly on `SIGTERM` / `SIGINT`
|
- **Graceful shutdown** — in-flight requests finish cleanly on `SIGTERM` / `SIGINT`
|
||||||
- **Configurable** — server timeouts, error handlers, not-found handlers all configurable
|
- **Configurable** — server timeouts, error/not-found/method-not-allowed handlers
|
||||||
- **Rich context** — typed request/response helpers, value store, cookie and form support
|
- **Rich context** — typed request/response helpers, value store, cookie and form support
|
||||||
- **Built-in middleware** — logger, recovery, CORS, request ID, max body size
|
- **Built-in middleware** — logger, recovery, CORS, request ID, max body size
|
||||||
|
- **WebSocket & SSE ready** — response writer implements `Hijacker`, `Flusher`, and `Pusher`
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -37,8 +39,8 @@ func main() {
|
|||||||
k := kite.New()
|
k := kite.New()
|
||||||
|
|
||||||
k.Use(
|
k.Use(
|
||||||
|
middleware.Logger(), // first, so it observes the final status after Recovery
|
||||||
middleware.Recovery(),
|
middleware.Recovery(),
|
||||||
middleware.Logger(),
|
|
||||||
middleware.RequestID(),
|
middleware.RequestID(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,7 +68,7 @@ k.PUT("/users/:id", updateUser)
|
|||||||
k.DELETE("/users/:id", deleteUser)
|
k.DELETE("/users/:id", deleteUser)
|
||||||
```
|
```
|
||||||
|
|
||||||
All standard HTTP methods are supported: `GET`, `POST`, `PUT`, `DELETE`, `HEAD`, `OPTIONS`, `CONNECT`, `TRACE`, `PATCH`.
|
Supported HTTP methods: `GET`, `POST`, `PUT`, `DELETE`, `HEAD`, `OPTIONS`, `CONNECT`, `TRACE`.
|
||||||
|
|
||||||
## Route Groups
|
## Route Groups
|
||||||
|
|
||||||
@@ -94,6 +96,46 @@ admin := api.Group("/admin", adminOnlyMiddleware)
|
|||||||
admin.GET("/stats", getStats) // chain: global → auth → adminOnly → handler
|
admin.GET("/stats", getStats) // chain: global → auth → adminOnly → handler
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Static Files & Single-Page Apps
|
||||||
|
|
||||||
|
### Static Assets
|
||||||
|
|
||||||
|
`Static` serves files from a directory on disk under a URL prefix:
|
||||||
|
|
||||||
|
```go
|
||||||
|
k.Static("/assets", "./public")
|
||||||
|
// GET /assets/css/main.css → ./public/css/main.css
|
||||||
|
```
|
||||||
|
|
||||||
|
Middleware can be applied per-prefix:
|
||||||
|
|
||||||
|
```go
|
||||||
|
k.Static("/admin/assets", "./admin-dist", authMiddleware)
|
||||||
|
```
|
||||||
|
|
||||||
|
`StaticFS` is the same but takes an `http.FileSystem`, useful for embedded assets via `embed`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
//go:embed dist
|
||||||
|
var distFS embed.FS
|
||||||
|
|
||||||
|
k.StaticFS("/assets", http.FS(distFS))
|
||||||
|
```
|
||||||
|
|
||||||
|
Both are also available on `Group` for prefixed/scoped serving.
|
||||||
|
|
||||||
|
### Single-Page Applications
|
||||||
|
|
||||||
|
For Vue, React, Svelte, and similar frameworks, `SPA` serves the directory and falls back to `index.html` for unknown paths — letting client-side routing work correctly:
|
||||||
|
|
||||||
|
```go
|
||||||
|
k.SPA("./dist")
|
||||||
|
```
|
||||||
|
|
||||||
|
`SPAFS` takes an `http.FileSystem` for embedded builds.
|
||||||
|
|
||||||
|
> **Note:** `SPA` and `SPAFS` register the file server as the NotFound handler. If you've already called `SetNotFoundHandler` (or call it afterward), a warning is logged about the override.
|
||||||
|
|
||||||
## Middleware
|
## Middleware
|
||||||
|
|
||||||
Middleware follows the standard onion model — each middleware wraps the next.
|
Middleware follows the standard onion model — each middleware wraps the next.
|
||||||
@@ -104,11 +146,11 @@ type Middleware func(h Handler) Handler
|
|||||||
|
|
||||||
### Global Middleware
|
### Global Middleware
|
||||||
|
|
||||||
Applied to every route:
|
Applied to every request, including NotFound, MethodNotAllowed, and OPTIONS preflights:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
k.Use(middleware.Recovery())
|
|
||||||
k.Use(middleware.Logger())
|
k.Use(middleware.Logger())
|
||||||
|
k.Use(middleware.Recovery())
|
||||||
```
|
```
|
||||||
|
|
||||||
### Route Middleware
|
### Route Middleware
|
||||||
@@ -133,6 +175,20 @@ k.GET("/admin", adminHandler, authMiddleware, rateLimitMiddleware)
|
|||||||
← Global 1
|
← Global 1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The first middleware registered is the outermost — it runs first on the way in, and last on the way out.
|
||||||
|
|
||||||
|
### Recommended Order
|
||||||
|
|
||||||
|
Register `Logger` **before** `Recovery`. The reverse order causes `Logger`'s deferred log line to read the status before `Recovery` writes the 500, which means panicked requests log as `200`.
|
||||||
|
|
||||||
|
```go
|
||||||
|
k.Use(
|
||||||
|
middleware.Logger(), // outermost — sees the final status
|
||||||
|
middleware.Recovery(), // catches panics from inner handlers
|
||||||
|
middleware.RequestID(),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Writing Custom Middleware
|
### Writing Custom Middleware
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@@ -156,7 +212,7 @@ import "git.trcreatives.at/trcreatives/go-kite/middleware"
|
|||||||
|
|
||||||
### Logger
|
### Logger
|
||||||
|
|
||||||
Logs method, path, status code and latency for every request.
|
Logs method, path, status code and latency for every request. Logs from a `defer`, so panicked and errored requests are captured too (provided Logger is registered before Recovery).
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// simple
|
// simple
|
||||||
@@ -179,7 +235,7 @@ k.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
|
|||||||
Catches panics and returns a `500` instead of crashing the server.
|
Catches panics and returns a `500` instead of crashing the server.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// simple — always register this first
|
// simple
|
||||||
k.Use(middleware.Recovery())
|
k.Use(middleware.Recovery())
|
||||||
|
|
||||||
// custom — integrate with Sentry, Datadog, etc.
|
// custom — integrate with Sentry, Datadog, etc.
|
||||||
@@ -207,7 +263,7 @@ k.Use(middleware.RequestIDWithConfig(middleware.RequestIDConfig{
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
// retrieve in handler
|
// retrieve in handler
|
||||||
id := ctx.GetValue("X-Request-ID")
|
id := ctx.GetValue("X-Request-ID").(string)
|
||||||
```
|
```
|
||||||
|
|
||||||
### CORS
|
### CORS
|
||||||
@@ -215,11 +271,11 @@ id := ctx.GetValue("X-Request-ID")
|
|||||||
Sets Cross-Origin Resource Sharing headers and handles preflight requests.
|
Sets Cross-Origin Resource Sharing headers and handles preflight requests.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// allow all origins
|
// allow all origins (no credentials)
|
||||||
k.Use(middleware.CORS())
|
k.Use(middleware.CORS())
|
||||||
|
|
||||||
// restrict origins
|
// restrict to specific origins
|
||||||
k.Use(middleware.CORS("https://myapp.com"))
|
k.Use(middleware.CORS("https://myapp.com", "https://staging.myapp.com"))
|
||||||
|
|
||||||
// full control
|
// full control
|
||||||
k.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
k.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
||||||
@@ -232,7 +288,9 @@ k.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
|||||||
}))
|
}))
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note:** `AllowCredentials: true` cannot be combined with `AllowedOrigins: ["*"]` — browsers will reject it.
|
For multi-origin allowlists, the middleware checks the request's `Origin` against the list and echoes back only the matching one — `Access-Control-Allow-Origin` always contains a single origin, never a comma-separated list. A `Vary: Origin` header is added so caches behave correctly.
|
||||||
|
|
||||||
|
> **Note:** When `AllowCredentials: true` is combined with `AllowedOrigins: ["*"]`, the middleware echoes the request's `Origin` instead of sending `*`, since browsers reject that combination.
|
||||||
|
|
||||||
### MaxBodySize
|
### MaxBodySize
|
||||||
|
|
||||||
@@ -293,14 +351,15 @@ Every handler receives a `*kite.Context` with the following methods:
|
|||||||
### Response
|
### Response
|
||||||
|
|
||||||
| Method | Description |
|
| Method | Description |
|
||||||
| ----------------------------------------- | ------------------- |
|
| ----------------------------------------- | ---------------------------------------- |
|
||||||
| `WriteBytes(status int, v []byte) error` | Raw bytes response |
|
| `WriteBytes(status int, v []byte) error` | Raw bytes response |
|
||||||
| `WriteString(status int, v string) error` | Plain text response |
|
| `WriteString(status int, v string) error` | Plain text response (UTF-8) |
|
||||||
| `WriteJSON(status int, v any) error` | JSON response |
|
| `WriteJSON(status int, v any) error` | JSON response (UTF-8) |
|
||||||
| `WriteXML(status int, v any) error` | XML response |
|
| `WriteXML(status int, v any) error` | XML response (UTF-8) |
|
||||||
| `WriteNoContent() error` | 204 No Content |
|
| `WriteNoContent() error` | 204 No Content |
|
||||||
| `Redirect(status int, url string) error` | Redirect |
|
| `Redirect(status int, url string) error` | Redirect |
|
||||||
| `SetHeader(key, value string)` | Set response header |
|
| `SetHeader(key, value string)` | Set response header (replaces existing) |
|
||||||
|
| `AddHeader(key, value string)` | Append response header (for multi-value) |
|
||||||
| `SetCookie(cookie *http.Cookie)` | Set response cookie |
|
| `SetCookie(cookie *http.Cookie)` | Set response cookie |
|
||||||
|
|
||||||
## Error Handling
|
## Error Handling
|
||||||
|
|||||||
17
context.go
17
context.go
@@ -16,7 +16,6 @@ import (
|
|||||||
type contextKey string
|
type contextKey string
|
||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
ctx context.Context
|
|
||||||
w *responseWriter
|
w *responseWriter
|
||||||
r *http.Request
|
r *http.Request
|
||||||
p httprouter.Params
|
p httprouter.Params
|
||||||
@@ -35,7 +34,7 @@ func (c *Context) GetStatusCode() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetContext() context.Context {
|
func (c *Context) GetContext() context.Context {
|
||||||
return c.ctx
|
return c.r.Context()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetPathParam(key string) string {
|
func (c *Context) GetPathParam(key string) string {
|
||||||
@@ -47,11 +46,15 @@ func (c *Context) GetQueryParam(key string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) SetValue(key string, v any) {
|
func (c *Context) SetValue(key string, v any) {
|
||||||
c.ctx = context.WithValue(c.ctx, contextKey(key), v)
|
c.r = c.r.WithContext(context.WithValue(c.r.Context(), contextKey(key), v))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetValue(key string) any {
|
func (c *Context) GetValue(key string) any {
|
||||||
return c.ctx.Value(contextKey(key))
|
return c.r.Context().Value(contextKey(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) AddHeader(key, v string) {
|
||||||
|
c.w.Header().Add(key, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) SetHeader(key, v string) {
|
func (c *Context) SetHeader(key, v string) {
|
||||||
@@ -126,20 +129,20 @@ func (c *Context) WriteBytes(statusCode int, v []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) WriteString(statusCode int, v string) error {
|
func (c *Context) WriteString(statusCode int, v string) error {
|
||||||
c.w.Header().Set("Content-Type", "text/plain")
|
c.w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
c.w.WriteHeader(statusCode)
|
c.w.WriteHeader(statusCode)
|
||||||
_, err := c.w.Write([]byte(v))
|
_, err := c.w.Write([]byte(v))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) WriteJSON(statusCode int, v any) error {
|
func (c *Context) WriteJSON(statusCode int, v any) error {
|
||||||
c.w.Header().Set("Content-Type", "application/json")
|
c.w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
c.w.WriteHeader(statusCode)
|
c.w.WriteHeader(statusCode)
|
||||||
return json.NewEncoder(c.w).Encode(v)
|
return json.NewEncoder(c.w).Encode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) WriteXML(statusCode int, v any) error {
|
func (c *Context) WriteXML(statusCode int, v any) error {
|
||||||
c.w.Header().Set("Content-Type", "application/xml")
|
c.w.Header().Set("Content-Type", "application/xml; charset=utf-8")
|
||||||
c.w.WriteHeader(statusCode)
|
c.w.WriteHeader(statusCode)
|
||||||
return xml.NewEncoder(c.w).Encode(v)
|
return xml.NewEncoder(c.w).Encode(v)
|
||||||
}
|
}
|
||||||
|
|||||||
22
group.go
22
group.go
@@ -8,19 +8,25 @@ type Group struct {
|
|||||||
mws []Middleware
|
mws []Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) Static(prefix, root string) {
|
func (g *Group) Static(prefix, root string, mws ...Middleware) {
|
||||||
g.StaticFS(prefix, http.Dir(root))
|
g.StaticFS(prefix, http.Dir(root), mws...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) StaticFS(prefix string, fs http.FileSystem) {
|
func (g *Group) StaticFS(prefix string, fs http.FileSystem, mws ...Middleware) {
|
||||||
if prefix == "" || prefix[len(prefix)-1] != '/' {
|
if prefix == "" || prefix[len(prefix)-1] != '/' {
|
||||||
prefix += "/"
|
prefix += "/"
|
||||||
}
|
}
|
||||||
fullPrefix := g.prefix + prefix
|
fullPrefix := g.prefix + prefix
|
||||||
|
|
||||||
handler := http.StripPrefix(fullPrefix, http.FileServer(fs))
|
handler := http.StripPrefix(fullPrefix, http.FileServer(fs))
|
||||||
g.k.r.Handler(http.MethodGet, fullPrefix+"*filepath", handler)
|
|
||||||
g.k.r.Handler(http.MethodHead, fullPrefix+"*filepath", handler)
|
h := func(ctx *Context) error {
|
||||||
|
handler.ServeHTTP(ctx.w, ctx.r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
g.k.handle(http.MethodGet, fullPrefix+"*filepath", h, mws...)
|
||||||
|
g.k.handle(http.MethodHead, fullPrefix+"*filepath", h, mws...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) Group(prefix string, mws ...Middleware) *Group {
|
func (g *Group) Group(prefix string, mws ...Middleware) *Group {
|
||||||
@@ -68,5 +74,9 @@ func (g *Group) TRACE(path string, h Handler, mws ...Middleware) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) handle(method, path string, h Handler, mws ...Middleware) {
|
func (g *Group) handle(method, path string, h Handler, mws ...Middleware) {
|
||||||
g.k.handle(method, g.prefix+path, h, append(g.mws, mws...)...)
|
combined := make([]Middleware, 0, len(g.mws)+len(mws))
|
||||||
|
combined = append(combined, g.mws...)
|
||||||
|
combined = append(combined, mws...)
|
||||||
|
|
||||||
|
g.k.handle(method, g.prefix+path, h, combined...)
|
||||||
}
|
}
|
||||||
|
|||||||
91
kite.go
91
kite.go
@@ -17,12 +17,16 @@ type Kite struct {
|
|||||||
errorHandler ErrorHandler
|
errorHandler ErrorHandler
|
||||||
mws []Middleware
|
mws []Middleware
|
||||||
|
|
||||||
|
customNotFound bool
|
||||||
|
|
||||||
serverReadTimeout time.Duration
|
serverReadTimeout time.Duration
|
||||||
serverWriteTimeout time.Duration
|
serverWriteTimeout time.Duration
|
||||||
serverIdleTimeout time.Duration
|
serverIdleTimeout time.Duration
|
||||||
serverShutdownTimeout time.Duration
|
serverShutdownTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type kiteContextKey struct{}
|
||||||
|
|
||||||
func New() *Kite {
|
func New() *Kite {
|
||||||
k := &Kite{
|
k := &Kite{
|
||||||
r: httprouter.New(),
|
r: httprouter.New(),
|
||||||
@@ -34,7 +38,7 @@ func New() *Kite {
|
|||||||
}
|
}
|
||||||
|
|
||||||
k.SetErrorHandler(defaultErrorHandler)
|
k.SetErrorHandler(defaultErrorHandler)
|
||||||
k.SetNotFoundHandler(defaultNotFoundHandler)
|
k.setNotFoundHandler(defaultNotFoundHandler)
|
||||||
k.SetMethodNotAllowedHandler(defaultMethodNotAllowedHandler)
|
k.SetMethodNotAllowedHandler(defaultMethodNotAllowedHandler)
|
||||||
|
|
||||||
return k
|
return k
|
||||||
@@ -44,17 +48,33 @@ func (k *Kite) SetErrorHandler(errorHandler ErrorHandler) {
|
|||||||
k.errorHandler = errorHandler
|
k.errorHandler = errorHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Kite) SetNotFoundHandler(notFoundHandler NotFoundHandler) {
|
func (k *Kite) setNotFoundHandler(notFoundHandler NotFoundHandler) {
|
||||||
k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := &Context{ctx: r.Context(), w: newResponseWriter(w), r: r}
|
ctx := r.Context().Value(kiteContextKey{}).(*Context)
|
||||||
notFoundHandler(ctx)
|
if err := notFoundHandler(ctx); err != nil {
|
||||||
|
if err := k.errorHandler(ctx, err); err != nil {
|
||||||
|
log.Printf("[Kite] Error handler failed: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k *Kite) SetNotFoundHandler(notFoundHandler NotFoundHandler) {
|
||||||
|
if k.customNotFound {
|
||||||
|
log.Println("[Kite] SetNotFoundHandler is overriding a previously configured NotFound handler")
|
||||||
|
}
|
||||||
|
k.setNotFoundHandler(notFoundHandler)
|
||||||
|
k.customNotFound = true
|
||||||
|
}
|
||||||
|
|
||||||
func (k *Kite) SetMethodNotAllowedHandler(methodNotAllowedHandler MethodNotAllowedHandler) {
|
func (k *Kite) SetMethodNotAllowedHandler(methodNotAllowedHandler MethodNotAllowedHandler) {
|
||||||
k.r.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
k.r.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := &Context{ctx: r.Context(), w: newResponseWriter(w), r: r}
|
ctx := r.Context().Value(kiteContextKey{}).(*Context)
|
||||||
methodNotAllowedHandler(ctx)
|
if err := methodNotAllowedHandler(ctx); err != nil {
|
||||||
|
if err := k.errorHandler(ctx, err); err != nil {
|
||||||
|
log.Printf("[Kite] Error handler failed: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +97,7 @@ func (k *Kite) SetServerShutdownTimeout(d time.Duration) {
|
|||||||
func (k *Kite) Start(listenAddr string) error {
|
func (k *Kite) Start(listenAddr string) error {
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: listenAddr,
|
Addr: listenAddr,
|
||||||
Handler: k.r,
|
Handler: k.buildHandler(),
|
||||||
|
|
||||||
ReadTimeout: k.serverReadTimeout,
|
ReadTimeout: k.serverReadTimeout,
|
||||||
WriteTimeout: k.serverWriteTimeout,
|
WriteTimeout: k.serverWriteTimeout,
|
||||||
@@ -105,7 +125,7 @@ func (k *Kite) Start(listenAddr string) error {
|
|||||||
func (k *Kite) StartTLS(listenAddr, certFile, keyFile string) error {
|
func (k *Kite) StartTLS(listenAddr, certFile, keyFile string) error {
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: listenAddr,
|
Addr: listenAddr,
|
||||||
Handler: k.r,
|
Handler: k.buildHandler(),
|
||||||
|
|
||||||
ReadTimeout: k.serverReadTimeout,
|
ReadTimeout: k.serverReadTimeout,
|
||||||
WriteTimeout: k.serverWriteTimeout,
|
WriteTimeout: k.serverWriteTimeout,
|
||||||
@@ -135,6 +155,11 @@ func (k *Kite) SPA(root string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (k *Kite) SPAFS(fs http.FileSystem) {
|
func (k *Kite) SPAFS(fs http.FileSystem) {
|
||||||
|
if k.customNotFound {
|
||||||
|
log.Println("[Kite] SPAFS is overriding a previously configured NotFound handler")
|
||||||
|
}
|
||||||
|
k.customNotFound = true
|
||||||
|
|
||||||
fileServer := http.FileServer(fs)
|
fileServer := http.FileServer(fs)
|
||||||
|
|
||||||
k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -150,18 +175,24 @@ func (k *Kite) SPAFS(fs http.FileSystem) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Kite) Static(prefix, root string) {
|
func (k *Kite) Static(prefix, root string, mws ...Middleware) {
|
||||||
k.StaticFS(prefix, http.Dir(root))
|
k.StaticFS(prefix, http.Dir(root), mws...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Kite) StaticFS(prefix string, fs http.FileSystem) {
|
func (k *Kite) StaticFS(prefix string, fs http.FileSystem, mws ...Middleware) {
|
||||||
if prefix == "" || prefix[len(prefix)-1] != '/' {
|
if prefix == "" || prefix[len(prefix)-1] != '/' {
|
||||||
prefix += "/"
|
prefix += "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := http.StripPrefix(prefix, http.FileServer(fs))
|
handler := http.StripPrefix(prefix, http.FileServer(fs))
|
||||||
k.r.Handler(http.MethodGet, prefix+"*filepath", handler)
|
|
||||||
k.r.Handler(http.MethodHead, prefix+"*filepath", handler)
|
h := func(ctx *Context) error {
|
||||||
|
handler.ServeHTTP(ctx.w, ctx.r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
k.handle(http.MethodGet, prefix+"*filepath", h, mws...)
|
||||||
|
k.handle(http.MethodHead, prefix+"*filepath", h, mws...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Kite) Group(prefix string, mws ...Middleware) *Group {
|
func (k *Kite) Group(prefix string, mws ...Middleware) *Group {
|
||||||
@@ -208,25 +239,41 @@ func (k *Kite) TRACE(path string, h Handler, mws ...Middleware) {
|
|||||||
k.handle(http.MethodTrace, path, h, mws...)
|
k.handle(http.MethodTrace, path, h, mws...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Kite) handle(method, path string, h Handler, mws ...Middleware) {
|
func (k *Kite) buildHandler() http.Handler {
|
||||||
k.r.Handle(method, path, func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
inner := func(ctx *Context) error {
|
||||||
|
ctx.r = ctx.r.WithContext(context.WithValue(ctx.r.Context(), kiteContextKey{}, ctx))
|
||||||
|
k.r.ServeHTTP(ctx.w, ctx.r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapped := inner
|
||||||
|
for i := len(k.mws) - 1; i >= 0; i-- {
|
||||||
|
wrapped = k.mws[i](wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := &Context{
|
ctx := &Context{
|
||||||
ctx: r.Context(),
|
|
||||||
w: newResponseWriter(w),
|
w: newResponseWriter(w),
|
||||||
r: r,
|
r: r,
|
||||||
p: p,
|
|
||||||
}
|
}
|
||||||
|
if err := wrapped(ctx); err != nil {
|
||||||
|
if err := k.errorHandler(ctx, err); err != nil {
|
||||||
|
log.Printf("[Kite] Error handler failed: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *Kite) handle(method, path string, h Handler, mws ...Middleware) {
|
||||||
|
k.r.Handle(method, path, func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||||
|
ctx := r.Context().Value(kiteContextKey{}).(*Context)
|
||||||
|
ctx.p = p
|
||||||
|
|
||||||
wrappedHandler := h
|
wrappedHandler := h
|
||||||
|
|
||||||
for i := len(mws) - 1; i >= 0; i-- {
|
for i := len(mws) - 1; i >= 0; i-- {
|
||||||
wrappedHandler = mws[i](wrappedHandler)
|
wrappedHandler = mws[i](wrappedHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := len(k.mws) - 1; i >= 0; i-- {
|
|
||||||
wrappedHandler = k.mws[i](wrappedHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := wrappedHandler(ctx); err != nil {
|
if err := wrappedHandler(ctx); err != nil {
|
||||||
if err := k.errorHandler(ctx, err); err != nil {
|
if err := k.errorHandler(ctx, err); err != nil {
|
||||||
log.Printf("[Kite] Error handler failed: %v\n", err)
|
log.Printf("[Kite] Error handler failed: %v\n", err)
|
||||||
|
|||||||
@@ -34,14 +34,37 @@ func CORSWithConfig(cfg CORSConfig) kite.Middleware {
|
|||||||
cfg.AllowedHeaders = []string{"Content-Type", "Authorization"}
|
cfg.AllowedHeaders = []string{"Content-Type", "Authorization"}
|
||||||
}
|
}
|
||||||
|
|
||||||
origins := strings.Join(cfg.AllowedOrigins, ", ")
|
allowAll := len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*"
|
||||||
methods := strings.Join(cfg.AllowedMethods, ", ")
|
methods := strings.Join(cfg.AllowedMethods, ", ")
|
||||||
headers := strings.Join(cfg.AllowedHeaders, ", ")
|
headers := strings.Join(cfg.AllowedHeaders, ", ")
|
||||||
exposed := strings.Join(cfg.ExposedHeaders, ", ")
|
exposed := strings.Join(cfg.ExposedHeaders, ", ")
|
||||||
|
|
||||||
|
originSet := make(map[string]struct{}, len(cfg.AllowedOrigins))
|
||||||
|
for _, o := range cfg.AllowedOrigins {
|
||||||
|
originSet[o] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
return func(h kite.Handler) kite.Handler {
|
return func(h kite.Handler) kite.Handler {
|
||||||
return func(ctx *kite.Context) error {
|
return func(ctx *kite.Context) error {
|
||||||
ctx.SetHeader("Access-Control-Allow-Origin", origins)
|
origin := ctx.GetHeader("Origin")
|
||||||
|
|
||||||
|
allowOrigin := ""
|
||||||
|
switch {
|
||||||
|
case allowAll && !cfg.AllowCredentials:
|
||||||
|
allowOrigin = "*"
|
||||||
|
case allowAll && cfg.AllowCredentials:
|
||||||
|
allowOrigin = origin
|
||||||
|
default:
|
||||||
|
if _, ok := originSet[origin]; ok {
|
||||||
|
allowOrigin = origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowOrigin != "" {
|
||||||
|
ctx.SetHeader("Access-Control-Allow-Origin", allowOrigin)
|
||||||
|
ctx.AddHeader("Vary", "Origin")
|
||||||
|
}
|
||||||
|
|
||||||
ctx.SetHeader("Access-Control-Allow-Methods", methods)
|
ctx.SetHeader("Access-Control-Allow-Methods", methods)
|
||||||
ctx.SetHeader("Access-Control-Allow-Headers", headers)
|
ctx.SetHeader("Access-Control-Allow-Headers", headers)
|
||||||
|
|
||||||
|
|||||||
@@ -42,10 +42,10 @@ func LoggerWithConfig(cfg LoggerConfig) kite.Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
err := h(ctx)
|
defer func() {
|
||||||
logger.Println(cfg.Format(ctx.GetMethod(), ctx.GetPath(), ctx.GetStatusCode(), time.Since(start)))
|
logger.Println(cfg.Format(ctx.GetMethod(), ctx.GetPath(), ctx.GetStatusCode(), time.Since(start)))
|
||||||
|
}()
|
||||||
return err
|
return h(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package kite
|
package kite
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
type responseWriter struct {
|
type responseWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
@@ -31,3 +36,24 @@ func (rw *responseWriter) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
return rw.ResponseWriter.Write(b)
|
return rw.ResponseWriter.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
h, ok := rw.ResponseWriter.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker")
|
||||||
|
}
|
||||||
|
return h.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriter) Flush() {
|
||||||
|
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriter) Push(target string, opts *http.PushOptions) error {
|
||||||
|
if p, ok := rw.ResponseWriter.(http.Pusher); ok {
|
||||||
|
return p.Push(target, opts)
|
||||||
|
}
|
||||||
|
return http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user