This commit is contained in:
Sky Johnson 2025-04-26 17:38:17 -05:00
parent 9dbfbee993
commit 5be8eac6d8
2 changed files with 204 additions and 328 deletions

428
router.go
View File

@ -3,58 +3,35 @@ package router
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"slices"
) )
// Res is an alias for http.ResponseWriter for shorter, cleaner code
type Res = http.ResponseWriter type Res = http.ResponseWriter
// Req is an alias for *http.Request for shorter, cleaner code
type Req = *http.Request type Req = *http.Request
// Handler is an interface for handling HTTP requests with path parameters. // Handler is a request handler with parameters.
type Handler interface { type Handler func(w Res, r Req, params []string)
Serve(params []string)
func (h Handler) Serve(w Res, r Req, params []string) {
h(w, r, params)
} }
// SimpleHandler implements the Handler interface
type SimpleHandler struct {
fn func(params []string)
}
func (h *SimpleHandler) Serve(params []string) {
h.fn(params)
}
// NewHandler creates a Handler from a function
func NewHandler(fn func(params []string)) Handler {
return &SimpleHandler{fn: fn}
}
// Middleware wraps a handler with additional functionality.
type Middleware func(Handler) Handler type Middleware func(Handler) Handler
// node represents a segment in the URL path and its handling logic.
type node struct { type node struct {
segment string // the path segment this node matches segment string
handler Handler // handler for this path, if it's an endpoint handler Handler
children []*node // child nodes for subsequent path segments children []*node
isDynamic bool // true for param segments like [id] isDynamic bool
isWildcard bool // true for catch-all segments like *filepath isWildcard bool
maxParams uint8 // maximum number of parameters in paths under this node maxParams uint8
} }
// Router routes HTTP requests by method and path.
// It supports static paths, path parameters, wildcards, and middleware.
type Router struct { type Router struct {
get *node get, post, put, patch, delete *node
post *node middleware []Middleware
put *node
patch *node
delete *node
middleware []Middleware // Global middleware
} }
// Group represents a route group with a path prefix and shared middleware.
type Group struct { type Group struct {
router *Router router *Router
prefix string prefix string
@ -73,90 +50,55 @@ func New() *Router {
} }
} }
// ServeHTTP implements http.Handler interface // ServeHTTP implements http.Handler.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handler, params, found := r.Lookup(req.Method, req.URL.Path) h, params, ok := r.Lookup(req.Method, req.URL.Path)
if !found { if !ok {
http.NotFound(w, req) http.NotFound(w, req)
return return
} }
h(w, req, params)
// Create an HTTP-specific handler wrapper
httpHandler := &httpHandler{
w: w,
r: req,
h: func(w http.ResponseWriter, r *http.Request, params []string) {
handler.Serve(params)
},
}
httpHandler.Serve(params)
} }
// httpHandler adapts net/http handlers to the router. // Use adds middleware to the router.
type httpHandler struct { func (r *Router) Use(mw ...Middleware) *Router {
w Res r.middleware = append(r.middleware, mw...)
r Req
h func(w Res, r Req, params []string)
}
// Serve executes the http handler with parameters.
func (h *httpHandler) Serve(params []string) {
h.h(h.w, h.r, params)
}
// Use adds middleware to the router's global middleware stack.
func (r *Router) Use(middleware ...Middleware) *Router {
r.middleware = append(r.middleware, middleware...)
return r return r
} }
// Group creates a new route group with the given path prefix. // Group creates a new route group.
func (r *Router) Group(prefix string) *Group { func (r *Router) Group(prefix string) *Group {
return &Group{ return &Group{router: r, prefix: prefix, middleware: []Middleware{}}
router: r,
prefix: prefix,
middleware: []Middleware{},
}
} }
// Use adds middleware to the group's middleware stack. // Use adds middleware to the group.
func (g *Group) Use(middleware ...Middleware) *Group { func (g *Group) Use(mw ...Middleware) *Group {
g.middleware = append(g.middleware, middleware...) g.middleware = append(g.middleware, mw...)
return g return g
} }
// Group creates a nested group with an additional prefix. // Group creates a nested group.
func (g *Group) Group(prefix string) *Group { func (g *Group) Group(prefix string) *Group {
return &Group{ return &Group{router: g.router, prefix: g.prefix + prefix, middleware: slices.Clone(g.middleware)}
router: g.router,
prefix: g.prefix + prefix,
middleware: append([]Middleware{}, g.middleware...),
}
} }
// applyMiddleware wraps a handler with middleware in reverse order. // applyMiddleware applies middleware in reverse order.
func applyMiddleware(handler Handler, middleware []Middleware) Handler { func applyMiddleware(h Handler, mw []Middleware) Handler {
h := handler for i := len(mw) - 1; i >= 0; i-- {
for i := len(middleware) - 1; i >= 0; i-- { h = mw[i](h)
h = middleware[i](h)
} }
return h return h
} }
// HandlerFunc is a function that handles HTTP requests with parameters.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, params []string)
// Handle registers a handler for the given method and path. // Handle registers a handler for the given method and path.
func (r *Router) Handle(method, path string, handler HandlerFunc) error { func (r *Router) Handle(method, path string, h Handler) error {
root := r.methodNode(method) root := r.methodNode(method)
if root == nil { if root == nil {
return fmt.Errorf("unsupported method: %s", method) return fmt.Errorf("unsupported method: %s", method)
} }
return r.addRoute(root, path, &httpHandler{h: handler}, r.middleware) return r.addRoute(root, path, h, r.middleware)
} }
// methodNode returns the root node for the given HTTP method.
func (r *Router) methodNode(method string) *node { func (r *Router) methodNode(method string) *node {
switch method { switch method {
case "GET": case "GET":
@ -174,292 +116,244 @@ func (r *Router) methodNode(method string) *node {
} }
} }
// Get registers a handler for GET requests at the given path. // Get registers a GET handler.
func (r *Router) Get(path string, handler HandlerFunc) error { func (r *Router) Get(path string, h Handler) error {
return r.Handle("GET", path, handler) return r.Handle("GET", path, h)
} }
// Post registers a handler for POST requests at the given path. // Post registers a POST handler.
func (r *Router) Post(path string, handler HandlerFunc) error { func (r *Router) Post(path string, h Handler) error {
return r.Handle("POST", path, handler) return r.Handle("POST", path, h)
} }
// Put registers a handler for PUT requests at the given path. // Put registers a PUT handler.
func (r *Router) Put(path string, handler HandlerFunc) error { func (r *Router) Put(path string, h Handler) error {
return r.Handle("PUT", path, handler) return r.Handle("PUT", path, h)
} }
// Patch registers a handler for PATCH requests at the given path. // Patch registers a PATCH handler.
func (r *Router) Patch(path string, handler HandlerFunc) error { func (r *Router) Patch(path string, h Handler) error {
return r.Handle("PATCH", path, handler) return r.Handle("PATCH", path, h)
} }
// Delete registers a handler for DELETE requests at the given path. // Delete registers a DELETE handler.
func (r *Router) Delete(path string, handler HandlerFunc) error { func (r *Router) Delete(path string, h Handler) error {
return r.Handle("DELETE", path, handler) return r.Handle("DELETE", path, h)
} }
// buildGroupMiddleware returns combined middleware for the group
func (g *Group) buildGroupMiddleware() []Middleware { func (g *Group) buildGroupMiddleware() []Middleware {
middleware := append([]Middleware{}, g.router.middleware...) mw := slices.Clone(g.router.middleware)
return append(middleware, g.middleware...) return append(mw, g.middleware...)
} }
// Handle registers a handler for the given method and path. // Handle registers a handler in the group.
func (g *Group) Handle(method, path string, handler HandlerFunc) error { func (g *Group) Handle(method, path string, h Handler) error {
root := g.router.methodNode(method) root := g.router.methodNode(method)
if root == nil { if root == nil {
return fmt.Errorf("unsupported method: %s", method) return fmt.Errorf("unsupported method: %s", method)
} }
return g.router.addRoute(root, g.prefix+path, h, g.buildGroupMiddleware())
fullPath := g.prefix + path
return g.router.addRoute(root, fullPath, &httpHandler{h: handler}, g.buildGroupMiddleware())
} }
// Get registers a handler for GET requests at the given path. // Get registers a GET handler in the group.
func (g *Group) Get(path string, handler HandlerFunc) error { func (g *Group) Get(path string, h Handler) error {
return g.Handle("GET", path, handler) return g.Handle("GET", path, h)
} }
// Post registers a handler for POST requests at the given path. // Post registers a POST handler in the group.
func (g *Group) Post(path string, handler HandlerFunc) error { func (g *Group) Post(path string, h Handler) error {
return g.Handle("POST", path, handler) return g.Handle("POST", path, h)
} }
// Put registers a handler for PUT requests at the given path. // Put registers a PUT handler in the group.
func (g *Group) Put(path string, handler HandlerFunc) error { func (g *Group) Put(path string, h Handler) error {
return g.Handle("PUT", path, handler) return g.Handle("PUT", path, h)
} }
// Patch registers a handler for PATCH requests at the given path. // Patch registers a PATCH handler in the group.
func (g *Group) Patch(path string, handler HandlerFunc) error { func (g *Group) Patch(path string, h Handler) error {
return g.Handle("PATCH", path, handler) return g.Handle("PATCH", path, h)
} }
// Delete registers a handler for DELETE requests at the given path. // Delete registers a DELETE handler in the group.
func (g *Group) Delete(path string, handler HandlerFunc) error { func (g *Group) Delete(path string, h Handler) error {
return g.Handle("DELETE", path, handler) return g.Handle("DELETE", path, h)
} }
// WithMiddleware applies specific middleware to the next route registration. // WithMiddleware applies specific middleware for next registration.
func (r *Router) WithMiddleware(middleware ...Middleware) *MiddlewareRouter { func (r *Router) WithMiddleware(mw ...Middleware) *MiddlewareRouter {
return &MiddlewareRouter{ return &MiddlewareRouter{router: r, middleware: mw}
router: r,
middleware: middleware,
}
} }
// WithMiddleware applies specific middleware to the next route registration. // WithMiddleware applies specific middleware for next group route.
func (g *Group) WithMiddleware(middleware ...Middleware) *MiddlewareGroup { func (g *Group) WithMiddleware(mw ...Middleware) *MiddlewareGroup {
return &MiddlewareGroup{ return &MiddlewareGroup{group: g, middleware: mw}
group: g,
middleware: middleware,
}
} }
// MiddlewareRouter handles route registration with specific middleware.
type MiddlewareRouter struct { type MiddlewareRouter struct {
router *Router router *Router
middleware []Middleware middleware []Middleware
} }
// MiddlewareGroup handles group route registration with specific middleware.
type MiddlewareGroup struct { type MiddlewareGroup struct {
group *Group group *Group
middleware []Middleware middleware []Middleware
} }
// buildMiddleware returns combined middleware for the middleware router
func (mr *MiddlewareRouter) buildMiddleware() []Middleware { func (mr *MiddlewareRouter) buildMiddleware() []Middleware {
middleware := append([]Middleware{}, mr.router.middleware...) mw := slices.Clone(mr.router.middleware)
return append(middleware, mr.middleware...) return append(mw, mr.middleware...)
} }
// Handle registers a handler for the given method and path. // Handle registers a handler with middleware router.
func (mr *MiddlewareRouter) Handle(method, path string, handler HandlerFunc) error { func (mr *MiddlewareRouter) Handle(method, path string, h Handler) error {
root := mr.router.methodNode(method) root := mr.router.methodNode(method)
if root == nil { if root == nil {
return fmt.Errorf("unsupported method: %s", method) return fmt.Errorf("unsupported method: %s", method)
} }
return mr.router.addRoute(root, path, h, mr.buildMiddleware())
return mr.router.addRoute(root, path, &httpHandler{h: handler}, mr.buildMiddleware())
} }
// Get registers a handler for GET requests with specific middleware. // Get registers a GET handler with middleware router.
func (mr *MiddlewareRouter) Get(path string, handler HandlerFunc) error { func (mr *MiddlewareRouter) Get(path string, h Handler) error {
return mr.Handle("GET", path, handler) return mr.Handle("GET", path, h)
} }
// Post registers a handler for POST requests with specific middleware. // Post registers a POST handler with middleware router.
func (mr *MiddlewareRouter) Post(path string, handler HandlerFunc) error { func (mr *MiddlewareRouter) Post(path string, h Handler) error {
return mr.Handle("POST", path, handler) return mr.Handle("POST", path, h)
} }
// Put registers a handler for PUT requests with specific middleware. // Put registers a PUT handler with middleware router.
func (mr *MiddlewareRouter) Put(path string, handler HandlerFunc) error { func (mr *MiddlewareRouter) Put(path string, h Handler) error {
return mr.Handle("PUT", path, handler) return mr.Handle("PUT", path, h)
} }
// Patch registers a handler for PATCH requests with specific middleware. // Patch registers a PATCH handler with middleware router.
func (mr *MiddlewareRouter) Patch(path string, handler HandlerFunc) error { func (mr *MiddlewareRouter) Patch(path string, h Handler) error {
return mr.Handle("PATCH", path, handler) return mr.Handle("PATCH", path, h)
} }
// Delete registers a handler for DELETE requests with specific middleware. // Delete registers a DELETE handler with middleware router.
func (mr *MiddlewareRouter) Delete(path string, handler HandlerFunc) error { func (mr *MiddlewareRouter) Delete(path string, h Handler) error {
return mr.Handle("DELETE", path, handler) return mr.Handle("DELETE", path, h)
} }
// buildMiddleware returns combined middleware for the middleware group
func (mg *MiddlewareGroup) buildMiddleware() []Middleware { func (mg *MiddlewareGroup) buildMiddleware() []Middleware {
middleware := append([]Middleware{}, mg.group.router.middleware...) mw := slices.Clone(mg.group.router.middleware)
middleware = append(middleware, mg.group.middleware...) mw = append(mw, mg.group.middleware...)
return append(middleware, mg.middleware...) return append(mw, mg.middleware...)
} }
// Handle registers a handler for the given method and path. // Handle registers a handler with middleware group.
func (mg *MiddlewareGroup) Handle(method, path string, handler HandlerFunc) error { func (mg *MiddlewareGroup) Handle(method, path string, h Handler) error {
root := mg.group.router.methodNode(method) root := mg.group.router.methodNode(method)
if root == nil { if root == nil {
return fmt.Errorf("unsupported method: %s", method) return fmt.Errorf("unsupported method: %s", method)
} }
return mg.group.router.addRoute(root, mg.group.prefix+path, h, mg.buildMiddleware())
fullPath := mg.group.prefix + path
return mg.group.router.addRoute(root, fullPath, &httpHandler{h: handler}, mg.buildMiddleware())
} }
// Get registers a handler for GET requests with specific middleware. // Get registers a GET handler with middleware group.
func (mg *MiddlewareGroup) Get(path string, handler HandlerFunc) error { func (mg *MiddlewareGroup) Get(path string, h Handler) error {
return mg.Handle("GET", path, handler) return mg.Handle("GET", path, h)
} }
// Post registers a handler for POST requests with specific middleware. // Post registers a POST handler with middleware group.
func (mg *MiddlewareGroup) Post(path string, handler HandlerFunc) error { func (mg *MiddlewareGroup) Post(path string, h Handler) error {
return mg.Handle("POST", path, handler) return mg.Handle("POST", path, h)
} }
// Put registers a handler for PUT requests with specific middleware. // Put registers a PUT handler with middleware group.
func (mg *MiddlewareGroup) Put(path string, handler HandlerFunc) error { func (mg *MiddlewareGroup) Put(path string, h Handler) error {
return mg.Handle("PUT", path, handler) return mg.Handle("PUT", path, h)
} }
// Patch registers a handler for PATCH requests with specific middleware. // Patch registers a PATCH handler with middleware group.
func (mg *MiddlewareGroup) Patch(path string, handler HandlerFunc) error { func (mg *MiddlewareGroup) Patch(path string, h Handler) error {
return mg.Handle("PATCH", path, handler) return mg.Handle("PATCH", path, h)
} }
// Delete registers a handler for DELETE requests with specific middleware. // Delete registers a DELETE handler with middleware group.
func (mg *MiddlewareGroup) Delete(path string, handler HandlerFunc) error { func (mg *MiddlewareGroup) Delete(path string, h Handler) error {
return mg.Handle("DELETE", path, handler) return mg.Handle("DELETE", path, h)
} }
// Adapter for standard http.HandlerFunc // readSegment extracts the next path segment.
func StandardHandler(handler http.HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, _ []string) {
handler(w, r)
}
}
// readSegment extracts the next path segment starting at the given position.
// Returns the segment, the position after it, and whether there are more segments.
func readSegment(path string, start int) (segment string, end int, hasMore bool) { func readSegment(path string, start int) (segment string, end int, hasMore bool) {
if start >= len(path) { if start >= len(path) {
return "", start, false return "", start, false
} }
if path[start] == '/' { if path[start] == '/' {
start++ start++
} }
if start >= len(path) { if start >= len(path) {
return "", start, false return "", start, false
} }
end = start end = start
for end < len(path) && path[end] != '/' { for end < len(path) && path[end] != '/' {
end++ end++
} }
return path[start:end], end, end < len(path) return path[start:end], end, end < len(path)
} }
// addRoute adds a new route to the prefix tree with middleware. // addRoute adds a new route to the trie.
func (r *Router) addRoute(root *node, path string, handler Handler, middleware []Middleware) error { func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error {
wrappedHandler := applyMiddleware(handler, middleware) h = applyMiddleware(h, mw)
if path == "/" { if path == "/" {
root.handler = wrappedHandler root.handler = h
return nil return nil
} }
current := root current := root
pos := 0 pos := 0
var lastWildcard bool lastWC := false
paramsCount := uint8(0) count := uint8(0)
for { for {
segment, newPos, hasMore := readSegment(path, pos) seg, newPos, more := readSegment(path, pos)
if segment == "" { if seg == "" {
break break
} }
isDyn := len(seg) > 2 && seg[0] == '[' && seg[len(seg)-1] == ']'
isDynamic := len(segment) > 2 && segment[0] == '[' && segment[len(segment)-1] == ']' isWC := len(seg) > 0 && seg[0] == '*'
isWildcard := len(segment) > 0 && segment[0] == '*' if isWC {
if lastWC || more {
if isWildcard {
if lastWildcard {
return fmt.Errorf("wildcard must be the last segment in the path") return fmt.Errorf("wildcard must be the last segment in the path")
} }
if hasMore { lastWC = true
return fmt.Errorf("wildcard must be the last segment in the path")
}
lastWildcard = true
} }
if isDyn || isWC {
if isDynamic || isWildcard { count++
paramsCount++
} }
var child *node var child *node
for _, n := range current.children { for _, c := range current.children {
if n.segment == segment { if c.segment == seg {
child = n child = c
break break
} }
} }
if child == nil { if child == nil {
child = &node{ child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
segment: segment,
isDynamic: isDynamic,
isWildcard: isWildcard,
}
current.children = append(current.children, child) current.children = append(current.children, child)
} }
if child.maxParams < count {
if child.maxParams < paramsCount { child.maxParams = count
child.maxParams = paramsCount
} }
current = child current = child
pos = newPos pos = newPos
} }
current.handler = h
current.handler = wrappedHandler
return nil return nil
} }
// Lookup finds a handler matching the given method and path. // Lookup finds a handler matching method and path.
// Returns the handler, any captured parameters, and whether a match was found.
func (r *Router) Lookup(method, path string) (Handler, []string, bool) { func (r *Router) Lookup(method, path string) (Handler, []string, bool) {
root := r.methodNode(method) root := r.methodNode(method)
if root == nil { if root == nil {
return nil, nil, false return nil, nil, false
} }
if path == "/" { if path == "/" {
return root.handler, []string{}, root.handler != nil return root.handler, []string{}, root.handler != nil
} }
params := make([]string, 0, root.maxParams) params := make([]string, 0, root.maxParams)
h, found := match(root, path, 0, &params) h, found := match(root, path, 0, &params)
if !found { if !found {
@ -468,41 +362,35 @@ func (r *Router) Lookup(method, path string) (Handler, []string, bool) {
return h, params, true return h, params, true
} }
// match recursively traverses the prefix tree to find a matching handler. // match traverses the trie to find a handler.
// It populates params with any captured path parameters or wildcard matches.
func match(current *node, path string, start int, params *[]string) (Handler, bool) { func match(current *node, path string, start int, params *[]string) (Handler, bool) {
// Check for wildcard children first for _, c := range current.children {
for _, child := range current.children { if c.isWildcard {
if child.isWildcard { rem := path[start:]
remaining := path[start:] if len(rem) > 0 && rem[0] == '/' {
if len(remaining) > 0 && remaining[0] == '/' { rem = rem[1:]
remaining = remaining[1:]
} }
*params = append(*params, remaining) *params = append(*params, rem)
return child.handler, child.handler != nil return c.handler, c.handler != nil
} }
} }
seg, pos, more := readSegment(path, start)
// Read current segment if seg == "" {
segment, pos, hasMore := readSegment(path, start)
if segment == "" {
return current.handler, current.handler != nil return current.handler, current.handler != nil
} }
for _, c := range current.children {
// Try to match children if c.segment == seg || c.isDynamic {
for _, child := range current.children { if c.isDynamic {
if child.segment == segment || child.isDynamic { *params = append(*params, seg)
if child.isDynamic {
*params = append(*params, segment)
} }
if !hasMore { if !more {
return child.handler, child.handler != nil return c.handler, c.handler != nil
} }
if h, found := match(child, path, pos, params); found { h, ok := match(c, path, pos, params)
if ok {
return h, true return h, true
} }
} }
} }
return nil, false return nil, false
} }

View File

@ -10,24 +10,18 @@ import (
func TestRootPath(t *testing.T) { func TestRootPath(t *testing.T) {
r := New() r := New()
r.Get("/", func(w Res, r Req, params []string) { r.Get("/", func(w Res, r Req, params []string) {})
// No-op for testing
})
h, params, found := r.Lookup("GET", "/") _, _, found := r.Lookup("GET", "/")
assert.True(t, found) assert.True(t, found)
h.Serve(params)
} }
func TestStaticPath(t *testing.T) { func TestStaticPath(t *testing.T) {
r := New() r := New()
r.Get("/users/all", func(w Res, r Req, params []string) { r.Get("/users/all", func(w Res, r Req, params []string) {})
// No-op for testing
})
h, params, found := r.Lookup("GET", "/users/all") _, _, found := r.Lookup("GET", "/users/all")
assert.True(t, found) assert.True(t, found)
h.Serve(params)
} }
func TestSingleParameter(t *testing.T) { func TestSingleParameter(t *testing.T) {
@ -41,7 +35,7 @@ func TestSingleParameter(t *testing.T) {
h, params, found := r.Lookup("GET", "/users/123") h, params, found := r.Lookup("GET", "/users/123")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, called) assert.True(t, called)
} }
@ -57,15 +51,13 @@ func TestMultipleParameters(t *testing.T) {
h, params, found := r.Lookup("GET", "/users/123/posts/456") h, params, found := r.Lookup("GET", "/users/123/posts/456")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, called) assert.True(t, called)
} }
func TestNonExistentPath(t *testing.T) { func TestNonExistentPath(t *testing.T) {
r := New() r := New()
r.Get("/users/[id]", func(w Res, r Req, params []string) { r.Get("/users/[id]", func(w Res, r Req, params []string) {})
// No-op for testing
})
_, _, found := r.Lookup("GET", "/posts/123") _, _, found := r.Lookup("GET", "/posts/123")
assert.False(t, found) assert.False(t, found)
@ -73,9 +65,7 @@ func TestNonExistentPath(t *testing.T) {
func TestWrongMethod(t *testing.T) { func TestWrongMethod(t *testing.T) {
r := New() r := New()
r.Get("/users/[id]", func(w Res, r Req, params []string) { r.Get("/users/[id]", func(w Res, r Req, params []string) {})
// No-op for testing
})
_, _, found := r.Lookup("POST", "/users/123") _, _, found := r.Lookup("POST", "/users/123")
assert.False(t, found) assert.False(t, found)
@ -92,7 +82,7 @@ func TestTrailingSlash(t *testing.T) {
h, params, found := r.Lookup("GET", "/users/123/") h, params, found := r.Lookup("GET", "/users/123/")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, called) assert.True(t, called)
} }
@ -129,7 +119,7 @@ func TestWildcardPath(t *testing.T) {
h, params, found := r.Lookup("GET", "/files/docs/report.pdf") h, params, found := r.Lookup("GET", "/files/docs/report.pdf")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, called) assert.True(t, called)
}) })
@ -143,7 +133,7 @@ func TestWildcardPath(t *testing.T) {
h, params, found := r.Lookup("GET", "/download/") h, params, found := r.Lookup("GET", "/download/")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, called) assert.True(t, called)
}) })
@ -158,7 +148,7 @@ func TestWildcardPath(t *testing.T) {
h, params, found := r.Lookup("GET", "/users/123/settings/profile/avatar") h, params, found := r.Lookup("GET", "/users/123/settings/profile/avatar")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, called) assert.True(t, called)
}) })
@ -178,13 +168,11 @@ func TestMiddleware(t *testing.T) {
t.Run("global middleware", func(t *testing.T) { t.Run("global middleware", func(t *testing.T) {
r := New() r := New()
// Track middleware execution
executed := false executed := false
r.Use(func(next Handler) Handler { r.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
executed = true executed = true
next.Serve(params) next(w, r, params)
}) })
}) })
@ -192,7 +180,7 @@ func TestMiddleware(t *testing.T) {
h, params, found := r.Lookup("GET", "/test") h, params, found := r.Lookup("GET", "/test")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, executed) assert.True(t, executed)
}) })
@ -203,17 +191,17 @@ func TestMiddleware(t *testing.T) {
order := []int{} order := []int{}
r.Use(func(next Handler) Handler { r.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
order = append(order, 1) order = append(order, 1)
next.Serve(params) next.Serve(nil, nil, params)
order = append(order, 4) order = append(order, 4)
}) })
}) })
r.Use(func(next Handler) Handler { r.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
order = append(order, 2) order = append(order, 2)
next.Serve(params) next.Serve(nil, nil, params)
order = append(order, 3) order = append(order, 3)
}) })
}) })
@ -224,7 +212,7 @@ func TestMiddleware(t *testing.T) {
h, params, found := r.Lookup("GET", "/test") h, params, found := r.Lookup("GET", "/test")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
// Check middleware execution order (first middleware wraps second) // Check middleware execution order (first middleware wraps second)
assert.Equal(t, len(order), 5) assert.Equal(t, len(order), 5)
@ -241,9 +229,9 @@ func TestMiddleware(t *testing.T) {
executed := false executed := false
middleware := func(next Handler) Handler { middleware := func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
executed = true executed = true
next.Serve(params) next.Serve(nil, nil, params)
}) })
} }
@ -251,7 +239,7 @@ func TestMiddleware(t *testing.T) {
h, params, found := r.Lookup("GET", "/test") h, params, found := r.Lookup("GET", "/test")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, executed) assert.True(t, executed)
}) })
} }
@ -267,7 +255,7 @@ func TestGroup(t *testing.T) {
h, params, found := r.Lookup("GET", "/api/users") h, params, found := r.Lookup("GET", "/api/users")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
}) })
t.Run("nested groups", func(t *testing.T) { t.Run("nested groups", func(t *testing.T) {
@ -280,7 +268,7 @@ func TestGroup(t *testing.T) {
h, params, found := r.Lookup("GET", "/api/v1/users") h, params, found := r.Lookup("GET", "/api/v1/users")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
}) })
t.Run("group middleware", func(t *testing.T) { t.Run("group middleware", func(t *testing.T) {
@ -290,9 +278,9 @@ func TestGroup(t *testing.T) {
// Create group with middleware // Create group with middleware
api := r.Group("/api") api := r.Group("/api")
api.Use(func(next Handler) Handler { api.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
executed = true executed = true
next.Serve(params) next.Serve(nil, nil, params)
}) })
}) })
@ -300,7 +288,7 @@ func TestGroup(t *testing.T) {
h, params, found := r.Lookup("GET", "/api/users") h, params, found := r.Lookup("GET", "/api/users")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
assert.True(t, executed) assert.True(t, executed)
}) })
@ -311,18 +299,18 @@ func TestGroup(t *testing.T) {
// Create group with middleware // Create group with middleware
api := r.Group("/api") api := r.Group("/api")
api.Use(func(next Handler) Handler { api.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
order = append(order, 1) order = append(order, 1)
next.Serve(params) next.Serve(nil, nil, params)
}) })
}) })
// Create nested group with additional middleware // Create nested group with additional middleware
v1 := api.Group("/v1") v1 := api.Group("/v1")
v1.Use(func(next Handler) Handler { v1.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
order = append(order, 2) order = append(order, 2)
next.Serve(params) next.Serve(nil, nil, params)
}) })
}) })
@ -332,7 +320,7 @@ func TestGroup(t *testing.T) {
h, params, found := r.Lookup("GET", "/api/v1/users") h, params, found := r.Lookup("GET", "/api/v1/users")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
// Check middleware execution order // Check middleware execution order
assert.Equal(t, len(order), 3) assert.Equal(t, len(order), 3)
@ -348,17 +336,17 @@ func TestGroup(t *testing.T) {
// Create group with middleware // Create group with middleware
api := r.Group("/api") api := r.Group("/api")
api.Use(func(next Handler) Handler { api.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
order = append(order, 1) order = append(order, 1)
next.Serve(params) next.Serve(nil, nil, params)
}) })
}) })
// Add route with specific middleware // Add route with specific middleware
api.WithMiddleware(func(next Handler) Handler { api.WithMiddleware(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
order = append(order, 2) order = append(order, 2)
next.Serve(params) next.Serve(nil, nil, params)
}) })
}).Get("/users", func(w Res, r Req, params []string) { }).Get("/users", func(w Res, r Req, params []string) {
order = append(order, 3) order = append(order, 3)
@ -366,7 +354,7 @@ func TestGroup(t *testing.T) {
h, params, found := r.Lookup("GET", "/api/users") h, params, found := r.Lookup("GET", "/api/users")
assert.True(t, found) assert.True(t, found)
h.Serve(params) h.Serve(nil, nil, params)
// Check middleware execution order // Check middleware execution order
assert.Equal(t, len(order), 3) assert.Equal(t, len(order), 3)
@ -465,8 +453,8 @@ func BenchmarkWildcardLookup(b *testing.B) {
func BenchmarkMiddleware(b *testing.B) { func BenchmarkMiddleware(b *testing.B) {
passthrough := func(next Handler) Handler { passthrough := func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
next.Serve(params) next.Serve(nil, nil, params)
}) })
} }
@ -477,7 +465,7 @@ func BenchmarkMiddleware(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
h, params, _ := r.Lookup("GET", "/test") h, params, _ := r.Lookup("GET", "/test")
h.Serve(params) h.Serve(nil, nil, params)
} }
}) })
@ -489,7 +477,7 @@ func BenchmarkMiddleware(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
h, params, _ := r.Lookup("GET", "/test") h, params, _ := r.Lookup("GET", "/test")
h.Serve(params) h.Serve(nil, nil, params)
} }
}) })
@ -503,7 +491,7 @@ func BenchmarkMiddleware(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
h, params, _ := r.Lookup("GET", "/test") h, params, _ := r.Lookup("GET", "/test")
h.Serve(params) h.Serve(nil, nil, params)
} }
}) })
} }
@ -537,8 +525,8 @@ func BenchmarkGroups(b *testing.B) {
r := New() r := New()
api := r.Group("/api") api := r.Group("/api")
api.Use(func(next Handler) Handler { api.Use(func(next Handler) Handler {
return NewHandler(func(params []string) { return Handler(func(w Res, r Req, params []string) {
next.Serve(params) next.Serve(nil, nil, params)
}) })
}) })
v1 := api.Group("/v1") v1 := api.Group("/v1")
@ -547,7 +535,7 @@ func BenchmarkGroups(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
h, params, _ := r.Lookup("GET", "/api/v1/users") h, params, _ := r.Lookup("GET", "/api/v1/users")
h.Serve(params) h.Serve(nil, nil, params)
} }
}) })
} }