Moonshark/routers/luaRouter.go

788 lines
20 KiB
Go

package routers
import (
"encoding/binary"
"errors"
"hash/fnv"
"os"
"path/filepath"
"strings"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/VictoriaMetrics/fastcache"
)
// Maximum number of URL parameters per route
const maxParams = 20
// Default cache sizes
const (
defaultBytecodeMaxBytes = 32 * 1024 * 1024 // 32MB for bytecode cache
defaultRouteMaxBytes = 8 * 1024 * 1024 // 8MB for route match cache
)
// Params holds URL parameters with fixed-size arrays to avoid allocations
type Params struct {
Keys [maxParams]string
Values [maxParams]string
Count int
}
// Get returns a parameter value by name
func (p *Params) Get(name string) string {
for i := range p.Count {
if p.Keys[i] == name {
return p.Values[i]
}
}
return ""
}
// LuaRouter is a filesystem-based HTTP router for Lua files
type LuaRouter struct {
routesDir string // Root directory containing route files
routes map[string]*node // Method -> route tree
failedRoutes map[string]*RouteError // Track failed routes
mu sync.RWMutex // Lock for concurrent access to routes
routeCache *fastcache.Cache // Cache for route lookups
bytecodeCache *fastcache.Cache // Cache for compiled bytecode
// Middleware tracking for path hierarchy
middlewareFiles map[string][]string // path -> middleware file paths
// New caching fields
middlewareCache map[string][]byte // path -> content
sourceCache map[string][]byte // combined source cache key -> compiled bytecode
sourceMtimes map[string]time.Time // track modification times
// Shared Lua state for compilation
compileState *luajit.State
compileStateMu sync.Mutex // Protect concurrent access to Lua state
}
// node represents a node in the routing trie
type node struct {
handler string // Path to Lua file (empty if not an endpoint)
indexFile string // Path to index.lua file (catch-all)
paramName string // Parameter name (if this is a parameter node)
staticChild map[string]*node // Static children by segment name
paramChild *node // Parameter/wildcard child
err error // Compilation error if any
modTime time.Time // Last modification time
}
// NewLuaRouter creates a new LuaRouter instance
func NewLuaRouter(routesDir string) (*LuaRouter, error) {
info, err := os.Stat(routesDir)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, errors.New("routes path is not a directory")
}
// Create shared Lua state
compileState := luajit.New()
if compileState == nil {
return nil, errors.New("failed to create Lua compile state")
}
r := &LuaRouter{
routesDir: routesDir,
routes: make(map[string]*node),
failedRoutes: make(map[string]*RouteError),
middlewareFiles: make(map[string][]string),
routeCache: fastcache.New(defaultRouteMaxBytes),
bytecodeCache: fastcache.New(defaultBytecodeMaxBytes),
middlewareCache: make(map[string][]byte),
sourceCache: make(map[string][]byte),
sourceMtimes: make(map[string]time.Time),
compileState: compileState,
}
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
for _, method := range methods {
r.routes[method] = &node{
staticChild: make(map[string]*node),
}
}
err = r.buildRoutes()
if len(r.failedRoutes) > 0 {
return r, ErrRoutesCompilationErrors
}
return r, err
}
// buildRoutes scans the routes directory and builds the routing tree
func (r *LuaRouter) buildRoutes() error {
r.failedRoutes = make(map[string]*RouteError)
r.middlewareFiles = make(map[string][]string)
// First pass: collect all middleware files
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
if err != nil {
return err
}
urlPath := "/"
if relDir != "." {
urlPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
}
r.middlewareFiles[urlPath] = append(r.middlewareFiles[urlPath], path)
}
return nil
})
if err != nil {
return err
}
// Second pass: build routes with combined middleware + handler
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
fileName := strings.TrimSuffix(info.Name(), ".lua")
// Skip middleware files (already processed)
if fileName == "middleware" {
return nil
}
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
if err != nil {
return err
}
urlPath := "/"
if relDir != "." {
urlPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
}
// Handle index.lua files
if fileName == "index" {
for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"} {
root := r.routes[method]
node := r.findOrCreateNode(root, urlPath)
node.indexFile = path
node.modTime = info.ModTime()
r.compileWithMiddleware(node, urlPath, path)
}
return nil
}
// Handle method files
method := strings.ToUpper(fileName)
root, exists := r.routes[method]
if !exists {
return nil
}
r.addRoute(root, urlPath, path, info.ModTime())
return nil
})
}
// addRoute adds a route to the routing tree
func (r *LuaRouter) addRoute(root *node, urlPath, handlerPath string, modTime time.Time) error {
segments := strings.Split(strings.Trim(urlPath, "/"), "/")
current := root
for _, segment := range segments {
if segment == "" {
continue
}
if len(segment) >= 2 && segment[0] == '[' && segment[len(segment)-1] == ']' {
if current.paramChild == nil {
current.paramChild = &node{
paramName: segment[1 : len(segment)-1],
staticChild: make(map[string]*node),
}
}
current = current.paramChild
} else {
child, exists := current.staticChild[segment]
if !exists {
child = &node{
staticChild: make(map[string]*node),
}
current.staticChild[segment] = child
}
current = child
}
}
current.handler = handlerPath
current.modTime = modTime
return r.compileWithMiddleware(current, urlPath, handlerPath)
}
// compileWithMiddleware combines middleware and handler source, then compiles
func (r *LuaRouter) compileWithMiddleware(n *node, urlPath, scriptPath string) error {
if scriptPath == "" {
return nil
}
// Check if we need to recompile by comparing modification times
sourceKey := r.getSourceCacheKey(urlPath, scriptPath)
needsRecompile := false
// Check handler modification time
handlerInfo, err := os.Stat(scriptPath)
if err != nil {
n.err = err
return err
}
lastCompiled, exists := r.sourceMtimes[sourceKey]
if !exists || handlerInfo.ModTime().After(lastCompiled) {
needsRecompile = true
}
// Check middleware modification times
if !needsRecompile {
middlewareChain := r.getMiddlewareChain(urlPath)
for _, mwPath := range middlewareChain {
mwInfo, err := os.Stat(mwPath)
if err != nil {
n.err = err
return err
}
if mwInfo.ModTime().After(lastCompiled) {
needsRecompile = true
break
}
}
}
// Use cached bytecode if available and fresh
if !needsRecompile {
if bytecode, exists := r.sourceCache[sourceKey]; exists {
bytecodeKey := getBytecodeKey(scriptPath)
r.bytecodeCache.Set(bytecodeKey, bytecode)
return nil
}
}
// Build combined source
combinedSource, err := r.buildCombinedSource(urlPath, scriptPath)
if err != nil {
n.err = err
return err
}
// Compile combined source using shared state
r.compileStateMu.Lock()
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
r.compileStateMu.Unlock()
if err != nil {
n.err = err
return err
}
// Cache everything
bytecodeKey := getBytecodeKey(scriptPath)
r.bytecodeCache.Set(bytecodeKey, bytecode)
r.sourceCache[sourceKey] = bytecode
r.sourceMtimes[sourceKey] = time.Now()
n.err = nil
return nil
}
// buildCombinedSource builds the combined middleware + handler source
func (r *LuaRouter) buildCombinedSource(urlPath, scriptPath string) (string, error) {
var combinedSource strings.Builder
// Get middleware chain
middlewareChain := r.getMiddlewareChain(urlPath)
// Add middleware in order
for _, mwPath := range middlewareChain {
content, err := r.getFileContent(mwPath)
if err != nil {
return "", err
}
combinedSource.WriteString("-- Middleware: ")
combinedSource.WriteString(mwPath)
combinedSource.WriteString("\n")
combinedSource.Write(content)
combinedSource.WriteString("\n")
}
// Add main handler
content, err := r.getFileContent(scriptPath)
if err != nil {
return "", err
}
combinedSource.WriteString("-- Handler: ")
combinedSource.WriteString(scriptPath)
combinedSource.WriteString("\n")
combinedSource.Write(content)
return combinedSource.String(), nil
}
// getFileContent reads file content with caching
func (r *LuaRouter) getFileContent(path string) ([]byte, error) {
// Check cache first
if content, exists := r.middlewareCache[path]; exists {
// Verify file hasn't changed
info, err := os.Stat(path)
if err == nil {
if cachedTime, exists := r.sourceMtimes[path]; exists && !info.ModTime().After(cachedTime) {
return content, nil
}
}
}
// Read from disk
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Cache it
r.middlewareCache[path] = content
r.sourceMtimes[path] = time.Now()
return content, nil
}
// getSourceCacheKey generates a unique key for combined source
func (r *LuaRouter) getSourceCacheKey(urlPath, scriptPath string) string {
middlewareChain := r.getMiddlewareChain(urlPath)
var keyParts []string
keyParts = append(keyParts, middlewareChain...)
keyParts = append(keyParts, scriptPath)
return strings.Join(keyParts, "|")
}
// getMiddlewareChain returns middleware files that apply to the given path
func (r *LuaRouter) getMiddlewareChain(urlPath string) []string {
var chain []string
// Collect middleware from root to specific path
pathParts := strings.Split(strings.Trim(urlPath, "/"), "/")
if pathParts[0] == "" {
pathParts = []string{}
}
// Add root middleware
if mw, exists := r.middlewareFiles["/"]; exists {
chain = append(chain, mw...)
}
// Add middleware from each path level
currentPath := ""
for _, part := range pathParts {
currentPath += "/" + part
if mw, exists := r.middlewareFiles[currentPath]; exists {
chain = append(chain, mw...)
}
}
return chain
}
// findOrCreateNode finds or creates a node at the given path
func (r *LuaRouter) findOrCreateNode(root *node, urlPath string) *node {
segments := strings.Split(strings.Trim(urlPath, "/"), "/")
current := root
for _, segment := range segments {
if segment == "" {
continue
}
if len(segment) >= 2 && segment[0] == '[' && segment[len(segment)-1] == ']' {
if current.paramChild == nil {
current.paramChild = &node{
paramName: segment[1 : len(segment)-1],
staticChild: make(map[string]*node),
}
}
current = current.paramChild
} else {
child, exists := current.staticChild[segment]
if !exists {
child = &node{
staticChild: make(map[string]*node),
}
current.staticChild[segment] = child
}
current = child
}
}
return current
}
// getRouteKey generates a unique key for a route
func getRouteKey(path, handler string) string {
return path + ":" + handler
}
// hashString generates a hash for a string
func hashString(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}
// uint64ToBytes converts a uint64 to bytes for cache key
func uint64ToBytes(n uint64) []byte {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, n)
return b
}
// getCacheKey generates a cache key for a method and path
func getCacheKey(method, path string) []byte {
key := hashString(method + ":" + path)
return uint64ToBytes(key)
}
// getBytecodeKey generates a cache key for a handler path
func getBytecodeKey(handlerPath string) []byte {
key := hashString(handlerPath)
return uint64ToBytes(key)
}
// Match finds a handler for the given method and path
func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
params.Count = 0
r.mu.RLock()
root, exists := r.routes[method]
r.mu.RUnlock()
if !exists {
return nil, false
}
segments := strings.Split(strings.Trim(path, "/"), "/")
return r.matchPath(root, segments, params, 0)
}
// matchPath recursively matches a path against the routing tree
func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) {
// Filter empty segments
filteredSegments := segments[:0]
for _, segment := range segments {
if segment != "" {
filteredSegments = append(filteredSegments, segment)
}
}
segments = filteredSegments
if len(segments) == 0 {
if current.handler != "" {
return current, true
}
if current.indexFile != "" {
return current, true
}
return nil, false
}
segment := segments[0]
remaining := segments[1:]
// Try static child first
if child, exists := current.staticChild[segment]; exists {
if node, found := r.matchPath(child, remaining, params, depth+1); found {
return node, true
}
}
// Try parameter child
if current.paramChild != nil {
if params.Count < maxParams {
params.Keys[params.Count] = current.paramChild.paramName
params.Values[params.Count] = segment
params.Count++
}
if node, found := r.matchPath(current.paramChild, remaining, params, depth+1); found {
return node, true
}
params.Count--
}
// Fall back to index.lua
if current.indexFile != "" {
return current, true
}
return nil, false
}
// GetRouteInfo returns the combined bytecode, script path, and any error
func (r *LuaRouter) GetRouteInfo(method, path string, params *Params) ([]byte, string, error, bool) {
routeCacheKey := getCacheKey(method, path)
routeCacheData := r.routeCache.Get(nil, routeCacheKey)
if len(routeCacheData) > 0 {
handlerPath := string(routeCacheData[8:])
bytecodeKey := routeCacheData[:8]
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
n, exists := r.nodeForHandler(handlerPath)
if !exists {
r.routeCache.Del(routeCacheKey)
return nil, "", nil, false
}
if len(bytecode) > 0 {
return bytecode, handlerPath, n.err, true
}
fileInfo, err := os.Stat(handlerPath)
if err != nil || fileInfo.ModTime().After(n.modTime) {
scriptPath := n.handler
if scriptPath == "" {
scriptPath = n.indexFile
}
urlPath := r.getNodeURLPath(n)
if err := r.compileWithMiddleware(n, urlPath, scriptPath); err != nil {
return nil, handlerPath, n.err, true
}
newBytecodeKey := getBytecodeKey(handlerPath)
bytecode = r.bytecodeCache.Get(nil, newBytecodeKey)
newCacheData := make([]byte, 8+len(handlerPath))
copy(newCacheData[:8], newBytecodeKey)
copy(newCacheData[8:], handlerPath)
r.routeCache.Set(routeCacheKey, newCacheData)
return bytecode, handlerPath, n.err, true
}
return bytecode, handlerPath, n.err, true
}
node, found := r.Match(method, path, params)
if !found {
return nil, "", nil, false
}
scriptPath := node.handler
if scriptPath == "" && node.indexFile != "" {
scriptPath = node.indexFile
}
if scriptPath == "" {
return nil, "", nil, false
}
bytecodeKey := getBytecodeKey(scriptPath)
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
if len(bytecode) == 0 {
urlPath := r.getNodeURLPath(node)
if err := r.compileWithMiddleware(node, urlPath, scriptPath); err != nil {
return nil, scriptPath, node.err, true
}
bytecode = r.bytecodeCache.Get(nil, bytecodeKey)
}
cacheData := make([]byte, 8+len(scriptPath))
copy(cacheData[:8], bytecodeKey)
copy(cacheData[8:], scriptPath)
r.routeCache.Set(routeCacheKey, cacheData)
return bytecode, scriptPath, node.err, true
}
// getNodeURLPath reconstructs URL path for a node (simplified)
func (r *LuaRouter) getNodeURLPath(node *node) string {
// This is a simplified implementation - in practice you'd traverse up the tree
return "/"
}
// nodeForHandler finds a node by its handler path
func (r *LuaRouter) nodeForHandler(handlerPath string) (*node, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, root := range r.routes {
if node := findNodeByHandler(root, handlerPath); node != nil {
return node, true
}
}
return nil, false
}
// findNodeByHandler finds a node by its handler path
func findNodeByHandler(current *node, handlerPath string) *node {
if current == nil {
return nil
}
if current.handler == handlerPath || current.indexFile == handlerPath {
return current
}
for _, child := range current.staticChild {
if node := findNodeByHandler(child, handlerPath); node != nil {
return node
}
}
if current.paramChild != nil {
if node := findNodeByHandler(current.paramChild, handlerPath); node != nil {
return node
}
}
return nil
}
// Refresh rebuilds the router by rescanning the routes directory
func (r *LuaRouter) Refresh() error {
r.mu.Lock()
defer r.mu.Unlock()
for method := range r.routes {
r.routes[method] = &node{
staticChild: make(map[string]*node),
}
}
r.failedRoutes = make(map[string]*RouteError)
r.middlewareFiles = make(map[string][]string)
r.middlewareCache = make(map[string][]byte)
r.sourceCache = make(map[string][]byte)
r.sourceMtimes = make(map[string]time.Time)
err := r.buildRoutes()
if len(r.failedRoutes) > 0 {
return ErrRoutesCompilationErrors
}
return err
}
// ReportFailedRoutes returns a list of routes that failed to compile
func (r *LuaRouter) ReportFailedRoutes() []*RouteError {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]*RouteError, 0, len(r.failedRoutes))
for _, re := range r.failedRoutes {
result = append(result, re)
}
return result
}
// ClearCache clears all caches
func (r *LuaRouter) ClearCache() {
r.routeCache.Reset()
r.bytecodeCache.Reset()
r.middlewareCache = make(map[string][]byte)
r.sourceCache = make(map[string][]byte)
r.sourceMtimes = make(map[string]time.Time)
}
// Close cleans up the router and its resources
func (r *LuaRouter) Close() {
r.compileStateMu.Lock()
if r.compileState != nil {
r.compileState.Close()
r.compileState = nil
}
r.compileStateMu.Unlock()
}
// GetCacheStats returns statistics about the cache
func (r *LuaRouter) GetCacheStats() map[string]any {
var routeStats fastcache.Stats
var bytecodeStats fastcache.Stats
r.routeCache.UpdateStats(&routeStats)
r.bytecodeCache.UpdateStats(&bytecodeStats)
return map[string]any{
"routeEntries": routeStats.EntriesCount,
"routeBytes": routeStats.BytesSize,
"routeCollisions": routeStats.Collisions,
"bytecodeEntries": bytecodeStats.EntriesCount,
"bytecodeBytes": bytecodeStats.BytesSize,
"bytecodeCollisions": bytecodeStats.Collisions,
}
}
// GetRouteStats returns statistics about the router
func (r *LuaRouter) GetRouteStats() (int, int64) {
r.mu.RLock()
defer r.mu.RUnlock()
routeCount := 0
bytecodeBytes := int64(0)
for _, root := range r.routes {
count, bytes := countNodesAndBytecode(root)
routeCount += count
bytecodeBytes += bytes
}
return routeCount, bytecodeBytes
}
// countNodesAndBytecode traverses the tree and counts nodes and bytecode size
func countNodesAndBytecode(n *node) (count int, bytecodeBytes int64) {
if n == nil {
return 0, 0
}
if n.handler != "" || n.indexFile != "" {
count = 1
bytecodeBytes = 2048
}
for _, child := range n.staticChild {
childCount, childBytes := countNodesAndBytecode(child)
count += childCount
bytecodeBytes += childBytes
}
if n.paramChild != nil {
childCount, childBytes := countNodesAndBytecode(n.paramChild)
count += childCount
bytecodeBytes += childBytes
}
return count, bytecodeBytes
}
type NodeWithError struct {
ScriptPath string
Error error
}