502 lines
11 KiB
Go
502 lines
11 KiB
Go
package router
|
|
|
|
import (
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
"github.com/VictoriaMetrics/fastcache"
|
|
)
|
|
|
|
// node represents a node in the radix trie
|
|
type node struct {
|
|
segment string
|
|
bytecode []byte
|
|
scriptPath string
|
|
children []*node
|
|
isDynamic bool
|
|
isWildcard bool
|
|
maxParams uint8
|
|
}
|
|
|
|
// Router is a filesystem-based HTTP router for Lua files with bytecode caching
|
|
type Router struct {
|
|
routesDir string
|
|
get, post, put, patch, delete *node
|
|
bytecodeCache *fastcache.Cache
|
|
compileState *luajit.State
|
|
compileMu sync.Mutex
|
|
paramsBuffer []string
|
|
middlewareFiles map[string][]string // filesystem path -> middleware file paths
|
|
}
|
|
|
|
// Params holds URL parameters
|
|
type Params struct {
|
|
Keys []string
|
|
Values []string
|
|
}
|
|
|
|
// Get returns a parameter value by name
|
|
func (p *Params) Get(name string) string {
|
|
for i, key := range p.Keys {
|
|
if key == name && i < len(p.Values) {
|
|
return p.Values[i]
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// New creates a new Router instance
|
|
func New(routesDir string) (*Router, error) {
|
|
info, err := os.Stat(routesDir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !info.IsDir() {
|
|
return nil, errors.New("routes path is not a directory")
|
|
}
|
|
|
|
compileState := luajit.New()
|
|
if compileState == nil {
|
|
return nil, errors.New("failed to create Lua compile state")
|
|
}
|
|
|
|
r := &Router{
|
|
routesDir: routesDir,
|
|
get: &node{},
|
|
post: &node{},
|
|
put: &node{},
|
|
patch: &node{},
|
|
delete: &node{},
|
|
bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB
|
|
compileState: compileState,
|
|
paramsBuffer: make([]string, 64),
|
|
middlewareFiles: make(map[string][]string),
|
|
}
|
|
|
|
return r, r.buildRoutes()
|
|
}
|
|
|
|
// methodNode returns the root node for a method
|
|
func (r *Router) methodNode(method string) *node {
|
|
switch method {
|
|
case "GET":
|
|
return r.get
|
|
case "POST":
|
|
return r.post
|
|
case "PUT":
|
|
return r.put
|
|
case "PATCH":
|
|
return r.patch
|
|
case "DELETE":
|
|
return r.delete
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// buildRoutes scans the routes directory and builds the routing tree
|
|
func (r *Router) buildRoutes() error {
|
|
r.middlewareFiles = make(map[string][]string)
|
|
|
|
// First pass: collect all middleware files
|
|
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
|
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
|
return err
|
|
}
|
|
|
|
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
|
|
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fsPath := "/"
|
|
if relDir != "." {
|
|
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
|
}
|
|
|
|
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Second pass: build routes
|
|
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
|
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
|
return err
|
|
}
|
|
|
|
fileName := strings.TrimSuffix(info.Name(), ".lua")
|
|
|
|
// Skip middleware files
|
|
if fileName == "middleware" {
|
|
return nil
|
|
}
|
|
|
|
// Get relative path from routes directory
|
|
relPath, err := filepath.Rel(r.routesDir, path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get filesystem path (includes groups)
|
|
fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/")
|
|
if fsPath == "/." {
|
|
fsPath = "/"
|
|
}
|
|
|
|
// Get URL path (excludes groups)
|
|
urlPath := r.parseURLPath(fsPath)
|
|
|
|
// Handle method files (get.lua, post.lua, etc.)
|
|
method := strings.ToUpper(fileName)
|
|
root := r.methodNode(method)
|
|
if root != nil {
|
|
return r.addRoute(root, urlPath, fsPath, path)
|
|
}
|
|
|
|
// Handle index files - register for all methods
|
|
if fileName == "index" {
|
|
for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} {
|
|
if root := r.methodNode(method); root != nil {
|
|
if err := r.addRoute(root, urlPath, fsPath, path); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Handle named route files - register as GET by default
|
|
namedPath := urlPath
|
|
if urlPath == "/" {
|
|
namedPath = "/" + fileName
|
|
} else {
|
|
namedPath = urlPath + "/" + fileName
|
|
}
|
|
return r.addRoute(r.get, namedPath, fsPath, path)
|
|
})
|
|
}
|
|
|
|
// parseURLPath strips group segments from filesystem path
|
|
func (r *Router) parseURLPath(fsPath string) string {
|
|
segments := strings.Split(strings.Trim(fsPath, "/"), "/")
|
|
var urlSegments []string
|
|
|
|
for _, segment := range segments {
|
|
if segment == "" {
|
|
continue
|
|
}
|
|
// Skip group segments (enclosed in parentheses)
|
|
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
|
|
continue
|
|
}
|
|
urlSegments = append(urlSegments, segment)
|
|
}
|
|
|
|
if len(urlSegments) == 0 {
|
|
return "/"
|
|
}
|
|
return "/" + strings.Join(urlSegments, "/")
|
|
}
|
|
|
|
// getMiddlewareChain returns middleware files that apply to the given filesystem path
|
|
func (r *Router) getMiddlewareChain(fsPath string) []string {
|
|
var chain []string
|
|
|
|
pathParts := strings.Split(strings.Trim(fsPath, "/"), "/")
|
|
if pathParts[0] == "" {
|
|
pathParts = []string{}
|
|
}
|
|
|
|
// Add root middleware
|
|
if mw, exists := r.middlewareFiles["/"]; exists {
|
|
chain = append(chain, mw...)
|
|
}
|
|
|
|
// Add middleware from each path level (including groups)
|
|
currentPath := ""
|
|
for _, part := range pathParts {
|
|
currentPath += "/" + part
|
|
if mw, exists := r.middlewareFiles[currentPath]; exists {
|
|
chain = append(chain, mw...)
|
|
}
|
|
}
|
|
|
|
return chain
|
|
}
|
|
|
|
// buildCombinedSource combines middleware and handler source
|
|
func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) {
|
|
var combined strings.Builder
|
|
|
|
// Add middleware in order
|
|
middlewareChain := r.getMiddlewareChain(fsPath)
|
|
for _, mwPath := range middlewareChain {
|
|
content, err := os.ReadFile(mwPath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
combined.WriteString("-- Middleware: ")
|
|
combined.WriteString(mwPath)
|
|
combined.WriteString("\n")
|
|
combined.Write(content)
|
|
combined.WriteString("\n")
|
|
}
|
|
|
|
// Add main handler
|
|
content, err := os.ReadFile(scriptPath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
combined.WriteString("-- Handler: ")
|
|
combined.WriteString(scriptPath)
|
|
combined.WriteString("\n")
|
|
combined.Write(content)
|
|
|
|
return combined.String(), nil
|
|
}
|
|
|
|
// addRoute adds a new route to the trie with bytecode compilation
|
|
func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error {
|
|
// Build combined source with middleware
|
|
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Compile bytecode
|
|
r.compileMu.Lock()
|
|
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
|
|
r.compileMu.Unlock()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Cache bytecode
|
|
cacheKey := hashString(scriptPath)
|
|
r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode)
|
|
|
|
if urlPath == "/" {
|
|
root.bytecode = bytecode
|
|
root.scriptPath = scriptPath
|
|
return nil
|
|
}
|
|
|
|
current := root
|
|
pos := 0
|
|
paramCount := uint8(0)
|
|
|
|
for {
|
|
seg, newPos, more := readSegment(urlPath, pos)
|
|
if seg == "" {
|
|
break
|
|
}
|
|
|
|
isDyn := len(seg) > 2 && seg[0] == '[' && seg[len(seg)-1] == ']'
|
|
isWC := len(seg) > 0 && seg[0] == '*'
|
|
|
|
if isWC && more {
|
|
return errors.New("wildcard must be the last segment")
|
|
}
|
|
|
|
if isDyn || isWC {
|
|
paramCount++
|
|
}
|
|
|
|
// Find or create child
|
|
var child *node
|
|
for _, c := range current.children {
|
|
if c.segment == seg {
|
|
child = c
|
|
break
|
|
}
|
|
}
|
|
|
|
if child == nil {
|
|
child = &node{
|
|
segment: seg,
|
|
isDynamic: isDyn,
|
|
isWildcard: isWC,
|
|
}
|
|
current.children = append(current.children, child)
|
|
}
|
|
|
|
if child.maxParams < paramCount {
|
|
child.maxParams = paramCount
|
|
}
|
|
|
|
current = child
|
|
pos = newPos
|
|
}
|
|
|
|
current.bytecode = bytecode
|
|
current.scriptPath = scriptPath
|
|
return nil
|
|
}
|
|
|
|
// readSegment extracts the next path segment
|
|
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
|
|
if start >= len(path) {
|
|
return "", start, false
|
|
}
|
|
if path[start] == '/' {
|
|
start++
|
|
}
|
|
if start >= len(path) {
|
|
return "", start, false
|
|
}
|
|
end = start
|
|
for end < len(path) && path[end] != '/' {
|
|
end++
|
|
}
|
|
return path[start:end], end, end < len(path)
|
|
}
|
|
|
|
// Lookup finds bytecode and parameters for a method and path
|
|
func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) {
|
|
root := r.methodNode(method)
|
|
if root == nil {
|
|
return nil, nil, false
|
|
}
|
|
|
|
if path == "/" {
|
|
if root.bytecode != nil {
|
|
return root.bytecode, &Params{}, true
|
|
}
|
|
return nil, nil, false
|
|
}
|
|
|
|
// Prepare params buffer
|
|
buffer := r.paramsBuffer
|
|
if cap(buffer) < int(root.maxParams) {
|
|
buffer = make([]string, root.maxParams)
|
|
r.paramsBuffer = buffer
|
|
}
|
|
buffer = buffer[:0]
|
|
|
|
var keys []string
|
|
bytecode, paramCount, found := r.match(root, path, 0, &buffer, &keys)
|
|
if !found {
|
|
return nil, nil, false
|
|
}
|
|
|
|
params := &Params{
|
|
Keys: keys[:paramCount],
|
|
Values: buffer[:paramCount],
|
|
}
|
|
|
|
return bytecode, params, true
|
|
}
|
|
|
|
// match traverses the trie to find bytecode
|
|
func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
|
|
paramCount := 0
|
|
|
|
// Check wildcard first
|
|
for _, c := range current.children {
|
|
if c.isWildcard {
|
|
rem := path[start:]
|
|
if len(rem) > 0 && rem[0] == '/' {
|
|
rem = rem[1:]
|
|
}
|
|
*params = append(*params, rem)
|
|
*keys = append(*keys, strings.TrimPrefix(c.segment, "*"))
|
|
return c.bytecode, 1, c.bytecode != nil
|
|
}
|
|
}
|
|
|
|
seg, pos, more := readSegment(path, start)
|
|
if seg == "" {
|
|
return current.bytecode, 0, current.bytecode != nil
|
|
}
|
|
|
|
for _, c := range current.children {
|
|
if c.segment == seg || c.isDynamic {
|
|
if c.isDynamic {
|
|
*params = append(*params, seg)
|
|
paramName := c.segment[1 : len(c.segment)-1] // Remove [ ]
|
|
*keys = append(*keys, paramName)
|
|
paramCount++
|
|
}
|
|
|
|
if !more {
|
|
return c.bytecode, paramCount, c.bytecode != nil
|
|
}
|
|
|
|
bytecode, nestedCount, ok := r.match(c, path, pos, params, keys)
|
|
if ok {
|
|
return bytecode, paramCount + nestedCount, true
|
|
}
|
|
|
|
// Backtrack on failure
|
|
if c.isDynamic {
|
|
*params = (*params)[:len(*params)-1]
|
|
*keys = (*keys)[:len(*keys)-1]
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, 0, false
|
|
}
|
|
|
|
// GetBytecode gets cached bytecode by script path
|
|
func (r *Router) GetBytecode(scriptPath string) []byte {
|
|
cacheKey := hashString(scriptPath)
|
|
return r.bytecodeCache.Get(nil, uint64ToBytes(cacheKey))
|
|
}
|
|
|
|
// Refresh rebuilds the router
|
|
func (r *Router) Refresh() error {
|
|
r.get = &node{}
|
|
r.post = &node{}
|
|
r.put = &node{}
|
|
r.patch = &node{}
|
|
r.delete = &node{}
|
|
r.middlewareFiles = make(map[string][]string)
|
|
r.bytecodeCache.Reset()
|
|
return r.buildRoutes()
|
|
}
|
|
|
|
// Close cleans up resources
|
|
func (r *Router) Close() {
|
|
r.compileMu.Lock()
|
|
if r.compileState != nil {
|
|
r.compileState.Close()
|
|
r.compileState = nil
|
|
}
|
|
r.compileMu.Unlock()
|
|
}
|
|
|
|
// Helper functions from cache.go
|
|
func hashString(s string) uint64 {
|
|
h := uint64(5381)
|
|
for i := 0; i < len(s); i++ {
|
|
h = ((h << 5) + h) + uint64(s[i])
|
|
}
|
|
return h
|
|
}
|
|
|
|
func uint64ToBytes(n uint64) []byte {
|
|
b := make([]byte, 8)
|
|
b[0] = byte(n)
|
|
b[1] = byte(n >> 8)
|
|
b[2] = byte(n >> 16)
|
|
b[3] = byte(n >> 24)
|
|
b[4] = byte(n >> 32)
|
|
b[5] = byte(n >> 40)
|
|
b[6] = byte(n >> 48)
|
|
b[7] = byte(n >> 56)
|
|
return b
|
|
}
|