Moonshark/router/router.go

502 lines
11 KiB
Go

package router
import (
"errors"
"os"
"path/filepath"
"strings"
"sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/VictoriaMetrics/fastcache"
)
// node represents a node in the radix trie
type node struct {
segment string
bytecode []byte
scriptPath string
children []*node
isDynamic bool
isWildcard bool
maxParams uint8
}
// Router is a filesystem-based HTTP router for Lua files with bytecode caching
type Router struct {
routesDir string
get, post, put, patch, delete *node
bytecodeCache *fastcache.Cache
compileState *luajit.State
compileMu sync.Mutex
paramsBuffer []string
middlewareFiles map[string][]string // filesystem path -> middleware file paths
}
// Params holds URL parameters
type Params struct {
Keys []string
Values []string
}
// Get returns a parameter value by name
func (p *Params) Get(name string) string {
for i, key := range p.Keys {
if key == name && i < len(p.Values) {
return p.Values[i]
}
}
return ""
}
// New creates a new Router instance
func New(routesDir string) (*Router, error) {
info, err := os.Stat(routesDir)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, errors.New("routes path is not a directory")
}
compileState := luajit.New()
if compileState == nil {
return nil, errors.New("failed to create Lua compile state")
}
r := &Router{
routesDir: routesDir,
get: &node{},
post: &node{},
put: &node{},
patch: &node{},
delete: &node{},
bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB
compileState: compileState,
paramsBuffer: make([]string, 64),
middlewareFiles: make(map[string][]string),
}
return r, r.buildRoutes()
}
// methodNode returns the root node for a method
func (r *Router) methodNode(method string) *node {
switch method {
case "GET":
return r.get
case "POST":
return r.post
case "PUT":
return r.put
case "PATCH":
return r.patch
case "DELETE":
return r.delete
default:
return nil
}
}
// buildRoutes scans the routes directory and builds the routing tree
func (r *Router) buildRoutes() error {
r.middlewareFiles = make(map[string][]string)
// First pass: collect all middleware files
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
if err != nil {
return err
}
fsPath := "/"
if relDir != "." {
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
}
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
}
return nil
})
if err != nil {
return err
}
// Second pass: build routes
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
fileName := strings.TrimSuffix(info.Name(), ".lua")
// Skip middleware files
if fileName == "middleware" {
return nil
}
// Get relative path from routes directory
relPath, err := filepath.Rel(r.routesDir, path)
if err != nil {
return err
}
// Get filesystem path (includes groups)
fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/")
if fsPath == "/." {
fsPath = "/"
}
// Get URL path (excludes groups)
urlPath := r.parseURLPath(fsPath)
// Handle method files (get.lua, post.lua, etc.)
method := strings.ToUpper(fileName)
root := r.methodNode(method)
if root != nil {
return r.addRoute(root, urlPath, fsPath, path)
}
// Handle index files - register for all methods
if fileName == "index" {
for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} {
if root := r.methodNode(method); root != nil {
if err := r.addRoute(root, urlPath, fsPath, path); err != nil {
return err
}
}
}
return nil
}
// Handle named route files - register as GET by default
namedPath := urlPath
if urlPath == "/" {
namedPath = "/" + fileName
} else {
namedPath = urlPath + "/" + fileName
}
return r.addRoute(r.get, namedPath, fsPath, path)
})
}
// parseURLPath strips group segments from filesystem path
func (r *Router) parseURLPath(fsPath string) string {
segments := strings.Split(strings.Trim(fsPath, "/"), "/")
var urlSegments []string
for _, segment := range segments {
if segment == "" {
continue
}
// Skip group segments (enclosed in parentheses)
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
continue
}
urlSegments = append(urlSegments, segment)
}
if len(urlSegments) == 0 {
return "/"
}
return "/" + strings.Join(urlSegments, "/")
}
// getMiddlewareChain returns middleware files that apply to the given filesystem path
func (r *Router) getMiddlewareChain(fsPath string) []string {
var chain []string
pathParts := strings.Split(strings.Trim(fsPath, "/"), "/")
if pathParts[0] == "" {
pathParts = []string{}
}
// Add root middleware
if mw, exists := r.middlewareFiles["/"]; exists {
chain = append(chain, mw...)
}
// Add middleware from each path level (including groups)
currentPath := ""
for _, part := range pathParts {
currentPath += "/" + part
if mw, exists := r.middlewareFiles[currentPath]; exists {
chain = append(chain, mw...)
}
}
return chain
}
// buildCombinedSource combines middleware and handler source
func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) {
var combined strings.Builder
// Add middleware in order
middlewareChain := r.getMiddlewareChain(fsPath)
for _, mwPath := range middlewareChain {
content, err := os.ReadFile(mwPath)
if err != nil {
return "", err
}
combined.WriteString("-- Middleware: ")
combined.WriteString(mwPath)
combined.WriteString("\n")
combined.Write(content)
combined.WriteString("\n")
}
// Add main handler
content, err := os.ReadFile(scriptPath)
if err != nil {
return "", err
}
combined.WriteString("-- Handler: ")
combined.WriteString(scriptPath)
combined.WriteString("\n")
combined.Write(content)
return combined.String(), nil
}
// addRoute adds a new route to the trie with bytecode compilation
func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error {
// Build combined source with middleware
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
if err != nil {
return err
}
// Compile bytecode
r.compileMu.Lock()
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
r.compileMu.Unlock()
if err != nil {
return err
}
// Cache bytecode
cacheKey := hashString(scriptPath)
r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode)
if urlPath == "/" {
root.bytecode = bytecode
root.scriptPath = scriptPath
return nil
}
current := root
pos := 0
paramCount := uint8(0)
for {
seg, newPos, more := readSegment(urlPath, pos)
if seg == "" {
break
}
isDyn := len(seg) > 2 && seg[0] == '[' && seg[len(seg)-1] == ']'
isWC := len(seg) > 0 && seg[0] == '*'
if isWC && more {
return errors.New("wildcard must be the last segment")
}
if isDyn || isWC {
paramCount++
}
// Find or create child
var child *node
for _, c := range current.children {
if c.segment == seg {
child = c
break
}
}
if child == nil {
child = &node{
segment: seg,
isDynamic: isDyn,
isWildcard: isWC,
}
current.children = append(current.children, child)
}
if child.maxParams < paramCount {
child.maxParams = paramCount
}
current = child
pos = newPos
}
current.bytecode = bytecode
current.scriptPath = scriptPath
return nil
}
// readSegment extracts the next path segment
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
if start >= len(path) {
return "", start, false
}
if path[start] == '/' {
start++
}
if start >= len(path) {
return "", start, false
}
end = start
for end < len(path) && path[end] != '/' {
end++
}
return path[start:end], end, end < len(path)
}
// Lookup finds bytecode and parameters for a method and path
func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) {
root := r.methodNode(method)
if root == nil {
return nil, nil, false
}
if path == "/" {
if root.bytecode != nil {
return root.bytecode, &Params{}, true
}
return nil, nil, false
}
// Prepare params buffer
buffer := r.paramsBuffer
if cap(buffer) < int(root.maxParams) {
buffer = make([]string, root.maxParams)
r.paramsBuffer = buffer
}
buffer = buffer[:0]
var keys []string
bytecode, paramCount, found := r.match(root, path, 0, &buffer, &keys)
if !found {
return nil, nil, false
}
params := &Params{
Keys: keys[:paramCount],
Values: buffer[:paramCount],
}
return bytecode, params, true
}
// match traverses the trie to find bytecode
func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
paramCount := 0
// Check wildcard first
for _, c := range current.children {
if c.isWildcard {
rem := path[start:]
if len(rem) > 0 && rem[0] == '/' {
rem = rem[1:]
}
*params = append(*params, rem)
*keys = append(*keys, strings.TrimPrefix(c.segment, "*"))
return c.bytecode, 1, c.bytecode != nil
}
}
seg, pos, more := readSegment(path, start)
if seg == "" {
return current.bytecode, 0, current.bytecode != nil
}
for _, c := range current.children {
if c.segment == seg || c.isDynamic {
if c.isDynamic {
*params = append(*params, seg)
paramName := c.segment[1 : len(c.segment)-1] // Remove [ ]
*keys = append(*keys, paramName)
paramCount++
}
if !more {
return c.bytecode, paramCount, c.bytecode != nil
}
bytecode, nestedCount, ok := r.match(c, path, pos, params, keys)
if ok {
return bytecode, paramCount + nestedCount, true
}
// Backtrack on failure
if c.isDynamic {
*params = (*params)[:len(*params)-1]
*keys = (*keys)[:len(*keys)-1]
}
}
}
return nil, 0, false
}
// GetBytecode gets cached bytecode by script path
func (r *Router) GetBytecode(scriptPath string) []byte {
cacheKey := hashString(scriptPath)
return r.bytecodeCache.Get(nil, uint64ToBytes(cacheKey))
}
// Refresh rebuilds the router
func (r *Router) Refresh() error {
r.get = &node{}
r.post = &node{}
r.put = &node{}
r.patch = &node{}
r.delete = &node{}
r.middlewareFiles = make(map[string][]string)
r.bytecodeCache.Reset()
return r.buildRoutes()
}
// Close cleans up resources
func (r *Router) Close() {
r.compileMu.Lock()
if r.compileState != nil {
r.compileState.Close()
r.compileState = nil
}
r.compileMu.Unlock()
}
// Helper functions from cache.go
func hashString(s string) uint64 {
h := uint64(5381)
for i := 0; i < len(s); i++ {
h = ((h << 5) + h) + uint64(s[i])
}
return h
}
func uint64ToBytes(n uint64) []byte {
b := make([]byte, 8)
b[0] = byte(n)
b[1] = byte(n >> 8)
b[2] = byte(n >> 16)
b[3] = byte(n >> 24)
b[4] = byte(n >> 32)
b[5] = byte(n >> 40)
b[6] = byte(n >> 48)
b[7] = byte(n >> 56)
return b
}