revert to string router
This commit is contained in:
parent
d44a9b5b28
commit
2c731b9cbf
@ -125,7 +125,7 @@ func requestMux(ctx *fasthttp.RequestCtx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// See if the requested route even exists
|
// See if the requested route even exists
|
||||||
bytecode, params, found := rtr.Lookup(method, path)
|
bytecode, params, found := rtr.Lookup(string(method), string(path))
|
||||||
if !found {
|
if !found {
|
||||||
http.Send404(ctx)
|
http.Send404(ctx)
|
||||||
logRequest(ctx, method, path, start)
|
logRequest(ctx, method, path, start)
|
||||||
|
116
router/router.go
116
router/router.go
@ -1,7 +1,6 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -12,18 +11,9 @@ import (
|
|||||||
"github.com/VictoriaMetrics/fastcache"
|
"github.com/VictoriaMetrics/fastcache"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
slash = []byte("/")
|
|
||||||
get = []byte("GET")
|
|
||||||
post = []byte("POST")
|
|
||||||
put = []byte("PUT")
|
|
||||||
patch = []byte("PATCH")
|
|
||||||
delete = []byte("DELETE")
|
|
||||||
)
|
|
||||||
|
|
||||||
// node represents a node in the radix trie
|
// node represents a node in the radix trie
|
||||||
type node struct {
|
type node struct {
|
||||||
segment []byte
|
segment string
|
||||||
bytecode []byte
|
bytecode []byte
|
||||||
scriptPath string
|
scriptPath string
|
||||||
children []*node
|
children []*node
|
||||||
@ -40,7 +30,7 @@ type Router struct {
|
|||||||
compileState *luajit.State
|
compileState *luajit.State
|
||||||
compileMu sync.Mutex
|
compileMu sync.Mutex
|
||||||
paramsBuffer []string
|
paramsBuffer []string
|
||||||
middlewareFiles map[string][]string
|
middlewareFiles map[string][]string // filesystem path -> middleware file paths
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params holds URL parameters
|
// Params holds URL parameters
|
||||||
@ -81,7 +71,7 @@ func New(routesDir string) (*Router, error) {
|
|||||||
put: &node{},
|
put: &node{},
|
||||||
patch: &node{},
|
patch: &node{},
|
||||||
delete: &node{},
|
delete: &node{},
|
||||||
bytecodeCache: fastcache.New(32 * 1024 * 1024),
|
bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB
|
||||||
compileState: compileState,
|
compileState: compileState,
|
||||||
paramsBuffer: make([]string, 64),
|
paramsBuffer: make([]string, 64),
|
||||||
middlewareFiles: make(map[string][]string),
|
middlewareFiles: make(map[string][]string),
|
||||||
@ -91,17 +81,17 @@ func New(routesDir string) (*Router, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// methodNode returns the root node for a method
|
// methodNode returns the root node for a method
|
||||||
func (r *Router) methodNode(method []byte) *node {
|
func (r *Router) methodNode(method string) *node {
|
||||||
switch {
|
switch method {
|
||||||
case bytes.Equal(method, get):
|
case "GET":
|
||||||
return r.get
|
return r.get
|
||||||
case bytes.Equal(method, post):
|
case "POST":
|
||||||
return r.post
|
return r.post
|
||||||
case bytes.Equal(method, put):
|
case "PUT":
|
||||||
return r.put
|
return r.put
|
||||||
case bytes.Equal(method, patch):
|
case "PATCH":
|
||||||
return r.patch
|
return r.patch
|
||||||
case bytes.Equal(method, delete):
|
case "DELETE":
|
||||||
return r.delete
|
return r.delete
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
@ -118,8 +108,7 @@ func (r *Router) buildRoutes() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := strings.TrimSuffix(info.Name(), ".lua")
|
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
|
||||||
if fileName == "middleware" {
|
|
||||||
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -132,6 +121,7 @@ func (r *Router) buildRoutes() error {
|
|||||||
|
|
||||||
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
|
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -146,36 +136,39 @@ func (r *Router) buildRoutes() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fileName := strings.TrimSuffix(info.Name(), ".lua")
|
fileName := strings.TrimSuffix(info.Name(), ".lua")
|
||||||
|
|
||||||
|
// Skip middleware files
|
||||||
if fileName == "middleware" {
|
if fileName == "middleware" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get relative path from routes directory
|
||||||
relPath, err := filepath.Rel(r.routesDir, path)
|
relPath, err := filepath.Rel(r.routesDir, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get filesystem path (includes groups)
|
||||||
fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/")
|
fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/")
|
||||||
if fsPath == "/." {
|
if fsPath == "/." {
|
||||||
fsPath = "/"
|
fsPath = "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get URL path (excludes groups)
|
||||||
urlPath := r.parseURLPath(fsPath)
|
urlPath := r.parseURLPath(fsPath)
|
||||||
urlPathBytes := []byte(urlPath)
|
|
||||||
|
|
||||||
// Handle method files
|
// Handle method files (get.lua, post.lua, etc.)
|
||||||
methodBytes := []byte(strings.ToUpper(fileName))
|
method := strings.ToUpper(fileName)
|
||||||
root := r.methodNode(methodBytes)
|
root := r.methodNode(method)
|
||||||
if root != nil {
|
if root != nil {
|
||||||
return r.addRoute(root, urlPathBytes, fsPath, path)
|
return r.addRoute(root, urlPath, fsPath, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle index files
|
// Handle index files - register for all methods
|
||||||
if fileName == "index" {
|
if fileName == "index" {
|
||||||
methods := [][]byte{get, post, put, patch, delete}
|
for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} {
|
||||||
for _, method := range methods {
|
|
||||||
if root := r.methodNode(method); root != nil {
|
if root := r.methodNode(method); root != nil {
|
||||||
if err := r.addRoute(root, urlPathBytes, fsPath, path); err != nil {
|
if err := r.addRoute(root, urlPath, fsPath, path); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -183,13 +176,12 @@ func (r *Router) buildRoutes() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle named route files
|
// Handle named route files - register as GET by default
|
||||||
var namedPath []byte
|
namedPath := urlPath
|
||||||
if urlPath == "/" {
|
if urlPath == "/" {
|
||||||
namedPath = append(slash, fileName...)
|
namedPath = "/" + fileName
|
||||||
} else {
|
} else {
|
||||||
namedPath = append(urlPathBytes, '/')
|
namedPath = urlPath + "/" + fileName
|
||||||
namedPath = append(namedPath, fileName...)
|
|
||||||
}
|
}
|
||||||
return r.addRoute(r.get, namedPath, fsPath, path)
|
return r.addRoute(r.get, namedPath, fsPath, path)
|
||||||
})
|
})
|
||||||
@ -204,6 +196,7 @@ func (r *Router) parseURLPath(fsPath string) string {
|
|||||||
if segment == "" {
|
if segment == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Skip group segments (enclosed in parentheses)
|
||||||
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
|
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -225,10 +218,12 @@ func (r *Router) getMiddlewareChain(fsPath string) []string {
|
|||||||
pathParts = []string{}
|
pathParts = []string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add root middleware
|
||||||
if mw, exists := r.middlewareFiles["/"]; exists {
|
if mw, exists := r.middlewareFiles["/"]; exists {
|
||||||
chain = append(chain, mw...)
|
chain = append(chain, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add middleware from each path level (including groups)
|
||||||
currentPath := ""
|
currentPath := ""
|
||||||
for _, part := range pathParts {
|
for _, part := range pathParts {
|
||||||
currentPath += "/" + part
|
currentPath += "/" + part
|
||||||
@ -244,6 +239,7 @@ func (r *Router) getMiddlewareChain(fsPath string) []string {
|
|||||||
func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) {
|
func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) {
|
||||||
var combined strings.Builder
|
var combined strings.Builder
|
||||||
|
|
||||||
|
// Add middleware in order
|
||||||
middlewareChain := r.getMiddlewareChain(fsPath)
|
middlewareChain := r.getMiddlewareChain(fsPath)
|
||||||
for _, mwPath := range middlewareChain {
|
for _, mwPath := range middlewareChain {
|
||||||
content, err := os.ReadFile(mwPath)
|
content, err := os.ReadFile(mwPath)
|
||||||
@ -257,6 +253,7 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error)
|
|||||||
combined.WriteString("\n")
|
combined.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add main handler
|
||||||
content, err := os.ReadFile(scriptPath)
|
content, err := os.ReadFile(scriptPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -270,12 +267,14 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addRoute adds a new route to the trie with bytecode compilation
|
// addRoute adds a new route to the trie with bytecode compilation
|
||||||
func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) error {
|
func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error {
|
||||||
|
// Build combined source with middleware
|
||||||
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
|
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compile bytecode
|
||||||
r.compileMu.Lock()
|
r.compileMu.Lock()
|
||||||
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
|
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
|
||||||
r.compileMu.Unlock()
|
r.compileMu.Unlock()
|
||||||
@ -284,10 +283,11 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache bytecode
|
||||||
cacheKey := hashString(scriptPath)
|
cacheKey := hashString(scriptPath)
|
||||||
r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode)
|
r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode)
|
||||||
|
|
||||||
if len(urlPath) == 1 && urlPath[0] == '/' {
|
if urlPath == "/" {
|
||||||
root.bytecode = bytecode
|
root.bytecode = bytecode
|
||||||
root.scriptPath = scriptPath
|
root.scriptPath = scriptPath
|
||||||
return nil
|
return nil
|
||||||
@ -298,8 +298,8 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
|
|||||||
paramCount := uint8(0)
|
paramCount := uint8(0)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
seg, newPos, more := readSegmentBytes(urlPath, pos)
|
seg, newPos, more := readSegment(urlPath, pos)
|
||||||
if len(seg) == 0 {
|
if seg == "" {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,9 +314,10 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
|
|||||||
paramCount++
|
paramCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find or create child
|
||||||
var child *node
|
var child *node
|
||||||
for _, c := range current.children {
|
for _, c := range current.children {
|
||||||
if bytes.Equal(c.segment, seg) {
|
if c.segment == seg {
|
||||||
child = c
|
child = c
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -324,7 +325,7 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
|
|||||||
|
|
||||||
if child == nil {
|
if child == nil {
|
||||||
child = &node{
|
child = &node{
|
||||||
segment: append([]byte(nil), seg...),
|
segment: seg,
|
||||||
isDynamic: isDyn,
|
isDynamic: isDyn,
|
||||||
isWildcard: isWC,
|
isWildcard: isWC,
|
||||||
}
|
}
|
||||||
@ -344,16 +345,16 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// readSegmentBytes extracts the next path segment from byte slice
|
// readSegment extracts the next path segment
|
||||||
func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore bool) {
|
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
|
||||||
if start >= len(path) {
|
if start >= len(path) {
|
||||||
return nil, start, false
|
return "", start, false
|
||||||
}
|
}
|
||||||
if path[start] == '/' {
|
if path[start] == '/' {
|
||||||
start++
|
start++
|
||||||
}
|
}
|
||||||
if start >= len(path) {
|
if start >= len(path) {
|
||||||
return nil, start, false
|
return "", start, false
|
||||||
}
|
}
|
||||||
end = start
|
end = start
|
||||||
for end < len(path) && path[end] != '/' {
|
for end < len(path) && path[end] != '/' {
|
||||||
@ -363,19 +364,20 @@ func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Lookup finds bytecode and parameters for a method and path
|
// Lookup finds bytecode and parameters for a method and path
|
||||||
func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) {
|
func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) {
|
||||||
root := r.methodNode(method)
|
root := r.methodNode(method)
|
||||||
if root == nil {
|
if root == nil {
|
||||||
return nil, nil, false
|
return nil, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(path) == 1 && path[0] == '/' {
|
if path == "/" {
|
||||||
if root.bytecode != nil {
|
if root.bytecode != nil {
|
||||||
return root.bytecode, &Params{}, true
|
return root.bytecode, &Params{}, true
|
||||||
}
|
}
|
||||||
return nil, nil, false
|
return nil, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prepare params buffer
|
||||||
buffer := r.paramsBuffer
|
buffer := r.paramsBuffer
|
||||||
if cap(buffer) < int(root.maxParams) {
|
if cap(buffer) < int(root.maxParams) {
|
||||||
buffer = make([]string, root.maxParams)
|
buffer = make([]string, root.maxParams)
|
||||||
@ -398,7 +400,7 @@ func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// match traverses the trie to find bytecode
|
// match traverses the trie to find bytecode
|
||||||
func (r *Router) match(current *node, path []byte, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
|
func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
|
||||||
paramCount := 0
|
paramCount := 0
|
||||||
|
|
||||||
// Check wildcard first
|
// Check wildcard first
|
||||||
@ -408,23 +410,22 @@ func (r *Router) match(current *node, path []byte, start int, params *[]string,
|
|||||||
if len(rem) > 0 && rem[0] == '/' {
|
if len(rem) > 0 && rem[0] == '/' {
|
||||||
rem = rem[1:]
|
rem = rem[1:]
|
||||||
}
|
}
|
||||||
*params = append(*params, string(rem))
|
*params = append(*params, rem)
|
||||||
paramName := string(c.segment[1:]) // Remove *
|
*keys = append(*keys, strings.TrimPrefix(c.segment, "*"))
|
||||||
*keys = append(*keys, paramName)
|
|
||||||
return c.bytecode, 1, c.bytecode != nil
|
return c.bytecode, 1, c.bytecode != nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
seg, pos, more := readSegmentBytes(path, start)
|
seg, pos, more := readSegment(path, start)
|
||||||
if len(seg) == 0 {
|
if seg == "" {
|
||||||
return current.bytecode, 0, current.bytecode != nil
|
return current.bytecode, 0, current.bytecode != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range current.children {
|
for _, c := range current.children {
|
||||||
if bytes.Equal(c.segment, seg) || c.isDynamic {
|
if c.segment == seg || c.isDynamic {
|
||||||
if c.isDynamic {
|
if c.isDynamic {
|
||||||
*params = append(*params, string(seg))
|
*params = append(*params, seg)
|
||||||
paramName := string(c.segment[1 : len(c.segment)-1]) // Remove [ ]
|
paramName := c.segment[1 : len(c.segment)-1] // Remove [ ]
|
||||||
*keys = append(*keys, paramName)
|
*keys = append(*keys, paramName)
|
||||||
paramCount++
|
paramCount++
|
||||||
}
|
}
|
||||||
@ -477,6 +478,7 @@ func (r *Router) Close() {
|
|||||||
r.compileMu.Unlock()
|
r.compileMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper functions from cache.go
|
||||||
func hashString(s string) uint64 {
|
func hashString(s string) uint64 {
|
||||||
h := uint64(5381)
|
h := uint64(5381)
|
||||||
for i := 0; i < len(s); i++ {
|
for i := 0; i < len(s); i++ {
|
||||||
|
318
router/router_test.go
Normal file
318
router/router_test.go
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestRoutes(t testing.TB) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create test route files
|
||||||
|
routes := map[string]string{
|
||||||
|
"index.lua": `return "home"`,
|
||||||
|
"about.lua": `return "about"`,
|
||||||
|
"api/users.lua": `return "users"`,
|
||||||
|
"api/users/get.lua": `return "get_users"`,
|
||||||
|
"api/users/post.lua": `return "create_user"`,
|
||||||
|
"api/users/[id].lua": `return "user_" .. id`,
|
||||||
|
"api/posts/[slug]/comments.lua": `return "comments_" .. slug`,
|
||||||
|
"files/*path.lua": `return "file_" .. path`,
|
||||||
|
"middleware.lua": `-- root middleware`,
|
||||||
|
"api/middleware.lua": `-- api middleware`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for path, content := range routes {
|
||||||
|
fullPath := filepath.Join(tempDir, path)
|
||||||
|
dir := filepath.Dir(fullPath)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tempDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterBasicFunctionality(t *testing.T) {
|
||||||
|
routesDir := setupTestRoutes(t)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
expected bool
|
||||||
|
params map[string]string
|
||||||
|
}{
|
||||||
|
{"GET", "/", true, nil},
|
||||||
|
{"GET", "/about", true, nil},
|
||||||
|
{"GET", "/api/users", true, nil},
|
||||||
|
{"GET", "/api/users", true, nil},
|
||||||
|
{"POST", "/api/users", true, nil},
|
||||||
|
{"GET", "/api/users/123", true, map[string]string{"id": "123"}},
|
||||||
|
{"GET", "/api/posts/hello-world/comments", true, map[string]string{"slug": "hello-world"}},
|
||||||
|
{"GET", "/files/docs/readme.txt", true, map[string]string{"path": "docs/readme.txt"}},
|
||||||
|
{"GET", "/nonexistent", false, nil},
|
||||||
|
{"DELETE", "/api/users", false, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.method+"_"+tt.path, func(t *testing.T) {
|
||||||
|
bytecode, params, found := router.Lookup(tt.method, tt.path)
|
||||||
|
|
||||||
|
if found != tt.expected {
|
||||||
|
t.Errorf("expected found=%v, got %v", tt.expected, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expected {
|
||||||
|
if bytecode == nil {
|
||||||
|
t.Error("expected bytecode, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.params != nil {
|
||||||
|
for key, expectedValue := range tt.params {
|
||||||
|
if actualValue := params.Get(key); actualValue != expectedValue {
|
||||||
|
t.Errorf("param %s: expected %s, got %s", key, expectedValue, actualValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterParamsStruct(t *testing.T) {
|
||||||
|
params := &Params{
|
||||||
|
Keys: []string{"id", "slug"},
|
||||||
|
Values: []string{"123", "hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.Get("id") != "123" {
|
||||||
|
t.Errorf("expected '123', got '%s'", params.Get("id"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.Get("slug") != "hello" {
|
||||||
|
t.Errorf("expected 'hello', got '%s'", params.Get("slug"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.Get("missing") != "" {
|
||||||
|
t.Errorf("expected empty string for missing param, got '%s'", params.Get("missing"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterMethodNodes(t *testing.T) {
|
||||||
|
routesDir := setupTestRoutes(t)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
// Test that different methods work independently
|
||||||
|
_, _, foundGet := router.Lookup("GET", "/api/users")
|
||||||
|
_, _, foundPost := router.Lookup("POST", "/api/users")
|
||||||
|
_, _, foundPut := router.Lookup("PUT", "/api/users")
|
||||||
|
|
||||||
|
if !foundGet {
|
||||||
|
t.Error("GET /api/users should be found")
|
||||||
|
}
|
||||||
|
if !foundPost {
|
||||||
|
t.Error("POST /api/users should be found")
|
||||||
|
}
|
||||||
|
if foundPut {
|
||||||
|
t.Error("PUT /api/users should not be found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterWildcardValidation(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create invalid wildcard route (not at end)
|
||||||
|
invalidPath := filepath.Join(tempDir, "bad/*path/more.lua")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(invalidPath), 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(invalidPath, []byte(`return "bad"`), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := New(tempDir)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for wildcard not at end")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLookupStatic(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
method := "GET"
|
||||||
|
path := "/api/users"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = router.Lookup(method, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLookupDynamic(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
method := "GET"
|
||||||
|
path := "/api/users/12345"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = router.Lookup(method, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLookupWildcard(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
method := "GET"
|
||||||
|
path := "/files/docs/deep/nested/file.txt"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = router.Lookup(method, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLookupComplex(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
method := "GET"
|
||||||
|
path := "/api/posts/my-blog-post-title/comments"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = router.Lookup(method, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLookupNotFound(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
method := "GET"
|
||||||
|
path := "/this/path/does/not/exist"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = router.Lookup(method, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLookupMixed(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
paths := []string{
|
||||||
|
"/",
|
||||||
|
"/about",
|
||||||
|
"/api/users",
|
||||||
|
"/api/users/123",
|
||||||
|
"/api/posts/hello/comments",
|
||||||
|
"/files/document.pdf",
|
||||||
|
"/nonexistent",
|
||||||
|
}
|
||||||
|
method := "GET"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
path := paths[i%len(paths)]
|
||||||
|
_, _, _ = router.Lookup(method, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comparison benchmarks for string vs byte slice performance
|
||||||
|
func BenchmarkLookupStringConversion(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
methodStr := "GET"
|
||||||
|
pathStr := "/api/users/12345"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Direct string usage
|
||||||
|
_, _, _ = router.Lookup(methodStr, pathStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLookupPreallocated(b *testing.B) {
|
||||||
|
routesDir := setupTestRoutes(b)
|
||||||
|
router, err := New(routesDir)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer router.Close()
|
||||||
|
|
||||||
|
// Pre-allocated strings (optimal case)
|
||||||
|
method := "GET"
|
||||||
|
path := "/api/users/12345"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = router.Lookup(method, path)
|
||||||
|
}
|
||||||
|
}
|
@ -12,31 +12,30 @@ import (
|
|||||||
|
|
||||||
// A prebuilt, ready-to-go context for HTTP requests to the runner.
|
// A prebuilt, ready-to-go context for HTTP requests to the runner.
|
||||||
type HTTPContext struct {
|
type HTTPContext struct {
|
||||||
Method []byte
|
Values map[string]any // Contains all context data for Lua
|
||||||
Path []byte
|
|
||||||
Host []byte
|
// Separate maps for efficient access during context building
|
||||||
Headers map[string]any
|
headers map[string]string
|
||||||
Cookies map[string]string
|
cookies map[string]string
|
||||||
Query map[string]string
|
query map[string]string
|
||||||
Params map[string]any
|
params map[string]string
|
||||||
Form map[string]any
|
form map[string]any
|
||||||
Session map[string]any
|
session map[string]any
|
||||||
Env map[string]any
|
env map[string]any
|
||||||
Values map[string]any // Extra context vars just in case
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTP context pool to reduce allocations
|
// HTTP context pool to reduce allocations
|
||||||
var httpContextPool = sync.Pool{
|
var httpContextPool = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
return &HTTPContext{
|
return &HTTPContext{
|
||||||
Headers: make(map[string]any, 16),
|
|
||||||
Cookies: make(map[string]string, 8),
|
|
||||||
Query: make(map[string]string, 8),
|
|
||||||
Params: make(map[string]any, 4),
|
|
||||||
Form: make(map[string]any, 8),
|
|
||||||
Session: make(map[string]any, 4),
|
|
||||||
Env: make(map[string]any, 16),
|
|
||||||
Values: make(map[string]any, 32),
|
Values: make(map[string]any, 32),
|
||||||
|
headers: make(map[string]string, 16),
|
||||||
|
cookies: make(map[string]string, 8),
|
||||||
|
query: make(map[string]string, 8),
|
||||||
|
params: make(map[string]string, 4),
|
||||||
|
form: make(map[string]any, 8),
|
||||||
|
session: make(map[string]any, 4),
|
||||||
|
env: make(map[string]any, 16),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -45,30 +44,25 @@ var httpContextPool = sync.Pool{
|
|||||||
func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *HTTPContext {
|
func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *HTTPContext {
|
||||||
ctx := httpContextPool.Get().(*HTTPContext)
|
ctx := httpContextPool.Get().(*HTTPContext)
|
||||||
|
|
||||||
// Extract basic HTTP info
|
|
||||||
ctx.Method = httpCtx.Method()
|
|
||||||
ctx.Path = httpCtx.Path()
|
|
||||||
ctx.Host = httpCtx.Host()
|
|
||||||
|
|
||||||
// Extract headers
|
// Extract headers
|
||||||
httpCtx.Request.Header.VisitAll(func(key, value []byte) {
|
httpCtx.Request.Header.VisitAll(func(key, value []byte) {
|
||||||
ctx.Headers[string(key)] = string(value)
|
ctx.headers[string(key)] = string(value)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Extract cookies
|
// Extract cookies
|
||||||
httpCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
httpCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||||
ctx.Cookies[string(key)] = string(value)
|
ctx.cookies[string(key)] = string(value)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Extract query params
|
// Extract query params
|
||||||
httpCtx.QueryArgs().VisitAll(func(key, value []byte) {
|
httpCtx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||||
ctx.Query[string(key)] = string(value)
|
ctx.query[string(key)] = string(value)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Extract route parameters
|
// Extract route parameters
|
||||||
if params != nil {
|
if params != nil {
|
||||||
for i := 0; i < min(len(params.Keys), len(params.Values)); i++ {
|
for i := range min(len(params.Keys), len(params.Values)) {
|
||||||
ctx.Params[params.Keys[i]] = params.Values[i]
|
ctx.params[params.Keys[i]] = params.Values[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,90 +70,69 @@ func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session
|
|||||||
if httpCtx.IsPost() || httpCtx.IsPut() || httpCtx.IsPatch() {
|
if httpCtx.IsPost() || httpCtx.IsPut() || httpCtx.IsPatch() {
|
||||||
if form, err := utils.ParseForm(httpCtx); err == nil {
|
if form, err := utils.ParseForm(httpCtx); err == nil {
|
||||||
for k, v := range form {
|
for k, v := range form {
|
||||||
ctx.Form[k] = v
|
ctx.form[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract session data
|
// Extract session data
|
||||||
session.AdvanceFlash()
|
session.AdvanceFlash()
|
||||||
ctx.Session["id"] = session.ID
|
ctx.session["id"] = session.ID
|
||||||
if session.IsEmpty() {
|
if session.IsEmpty() {
|
||||||
ctx.Session["data"] = emptyMap
|
ctx.session["data"] = emptyMap
|
||||||
ctx.Session["flash"] = emptyMap
|
ctx.session["flash"] = emptyMap
|
||||||
} else {
|
} else {
|
||||||
ctx.Session["data"] = session.GetAll()
|
ctx.session["data"] = session.GetAll()
|
||||||
ctx.Session["flash"] = session.GetAllFlash()
|
ctx.session["flash"] = session.GetAllFlash()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add environment vars
|
// Add environment vars
|
||||||
if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil {
|
if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil {
|
||||||
for k, v := range envMgr.GetAll() {
|
for k, v := range envMgr.GetAll() {
|
||||||
ctx.Env[k] = v
|
ctx.env[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Populate Values with all context data
|
||||||
|
ctx.Values["method"] = string(httpCtx.Method())
|
||||||
|
ctx.Values["path"] = string(httpCtx.Path())
|
||||||
|
ctx.Values["host"] = string(httpCtx.Host())
|
||||||
|
ctx.Values["headers"] = ctx.headers
|
||||||
|
ctx.Values["cookies"] = ctx.cookies
|
||||||
|
ctx.Values["query"] = ctx.query
|
||||||
|
ctx.Values["params"] = ctx.params
|
||||||
|
ctx.Values["form"] = ctx.form
|
||||||
|
ctx.Values["session"] = ctx.session
|
||||||
|
ctx.Values["env"] = ctx.env
|
||||||
|
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear out all the request data from the context and give it back to the pool. Keeps the contexts and inner maps
|
// Clear out all the request data from the context and give it back to the pool.
|
||||||
// allocated to prevent GC churn.
|
|
||||||
func (c *HTTPContext) Release() {
|
func (c *HTTPContext) Release() {
|
||||||
for k := range c.Headers {
|
clear(c.Values)
|
||||||
delete(c.Headers, k)
|
clear(c.headers)
|
||||||
}
|
clear(c.cookies)
|
||||||
for k := range c.Cookies {
|
clear(c.query)
|
||||||
delete(c.Cookies, k)
|
clear(c.params)
|
||||||
}
|
clear(c.form)
|
||||||
for k := range c.Query {
|
clear(c.session)
|
||||||
delete(c.Query, k)
|
clear(c.env)
|
||||||
}
|
|
||||||
for k := range c.Params {
|
|
||||||
delete(c.Params, k)
|
|
||||||
}
|
|
||||||
for k := range c.Form {
|
|
||||||
delete(c.Form, k)
|
|
||||||
}
|
|
||||||
for k := range c.Session {
|
|
||||||
delete(c.Session, k)
|
|
||||||
}
|
|
||||||
for k := range c.Env {
|
|
||||||
delete(c.Env, k)
|
|
||||||
}
|
|
||||||
for k := range c.Values {
|
|
||||||
delete(c.Values, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Method = nil
|
|
||||||
c.Path = nil
|
|
||||||
c.Host = nil
|
|
||||||
|
|
||||||
httpContextPool.Put(c)
|
httpContextPool.Put(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a value to the extras map
|
// Add a value to the extras section
|
||||||
func (c *HTTPContext) Set(key string, value any) {
|
func (c *HTTPContext) Set(key string, value any) {
|
||||||
c.Values[key] = value
|
c.Values[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a value from the extras map
|
// Get a value from the context
|
||||||
func (c *HTTPContext) Get(key string) any {
|
func (c *HTTPContext) Get(key string) any {
|
||||||
return c.Values[key]
|
return c.Values[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a representation of the context ready for Lua
|
// Returns the Values map directly - zero overhead
|
||||||
func (c *HTTPContext) ToMap() map[string]any {
|
func (c *HTTPContext) ToMap() map[string]any {
|
||||||
return map[string]any{
|
return c.Values
|
||||||
"method": string(c.Method),
|
|
||||||
"path": string(c.Path),
|
|
||||||
"host": string(c.Host),
|
|
||||||
"headers": c.Headers,
|
|
||||||
"cookies": c.Cookies,
|
|
||||||
"query": c.Query,
|
|
||||||
"params": c.Params,
|
|
||||||
"form": c.Form,
|
|
||||||
"session": c.Session,
|
|
||||||
"env": c.Env,
|
|
||||||
"values": c.Values,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user