diff --git a/auth/middleware.go b/auth/middleware.go index 20f903a..0b19b0b 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -9,7 +9,7 @@ const UserCtxKey = "user" // Middleware adds authentication handling func Middleware(userLookup func(int) any) sushi.Middleware { - return func(ctx sushi.Ctx, params []any, next func()) { + return func(ctx sushi.Ctx, next func()) { sess := sushi.GetCurrentSession(ctx) if sess != nil && sess.UserID > 0 && userLookup != nil { user := userLookup(sess.UserID) @@ -31,7 +31,7 @@ func RequireAuth(redirectPath ...string) sushi.Middleware { redirect = redirectPath[0] } - return func(ctx sushi.Ctx, params []any, next func()) { + return func(ctx sushi.Ctx, next func()) { if !ctx.IsAuthenticated() { ctx.Redirect(redirect, fasthttp.StatusFound) return @@ -47,7 +47,7 @@ func RequireGuest(redirectPath ...string) sushi.Middleware { redirect = redirectPath[0] } - return func(ctx sushi.Ctx, params []any, next func()) { + return func(ctx sushi.Ctx, next func()) { if ctx.IsAuthenticated() { ctx.Redirect(redirect, fasthttp.StatusFound) return diff --git a/csrf/csrf.go b/csrf/csrf.go index 7f15fac..65d423f 100644 --- a/csrf/csrf.go +++ b/csrf/csrf.go @@ -110,7 +110,7 @@ func ValidateFormCSRFToken(ctx sushi.Ctx) bool { // Middleware returns middleware that automatically validates CSRF tokens func Middleware() sushi.Middleware { - return func(ctx sushi.Ctx, params []any, next func()) { + return func(ctx sushi.Ctx, next func()) { method := string(ctx.Method()) if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { diff --git a/fs.go b/fs.go index e44c520..85074b5 100644 --- a/fs.go +++ b/fs.go @@ -50,7 +50,7 @@ func StaticFS(fsOptions StaticOptions) Handler { fsHandler := fs.NewRequestHandler() - return func(ctx Ctx, params []any) { + return func(ctx Ctx) { fsHandler(ctx.RequestCtx) } } @@ -62,14 +62,14 @@ func Static(root string) Handler { // StaticFile serves a single file func StaticFile(filePath string) Handler { - return func(ctx Ctx, params []any) { + return func(ctx Ctx) { fasthttp.ServeFile(ctx.RequestCtx, filePath) } } // StaticEmbed creates a handler for embedded files func StaticEmbed(files map[string][]byte) Handler { - return func(ctx Ctx, params []any) { + return func(ctx Ctx) { requestPath := string(ctx.Path()) // Try to find the file diff --git a/params.go b/params.go new file mode 100644 index 0000000..d4c9144 --- /dev/null +++ b/params.go @@ -0,0 +1,114 @@ +package sushi + +import ( + "strconv" + "strings" +) + +const RouteParamsCtxKey = "route_params" + +type ParamValue struct { + value string + exists bool +} + +// RouteParam gets a route parameter by index for chaining +func (ctx Ctx) RouteParam(index int) ParamValue { + if params, ok := ctx.UserValue(RouteParamsCtxKey).([]string); ok { + if index >= 0 && index < len(params) { + return ParamValue{value: params[index], exists: true} + } + } + return ParamValue{value: "", exists: false} +} + +// Param gets a route parameter by name for chaining (requires named params) +func (ctx Ctx) Param(name string) ParamValue { + if paramMap, ok := ctx.UserValue("param_names").(map[string]string); ok { + if value, exists := paramMap[name]; exists { + return ParamValue{value: value, exists: true} + } + } + return ParamValue{value: "", exists: false} +} + +// String returns the value as string +func (p ParamValue) String() string { + return p.value +} + +// StringDefault returns string with default value +func (p ParamValue) StringDefault(defaultValue string) string { + if p.value == "" { + return defaultValue + } + return p.value +} + +// Int returns the value as integer +func (p ParamValue) Int() int { + if p.value == "" { + return 0 + } + if parsed, err := strconv.Atoi(p.value); err == nil { + return parsed + } + return 0 +} + +// IntDefault returns integer with default value +func (p ParamValue) IntDefault(defaultValue int) int { + if p.value == "" { + return defaultValue + } + if parsed, err := strconv.Atoi(p.value); err == nil { + return parsed + } + return defaultValue +} + +// Float returns the value as float64 +func (p ParamValue) Float() float64 { + if p.value == "" { + return 0.0 + } + if parsed, err := strconv.ParseFloat(p.value, 64); err == nil { + return parsed + } + return 0.0 +} + +// FloatDefault returns float64 with default value +func (p ParamValue) FloatDefault(defaultValue float64) float64 { + if p.value == "" { + return defaultValue + } + if parsed, err := strconv.ParseFloat(p.value, 64); err == nil { + return parsed + } + return defaultValue +} + +// Bool returns the value as boolean +func (p ParamValue) Bool() bool { + value := strings.ToLower(p.value) + return value == "true" || value == "on" || value == "1" || value == "yes" +} + +// BoolDefault returns boolean with default value +func (p ParamValue) BoolDefault(defaultValue bool) bool { + if p.value == "" { + return defaultValue + } + return p.Bool() +} + +// Exists returns true if the parameter exists +func (p ParamValue) Exists() bool { + return p.exists +} + +// IsEmpty returns true if the parameter is empty or doesn't exist +func (p ParamValue) IsEmpty() bool { + return p.value == "" +} diff --git a/router.go b/router.go index 5e86dfc..944881a 100644 --- a/router.go +++ b/router.go @@ -12,6 +12,7 @@ type node struct { children []*node isDynamic bool isWildcard bool + paramNames []string maxParams uint8 } @@ -22,7 +23,7 @@ type Router struct { patch *node delete *node middleware []Middleware - paramsBuffer []any + paramsBuffer []string } type Group struct { @@ -40,7 +41,7 @@ func NewRouter() *Router { patch: &node{}, delete: &node{}, middleware: []Middleware{}, - paramsBuffer: make([]any, 64), + paramsBuffer: make([]string, 64), } } @@ -49,13 +50,32 @@ func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) method := string(ctx.Method()) - h, params, found := r.Lookup(method, path) + h, params, paramNames, found := r.Lookup(method, path) if !found { ctx.SetStatusCode(fasthttp.StatusNotFound) return } - h(Ctx{ctx}, params) + // Store params in context + sushiCtx := Ctx{ctx} + if len(params) > 0 { + sushiCtx.SetUserValue(RouteParamsCtxKey, params) + + // Create named param map if param names exist + if len(paramNames) > 0 { + paramMap := make(map[string]string) + for i, name := range paramNames { + if i < len(params) && name != "" { + paramMap[name] = params[i] + } + } + if len(paramMap) > 0 { + sushiCtx.SetUserValue("param_names", paramMap) + } + } + } + + h(sushiCtx) } // Handler returns a fasthttp request handler @@ -145,19 +165,19 @@ func applyMiddleware(h Handler, mw []Middleware) Handler { return h } - return func(ctx Ctx, params []any) { + return func(ctx Ctx) { var index int var next func() next = func() { if index >= len(mw) { - h(ctx, params) + h(ctx) return } currentMW := mw[index] index++ - currentMW(ctx, params, next) + currentMW(ctx, next) } next() @@ -181,6 +201,16 @@ func readSegment(path string, start int) (segment string, end int, hasMore bool) return path[start:end], end, end < len(path) } +func extractParamName(segment string) string { + if len(segment) > 1 && segment[0] == ':' { + return segment[1:] + } + if len(segment) > 1 && segment[0] == '*' { + return segment[1:] + } + return "" +} + func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error { h = applyMiddleware(h, mw) if path == "/" { @@ -191,6 +221,8 @@ func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) e pos := 0 lastWC := false count := uint8(0) + var paramNames []string + for { seg, newPos, more := readSegment(path, pos) if seg == "" { @@ -206,6 +238,7 @@ func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) e } if isDyn || isWC { count++ + paramNames = append(paramNames, extractParamName(seg)) } var child *node for _, c := range current.children { @@ -225,35 +258,36 @@ func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) e pos = newPos } current.handler = h + current.paramNames = paramNames return nil } // Lookup finds a handler matching method and path -func (r *Router) Lookup(method, path string) (Handler, []any, bool) { +func (r *Router) Lookup(method, path string) (Handler, []string, []string, bool) { root := r.methodNode(method) if root == nil { - return nil, nil, false + return nil, nil, nil, false } if path == "/" { - return root.handler, nil, root.handler != nil + return root.handler, nil, nil, root.handler != nil } buffer := r.paramsBuffer if cap(buffer) < int(root.maxParams) { - buffer = make([]any, root.maxParams) + buffer = make([]string, root.maxParams) r.paramsBuffer = buffer } buffer = buffer[:0] - h, paramCount, found := match(root, path, 0, &buffer) + h, paramCount, paramNames, found := match(root, path, 0, &buffer) if !found { - return nil, nil, false + return nil, nil, nil, false } - return h, buffer[:paramCount], true + return h, buffer[:paramCount], paramNames, true } -func match(current *node, path string, start int, params *[]any) (Handler, int, bool) { +func match(current *node, path string, start int, params *[]string) (Handler, int, []string, bool) { paramCount := 0 for _, c := range current.children { @@ -263,13 +297,13 @@ func match(current *node, path string, start int, params *[]any) (Handler, int, rem = rem[1:] } *params = append(*params, rem) - return c.handler, 1, c.handler != nil + return c.handler, 1, c.paramNames, c.handler != nil } } seg, pos, more := readSegment(path, start) if seg == "" { - return current.handler, 0, current.handler != nil + return current.handler, 0, current.paramNames, current.handler != nil } for _, c := range current.children { @@ -279,14 +313,14 @@ func match(current *node, path string, start int, params *[]any) (Handler, int, paramCount++ } if !more { - return c.handler, paramCount, c.handler != nil + return c.handler, paramCount, c.paramNames, c.handler != nil } - h, nestedCount, ok := match(c, path, pos, params) + h, nestedCount, paramNames, ok := match(c, path, pos, params) if ok { - return h, paramCount + nestedCount, true + return h, paramCount + nestedCount, paramNames, true } } } - return nil, 0, false + return nil, 0, nil, false } diff --git a/session/middleware.go b/session/middleware.go index 675cbac..24a0046 100644 --- a/session/middleware.go +++ b/session/middleware.go @@ -4,7 +4,7 @@ import sushi "git.sharkk.net/Sharkk/Sushi" // Middleware provides session handling func Middleware() sushi.Middleware { - return func(ctx sushi.Ctx, params []any, next func()) { + return func(ctx sushi.Ctx, next func()) { sessionID := sushi.GetCookie(ctx, sushi.SessionCookieName) var sess *sushi.Session diff --git a/sushi.go b/sushi.go index 8094167..9991c15 100644 --- a/sushi.go +++ b/sushi.go @@ -8,8 +8,8 @@ import ( "github.com/valyala/fasthttp" ) -func (h Handler) Serve(ctx Ctx, params []any) { - h(ctx, params) +func (h Handler) Serve(ctx Ctx) { + h(ctx) } func IsHTTPS(ctx Ctx) bool { @@ -20,7 +20,7 @@ func IsHTTPS(ctx Ctx) bool { // StandardHandler adapts a standard fasthttp.RequestHandler to the router's Handler func StandardHandler(handler fasthttp.RequestHandler) Handler { - return func(ctx Ctx, _ []any) { + return func(ctx Ctx) { handler(ctx.RequestCtx) } } diff --git a/timing/timing.go b/timing/timing.go index cd82553..b8b95d5 100644 --- a/timing/timing.go +++ b/timing/timing.go @@ -11,7 +11,7 @@ const RequestTimerKey = "request_start_time" // Middleware adds request timing functionality func Middleware() sushi.Middleware { - return func(ctx sushi.Ctx, params []any, next func()) { + return func(ctx sushi.Ctx, next func()) { startTime := time.Now() ctx.SetUserValue(RequestTimerKey, startTime) next() diff --git a/types.go b/types.go index 7510328..ba6c033 100644 --- a/types.go +++ b/types.go @@ -10,8 +10,8 @@ type Ctx struct { *fasthttp.RequestCtx } -type Handler func(ctx Ctx, params []any) -type Middleware func(ctx Ctx, params []any, next func()) +type Handler func(ctx Ctx) +type Middleware func(ctx Ctx, next func()) // SendHTML sends an HTML response func (ctx Ctx) SendHTML(html string) {