minify handlers by moving params to ctx

This commit is contained in:
Sky Johnson 2025-08-18 22:05:05 -05:00
parent 09f66cfaa4
commit f7d344d3e0
9 changed files with 183 additions and 35 deletions

View File

@ -9,7 +9,7 @@ const UserCtxKey = "user"
// Middleware adds authentication handling // Middleware adds authentication handling
func Middleware(userLookup func(int) any) sushi.Middleware { func Middleware(userLookup func(int) any) sushi.Middleware {
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, next func()) {
sess := sushi.GetCurrentSession(ctx) sess := sushi.GetCurrentSession(ctx)
if sess != nil && sess.UserID > 0 && userLookup != nil { if sess != nil && sess.UserID > 0 && userLookup != nil {
user := userLookup(sess.UserID) user := userLookup(sess.UserID)
@ -31,7 +31,7 @@ func RequireAuth(redirectPath ...string) sushi.Middleware {
redirect = redirectPath[0] redirect = redirectPath[0]
} }
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, next func()) {
if !ctx.IsAuthenticated() { if !ctx.IsAuthenticated() {
ctx.Redirect(redirect, fasthttp.StatusFound) ctx.Redirect(redirect, fasthttp.StatusFound)
return return
@ -47,7 +47,7 @@ func RequireGuest(redirectPath ...string) sushi.Middleware {
redirect = redirectPath[0] redirect = redirectPath[0]
} }
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, next func()) {
if ctx.IsAuthenticated() { if ctx.IsAuthenticated() {
ctx.Redirect(redirect, fasthttp.StatusFound) ctx.Redirect(redirect, fasthttp.StatusFound)
return return

View File

@ -110,7 +110,7 @@ func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
// Middleware returns middleware that automatically validates CSRF tokens // Middleware returns middleware that automatically validates CSRF tokens
func Middleware() sushi.Middleware { func Middleware() sushi.Middleware {
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, next func()) {
method := string(ctx.Method()) method := string(ctx.Method())
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {

6
fs.go
View File

@ -50,7 +50,7 @@ func StaticFS(fsOptions StaticOptions) Handler {
fsHandler := fs.NewRequestHandler() fsHandler := fs.NewRequestHandler()
return func(ctx Ctx, params []any) { return func(ctx Ctx) {
fsHandler(ctx.RequestCtx) fsHandler(ctx.RequestCtx)
} }
} }
@ -62,14 +62,14 @@ func Static(root string) Handler {
// StaticFile serves a single file // StaticFile serves a single file
func StaticFile(filePath string) Handler { func StaticFile(filePath string) Handler {
return func(ctx Ctx, params []any) { return func(ctx Ctx) {
fasthttp.ServeFile(ctx.RequestCtx, filePath) fasthttp.ServeFile(ctx.RequestCtx, filePath)
} }
} }
// StaticEmbed creates a handler for embedded files // StaticEmbed creates a handler for embedded files
func StaticEmbed(files map[string][]byte) Handler { func StaticEmbed(files map[string][]byte) Handler {
return func(ctx Ctx, params []any) { return func(ctx Ctx) {
requestPath := string(ctx.Path()) requestPath := string(ctx.Path())
// Try to find the file // Try to find the file

114
params.go Normal file
View File

@ -0,0 +1,114 @@
package sushi
import (
"strconv"
"strings"
)
const RouteParamsCtxKey = "route_params"
type ParamValue struct {
value string
exists bool
}
// RouteParam gets a route parameter by index for chaining
func (ctx Ctx) RouteParam(index int) ParamValue {
if params, ok := ctx.UserValue(RouteParamsCtxKey).([]string); ok {
if index >= 0 && index < len(params) {
return ParamValue{value: params[index], exists: true}
}
}
return ParamValue{value: "", exists: false}
}
// Param gets a route parameter by name for chaining (requires named params)
func (ctx Ctx) Param(name string) ParamValue {
if paramMap, ok := ctx.UserValue("param_names").(map[string]string); ok {
if value, exists := paramMap[name]; exists {
return ParamValue{value: value, exists: true}
}
}
return ParamValue{value: "", exists: false}
}
// String returns the value as string
func (p ParamValue) String() string {
return p.value
}
// StringDefault returns string with default value
func (p ParamValue) StringDefault(defaultValue string) string {
if p.value == "" {
return defaultValue
}
return p.value
}
// Int returns the value as integer
func (p ParamValue) Int() int {
if p.value == "" {
return 0
}
if parsed, err := strconv.Atoi(p.value); err == nil {
return parsed
}
return 0
}
// IntDefault returns integer with default value
func (p ParamValue) IntDefault(defaultValue int) int {
if p.value == "" {
return defaultValue
}
if parsed, err := strconv.Atoi(p.value); err == nil {
return parsed
}
return defaultValue
}
// Float returns the value as float64
func (p ParamValue) Float() float64 {
if p.value == "" {
return 0.0
}
if parsed, err := strconv.ParseFloat(p.value, 64); err == nil {
return parsed
}
return 0.0
}
// FloatDefault returns float64 with default value
func (p ParamValue) FloatDefault(defaultValue float64) float64 {
if p.value == "" {
return defaultValue
}
if parsed, err := strconv.ParseFloat(p.value, 64); err == nil {
return parsed
}
return defaultValue
}
// Bool returns the value as boolean
func (p ParamValue) Bool() bool {
value := strings.ToLower(p.value)
return value == "true" || value == "on" || value == "1" || value == "yes"
}
// BoolDefault returns boolean with default value
func (p ParamValue) BoolDefault(defaultValue bool) bool {
if p.value == "" {
return defaultValue
}
return p.Bool()
}
// Exists returns true if the parameter exists
func (p ParamValue) Exists() bool {
return p.exists
}
// IsEmpty returns true if the parameter is empty or doesn't exist
func (p ParamValue) IsEmpty() bool {
return p.value == ""
}

View File

@ -12,6 +12,7 @@ type node struct {
children []*node children []*node
isDynamic bool isDynamic bool
isWildcard bool isWildcard bool
paramNames []string
maxParams uint8 maxParams uint8
} }
@ -22,7 +23,7 @@ type Router struct {
patch *node patch *node
delete *node delete *node
middleware []Middleware middleware []Middleware
paramsBuffer []any paramsBuffer []string
} }
type Group struct { type Group struct {
@ -40,7 +41,7 @@ func NewRouter() *Router {
patch: &node{}, patch: &node{},
delete: &node{}, delete: &node{},
middleware: []Middleware{}, middleware: []Middleware{},
paramsBuffer: make([]any, 64), paramsBuffer: make([]string, 64),
} }
} }
@ -49,13 +50,32 @@ func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path()) path := string(ctx.Path())
method := string(ctx.Method()) method := string(ctx.Method())
h, params, found := r.Lookup(method, path) h, params, paramNames, found := r.Lookup(method, path)
if !found { if !found {
ctx.SetStatusCode(fasthttp.StatusNotFound) ctx.SetStatusCode(fasthttp.StatusNotFound)
return return
} }
h(Ctx{ctx}, params) // Store params in context
sushiCtx := Ctx{ctx}
if len(params) > 0 {
sushiCtx.SetUserValue(RouteParamsCtxKey, params)
// Create named param map if param names exist
if len(paramNames) > 0 {
paramMap := make(map[string]string)
for i, name := range paramNames {
if i < len(params) && name != "" {
paramMap[name] = params[i]
}
}
if len(paramMap) > 0 {
sushiCtx.SetUserValue("param_names", paramMap)
}
}
}
h(sushiCtx)
} }
// Handler returns a fasthttp request handler // Handler returns a fasthttp request handler
@ -145,19 +165,19 @@ func applyMiddleware(h Handler, mw []Middleware) Handler {
return h return h
} }
return func(ctx Ctx, params []any) { return func(ctx Ctx) {
var index int var index int
var next func() var next func()
next = func() { next = func() {
if index >= len(mw) { if index >= len(mw) {
h(ctx, params) h(ctx)
return return
} }
currentMW := mw[index] currentMW := mw[index]
index++ index++
currentMW(ctx, params, next) currentMW(ctx, next)
} }
next() next()
@ -181,6 +201,16 @@ func readSegment(path string, start int) (segment string, end int, hasMore bool)
return path[start:end], end, end < len(path) return path[start:end], end, end < len(path)
} }
func extractParamName(segment string) string {
if len(segment) > 1 && segment[0] == ':' {
return segment[1:]
}
if len(segment) > 1 && segment[0] == '*' {
return segment[1:]
}
return ""
}
func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error { func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error {
h = applyMiddleware(h, mw) h = applyMiddleware(h, mw)
if path == "/" { if path == "/" {
@ -191,6 +221,8 @@ func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) e
pos := 0 pos := 0
lastWC := false lastWC := false
count := uint8(0) count := uint8(0)
var paramNames []string
for { for {
seg, newPos, more := readSegment(path, pos) seg, newPos, more := readSegment(path, pos)
if seg == "" { if seg == "" {
@ -206,6 +238,7 @@ func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) e
} }
if isDyn || isWC { if isDyn || isWC {
count++ count++
paramNames = append(paramNames, extractParamName(seg))
} }
var child *node var child *node
for _, c := range current.children { for _, c := range current.children {
@ -225,35 +258,36 @@ func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) e
pos = newPos pos = newPos
} }
current.handler = h current.handler = h
current.paramNames = paramNames
return nil return nil
} }
// Lookup finds a handler matching method and path // Lookup finds a handler matching method and path
func (r *Router) Lookup(method, path string) (Handler, []any, bool) { func (r *Router) Lookup(method, path string) (Handler, []string, []string, bool) {
root := r.methodNode(method) root := r.methodNode(method)
if root == nil { if root == nil {
return nil, nil, false return nil, nil, nil, false
} }
if path == "/" { if path == "/" {
return root.handler, nil, root.handler != nil return root.handler, nil, nil, root.handler != nil
} }
buffer := r.paramsBuffer buffer := r.paramsBuffer
if cap(buffer) < int(root.maxParams) { if cap(buffer) < int(root.maxParams) {
buffer = make([]any, root.maxParams) buffer = make([]string, root.maxParams)
r.paramsBuffer = buffer r.paramsBuffer = buffer
} }
buffer = buffer[:0] buffer = buffer[:0]
h, paramCount, found := match(root, path, 0, &buffer) h, paramCount, paramNames, found := match(root, path, 0, &buffer)
if !found { if !found {
return nil, nil, false return nil, nil, nil, false
} }
return h, buffer[:paramCount], true return h, buffer[:paramCount], paramNames, true
} }
func match(current *node, path string, start int, params *[]any) (Handler, int, bool) { func match(current *node, path string, start int, params *[]string) (Handler, int, []string, bool) {
paramCount := 0 paramCount := 0
for _, c := range current.children { for _, c := range current.children {
@ -263,13 +297,13 @@ func match(current *node, path string, start int, params *[]any) (Handler, int,
rem = rem[1:] rem = rem[1:]
} }
*params = append(*params, rem) *params = append(*params, rem)
return c.handler, 1, c.handler != nil return c.handler, 1, c.paramNames, c.handler != nil
} }
} }
seg, pos, more := readSegment(path, start) seg, pos, more := readSegment(path, start)
if seg == "" { if seg == "" {
return current.handler, 0, current.handler != nil return current.handler, 0, current.paramNames, current.handler != nil
} }
for _, c := range current.children { for _, c := range current.children {
@ -279,14 +313,14 @@ func match(current *node, path string, start int, params *[]any) (Handler, int,
paramCount++ paramCount++
} }
if !more { if !more {
return c.handler, paramCount, c.handler != nil return c.handler, paramCount, c.paramNames, c.handler != nil
} }
h, nestedCount, ok := match(c, path, pos, params) h, nestedCount, paramNames, ok := match(c, path, pos, params)
if ok { if ok {
return h, paramCount + nestedCount, true return h, paramCount + nestedCount, paramNames, true
} }
} }
} }
return nil, 0, false return nil, 0, nil, false
} }

View File

@ -4,7 +4,7 @@ import sushi "git.sharkk.net/Sharkk/Sushi"
// Middleware provides session handling // Middleware provides session handling
func Middleware() sushi.Middleware { func Middleware() sushi.Middleware {
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, next func()) {
sessionID := sushi.GetCookie(ctx, sushi.SessionCookieName) sessionID := sushi.GetCookie(ctx, sushi.SessionCookieName)
var sess *sushi.Session var sess *sushi.Session

View File

@ -8,8 +8,8 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
func (h Handler) Serve(ctx Ctx, params []any) { func (h Handler) Serve(ctx Ctx) {
h(ctx, params) h(ctx)
} }
func IsHTTPS(ctx Ctx) bool { func IsHTTPS(ctx Ctx) bool {
@ -20,7 +20,7 @@ func IsHTTPS(ctx Ctx) bool {
// StandardHandler adapts a standard fasthttp.RequestHandler to the router's Handler // StandardHandler adapts a standard fasthttp.RequestHandler to the router's Handler
func StandardHandler(handler fasthttp.RequestHandler) Handler { func StandardHandler(handler fasthttp.RequestHandler) Handler {
return func(ctx Ctx, _ []any) { return func(ctx Ctx) {
handler(ctx.RequestCtx) handler(ctx.RequestCtx)
} }
} }

View File

@ -11,7 +11,7 @@ const RequestTimerKey = "request_start_time"
// Middleware adds request timing functionality // Middleware adds request timing functionality
func Middleware() sushi.Middleware { func Middleware() sushi.Middleware {
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, next func()) {
startTime := time.Now() startTime := time.Now()
ctx.SetUserValue(RequestTimerKey, startTime) ctx.SetUserValue(RequestTimerKey, startTime)
next() next()

View File

@ -10,8 +10,8 @@ type Ctx struct {
*fasthttp.RequestCtx *fasthttp.RequestCtx
} }
type Handler func(ctx Ctx, params []any) type Handler func(ctx Ctx)
type Middleware func(ctx Ctx, params []any, next func()) type Middleware func(ctx Ctx, next func())
// SendHTML sends an HTML response // SendHTML sends an HTML response
func (ctx Ctx) SendHTML(html string) { func (ctx Ctx) SendHTML(html string) {