package kite import ( "context" "log" "net/http" "os" "os/signal" "syscall" "time" "github.com/julienschmidt/httprouter" ) type Kite struct { r *httprouter.Router errorHandler ErrorHandler mws []Middleware customNotFound bool serverReadTimeout time.Duration serverWriteTimeout time.Duration serverIdleTimeout time.Duration serverShutdownTimeout time.Duration } type kiteContextKey struct{} func New() *Kite { k := &Kite{ r: httprouter.New(), serverReadTimeout: 10 * time.Second, serverWriteTimeout: 30 * time.Second, serverIdleTimeout: 60 * time.Second, serverShutdownTimeout: 30 * time.Second, } k.SetErrorHandler(defaultErrorHandler) k.setNotFoundHandler(defaultNotFoundHandler) k.SetMethodNotAllowedHandler(defaultMethodNotAllowedHandler) return k } func (k *Kite) SetErrorHandler(errorHandler ErrorHandler) { k.errorHandler = errorHandler } func (k *Kite) setNotFoundHandler(notFoundHandler NotFoundHandler) { k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context().Value(kiteContextKey{}).(*Context) if err := notFoundHandler(ctx); err != nil { if err := k.errorHandler(ctx, err); err != nil { log.Printf("[Kite] Error handler failed: %v\n", err) } } }) } func (k *Kite) SetNotFoundHandler(notFoundHandler NotFoundHandler) { if k.customNotFound { log.Println("[Kite] SetNotFoundHandler is overriding a previously configured NotFound handler") } k.setNotFoundHandler(notFoundHandler) k.customNotFound = true } func (k *Kite) SetMethodNotAllowedHandler(methodNotAllowedHandler MethodNotAllowedHandler) { k.r.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context().Value(kiteContextKey{}).(*Context) if err := methodNotAllowedHandler(ctx); err != nil { if err := k.errorHandler(ctx, err); err != nil { log.Printf("[Kite] Error handler failed: %v\n", err) } } }) } func (k *Kite) SetServerReadTimeout(d time.Duration) { k.serverReadTimeout = d } func (k *Kite) SetServerWriteTimeout(d time.Duration) { k.serverWriteTimeout = d } func (k *Kite) SetServerIdleTimeout(d time.Duration) { k.serverIdleTimeout = d } func (k *Kite) SetServerShutdownTimeout(d time.Duration) { k.serverShutdownTimeout = d } func (k *Kite) Start(listenAddr string) error { srv := &http.Server{ Addr: listenAddr, Handler: k.buildHandler(), ReadTimeout: k.serverReadTimeout, WriteTimeout: k.serverWriteTimeout, IdleTimeout: k.serverIdleTimeout, } go func() { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Printf("[Kite] Server error: %v\n", err) } }() log.Printf("[Kite] Server started and listening on http://%s\n", listenAddr) quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit log.Println("[Kite] Shutting down...") ctx, cancel := context.WithTimeout(context.Background(), k.serverShutdownTimeout) defer cancel() return srv.Shutdown(ctx) } func (k *Kite) StartTLS(listenAddr, certFile, keyFile string) error { srv := &http.Server{ Addr: listenAddr, Handler: k.buildHandler(), ReadTimeout: k.serverReadTimeout, WriteTimeout: k.serverWriteTimeout, IdleTimeout: k.serverIdleTimeout, } go func() { if err := srv.ListenAndServeTLS(certFile, keyFile); err != nil && err != http.ErrServerClosed { log.Printf("[Kite] Server error: %v\n", err) } }() log.Printf("[Kite] Server started and listening on https://%s\n", listenAddr) quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit log.Println("[Kite] Shutting down...") ctx, cancel := context.WithTimeout(context.Background(), k.serverShutdownTimeout) defer cancel() return srv.Shutdown(ctx) } func (k *Kite) SPA(root string) { k.SPAFS(http.Dir(root)) } func (k *Kite) SPAFS(fs http.FileSystem) { if k.customNotFound { log.Println("[Kite] SPAFS is overriding a previously configured NotFound handler") } k.customNotFound = true fileServer := http.FileServer(fs) k.r.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { f, err := fs.Open(r.URL.Path) if err != nil { r.URL.Path = "/" fileServer.ServeHTTP(w, r) return } f.Close() fileServer.ServeHTTP(w, r) }) } func (k *Kite) Static(prefix, root string, mws ...Middleware) { k.StaticFS(prefix, http.Dir(root), mws...) } func (k *Kite) StaticFS(prefix string, fs http.FileSystem, mws ...Middleware) { if prefix == "" || prefix[len(prefix)-1] != '/' { prefix += "/" } handler := http.StripPrefix(prefix, http.FileServer(fs)) h := func(ctx *Context) error { handler.ServeHTTP(ctx.w, ctx.r) return nil } k.handle(http.MethodGet, prefix+"*filepath", h, mws...) k.handle(http.MethodHead, prefix+"*filepath", h, mws...) } func (k *Kite) Group(prefix string, mws ...Middleware) *Group { return &Group{ k: k, prefix: prefix, mws: mws, } } func (k *Kite) Use(mws ...Middleware) { k.mws = append(k.mws, mws...) } func (k *Kite) CONNECT(path string, h Handler, mws ...Middleware) { k.handle(http.MethodConnect, path, h, mws...) } func (k *Kite) DELETE(path string, h Handler, mws ...Middleware) { k.handle(http.MethodDelete, path, h, mws...) } func (k *Kite) GET(path string, h Handler, mws ...Middleware) { k.handle(http.MethodGet, path, h, mws...) } func (k *Kite) HEAD(path string, h Handler, mws ...Middleware) { k.handle(http.MethodHead, path, h, mws...) } func (k *Kite) OPTIONS(path string, h Handler, mws ...Middleware) { k.handle(http.MethodOptions, path, h, mws...) } func (k *Kite) PATCH(path string, h Handler, mws ...Middleware) { k.handle(http.MethodPatch, path, h, mws...) } func (k *Kite) POST(path string, h Handler, mws ...Middleware) { k.handle(http.MethodPost, path, h, mws...) } func (k *Kite) PUT(path string, h Handler, mws ...Middleware) { k.handle(http.MethodPut, path, h, mws...) } func (k *Kite) TRACE(path string, h Handler, mws ...Middleware) { k.handle(http.MethodTrace, path, h, mws...) } func (k *Kite) buildHandler() http.Handler { inner := func(ctx *Context) error { ctx.r = ctx.r.WithContext(context.WithValue(ctx.r.Context(), kiteContextKey{}, ctx)) k.r.ServeHTTP(ctx.w, ctx.r) return nil } wrapped := inner for i := len(k.mws) - 1; i >= 0; i-- { wrapped = k.mws[i](wrapped) } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := &Context{ w: newResponseWriter(w), r: r, } if err := wrapped(ctx); err != nil { if err := k.errorHandler(ctx, err); err != nil { log.Printf("[Kite] Error handler failed: %v\n", err) } } }) } func (k *Kite) handle(method, path string, h Handler, mws ...Middleware) { k.r.Handle(method, path, func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { ctx := r.Context().Value(kiteContextKey{}).(*Context) ctx.p = p wrappedHandler := h for i := len(mws) - 1; i >= 0; i-- { wrappedHandler = mws[i](wrappedHandler) } if err := wrappedHandler(ctx); err != nil { if err := k.errorHandler(ctx, err); err != nil { log.Printf("[Kite] Error handler failed: %v\n", err) } } }) }