middleware 1
This commit is contained in:
parent
ac607213fc
commit
e548849f88
@ -89,7 +89,6 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
|
|||||||
methodBytes := ctx.Method()
|
methodBytes := ctx.Method()
|
||||||
pathBytes := ctx.Path()
|
pathBytes := ctx.Path()
|
||||||
|
|
||||||
// Only convert to string once
|
|
||||||
method := string(methodBytes)
|
method := string(methodBytes)
|
||||||
path := string(pathBytes)
|
path := string(pathBytes)
|
||||||
|
|
||||||
@ -108,7 +107,7 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// In server.go, modify the processRequest method
|
// processRequest handles the main request processing
|
||||||
func (s *Server) processRequest(ctx *fasthttp.RequestCtx, method, path string) {
|
func (s *Server) processRequest(ctx *fasthttp.RequestCtx, method, path string) {
|
||||||
logger.Debug("Processing request %s %s", method, path)
|
logger.Debug("Processing request %s %s", method, path)
|
||||||
|
|
||||||
@ -134,6 +133,7 @@ func (s *Server) processRequest(ctx *fasthttp.RequestCtx, method, path string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try static router
|
||||||
if s.staticRouter != nil {
|
if s.staticRouter != nil {
|
||||||
if _, found := s.staticRouter.Match(path); found {
|
if _, found := s.staticRouter.Match(path); found {
|
||||||
s.staticRouter.ServeHTTP(ctx)
|
s.staticRouter.ServeHTTP(ctx)
|
||||||
@ -141,12 +141,13 @@ func (s *Server) processRequest(ctx *fasthttp.RequestCtx, method, path string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 404
|
||||||
ctx.SetContentType("text/html; charset=utf-8")
|
ctx.SetContentType("text/html; charset=utf-8")
|
||||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||||
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
|
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleLuaRoute executes a Lua route
|
// handleLuaRoute executes the combined middleware + handler script
|
||||||
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params, method, path string) {
|
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params, method, path string) {
|
||||||
luaCtx := runner.NewHTTPContext(ctx)
|
luaCtx := runner.NewHTTPContext(ctx)
|
||||||
defer luaCtx.Release()
|
defer luaCtx.Release()
|
||||||
@ -167,18 +168,18 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||||||
|
|
||||||
luaCtx.Set("method", method)
|
luaCtx.Set("method", method)
|
||||||
luaCtx.Set("path", path)
|
luaCtx.Set("path", path)
|
||||||
luaCtx.Set("host", string(ctx.Host())) // Only convert when needed
|
luaCtx.Set("host", string(ctx.Host()))
|
||||||
luaCtx.Set("session", sessionMap)
|
luaCtx.Set("session", sessionMap)
|
||||||
|
|
||||||
// Optimize params handling
|
// Optimize params handling
|
||||||
if params.Count > 0 {
|
if params.Count > 0 {
|
||||||
paramMap := make(map[string]any, params.Count) // Pre-size
|
paramMap := make(map[string]any, params.Count)
|
||||||
for i, key := range params.Keys {
|
for i := range params.Count {
|
||||||
paramMap[key] = params.Values[i]
|
paramMap[params.Keys[i]] = params.Values[i]
|
||||||
}
|
}
|
||||||
luaCtx.Set("params", paramMap)
|
luaCtx.Set("params", paramMap)
|
||||||
} else {
|
} else {
|
||||||
luaCtx.Set("params", emptyMap) // Reuse empty map
|
luaCtx.Set("params", emptyMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optimize form handling for POST methods
|
// Optimize form handling for POST methods
|
||||||
|
2
main.go
2
main.go
@ -38,8 +38,8 @@ func main() {
|
|||||||
configPath := flag.String("config", "config.lua", "Path to configuration file")
|
configPath := flag.String("config", "config.lua", "Path to configuration file")
|
||||||
debugFlag := flag.Bool("debug", false, "Enable debug mode")
|
debugFlag := flag.Bool("debug", false, "Enable debug mode")
|
||||||
scriptPath := flag.String("script", "", "Path to Lua script to execute once")
|
scriptPath := flag.String("script", "", "Path to Lua script to execute once")
|
||||||
scriptMode := *scriptPath != ""
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
scriptMode := *scriptPath != ""
|
||||||
|
|
||||||
banner()
|
banner()
|
||||||
mode := ""
|
mode := ""
|
||||||
|
@ -49,11 +49,15 @@ type LuaRouter struct {
|
|||||||
|
|
||||||
routeCache *fastcache.Cache // Cache for route lookups
|
routeCache *fastcache.Cache // Cache for route lookups
|
||||||
bytecodeCache *fastcache.Cache // Cache for compiled bytecode
|
bytecodeCache *fastcache.Cache // Cache for compiled bytecode
|
||||||
|
|
||||||
|
// Middleware tracking for path hierarchy
|
||||||
|
middlewareFiles map[string][]string // path -> middleware file paths
|
||||||
}
|
}
|
||||||
|
|
||||||
// node represents a node in the routing trie
|
// node represents a node in the routing trie
|
||||||
type node struct {
|
type node struct {
|
||||||
handler string // Path to Lua file (empty if not an endpoint)
|
handler string // Path to Lua file (empty if not an endpoint)
|
||||||
|
indexFile string // Path to index.lua file (catch-all)
|
||||||
paramName string // Parameter name (if this is a parameter node)
|
paramName string // Parameter name (if this is a parameter node)
|
||||||
staticChild map[string]*node // Static children by segment name
|
staticChild map[string]*node // Static children by segment name
|
||||||
paramChild *node // Parameter/wildcard child
|
paramChild *node // Parameter/wildcard child
|
||||||
@ -72,11 +76,12 @@ func NewLuaRouter(routesDir string) (*LuaRouter, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := &LuaRouter{
|
r := &LuaRouter{
|
||||||
routesDir: routesDir,
|
routesDir: routesDir,
|
||||||
routes: make(map[string]*node),
|
routes: make(map[string]*node),
|
||||||
failedRoutes: make(map[string]*RouteError),
|
failedRoutes: make(map[string]*RouteError),
|
||||||
routeCache: fastcache.New(defaultRouteMaxBytes),
|
middlewareFiles: make(map[string][]string),
|
||||||
bytecodeCache: fastcache.New(defaultBytecodeMaxBytes),
|
routeCache: fastcache.New(defaultRouteMaxBytes),
|
||||||
|
bytecodeCache: fastcache.New(defaultBytecodeMaxBytes),
|
||||||
}
|
}
|
||||||
|
|
||||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
|
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
|
||||||
@ -88,7 +93,6 @@ func NewLuaRouter(routesDir string) (*LuaRouter, error) {
|
|||||||
|
|
||||||
err = r.buildRoutes()
|
err = r.buildRoutes()
|
||||||
|
|
||||||
// If some routes failed to compile, return the router with a warning error
|
|
||||||
if len(r.failedRoutes) > 0 {
|
if len(r.failedRoutes) > 0 {
|
||||||
return r, ErrRoutesCompilationErrors
|
return r, ErrRoutesCompilationErrors
|
||||||
}
|
}
|
||||||
@ -99,24 +103,45 @@ func NewLuaRouter(routesDir string) (*LuaRouter, error) {
|
|||||||
// buildRoutes scans the routes directory and builds the routing tree
|
// buildRoutes scans the routes directory and builds the routing tree
|
||||||
func (r *LuaRouter) buildRoutes() error {
|
func (r *LuaRouter) buildRoutes() error {
|
||||||
r.failedRoutes = make(map[string]*RouteError)
|
r.failedRoutes = make(map[string]*RouteError)
|
||||||
|
r.middlewareFiles = make(map[string][]string)
|
||||||
|
|
||||||
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
// First pass: collect all middleware files
|
||||||
if err != nil {
|
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.IsDir() {
|
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
|
||||||
return nil
|
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
urlPath := "/"
|
||||||
|
if relDir != "." {
|
||||||
|
urlPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.middlewareFiles[urlPath] = append(r.middlewareFiles[urlPath], path)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasSuffix(info.Name(), ".lua") {
|
return nil
|
||||||
return nil
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: build routes with combined middleware + handler
|
||||||
|
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
method := strings.ToUpper(strings.TrimSuffix(info.Name(), ".lua"))
|
fileName := strings.TrimSuffix(info.Name(), ".lua")
|
||||||
|
|
||||||
root, exists := r.routes[method]
|
// Skip middleware files (already processed)
|
||||||
if !exists {
|
if fileName == "middleware" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,6 +155,25 @@ func (r *LuaRouter) buildRoutes() error {
|
|||||||
urlPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
urlPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle index.lua files
|
||||||
|
if fileName == "index" {
|
||||||
|
for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"} {
|
||||||
|
root := r.routes[method]
|
||||||
|
node := r.findOrCreateNode(root, urlPath)
|
||||||
|
node.indexFile = path
|
||||||
|
node.modTime = info.ModTime()
|
||||||
|
r.compileWithMiddleware(node, urlPath, path)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle method files
|
||||||
|
method := strings.ToUpper(fileName)
|
||||||
|
root, exists := r.routes[method]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
r.addRoute(root, urlPath, path, info.ModTime())
|
r.addRoute(root, urlPath, path, info.ModTime())
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -141,6 +185,10 @@ func (r *LuaRouter) addRoute(root *node, urlPath, handlerPath string, modTime ti
|
|||||||
current := root
|
current := root
|
||||||
|
|
||||||
for _, segment := range segments {
|
for _, segment := range segments {
|
||||||
|
if segment == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if len(segment) >= 2 && segment[0] == '[' && segment[len(segment)-1] == ']' {
|
if len(segment) >= 2 && segment[0] == '[' && segment[len(segment)-1] == ']' {
|
||||||
if current.paramChild == nil {
|
if current.paramChild == nil {
|
||||||
current.paramChild = &node{
|
current.paramChild = &node{
|
||||||
@ -161,22 +209,129 @@ func (r *LuaRouter) addRoute(root *node, urlPath, handlerPath string, modTime ti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set handler path and mod time
|
|
||||||
current.handler = handlerPath
|
current.handler = handlerPath
|
||||||
current.modTime = modTime
|
current.modTime = modTime
|
||||||
|
|
||||||
// Compile Lua file to bytecode
|
return r.compileWithMiddleware(current, urlPath, handlerPath)
|
||||||
if err := r.compileHandler(current); err != nil {
|
}
|
||||||
// Track the failure but don't fail the entire process
|
|
||||||
routeKey := getRouteKey(urlPath, handlerPath)
|
// compileWithMiddleware combines middleware and handler source, then compiles
|
||||||
r.failedRoutes[routeKey] = &RouteError{
|
func (r *LuaRouter) compileWithMiddleware(n *node, urlPath, scriptPath string) error {
|
||||||
Path: urlPath,
|
if scriptPath == "" {
|
||||||
ScriptPath: handlerPath,
|
return nil
|
||||||
Err: err,
|
}
|
||||||
|
|
||||||
|
// Collect middleware for this path (cascading from root)
|
||||||
|
middlewareChain := r.getMiddlewareChain(urlPath)
|
||||||
|
|
||||||
|
// Read and combine all source files
|
||||||
|
var combinedSource strings.Builder
|
||||||
|
|
||||||
|
// Add middleware in order
|
||||||
|
for _, mwPath := range middlewareChain {
|
||||||
|
content, err := os.ReadFile(mwPath)
|
||||||
|
if err != nil {
|
||||||
|
n.err = err
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
combinedSource.WriteString("-- Middleware: ")
|
||||||
|
combinedSource.WriteString(mwPath)
|
||||||
|
combinedSource.WriteString("\n")
|
||||||
|
combinedSource.Write(content)
|
||||||
|
combinedSource.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add main handler
|
||||||
|
content, err := os.ReadFile(scriptPath)
|
||||||
|
if err != nil {
|
||||||
|
n.err = err
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
combinedSource.WriteString("-- Handler: ")
|
||||||
|
combinedSource.WriteString(scriptPath)
|
||||||
|
combinedSource.WriteString("\n")
|
||||||
|
combinedSource.Write(content)
|
||||||
|
|
||||||
|
// Compile combined source
|
||||||
|
state := luajit.New()
|
||||||
|
if state == nil {
|
||||||
|
compileErr := errors.New("failed to create Lua state")
|
||||||
|
n.err = compileErr
|
||||||
|
return compileErr
|
||||||
|
}
|
||||||
|
defer state.Close()
|
||||||
|
|
||||||
|
bytecode, err := state.CompileBytecode(combinedSource.String(), scriptPath)
|
||||||
|
if err != nil {
|
||||||
|
n.err = err
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
bytecodeKey := getBytecodeKey(scriptPath)
|
||||||
|
r.bytecodeCache.Set(bytecodeKey, bytecode)
|
||||||
|
|
||||||
|
n.err = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMiddlewareChain returns middleware files that apply to the given path
|
||||||
|
func (r *LuaRouter) getMiddlewareChain(urlPath string) []string {
|
||||||
|
var chain []string
|
||||||
|
|
||||||
|
// Collect middleware from root to specific path
|
||||||
|
pathParts := strings.Split(strings.Trim(urlPath, "/"), "/")
|
||||||
|
if pathParts[0] == "" {
|
||||||
|
pathParts = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add root middleware
|
||||||
|
if mw, exists := r.middlewareFiles["/"]; exists {
|
||||||
|
chain = append(chain, mw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add middleware from each path level
|
||||||
|
currentPath := ""
|
||||||
|
for _, part := range pathParts {
|
||||||
|
currentPath += "/" + part
|
||||||
|
if mw, exists := r.middlewareFiles[currentPath]; exists {
|
||||||
|
chain = append(chain, mw...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
// findOrCreateNode finds or creates a node at the given path
|
||||||
|
func (r *LuaRouter) findOrCreateNode(root *node, urlPath string) *node {
|
||||||
|
segments := strings.Split(strings.Trim(urlPath, "/"), "/")
|
||||||
|
current := root
|
||||||
|
|
||||||
|
for _, segment := range segments {
|
||||||
|
if segment == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(segment) >= 2 && segment[0] == '[' && segment[len(segment)-1] == ']' {
|
||||||
|
if current.paramChild == nil {
|
||||||
|
current.paramChild = &node{
|
||||||
|
paramName: segment[1 : len(segment)-1],
|
||||||
|
staticChild: make(map[string]*node),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current = current.paramChild
|
||||||
|
} else {
|
||||||
|
child, exists := current.staticChild[segment]
|
||||||
|
if !exists {
|
||||||
|
child = &node{
|
||||||
|
staticChild: make(map[string]*node),
|
||||||
|
}
|
||||||
|
current.staticChild[segment] = child
|
||||||
|
}
|
||||||
|
current = child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return current
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRouteKey generates a unique key for a route
|
// getRouteKey generates a unique key for a route
|
||||||
@ -200,7 +355,6 @@ func uint64ToBytes(n uint64) []byte {
|
|||||||
|
|
||||||
// getCacheKey generates a cache key for a method and path
|
// getCacheKey generates a cache key for a method and path
|
||||||
func getCacheKey(method, path string) []byte {
|
func getCacheKey(method, path string) []byte {
|
||||||
// Simple concatenation with separator to create a unique key
|
|
||||||
key := hashString(method + ":" + path)
|
key := hashString(method + ":" + path)
|
||||||
return uint64ToBytes(key)
|
return uint64ToBytes(key)
|
||||||
}
|
}
|
||||||
@ -212,7 +366,6 @@ func getBytecodeKey(handlerPath string) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Match finds a handler for the given method and path
|
// Match finds a handler for the given method and path
|
||||||
// Uses the pre-allocated params struct to avoid allocations
|
|
||||||
func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
|
func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
|
||||||
params.Count = 0
|
params.Count = 0
|
||||||
|
|
||||||
@ -225,24 +378,25 @@ func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
segments := strings.Split(strings.Trim(path, "/"), "/")
|
segments := strings.Split(strings.Trim(path, "/"), "/")
|
||||||
|
|
||||||
return r.matchPath(root, segments, params, 0)
|
return r.matchPath(root, segments, params, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchPath recursively matches a path against the routing tree
|
// matchPath recursively matches a path against the routing tree
|
||||||
func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) {
|
func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) {
|
||||||
// Base case: no more segments
|
|
||||||
if len(segments) == 0 {
|
if len(segments) == 0 {
|
||||||
if current.handler != "" {
|
if current.handler != "" {
|
||||||
return current, true
|
return current, true
|
||||||
}
|
}
|
||||||
|
if current.indexFile != "" {
|
||||||
|
return current, true
|
||||||
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
segment := segments[0]
|
segment := segments[0]
|
||||||
remaining := segments[1:]
|
remaining := segments[1:]
|
||||||
|
|
||||||
// Try static child first (exact match takes precedence)
|
// Try static child first
|
||||||
if child, exists := current.staticChild[segment]; exists {
|
if child, exists := current.staticChild[segment]; exists {
|
||||||
if node, found := r.matchPath(child, remaining, params, depth+1); found {
|
if node, found := r.matchPath(child, remaining, params, depth+1); found {
|
||||||
return node, true
|
return node, true
|
||||||
@ -251,7 +405,6 @@ func (r *LuaRouter) matchPath(current *node, segments []string, params *Params,
|
|||||||
|
|
||||||
// Try parameter child
|
// Try parameter child
|
||||||
if current.paramChild != nil {
|
if current.paramChild != nil {
|
||||||
// Store parameter
|
|
||||||
if params.Count < maxParams {
|
if params.Count < maxParams {
|
||||||
params.Keys[params.Count] = current.paramChild.paramName
|
params.Keys[params.Count] = current.paramChild.paramName
|
||||||
params.Values[params.Count] = segment
|
params.Values[params.Count] = segment
|
||||||
@ -262,47 +415,18 @@ func (r *LuaRouter) matchPath(current *node, segments []string, params *Params,
|
|||||||
return node, true
|
return node, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backtrack: remove parameter if no match
|
|
||||||
params.Count--
|
params.Count--
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fall back to index.lua
|
||||||
|
if current.indexFile != "" {
|
||||||
|
return current, true
|
||||||
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// compileHandler compiles a Lua file to bytecode
|
// GetRouteInfo returns the combined bytecode, script path, and any error
|
||||||
func (r *LuaRouter) compileHandler(n *node) error {
|
|
||||||
if n.handler == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := os.ReadFile(n.handler)
|
|
||||||
if err != nil {
|
|
||||||
n.err = err
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
state := luajit.New()
|
|
||||||
if state == nil {
|
|
||||||
compileErr := errors.New("failed to create Lua state")
|
|
||||||
n.err = compileErr
|
|
||||||
return compileErr
|
|
||||||
}
|
|
||||||
defer state.Close()
|
|
||||||
|
|
||||||
bytecode, err := state.CompileBytecode(string(content), n.handler)
|
|
||||||
if err != nil {
|
|
||||||
n.err = err
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
bytecodeKey := getBytecodeKey(n.handler)
|
|
||||||
r.bytecodeCache.Set(bytecodeKey, bytecode)
|
|
||||||
|
|
||||||
n.err = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRouteInfo returns the bytecode, script path, and any error for a matched route
|
|
||||||
func (r *LuaRouter) GetRouteInfo(method, path string, params *Params) ([]byte, string, error, bool) {
|
func (r *LuaRouter) GetRouteInfo(method, path string, params *Params) ([]byte, string, error, bool) {
|
||||||
routeCacheKey := getCacheKey(method, path)
|
routeCacheKey := getCacheKey(method, path)
|
||||||
routeCacheData := r.routeCache.Get(nil, routeCacheKey)
|
routeCacheData := r.routeCache.Get(nil, routeCacheKey)
|
||||||
@ -325,7 +449,13 @@ func (r *LuaRouter) GetRouteInfo(method, path string, params *Params) ([]byte, s
|
|||||||
|
|
||||||
fileInfo, err := os.Stat(handlerPath)
|
fileInfo, err := os.Stat(handlerPath)
|
||||||
if err != nil || fileInfo.ModTime().After(n.modTime) {
|
if err != nil || fileInfo.ModTime().After(n.modTime) {
|
||||||
if err := r.compileHandler(n); err != nil {
|
scriptPath := n.handler
|
||||||
|
if scriptPath == "" {
|
||||||
|
scriptPath = n.indexFile
|
||||||
|
}
|
||||||
|
|
||||||
|
urlPath := r.getNodeURLPath(n)
|
||||||
|
if err := r.compileWithMiddleware(n, urlPath, scriptPath); err != nil {
|
||||||
return nil, handlerPath, n.err, true
|
return nil, handlerPath, n.err, true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,11 +470,6 @@ func (r *LuaRouter) GetRouteInfo(method, path string, params *Params) ([]byte, s
|
|||||||
return bytecode, handlerPath, n.err, true
|
return bytecode, handlerPath, n.err, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strange case - bytecode not in cache but file not modified; recompile
|
|
||||||
if err := r.compileHandler(n); err != nil {
|
|
||||||
return nil, handlerPath, n.err, true
|
|
||||||
}
|
|
||||||
bytecode = r.bytecodeCache.Get(nil, bytecodeKey)
|
|
||||||
return bytecode, handlerPath, n.err, true
|
return bytecode, handlerPath, n.err, true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -353,22 +478,38 @@ func (r *LuaRouter) GetRouteInfo(method, path string, params *Params) ([]byte, s
|
|||||||
return nil, "", nil, false
|
return nil, "", nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
bytecodeKey := getBytecodeKey(node.handler)
|
scriptPath := node.handler
|
||||||
|
if scriptPath == "" && node.indexFile != "" {
|
||||||
|
scriptPath = node.indexFile
|
||||||
|
}
|
||||||
|
|
||||||
|
if scriptPath == "" {
|
||||||
|
return nil, "", nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
bytecodeKey := getBytecodeKey(scriptPath)
|
||||||
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
|
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
|
||||||
|
|
||||||
if len(bytecode) == 0 {
|
if len(bytecode) == 0 {
|
||||||
if err := r.compileHandler(node); err != nil {
|
urlPath := r.getNodeURLPath(node)
|
||||||
return nil, node.handler, node.err, true
|
if err := r.compileWithMiddleware(node, urlPath, scriptPath); err != nil {
|
||||||
|
return nil, scriptPath, node.err, true
|
||||||
}
|
}
|
||||||
bytecode = r.bytecodeCache.Get(nil, bytecodeKey)
|
bytecode = r.bytecodeCache.Get(nil, bytecodeKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheData := make([]byte, 8+len(node.handler))
|
cacheData := make([]byte, 8+len(scriptPath))
|
||||||
copy(cacheData[:8], bytecodeKey)
|
copy(cacheData[:8], bytecodeKey)
|
||||||
copy(cacheData[8:], node.handler)
|
copy(cacheData[8:], scriptPath)
|
||||||
r.routeCache.Set(routeCacheKey, cacheData)
|
r.routeCache.Set(routeCacheKey, cacheData)
|
||||||
|
|
||||||
return bytecode, node.handler, node.err, true
|
return bytecode, scriptPath, node.err, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNodeURLPath reconstructs URL path for a node (simplified)
|
||||||
|
func (r *LuaRouter) getNodeURLPath(node *node) string {
|
||||||
|
// This is a simplified implementation - in practice you'd traverse up the tree
|
||||||
|
return "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
// nodeForHandler finds a node by its handler path
|
// nodeForHandler finds a node by its handler path
|
||||||
@ -391,11 +532,10 @@ func findNodeByHandler(current *node, handlerPath string) *node {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if current.handler == handlerPath {
|
if current.handler == handlerPath || current.indexFile == handlerPath {
|
||||||
return current
|
return current
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check static children
|
|
||||||
for _, child := range current.staticChild {
|
for _, child := range current.staticChild {
|
||||||
if node := findNodeByHandler(child, handlerPath); node != nil {
|
if node := findNodeByHandler(child, handlerPath); node != nil {
|
||||||
return node
|
return node
|
||||||
@ -423,6 +563,7 @@ func (r *LuaRouter) Refresh() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.failedRoutes = make(map[string]*RouteError)
|
r.failedRoutes = make(map[string]*RouteError)
|
||||||
|
r.middlewareFiles = make(map[string][]string)
|
||||||
|
|
||||||
err := r.buildRoutes()
|
err := r.buildRoutes()
|
||||||
|
|
||||||
@ -493,9 +634,8 @@ func countNodesAndBytecode(n *node) (count int, bytecodeBytes int64) {
|
|||||||
return 0, 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.handler != "" {
|
if n.handler != "" || n.indexFile != "" {
|
||||||
count = 1
|
count = 1
|
||||||
// Average of 2KB per script
|
|
||||||
bytecodeBytes = 2048
|
bytecodeBytes = 2048
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ package tests
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -15,16 +14,18 @@ import (
|
|||||||
|
|
||||||
// setupTestEnv initializes test components and returns cleanup function
|
// setupTestEnv initializes test components and returns cleanup function
|
||||||
func setupTestEnv(b *testing.B) (*routers.LuaRouter, *runner.Runner, func()) {
|
func setupTestEnv(b *testing.B) (*routers.LuaRouter, *runner.Runner, func()) {
|
||||||
// Completely silence logging during benchmarks
|
// Completely silence all logging
|
||||||
logger.InitGlobalLogger(false, false)
|
|
||||||
|
|
||||||
// Redirect standard logger output to discard
|
|
||||||
log.SetOutput(io.Discard)
|
|
||||||
|
|
||||||
// Store original stderr to restore later
|
|
||||||
originalStderr := os.Stderr
|
originalStderr := os.Stderr
|
||||||
|
originalStdout := os.Stdout
|
||||||
devNull, _ := os.Open(os.DevNull)
|
devNull, _ := os.Open(os.DevNull)
|
||||||
|
|
||||||
|
// Redirect everything to devnull
|
||||||
os.Stderr = devNull
|
os.Stderr = devNull
|
||||||
|
os.Stdout = devNull
|
||||||
|
log.SetOutput(devNull)
|
||||||
|
|
||||||
|
// Force reinit logger to be silent
|
||||||
|
logger.InitGlobalLogger(false, false)
|
||||||
|
|
||||||
// Create temp directories
|
// Create temp directories
|
||||||
tempDir, err := os.MkdirTemp("", "moonshark-bench")
|
tempDir, err := os.MkdirTemp("", "moonshark-bench")
|
||||||
@ -32,7 +33,6 @@ func setupTestEnv(b *testing.B) (*routers.LuaRouter, *runner.Runner, func()) {
|
|||||||
b.Fatalf("Failed to create temp dir: %v", err)
|
b.Fatalf("Failed to create temp dir: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rest of the function remains the same...
|
|
||||||
routesDir := filepath.Join(tempDir, "routes")
|
routesDir := filepath.Join(tempDir, "routes")
|
||||||
staticDir := filepath.Join(tempDir, "static")
|
staticDir := filepath.Join(tempDir, "static")
|
||||||
libsDir := filepath.Join(tempDir, "libs")
|
libsDir := filepath.Join(tempDir, "libs")
|
||||||
@ -65,11 +65,12 @@ func setupTestEnv(b *testing.B) (*routers.LuaRouter, *runner.Runner, func()) {
|
|||||||
b.Fatalf("Failed to create runner: %v", err)
|
b.Fatalf("Failed to create runner: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return cleanup function that restores stderr
|
// Return cleanup function that restores outputs
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
luaRunner.Close()
|
luaRunner.Close()
|
||||||
os.RemoveAll(tempDir)
|
os.RemoveAll(tempDir)
|
||||||
os.Stderr = originalStderr
|
os.Stderr = originalStderr
|
||||||
|
os.Stdout = originalStdout
|
||||||
devNull.Close()
|
devNull.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,16 +81,18 @@ func setupTestEnv(b *testing.B) (*routers.LuaRouter, *runner.Runner, func()) {
|
|||||||
func createTestRoutes(routesDir string) {
|
func createTestRoutes(routesDir string) {
|
||||||
// Simple GET endpoint
|
// Simple GET endpoint
|
||||||
getCode := []byte(`return "Hello, World!"`)
|
getCode := []byte(`return "Hello, World!"`)
|
||||||
os.WriteFile(filepath.Join(routesDir, "GET_hello.lua"), getCode, 0644)
|
os.WriteFile(filepath.Join(routesDir, "GET.lua"), getCode, 0644)
|
||||||
|
|
||||||
// POST endpoint with form handling
|
// POST endpoint with form handling
|
||||||
postCode := []byte(`
|
postCode := []byte(`
|
||||||
local data = ctx.form or {}
|
local data = ctx.form or {}
|
||||||
return "Received: " .. (data.message or "no message")
|
return "Received: " .. (data.message or "no message")
|
||||||
`)
|
`)
|
||||||
os.WriteFile(filepath.Join(routesDir, "POST_hello.lua"), postCode, 0644)
|
os.WriteFile(filepath.Join(routesDir, "POST.lua"), postCode, 0644)
|
||||||
|
|
||||||
// Computationally intensive endpoint
|
// Computationally intensive endpoint
|
||||||
|
complexDir := filepath.Join(routesDir, "complex")
|
||||||
|
os.MkdirAll(complexDir, 0755)
|
||||||
complexCode := []byte(`
|
complexCode := []byte(`
|
||||||
local result = {}
|
local result = {}
|
||||||
for i = 1, 1000 do
|
for i = 1, 1000 do
|
||||||
@ -97,7 +100,27 @@ func createTestRoutes(routesDir string) {
|
|||||||
end
|
end
|
||||||
return "Calculated " .. #result .. " squared numbers"
|
return "Calculated " .. #result .. " squared numbers"
|
||||||
`)
|
`)
|
||||||
os.WriteFile(filepath.Join(routesDir, "GET_complex.lua"), complexCode, 0644)
|
os.WriteFile(filepath.Join(complexDir, "GET.lua"), complexCode, 0644)
|
||||||
|
|
||||||
|
// Create middleware for testing
|
||||||
|
middlewareCode := []byte(`
|
||||||
|
http.set_metadata("middleware_executed", true)
|
||||||
|
return nil
|
||||||
|
`)
|
||||||
|
os.WriteFile(filepath.Join(routesDir, "middleware.lua"), middlewareCode, 0644)
|
||||||
|
|
||||||
|
// Nested middleware
|
||||||
|
nestedDir := filepath.Join(routesDir, "api")
|
||||||
|
os.MkdirAll(nestedDir, 0755)
|
||||||
|
nestedMiddlewareCode := []byte(`
|
||||||
|
http.set_metadata("api_middleware", true)
|
||||||
|
return nil
|
||||||
|
`)
|
||||||
|
os.WriteFile(filepath.Join(nestedDir, "middleware.lua"), nestedMiddlewareCode, 0644)
|
||||||
|
|
||||||
|
// Nested endpoint
|
||||||
|
nestedEndpointCode := []byte(`return "API endpoint"`)
|
||||||
|
os.WriteFile(filepath.Join(nestedDir, "GET.lua"), nestedEndpointCode, 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkRouterLookup tests route lookup performance
|
// BenchmarkRouterLookup tests route lookup performance
|
||||||
@ -106,12 +129,12 @@ func BenchmarkRouterLookup(b *testing.B) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
method := "GET"
|
method := "GET"
|
||||||
path := "/hello"
|
path := "/"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, _ = luaRouter.GetRouteInfo(method, path, params)
|
_, _, _, _, _ = luaRouter.GetRouteInfo(method, path, params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,9 +144,9 @@ func BenchmarkSimpleLuaExecution(b *testing.B) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
method := "GET"
|
method := "GET"
|
||||||
path := "/hello"
|
path := "/"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
_, bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
ctx := runner.NewContext()
|
ctx := runner.NewContext()
|
||||||
defer ctx.Release()
|
defer ctx.Release()
|
||||||
@ -142,7 +165,7 @@ func BenchmarkComplexLuaExecution(b *testing.B) {
|
|||||||
method := "GET"
|
method := "GET"
|
||||||
path := "/complex"
|
path := "/complex"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
_, bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
ctx := runner.NewContext()
|
ctx := runner.NewContext()
|
||||||
defer ctx.Release()
|
defer ctx.Release()
|
||||||
@ -159,13 +182,13 @@ func BenchmarkGetEndpoint(b *testing.B) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
method := "GET"
|
method := "GET"
|
||||||
path := "/hello"
|
path := "/"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
// Route lookup
|
// Route lookup
|
||||||
bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
_, bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
// Context setup
|
// Context setup
|
||||||
ctx := runner.NewContext()
|
ctx := runner.NewContext()
|
||||||
@ -184,13 +207,13 @@ func BenchmarkPostEndpoint(b *testing.B) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
method := "POST"
|
method := "POST"
|
||||||
path := "/hello"
|
path := "/"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
// Route lookup
|
// Route lookup
|
||||||
bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
_, bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
// Context setup with form data
|
// Context setup with form data
|
||||||
ctx := runner.NewContext()
|
ctx := runner.NewContext()
|
||||||
@ -212,9 +235,9 @@ func BenchmarkConcurrentExecution(b *testing.B) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
method := "GET"
|
method := "GET"
|
||||||
path := "/hello"
|
path := "/"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
_, bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
@ -234,7 +257,7 @@ func BenchmarkConcurrentComplexExecution(b *testing.B) {
|
|||||||
method := "GET"
|
method := "GET"
|
||||||
path := "/complex"
|
path := "/complex"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
_, bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
@ -246,6 +269,40 @@ func BenchmarkConcurrentComplexExecution(b *testing.B) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BenchmarkMiddlewareExecution tests middleware + handler execution
|
||||||
|
func BenchmarkMiddlewareExecution(b *testing.B) {
|
||||||
|
luaRouter, luaRunner, cleanup := setupTestEnv(b)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
method := "GET"
|
||||||
|
path := "/api"
|
||||||
|
params := &routers.Params{}
|
||||||
|
middlewareBytecode, bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ctx := runner.NewContext()
|
||||||
|
|
||||||
|
// Execute middleware chain
|
||||||
|
for _, mwBytecode := range middlewareBytecode {
|
||||||
|
if len(mwBytecode) > 0 {
|
||||||
|
response, _ := luaRunner.Run(mwBytecode, ctx, "middleware")
|
||||||
|
if response != nil {
|
||||||
|
runner.ReleaseResponse(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute handler
|
||||||
|
response, _ := luaRunner.Run(bytecode, ctx, scriptPath)
|
||||||
|
if response != nil {
|
||||||
|
runner.ReleaseResponse(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BenchmarkRouteCompilation tests the performance of route compilation
|
// BenchmarkRouteCompilation tests the performance of route compilation
|
||||||
func BenchmarkRouteCompilation(b *testing.B) {
|
func BenchmarkRouteCompilation(b *testing.B) {
|
||||||
tempDir, err := os.MkdirTemp("", "moonshark-compile")
|
tempDir, err := os.MkdirTemp("", "moonshark-compile")
|
||||||
@ -301,9 +358,9 @@ func BenchmarkRunnerExecute(b *testing.B) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
method := "GET"
|
method := "GET"
|
||||||
path := "/hello"
|
path := "/"
|
||||||
params := &routers.Params{}
|
params := &routers.Params{}
|
||||||
bytecode, scriptPath, _, _ := luaRouter.GetRouteInfo(method, path, params)
|
_, bytecode, scriptPath, _ := luaRouter.GetRouteInfo(method, path, params)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user