Sushi/router.go

327 lines
7.4 KiB
Go

package sushi
import (
"fmt"
"github.com/valyala/fasthttp"
)
type node struct {
segment string
handler Handler
children []*node
isDynamic bool
isWildcard bool
paramNames []string
maxParams uint8
}
type Router struct {
get *node
post *node
put *node
patch *node
delete *node
middleware []Middleware
paramsBuffer []string
}
type Group struct {
router *Router
prefix string
middleware []Middleware
}
// New creates a new Router instance
func NewRouter() *Router {
return &Router{
get: &node{},
post: &node{},
put: &node{},
patch: &node{},
delete: &node{},
middleware: []Middleware{},
paramsBuffer: make([]string, 64),
}
}
// ServeHTTP implements the Handler interface for fasthttp
func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
method := string(ctx.Method())
h, params, paramNames, found := r.Lookup(method, path)
if !found {
ctx.SetStatusCode(fasthttp.StatusNotFound)
return
}
// Store params in context
sushiCtx := Ctx{ctx}
if len(params) > 0 {
sushiCtx.SetUserValue(RouteParamsCtxKey, params)
// Create named param map if param names exist
if len(paramNames) > 0 {
paramMap := make(map[string]string)
for i, name := range paramNames {
if i < len(params) && name != "" {
paramMap[name] = params[i]
}
}
if len(paramMap) > 0 {
sushiCtx.SetUserValue("param_names", paramMap)
}
}
}
h(sushiCtx)
}
// Handler returns a fasthttp request handler
func (r *Router) Handler() fasthttp.RequestHandler {
return r.ServeHTTP
}
// Use adds middleware to the router
func (r *Router) Use(mw ...Middleware) *Router {
r.middleware = append(r.middleware, mw...)
return r
}
// Group creates a new route group
func (r *Router) Group(prefix string) *Group {
return &Group{router: r, prefix: prefix, middleware: []Middleware{}}
}
// Use adds middleware to the group
func (g *Group) Use(mw ...Middleware) *Group {
g.middleware = append(g.middleware, mw...)
return g
}
// Group creates a nested group
func (g *Group) Group(prefix string) *Group {
return &Group{
router: g.router,
prefix: g.prefix + prefix,
middleware: append([]Middleware{}, g.middleware...),
}
}
// HTTP method handlers for Router
func (r *Router) Get(path string, h Handler) error { return r.Handle("GET", path, h) }
func (r *Router) Post(path string, h Handler) error { return r.Handle("POST", path, h) }
func (r *Router) Put(path string, h Handler) error { return r.Handle("PUT", path, h) }
func (r *Router) Patch(path string, h Handler) error { return r.Handle("PATCH", path, h) }
func (r *Router) Delete(path string, h Handler) error { return r.Handle("DELETE", path, h) }
// HTTP method handlers for Group
func (g *Group) Get(path string, h Handler) error { return g.Handle("GET", path, h) }
func (g *Group) Post(path string, h Handler) error { return g.Handle("POST", path, h) }
func (g *Group) Put(path string, h Handler) error { return g.Handle("PUT", path, h) }
func (g *Group) Patch(path string, h Handler) error { return g.Handle("PATCH", path, h) }
func (g *Group) Delete(path string, h Handler) error { return g.Handle("DELETE", path, h) }
// Handle registers a handler for the given method and path
func (r *Router) Handle(method, path string, h Handler) error {
root := r.methodNode(method)
if root == nil {
return fmt.Errorf("unsupported method: %s", method)
}
return r.addRoute(root, path, h, r.middleware)
}
// Handle registers a handler in the group
func (g *Group) Handle(method, path string, h Handler) error {
root := g.router.methodNode(method)
if root == nil {
return fmt.Errorf("unsupported method: %s", method)
}
mw := append([]Middleware{}, g.router.middleware...)
mw = append(mw, g.middleware...)
return g.router.addRoute(root, g.prefix+path, h, mw)
}
func (r *Router) methodNode(method string) *node {
switch method {
case "GET":
return r.get
case "POST":
return r.post
case "PUT":
return r.put
case "PATCH":
return r.patch
case "DELETE":
return r.delete
default:
return nil
}
}
func applyMiddleware(h Handler, mw []Middleware) Handler {
if len(mw) == 0 {
return h
}
return func(ctx Ctx) {
var index int
var next func()
next = func() {
if index >= len(mw) {
h(ctx)
return
}
currentMW := mw[index]
index++
currentMW(ctx, next)
}
next()
}
}
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
if start >= len(path) {
return "", start, false
}
if path[start] == '/' {
start++
}
if start >= len(path) {
return "", start, false
}
end = start
for end < len(path) && path[end] != '/' {
end++
}
return path[start:end], end, end < len(path)
}
func extractParamName(segment string) string {
if len(segment) > 1 && segment[0] == ':' {
return segment[1:]
}
if len(segment) > 1 && segment[0] == '*' {
return segment[1:]
}
return ""
}
func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error {
h = applyMiddleware(h, mw)
if path == "/" {
root.handler = h
return nil
}
current := root
pos := 0
lastWC := false
count := uint8(0)
var paramNames []string
for {
seg, newPos, more := readSegment(path, pos)
if seg == "" {
break
}
isDyn := len(seg) > 1 && seg[0] == ':'
isWC := len(seg) > 0 && seg[0] == '*'
if isWC {
if lastWC || more {
return fmt.Errorf("wildcard must be the last segment in the path")
}
lastWC = true
}
if isDyn || isWC {
count++
paramNames = append(paramNames, extractParamName(seg))
}
var child *node
for _, c := range current.children {
if c.segment == seg {
child = c
break
}
}
if child == nil {
child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
current.children = append(current.children, child)
}
if child.maxParams < count {
child.maxParams = count
}
current = child
pos = newPos
}
current.handler = h
current.paramNames = paramNames
return nil
}
// Lookup finds a handler matching method and path
func (r *Router) Lookup(method, path string) (Handler, []string, []string, bool) {
root := r.methodNode(method)
if root == nil {
return nil, nil, nil, false
}
if path == "/" {
return root.handler, nil, nil, root.handler != nil
}
buffer := r.paramsBuffer
if cap(buffer) < int(root.maxParams) {
buffer = make([]string, root.maxParams)
r.paramsBuffer = buffer
}
buffer = buffer[:0]
h, paramCount, paramNames, found := match(root, path, 0, &buffer)
if !found {
return nil, nil, nil, false
}
return h, buffer[:paramCount], paramNames, true
}
func match(current *node, path string, start int, params *[]string) (Handler, int, []string, bool) {
paramCount := 0
for _, c := range current.children {
if c.isWildcard {
rem := path[start:]
if len(rem) > 0 && rem[0] == '/' {
rem = rem[1:]
}
*params = append(*params, rem)
return c.handler, 1, c.paramNames, c.handler != nil
}
}
seg, pos, more := readSegment(path, start)
if seg == "" {
return current.handler, 0, current.paramNames, current.handler != nil
}
for _, c := range current.children {
if c.segment == seg || c.isDynamic {
if c.isDynamic {
*params = append(*params, seg)
paramCount++
}
if !more {
return c.handler, paramCount, c.paramNames, c.handler != nil
}
h, nestedCount, paramNames, ok := match(c, path, pos, params)
if ok {
return h, paramCount + nestedCount, paramNames, true
}
}
}
return nil, 0, nil, false
}