revert to string router

This commit is contained in:
Sky Johnson 2025-06-06 22:25:19 -05:00
parent d44a9b5b28
commit 2c731b9cbf
4 changed files with 432 additions and 139 deletions

View File

@ -125,7 +125,7 @@ func requestMux(ctx *fasthttp.RequestCtx) {
}
// See if the requested route even exists
bytecode, params, found := rtr.Lookup(method, path)
bytecode, params, found := rtr.Lookup(string(method), string(path))
if !found {
http.Send404(ctx)
logRequest(ctx, method, path, start)

View File

@ -1,7 +1,6 @@
package router
import (
"bytes"
"errors"
"os"
"path/filepath"
@ -12,18 +11,9 @@ import (
"github.com/VictoriaMetrics/fastcache"
)
var (
slash = []byte("/")
get = []byte("GET")
post = []byte("POST")
put = []byte("PUT")
patch = []byte("PATCH")
delete = []byte("DELETE")
)
// node represents a node in the radix trie
type node struct {
segment []byte
segment string
bytecode []byte
scriptPath string
children []*node
@ -40,7 +30,7 @@ type Router struct {
compileState *luajit.State
compileMu sync.Mutex
paramsBuffer []string
middlewareFiles map[string][]string
middlewareFiles map[string][]string // filesystem path -> middleware file paths
}
// Params holds URL parameters
@ -81,7 +71,7 @@ func New(routesDir string) (*Router, error) {
put: &node{},
patch: &node{},
delete: &node{},
bytecodeCache: fastcache.New(32 * 1024 * 1024),
bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB
compileState: compileState,
paramsBuffer: make([]string, 64),
middlewareFiles: make(map[string][]string),
@ -91,17 +81,17 @@ func New(routesDir string) (*Router, error) {
}
// methodNode returns the root node for a method
func (r *Router) methodNode(method []byte) *node {
switch {
case bytes.Equal(method, get):
func (r *Router) methodNode(method string) *node {
switch method {
case "GET":
return r.get
case bytes.Equal(method, post):
case "POST":
return r.post
case bytes.Equal(method, put):
case "PUT":
return r.put
case bytes.Equal(method, patch):
case "PATCH":
return r.patch
case bytes.Equal(method, delete):
case "DELETE":
return r.delete
default:
return nil
@ -118,8 +108,7 @@ func (r *Router) buildRoutes() error {
return err
}
fileName := strings.TrimSuffix(info.Name(), ".lua")
if fileName == "middleware" {
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
if err != nil {
return err
@ -132,6 +121,7 @@ func (r *Router) buildRoutes() error {
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
}
return nil
})
@ -146,36 +136,39 @@ func (r *Router) buildRoutes() error {
}
fileName := strings.TrimSuffix(info.Name(), ".lua")
// Skip middleware files
if fileName == "middleware" {
return nil
}
// Get relative path from routes directory
relPath, err := filepath.Rel(r.routesDir, path)
if err != nil {
return err
}
// Get filesystem path (includes groups)
fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/")
if fsPath == "/." {
fsPath = "/"
}
// Get URL path (excludes groups)
urlPath := r.parseURLPath(fsPath)
urlPathBytes := []byte(urlPath)
// Handle method files
methodBytes := []byte(strings.ToUpper(fileName))
root := r.methodNode(methodBytes)
// Handle method files (get.lua, post.lua, etc.)
method := strings.ToUpper(fileName)
root := r.methodNode(method)
if root != nil {
return r.addRoute(root, urlPathBytes, fsPath, path)
return r.addRoute(root, urlPath, fsPath, path)
}
// Handle index files
// Handle index files - register for all methods
if fileName == "index" {
methods := [][]byte{get, post, put, patch, delete}
for _, method := range methods {
for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} {
if root := r.methodNode(method); root != nil {
if err := r.addRoute(root, urlPathBytes, fsPath, path); err != nil {
if err := r.addRoute(root, urlPath, fsPath, path); err != nil {
return err
}
}
@ -183,13 +176,12 @@ func (r *Router) buildRoutes() error {
return nil
}
// Handle named route files
var namedPath []byte
// Handle named route files - register as GET by default
namedPath := urlPath
if urlPath == "/" {
namedPath = append(slash, fileName...)
namedPath = "/" + fileName
} else {
namedPath = append(urlPathBytes, '/')
namedPath = append(namedPath, fileName...)
namedPath = urlPath + "/" + fileName
}
return r.addRoute(r.get, namedPath, fsPath, path)
})
@ -204,6 +196,7 @@ func (r *Router) parseURLPath(fsPath string) string {
if segment == "" {
continue
}
// Skip group segments (enclosed in parentheses)
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
continue
}
@ -225,10 +218,12 @@ func (r *Router) getMiddlewareChain(fsPath string) []string {
pathParts = []string{}
}
// Add root middleware
if mw, exists := r.middlewareFiles["/"]; exists {
chain = append(chain, mw...)
}
// Add middleware from each path level (including groups)
currentPath := ""
for _, part := range pathParts {
currentPath += "/" + part
@ -244,6 +239,7 @@ func (r *Router) getMiddlewareChain(fsPath string) []string {
func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) {
var combined strings.Builder
// Add middleware in order
middlewareChain := r.getMiddlewareChain(fsPath)
for _, mwPath := range middlewareChain {
content, err := os.ReadFile(mwPath)
@ -257,6 +253,7 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error)
combined.WriteString("\n")
}
// Add main handler
content, err := os.ReadFile(scriptPath)
if err != nil {
return "", err
@ -270,12 +267,14 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error)
}
// addRoute adds a new route to the trie with bytecode compilation
func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) error {
func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error {
// Build combined source with middleware
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
if err != nil {
return err
}
// Compile bytecode
r.compileMu.Lock()
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
r.compileMu.Unlock()
@ -284,10 +283,11 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
return err
}
// Cache bytecode
cacheKey := hashString(scriptPath)
r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode)
if len(urlPath) == 1 && urlPath[0] == '/' {
if urlPath == "/" {
root.bytecode = bytecode
root.scriptPath = scriptPath
return nil
@ -298,8 +298,8 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
paramCount := uint8(0)
for {
seg, newPos, more := readSegmentBytes(urlPath, pos)
if len(seg) == 0 {
seg, newPos, more := readSegment(urlPath, pos)
if seg == "" {
break
}
@ -314,9 +314,10 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
paramCount++
}
// Find or create child
var child *node
for _, c := range current.children {
if bytes.Equal(c.segment, seg) {
if c.segment == seg {
child = c
break
}
@ -324,7 +325,7 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
if child == nil {
child = &node{
segment: append([]byte(nil), seg...),
segment: seg,
isDynamic: isDyn,
isWildcard: isWC,
}
@ -344,16 +345,16 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
return nil
}
// readSegmentBytes extracts the next path segment from byte slice
func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore bool) {
// readSegment extracts the next path segment
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
if start >= len(path) {
return nil, start, false
return "", start, false
}
if path[start] == '/' {
start++
}
if start >= len(path) {
return nil, start, false
return "", start, false
}
end = start
for end < len(path) && path[end] != '/' {
@ -363,19 +364,20 @@ func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore
}
// Lookup finds bytecode and parameters for a method and path
func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) {
func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) {
root := r.methodNode(method)
if root == nil {
return nil, nil, false
}
if len(path) == 1 && path[0] == '/' {
if path == "/" {
if root.bytecode != nil {
return root.bytecode, &Params{}, true
}
return nil, nil, false
}
// Prepare params buffer
buffer := r.paramsBuffer
if cap(buffer) < int(root.maxParams) {
buffer = make([]string, root.maxParams)
@ -398,7 +400,7 @@ func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) {
}
// match traverses the trie to find bytecode
func (r *Router) match(current *node, path []byte, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
paramCount := 0
// Check wildcard first
@ -408,23 +410,22 @@ func (r *Router) match(current *node, path []byte, start int, params *[]string,
if len(rem) > 0 && rem[0] == '/' {
rem = rem[1:]
}
*params = append(*params, string(rem))
paramName := string(c.segment[1:]) // Remove *
*keys = append(*keys, paramName)
*params = append(*params, rem)
*keys = append(*keys, strings.TrimPrefix(c.segment, "*"))
return c.bytecode, 1, c.bytecode != nil
}
}
seg, pos, more := readSegmentBytes(path, start)
if len(seg) == 0 {
seg, pos, more := readSegment(path, start)
if seg == "" {
return current.bytecode, 0, current.bytecode != nil
}
for _, c := range current.children {
if bytes.Equal(c.segment, seg) || c.isDynamic {
if c.segment == seg || c.isDynamic {
if c.isDynamic {
*params = append(*params, string(seg))
paramName := string(c.segment[1 : len(c.segment)-1]) // Remove [ ]
*params = append(*params, seg)
paramName := c.segment[1 : len(c.segment)-1] // Remove [ ]
*keys = append(*keys, paramName)
paramCount++
}
@ -477,6 +478,7 @@ func (r *Router) Close() {
r.compileMu.Unlock()
}
// Helper functions from cache.go
func hashString(s string) uint64 {
h := uint64(5381)
for i := 0; i < len(s); i++ {

318
router/router_test.go Normal file
View File

@ -0,0 +1,318 @@
package router
import (
"os"
"path/filepath"
"testing"
)
func setupTestRoutes(t testing.TB) string {
t.Helper()
tempDir := t.TempDir()
// Create test route files
routes := map[string]string{
"index.lua": `return "home"`,
"about.lua": `return "about"`,
"api/users.lua": `return "users"`,
"api/users/get.lua": `return "get_users"`,
"api/users/post.lua": `return "create_user"`,
"api/users/[id].lua": `return "user_" .. id`,
"api/posts/[slug]/comments.lua": `return "comments_" .. slug`,
"files/*path.lua": `return "file_" .. path`,
"middleware.lua": `-- root middleware`,
"api/middleware.lua": `-- api middleware`,
}
for path, content := range routes {
fullPath := filepath.Join(tempDir, path)
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil {
t.Fatal(err)
}
}
return tempDir
}
func TestRouterBasicFunctionality(t *testing.T) {
routesDir := setupTestRoutes(t)
router, err := New(routesDir)
if err != nil {
t.Fatal(err)
}
defer router.Close()
tests := []struct {
method string
path string
expected bool
params map[string]string
}{
{"GET", "/", true, nil},
{"GET", "/about", true, nil},
{"GET", "/api/users", true, nil},
{"GET", "/api/users", true, nil},
{"POST", "/api/users", true, nil},
{"GET", "/api/users/123", true, map[string]string{"id": "123"}},
{"GET", "/api/posts/hello-world/comments", true, map[string]string{"slug": "hello-world"}},
{"GET", "/files/docs/readme.txt", true, map[string]string{"path": "docs/readme.txt"}},
{"GET", "/nonexistent", false, nil},
{"DELETE", "/api/users", false, nil},
}
for _, tt := range tests {
t.Run(tt.method+"_"+tt.path, func(t *testing.T) {
bytecode, params, found := router.Lookup(tt.method, tt.path)
if found != tt.expected {
t.Errorf("expected found=%v, got %v", tt.expected, found)
}
if tt.expected {
if bytecode == nil {
t.Error("expected bytecode, got nil")
}
if tt.params != nil {
for key, expectedValue := range tt.params {
if actualValue := params.Get(key); actualValue != expectedValue {
t.Errorf("param %s: expected %s, got %s", key, expectedValue, actualValue)
}
}
}
}
})
}
}
func TestRouterParamsStruct(t *testing.T) {
params := &Params{
Keys: []string{"id", "slug"},
Values: []string{"123", "hello"},
}
if params.Get("id") != "123" {
t.Errorf("expected '123', got '%s'", params.Get("id"))
}
if params.Get("slug") != "hello" {
t.Errorf("expected 'hello', got '%s'", params.Get("slug"))
}
if params.Get("missing") != "" {
t.Errorf("expected empty string for missing param, got '%s'", params.Get("missing"))
}
}
func TestRouterMethodNodes(t *testing.T) {
routesDir := setupTestRoutes(t)
router, err := New(routesDir)
if err != nil {
t.Fatal(err)
}
defer router.Close()
// Test that different methods work independently
_, _, foundGet := router.Lookup("GET", "/api/users")
_, _, foundPost := router.Lookup("POST", "/api/users")
_, _, foundPut := router.Lookup("PUT", "/api/users")
if !foundGet {
t.Error("GET /api/users should be found")
}
if !foundPost {
t.Error("POST /api/users should be found")
}
if foundPut {
t.Error("PUT /api/users should not be found")
}
}
func TestRouterWildcardValidation(t *testing.T) {
tempDir := t.TempDir()
// Create invalid wildcard route (not at end)
invalidPath := filepath.Join(tempDir, "bad/*path/more.lua")
if err := os.MkdirAll(filepath.Dir(invalidPath), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(invalidPath, []byte(`return "bad"`), 0644); err != nil {
t.Fatal(err)
}
_, err := New(tempDir)
if err == nil {
t.Error("expected error for wildcard not at end")
}
}
func BenchmarkLookupStatic(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
method := "GET"
path := "/api/users"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _, _ = router.Lookup(method, path)
}
}
func BenchmarkLookupDynamic(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
method := "GET"
path := "/api/users/12345"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _, _ = router.Lookup(method, path)
}
}
func BenchmarkLookupWildcard(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
method := "GET"
path := "/files/docs/deep/nested/file.txt"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _, _ = router.Lookup(method, path)
}
}
func BenchmarkLookupComplex(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
method := "GET"
path := "/api/posts/my-blog-post-title/comments"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _, _ = router.Lookup(method, path)
}
}
func BenchmarkLookupNotFound(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
method := "GET"
path := "/this/path/does/not/exist"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _, _ = router.Lookup(method, path)
}
}
func BenchmarkLookupMixed(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
paths := []string{
"/",
"/about",
"/api/users",
"/api/users/123",
"/api/posts/hello/comments",
"/files/document.pdf",
"/nonexistent",
}
method := "GET"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
path := paths[i%len(paths)]
_, _, _ = router.Lookup(method, path)
}
}
// Comparison benchmarks for string vs byte slice performance
func BenchmarkLookupStringConversion(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
methodStr := "GET"
pathStr := "/api/users/12345"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Direct string usage
_, _, _ = router.Lookup(methodStr, pathStr)
}
}
func BenchmarkLookupPreallocated(b *testing.B) {
routesDir := setupTestRoutes(b)
router, err := New(routesDir)
if err != nil {
b.Fatal(err)
}
defer router.Close()
// Pre-allocated strings (optimal case)
method := "GET"
path := "/api/users/12345"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _, _ = router.Lookup(method, path)
}
}

View File

@ -12,31 +12,30 @@ import (
// A prebuilt, ready-to-go context for HTTP requests to the runner.
type HTTPContext struct {
Method []byte
Path []byte
Host []byte
Headers map[string]any
Cookies map[string]string
Query map[string]string
Params map[string]any
Form map[string]any
Session map[string]any
Env map[string]any
Values map[string]any // Extra context vars just in case
Values map[string]any // Contains all context data for Lua
// Separate maps for efficient access during context building
headers map[string]string
cookies map[string]string
query map[string]string
params map[string]string
form map[string]any
session map[string]any
env map[string]any
}
// HTTP context pool to reduce allocations
var httpContextPool = sync.Pool{
New: func() any {
return &HTTPContext{
Headers: make(map[string]any, 16),
Cookies: make(map[string]string, 8),
Query: make(map[string]string, 8),
Params: make(map[string]any, 4),
Form: make(map[string]any, 8),
Session: make(map[string]any, 4),
Env: make(map[string]any, 16),
Values: make(map[string]any, 32),
headers: make(map[string]string, 16),
cookies: make(map[string]string, 8),
query: make(map[string]string, 8),
params: make(map[string]string, 4),
form: make(map[string]any, 8),
session: make(map[string]any, 4),
env: make(map[string]any, 16),
}
},
}
@ -45,30 +44,25 @@ var httpContextPool = sync.Pool{
func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *HTTPContext {
ctx := httpContextPool.Get().(*HTTPContext)
// Extract basic HTTP info
ctx.Method = httpCtx.Method()
ctx.Path = httpCtx.Path()
ctx.Host = httpCtx.Host()
// Extract headers
httpCtx.Request.Header.VisitAll(func(key, value []byte) {
ctx.Headers[string(key)] = string(value)
ctx.headers[string(key)] = string(value)
})
// Extract cookies
httpCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
ctx.Cookies[string(key)] = string(value)
ctx.cookies[string(key)] = string(value)
})
// Extract query params
httpCtx.QueryArgs().VisitAll(func(key, value []byte) {
ctx.Query[string(key)] = string(value)
ctx.query[string(key)] = string(value)
})
// Extract route parameters
if params != nil {
for i := 0; i < min(len(params.Keys), len(params.Values)); i++ {
ctx.Params[params.Keys[i]] = params.Values[i]
for i := range min(len(params.Keys), len(params.Values)) {
ctx.params[params.Keys[i]] = params.Values[i]
}
}
@ -76,90 +70,69 @@ func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session
if httpCtx.IsPost() || httpCtx.IsPut() || httpCtx.IsPatch() {
if form, err := utils.ParseForm(httpCtx); err == nil {
for k, v := range form {
ctx.Form[k] = v
ctx.form[k] = v
}
}
}
// Extract session data
session.AdvanceFlash()
ctx.Session["id"] = session.ID
ctx.session["id"] = session.ID
if session.IsEmpty() {
ctx.Session["data"] = emptyMap
ctx.Session["flash"] = emptyMap
ctx.session["data"] = emptyMap
ctx.session["flash"] = emptyMap
} else {
ctx.Session["data"] = session.GetAll()
ctx.Session["flash"] = session.GetAllFlash()
ctx.session["data"] = session.GetAll()
ctx.session["flash"] = session.GetAllFlash()
}
// Add environment vars
if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil {
for k, v := range envMgr.GetAll() {
ctx.Env[k] = v
ctx.env[k] = v
}
}
// Populate Values with all context data
ctx.Values["method"] = string(httpCtx.Method())
ctx.Values["path"] = string(httpCtx.Path())
ctx.Values["host"] = string(httpCtx.Host())
ctx.Values["headers"] = ctx.headers
ctx.Values["cookies"] = ctx.cookies
ctx.Values["query"] = ctx.query
ctx.Values["params"] = ctx.params
ctx.Values["form"] = ctx.form
ctx.Values["session"] = ctx.session
ctx.Values["env"] = ctx.env
return ctx
}
// Clear out all the request data from the context and give it back to the pool. Keeps the contexts and inner maps
// allocated to prevent GC churn.
// Clear out all the request data from the context and give it back to the pool.
func (c *HTTPContext) Release() {
for k := range c.Headers {
delete(c.Headers, k)
}
for k := range c.Cookies {
delete(c.Cookies, k)
}
for k := range c.Query {
delete(c.Query, k)
}
for k := range c.Params {
delete(c.Params, k)
}
for k := range c.Form {
delete(c.Form, k)
}
for k := range c.Session {
delete(c.Session, k)
}
for k := range c.Env {
delete(c.Env, k)
}
for k := range c.Values {
delete(c.Values, k)
}
c.Method = nil
c.Path = nil
c.Host = nil
clear(c.Values)
clear(c.headers)
clear(c.cookies)
clear(c.query)
clear(c.params)
clear(c.form)
clear(c.session)
clear(c.env)
httpContextPool.Put(c)
}
// Add a value to the extras map
// Add a value to the extras section
func (c *HTTPContext) Set(key string, value any) {
c.Values[key] = value
}
// Get a value from the extras map
// Get a value from the context
func (c *HTTPContext) Get(key string) any {
return c.Values[key]
}
// Returns a representation of the context ready for Lua
// Returns the Values map directly - zero overhead
func (c *HTTPContext) ToMap() map[string]any {
return map[string]any{
"method": string(c.Method),
"path": string(c.Path),
"host": string(c.Host),
"headers": c.Headers,
"cookies": c.Cookies,
"query": c.Query,
"params": c.Params,
"form": c.Form,
"session": c.Session,
"env": c.Env,
"values": c.Values,
}
return c.Values
}