package sushi import ( "fmt" "github.com/valyala/fasthttp" ) type node struct { segment string handler Handler children []*node isDynamic bool isWildcard bool paramNames []string maxParams uint8 } type Router struct { get *node post *node put *node patch *node delete *node middleware []Middleware paramsBuffer []string } type Group struct { router *Router prefix string middleware []Middleware } // New creates a new Router instance func NewRouter() *Router { return &Router{ get: &node{}, post: &node{}, put: &node{}, patch: &node{}, delete: &node{}, middleware: []Middleware{}, paramsBuffer: make([]string, 64), } } // ServeHTTP implements the Handler interface for fasthttp func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) method := string(ctx.Method()) h, params, paramNames, found := r.Lookup(method, path) if !found { ctx.SetStatusCode(fasthttp.StatusNotFound) return } // Store params in context sushiCtx := Ctx{ctx} if len(params) > 0 { sushiCtx.SetUserValue(RouteParamsCtxKey, params) // Create named param map if param names exist if len(paramNames) > 0 { paramMap := make(map[string]string) for i, name := range paramNames { if i < len(params) && name != "" { paramMap[name] = params[i] } } if len(paramMap) > 0 { sushiCtx.SetUserValue("param_names", paramMap) } } } h(sushiCtx) } // Handler returns a fasthttp request handler func (r *Router) Handler() fasthttp.RequestHandler { return r.ServeHTTP } // Use adds middleware to the router func (r *Router) Use(mw ...Middleware) *Router { r.middleware = append(r.middleware, mw...) return r } // Group creates a new route group func (r *Router) Group(prefix string) *Group { return &Group{router: r, prefix: prefix, middleware: []Middleware{}} } // Use adds middleware to the group func (g *Group) Use(mw ...Middleware) *Group { g.middleware = append(g.middleware, mw...) return g } // Group creates a nested group func (g *Group) Group(prefix string) *Group { return &Group{ router: g.router, prefix: g.prefix + prefix, middleware: append([]Middleware{}, g.middleware...), } } // HTTP method handlers for Router func (r *Router) Get(path string, h Handler) error { return r.Handle("GET", path, h) } func (r *Router) Post(path string, h Handler) error { return r.Handle("POST", path, h) } func (r *Router) Put(path string, h Handler) error { return r.Handle("PUT", path, h) } func (r *Router) Patch(path string, h Handler) error { return r.Handle("PATCH", path, h) } func (r *Router) Delete(path string, h Handler) error { return r.Handle("DELETE", path, h) } // HTTP method handlers for Group func (g *Group) Get(path string, h Handler) error { return g.Handle("GET", path, h) } func (g *Group) Post(path string, h Handler) error { return g.Handle("POST", path, h) } func (g *Group) Put(path string, h Handler) error { return g.Handle("PUT", path, h) } func (g *Group) Patch(path string, h Handler) error { return g.Handle("PATCH", path, h) } func (g *Group) Delete(path string, h Handler) error { return g.Handle("DELETE", path, h) } // Handle registers a handler for the given method and path func (r *Router) Handle(method, path string, h Handler) error { root := r.methodNode(method) if root == nil { return fmt.Errorf("unsupported method: %s", method) } return r.addRoute(root, path, h, r.middleware) } // Handle registers a handler in the group func (g *Group) Handle(method, path string, h Handler) error { root := g.router.methodNode(method) if root == nil { return fmt.Errorf("unsupported method: %s", method) } mw := append([]Middleware{}, g.router.middleware...) mw = append(mw, g.middleware...) return g.router.addRoute(root, g.prefix+path, h, mw) } func (r *Router) methodNode(method string) *node { switch method { case "GET": return r.get case "POST": return r.post case "PUT": return r.put case "PATCH": return r.patch case "DELETE": return r.delete default: return nil } } func applyMiddleware(h Handler, mw []Middleware) Handler { if len(mw) == 0 { return h } return func(ctx Ctx) { var index int var next func() next = func() { if index >= len(mw) { h(ctx) return } currentMW := mw[index] index++ currentMW(ctx, next) } next() } } func readSegment(path string, start int) (segment string, end int, hasMore bool) { if start >= len(path) { return "", start, false } if path[start] == '/' { start++ } if start >= len(path) { return "", start, false } end = start for end < len(path) && path[end] != '/' { end++ } return path[start:end], end, end < len(path) } func extractParamName(segment string) string { if len(segment) > 1 && segment[0] == ':' { return segment[1:] } if len(segment) > 1 && segment[0] == '*' { return segment[1:] } return "" } func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error { h = applyMiddleware(h, mw) if path == "/" { root.handler = h return nil } current := root pos := 0 lastWC := false count := uint8(0) var paramNames []string for { seg, newPos, more := readSegment(path, pos) if seg == "" { break } isDyn := len(seg) > 1 && seg[0] == ':' isWC := len(seg) > 0 && seg[0] == '*' if isWC { if lastWC || more { return fmt.Errorf("wildcard must be the last segment in the path") } lastWC = true } if isDyn || isWC { count++ paramNames = append(paramNames, extractParamName(seg)) } var child *node for _, c := range current.children { if c.segment == seg { child = c break } } if child == nil { child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC} current.children = append(current.children, child) } if child.maxParams < count { child.maxParams = count } current = child pos = newPos } current.handler = h current.paramNames = paramNames return nil } // Lookup finds a handler matching method and path func (r *Router) Lookup(method, path string) (Handler, []string, []string, bool) { root := r.methodNode(method) if root == nil { return nil, nil, nil, false } if path == "/" { return root.handler, nil, nil, root.handler != nil } buffer := r.paramsBuffer if cap(buffer) < int(root.maxParams) { buffer = make([]string, root.maxParams) r.paramsBuffer = buffer } buffer = buffer[:0] h, paramCount, paramNames, found := match(root, path, 0, &buffer) if !found { return nil, nil, nil, false } return h, buffer[:paramCount], paramNames, true } func match(current *node, path string, start int, params *[]string) (Handler, int, []string, bool) { paramCount := 0 for _, c := range current.children { if c.isWildcard { rem := path[start:] if len(rem) > 0 && rem[0] == '/' { rem = rem[1:] } *params = append(*params, rem) return c.handler, 1, c.paramNames, c.handler != nil } } seg, pos, more := readSegment(path, start) if seg == "" { return current.handler, 0, current.paramNames, current.handler != nil } for _, c := range current.children { if c.segment == seg || c.isDynamic { if c.isDynamic { *params = append(*params, seg) paramCount++ } if !more { return c.handler, paramCount, c.paramNames, c.handler != nil } h, nestedCount, paramNames, ok := match(c, path, pos, params) if ok { return h, paramCount + nestedCount, paramNames, true } } } return nil, 0, nil, false }