diff --git a/moonshark.go b/moonshark.go index b739613..c8f8b13 100644 --- a/moonshark.go +++ b/moonshark.go @@ -125,7 +125,7 @@ func requestMux(ctx *fasthttp.RequestCtx) { } // See if the requested route even exists - bytecode, params, found := rtr.Lookup(method, path) + bytecode, params, found := rtr.Lookup(string(method), string(path)) if !found { http.Send404(ctx) logRequest(ctx, method, path, start) diff --git a/router/router.go b/router/router.go index 2a923ff..590c9d1 100644 --- a/router/router.go +++ b/router/router.go @@ -1,7 +1,6 @@ package router import ( - "bytes" "errors" "os" "path/filepath" @@ -12,18 +11,9 @@ import ( "github.com/VictoriaMetrics/fastcache" ) -var ( - slash = []byte("/") - get = []byte("GET") - post = []byte("POST") - put = []byte("PUT") - patch = []byte("PATCH") - delete = []byte("DELETE") -) - // node represents a node in the radix trie type node struct { - segment []byte + segment string bytecode []byte scriptPath string children []*node @@ -40,7 +30,7 @@ type Router struct { compileState *luajit.State compileMu sync.Mutex paramsBuffer []string - middlewareFiles map[string][]string + middlewareFiles map[string][]string // filesystem path -> middleware file paths } // Params holds URL parameters @@ -81,7 +71,7 @@ func New(routesDir string) (*Router, error) { put: &node{}, patch: &node{}, delete: &node{}, - bytecodeCache: fastcache.New(32 * 1024 * 1024), + bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB compileState: compileState, paramsBuffer: make([]string, 64), middlewareFiles: make(map[string][]string), @@ -91,17 +81,17 @@ func New(routesDir string) (*Router, error) { } // methodNode returns the root node for a method -func (r *Router) methodNode(method []byte) *node { - switch { - case bytes.Equal(method, get): +func (r *Router) methodNode(method string) *node { + switch method { + case "GET": return r.get - case bytes.Equal(method, post): + case "POST": return r.post - case bytes.Equal(method, put): + case "PUT": return r.put - case bytes.Equal(method, patch): + case "PATCH": return r.patch - case bytes.Equal(method, delete): + case "DELETE": return r.delete default: return nil @@ -118,8 +108,7 @@ func (r *Router) buildRoutes() error { return err } - fileName := strings.TrimSuffix(info.Name(), ".lua") - if fileName == "middleware" { + if strings.TrimSuffix(info.Name(), ".lua") == "middleware" { relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path)) if err != nil { return err @@ -132,6 +121,7 @@ func (r *Router) buildRoutes() error { r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path) } + return nil }) @@ -146,36 +136,39 @@ func (r *Router) buildRoutes() error { } fileName := strings.TrimSuffix(info.Name(), ".lua") + + // Skip middleware files if fileName == "middleware" { return nil } + // Get relative path from routes directory relPath, err := filepath.Rel(r.routesDir, path) if err != nil { return err } + // Get filesystem path (includes groups) fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/") if fsPath == "/." { fsPath = "/" } + // Get URL path (excludes groups) urlPath := r.parseURLPath(fsPath) - urlPathBytes := []byte(urlPath) - // Handle method files - methodBytes := []byte(strings.ToUpper(fileName)) - root := r.methodNode(methodBytes) + // Handle method files (get.lua, post.lua, etc.) + method := strings.ToUpper(fileName) + root := r.methodNode(method) if root != nil { - return r.addRoute(root, urlPathBytes, fsPath, path) + return r.addRoute(root, urlPath, fsPath, path) } - // Handle index files + // Handle index files - register for all methods if fileName == "index" { - methods := [][]byte{get, post, put, patch, delete} - for _, method := range methods { + for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} { if root := r.methodNode(method); root != nil { - if err := r.addRoute(root, urlPathBytes, fsPath, path); err != nil { + if err := r.addRoute(root, urlPath, fsPath, path); err != nil { return err } } @@ -183,13 +176,12 @@ func (r *Router) buildRoutes() error { return nil } - // Handle named route files - var namedPath []byte + // Handle named route files - register as GET by default + namedPath := urlPath if urlPath == "/" { - namedPath = append(slash, fileName...) + namedPath = "/" + fileName } else { - namedPath = append(urlPathBytes, '/') - namedPath = append(namedPath, fileName...) + namedPath = urlPath + "/" + fileName } return r.addRoute(r.get, namedPath, fsPath, path) }) @@ -204,6 +196,7 @@ func (r *Router) parseURLPath(fsPath string) string { if segment == "" { continue } + // Skip group segments (enclosed in parentheses) if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") { continue } @@ -225,10 +218,12 @@ func (r *Router) getMiddlewareChain(fsPath string) []string { pathParts = []string{} } + // Add root middleware if mw, exists := r.middlewareFiles["/"]; exists { chain = append(chain, mw...) } + // Add middleware from each path level (including groups) currentPath := "" for _, part := range pathParts { currentPath += "/" + part @@ -244,6 +239,7 @@ func (r *Router) getMiddlewareChain(fsPath string) []string { func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) { var combined strings.Builder + // Add middleware in order middlewareChain := r.getMiddlewareChain(fsPath) for _, mwPath := range middlewareChain { content, err := os.ReadFile(mwPath) @@ -257,6 +253,7 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) combined.WriteString("\n") } + // Add main handler content, err := os.ReadFile(scriptPath) if err != nil { return "", err @@ -270,12 +267,14 @@ func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) } // addRoute adds a new route to the trie with bytecode compilation -func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) error { +func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error { + // Build combined source with middleware combinedSource, err := r.buildCombinedSource(fsPath, scriptPath) if err != nil { return err } + // Compile bytecode r.compileMu.Lock() bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath) r.compileMu.Unlock() @@ -284,10 +283,11 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) return err } + // Cache bytecode cacheKey := hashString(scriptPath) r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode) - if len(urlPath) == 1 && urlPath[0] == '/' { + if urlPath == "/" { root.bytecode = bytecode root.scriptPath = scriptPath return nil @@ -298,8 +298,8 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) paramCount := uint8(0) for { - seg, newPos, more := readSegmentBytes(urlPath, pos) - if len(seg) == 0 { + seg, newPos, more := readSegment(urlPath, pos) + if seg == "" { break } @@ -314,9 +314,10 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) paramCount++ } + // Find or create child var child *node for _, c := range current.children { - if bytes.Equal(c.segment, seg) { + if c.segment == seg { child = c break } @@ -324,7 +325,7 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) if child == nil { child = &node{ - segment: append([]byte(nil), seg...), + segment: seg, isDynamic: isDyn, isWildcard: isWC, } @@ -344,16 +345,16 @@ func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) return nil } -// readSegmentBytes extracts the next path segment from byte slice -func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore bool) { +// readSegment extracts the next path segment +func readSegment(path string, start int) (segment string, end int, hasMore bool) { if start >= len(path) { - return nil, start, false + return "", start, false } if path[start] == '/' { start++ } if start >= len(path) { - return nil, start, false + return "", start, false } end = start for end < len(path) && path[end] != '/' { @@ -363,19 +364,20 @@ func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore } // Lookup finds bytecode and parameters for a method and path -func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) { +func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) { root := r.methodNode(method) if root == nil { return nil, nil, false } - if len(path) == 1 && path[0] == '/' { + if path == "/" { if root.bytecode != nil { return root.bytecode, &Params{}, true } return nil, nil, false } + // Prepare params buffer buffer := r.paramsBuffer if cap(buffer) < int(root.maxParams) { buffer = make([]string, root.maxParams) @@ -398,7 +400,7 @@ func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) { } // match traverses the trie to find bytecode -func (r *Router) match(current *node, path []byte, start int, params *[]string, keys *[]string) ([]byte, int, bool) { +func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) { paramCount := 0 // Check wildcard first @@ -408,23 +410,22 @@ func (r *Router) match(current *node, path []byte, start int, params *[]string, if len(rem) > 0 && rem[0] == '/' { rem = rem[1:] } - *params = append(*params, string(rem)) - paramName := string(c.segment[1:]) // Remove * - *keys = append(*keys, paramName) + *params = append(*params, rem) + *keys = append(*keys, strings.TrimPrefix(c.segment, "*")) return c.bytecode, 1, c.bytecode != nil } } - seg, pos, more := readSegmentBytes(path, start) - if len(seg) == 0 { + seg, pos, more := readSegment(path, start) + if seg == "" { return current.bytecode, 0, current.bytecode != nil } for _, c := range current.children { - if bytes.Equal(c.segment, seg) || c.isDynamic { + if c.segment == seg || c.isDynamic { if c.isDynamic { - *params = append(*params, string(seg)) - paramName := string(c.segment[1 : len(c.segment)-1]) // Remove [ ] + *params = append(*params, seg) + paramName := c.segment[1 : len(c.segment)-1] // Remove [ ] *keys = append(*keys, paramName) paramCount++ } @@ -477,6 +478,7 @@ func (r *Router) Close() { r.compileMu.Unlock() } +// Helper functions from cache.go func hashString(s string) uint64 { h := uint64(5381) for i := 0; i < len(s); i++ { diff --git a/router/router_test.go b/router/router_test.go new file mode 100644 index 0000000..be44d27 --- /dev/null +++ b/router/router_test.go @@ -0,0 +1,318 @@ +package router + +import ( + "os" + "path/filepath" + "testing" +) + +func setupTestRoutes(t testing.TB) string { + t.Helper() + + tempDir := t.TempDir() + + // Create test route files + routes := map[string]string{ + "index.lua": `return "home"`, + "about.lua": `return "about"`, + "api/users.lua": `return "users"`, + "api/users/get.lua": `return "get_users"`, + "api/users/post.lua": `return "create_user"`, + "api/users/[id].lua": `return "user_" .. id`, + "api/posts/[slug]/comments.lua": `return "comments_" .. slug`, + "files/*path.lua": `return "file_" .. path`, + "middleware.lua": `-- root middleware`, + "api/middleware.lua": `-- api middleware`, + } + + for path, content := range routes { + fullPath := filepath.Join(tempDir, path) + dir := filepath.Dir(fullPath) + + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatal(err) + } + + if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { + t.Fatal(err) + } + } + + return tempDir +} + +func TestRouterBasicFunctionality(t *testing.T) { + routesDir := setupTestRoutes(t) + router, err := New(routesDir) + if err != nil { + t.Fatal(err) + } + defer router.Close() + + tests := []struct { + method string + path string + expected bool + params map[string]string + }{ + {"GET", "/", true, nil}, + {"GET", "/about", true, nil}, + {"GET", "/api/users", true, nil}, + {"GET", "/api/users", true, nil}, + {"POST", "/api/users", true, nil}, + {"GET", "/api/users/123", true, map[string]string{"id": "123"}}, + {"GET", "/api/posts/hello-world/comments", true, map[string]string{"slug": "hello-world"}}, + {"GET", "/files/docs/readme.txt", true, map[string]string{"path": "docs/readme.txt"}}, + {"GET", "/nonexistent", false, nil}, + {"DELETE", "/api/users", false, nil}, + } + + for _, tt := range tests { + t.Run(tt.method+"_"+tt.path, func(t *testing.T) { + bytecode, params, found := router.Lookup(tt.method, tt.path) + + if found != tt.expected { + t.Errorf("expected found=%v, got %v", tt.expected, found) + } + + if tt.expected { + if bytecode == nil { + t.Error("expected bytecode, got nil") + } + + if tt.params != nil { + for key, expectedValue := range tt.params { + if actualValue := params.Get(key); actualValue != expectedValue { + t.Errorf("param %s: expected %s, got %s", key, expectedValue, actualValue) + } + } + } + } + }) + } +} + +func TestRouterParamsStruct(t *testing.T) { + params := &Params{ + Keys: []string{"id", "slug"}, + Values: []string{"123", "hello"}, + } + + if params.Get("id") != "123" { + t.Errorf("expected '123', got '%s'", params.Get("id")) + } + + if params.Get("slug") != "hello" { + t.Errorf("expected 'hello', got '%s'", params.Get("slug")) + } + + if params.Get("missing") != "" { + t.Errorf("expected empty string for missing param, got '%s'", params.Get("missing")) + } +} + +func TestRouterMethodNodes(t *testing.T) { + routesDir := setupTestRoutes(t) + router, err := New(routesDir) + if err != nil { + t.Fatal(err) + } + defer router.Close() + + // Test that different methods work independently + _, _, foundGet := router.Lookup("GET", "/api/users") + _, _, foundPost := router.Lookup("POST", "/api/users") + _, _, foundPut := router.Lookup("PUT", "/api/users") + + if !foundGet { + t.Error("GET /api/users should be found") + } + if !foundPost { + t.Error("POST /api/users should be found") + } + if foundPut { + t.Error("PUT /api/users should not be found") + } +} + +func TestRouterWildcardValidation(t *testing.T) { + tempDir := t.TempDir() + + // Create invalid wildcard route (not at end) + invalidPath := filepath.Join(tempDir, "bad/*path/more.lua") + if err := os.MkdirAll(filepath.Dir(invalidPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(invalidPath, []byte(`return "bad"`), 0644); err != nil { + t.Fatal(err) + } + + _, err := New(tempDir) + if err == nil { + t.Error("expected error for wildcard not at end") + } +} + +func BenchmarkLookupStatic(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + method := "GET" + path := "/api/users" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = router.Lookup(method, path) + } +} + +func BenchmarkLookupDynamic(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + method := "GET" + path := "/api/users/12345" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = router.Lookup(method, path) + } +} + +func BenchmarkLookupWildcard(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + method := "GET" + path := "/files/docs/deep/nested/file.txt" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = router.Lookup(method, path) + } +} + +func BenchmarkLookupComplex(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + method := "GET" + path := "/api/posts/my-blog-post-title/comments" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = router.Lookup(method, path) + } +} + +func BenchmarkLookupNotFound(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + method := "GET" + path := "/this/path/does/not/exist" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = router.Lookup(method, path) + } +} + +func BenchmarkLookupMixed(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + paths := []string{ + "/", + "/about", + "/api/users", + "/api/users/123", + "/api/posts/hello/comments", + "/files/document.pdf", + "/nonexistent", + } + method := "GET" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + path := paths[i%len(paths)] + _, _, _ = router.Lookup(method, path) + } +} + +// Comparison benchmarks for string vs byte slice performance +func BenchmarkLookupStringConversion(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + methodStr := "GET" + pathStr := "/api/users/12345" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Direct string usage + _, _, _ = router.Lookup(methodStr, pathStr) + } +} + +func BenchmarkLookupPreallocated(b *testing.B) { + routesDir := setupTestRoutes(b) + router, err := New(routesDir) + if err != nil { + b.Fatal(err) + } + defer router.Close() + + // Pre-allocated strings (optimal case) + method := "GET" + path := "/api/users/12345" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = router.Lookup(method, path) + } +} diff --git a/runner/httpContext.go b/runner/httpContext.go index 254a51a..b7b1142 100644 --- a/runner/httpContext.go +++ b/runner/httpContext.go @@ -12,31 +12,30 @@ import ( // A prebuilt, ready-to-go context for HTTP requests to the runner. type HTTPContext struct { - Method []byte - Path []byte - Host []byte - Headers map[string]any - Cookies map[string]string - Query map[string]string - Params map[string]any - Form map[string]any - Session map[string]any - Env map[string]any - Values map[string]any // Extra context vars just in case + Values map[string]any // Contains all context data for Lua + + // Separate maps for efficient access during context building + headers map[string]string + cookies map[string]string + query map[string]string + params map[string]string + form map[string]any + session map[string]any + env map[string]any } // HTTP context pool to reduce allocations var httpContextPool = sync.Pool{ New: func() any { return &HTTPContext{ - Headers: make(map[string]any, 16), - Cookies: make(map[string]string, 8), - Query: make(map[string]string, 8), - Params: make(map[string]any, 4), - Form: make(map[string]any, 8), - Session: make(map[string]any, 4), - Env: make(map[string]any, 16), Values: make(map[string]any, 32), + headers: make(map[string]string, 16), + cookies: make(map[string]string, 8), + query: make(map[string]string, 8), + params: make(map[string]string, 4), + form: make(map[string]any, 8), + session: make(map[string]any, 4), + env: make(map[string]any, 16), } }, } @@ -45,30 +44,25 @@ var httpContextPool = sync.Pool{ func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *HTTPContext { ctx := httpContextPool.Get().(*HTTPContext) - // Extract basic HTTP info - ctx.Method = httpCtx.Method() - ctx.Path = httpCtx.Path() - ctx.Host = httpCtx.Host() - // Extract headers httpCtx.Request.Header.VisitAll(func(key, value []byte) { - ctx.Headers[string(key)] = string(value) + ctx.headers[string(key)] = string(value) }) // Extract cookies httpCtx.Request.Header.VisitAllCookie(func(key, value []byte) { - ctx.Cookies[string(key)] = string(value) + ctx.cookies[string(key)] = string(value) }) // Extract query params httpCtx.QueryArgs().VisitAll(func(key, value []byte) { - ctx.Query[string(key)] = string(value) + ctx.query[string(key)] = string(value) }) // Extract route parameters if params != nil { - for i := 0; i < min(len(params.Keys), len(params.Values)); i++ { - ctx.Params[params.Keys[i]] = params.Values[i] + for i := range min(len(params.Keys), len(params.Values)) { + ctx.params[params.Keys[i]] = params.Values[i] } } @@ -76,90 +70,69 @@ func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session if httpCtx.IsPost() || httpCtx.IsPut() || httpCtx.IsPatch() { if form, err := utils.ParseForm(httpCtx); err == nil { for k, v := range form { - ctx.Form[k] = v + ctx.form[k] = v } } } // Extract session data session.AdvanceFlash() - ctx.Session["id"] = session.ID + ctx.session["id"] = session.ID if session.IsEmpty() { - ctx.Session["data"] = emptyMap - ctx.Session["flash"] = emptyMap + ctx.session["data"] = emptyMap + ctx.session["flash"] = emptyMap } else { - ctx.Session["data"] = session.GetAll() - ctx.Session["flash"] = session.GetAllFlash() + ctx.session["data"] = session.GetAll() + ctx.session["flash"] = session.GetAllFlash() } // Add environment vars if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil { for k, v := range envMgr.GetAll() { - ctx.Env[k] = v + ctx.env[k] = v } } + // Populate Values with all context data + ctx.Values["method"] = string(httpCtx.Method()) + ctx.Values["path"] = string(httpCtx.Path()) + ctx.Values["host"] = string(httpCtx.Host()) + ctx.Values["headers"] = ctx.headers + ctx.Values["cookies"] = ctx.cookies + ctx.Values["query"] = ctx.query + ctx.Values["params"] = ctx.params + ctx.Values["form"] = ctx.form + ctx.Values["session"] = ctx.session + ctx.Values["env"] = ctx.env + return ctx } -// Clear out all the request data from the context and give it back to the pool. Keeps the contexts and inner maps -// allocated to prevent GC churn. +// Clear out all the request data from the context and give it back to the pool. func (c *HTTPContext) Release() { - for k := range c.Headers { - delete(c.Headers, k) - } - for k := range c.Cookies { - delete(c.Cookies, k) - } - for k := range c.Query { - delete(c.Query, k) - } - for k := range c.Params { - delete(c.Params, k) - } - for k := range c.Form { - delete(c.Form, k) - } - for k := range c.Session { - delete(c.Session, k) - } - for k := range c.Env { - delete(c.Env, k) - } - for k := range c.Values { - delete(c.Values, k) - } - - c.Method = nil - c.Path = nil - c.Host = nil + clear(c.Values) + clear(c.headers) + clear(c.cookies) + clear(c.query) + clear(c.params) + clear(c.form) + clear(c.session) + clear(c.env) httpContextPool.Put(c) } -// Add a value to the extras map +// Add a value to the extras section func (c *HTTPContext) Set(key string, value any) { c.Values[key] = value } -// Get a value from the extras map +// Get a value from the context func (c *HTTPContext) Get(key string) any { return c.Values[key] } -// Returns a representation of the context ready for Lua +// Returns the Values map directly - zero overhead func (c *HTTPContext) ToMap() map[string]any { - return map[string]any{ - "method": string(c.Method), - "path": string(c.Path), - "host": string(c.Host), - "headers": c.Headers, - "cookies": c.Cookies, - "query": c.Query, - "params": c.Params, - "form": c.Form, - "session": c.Session, - "env": c.Env, - "values": c.Values, - } + return c.Values }