Moonshark/router/router.go
2025-06-06 18:57:47 -05:00

500 lines
11 KiB
Go

package router
import (
"bytes"
"errors"
"os"
"path/filepath"
"strings"
"sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/VictoriaMetrics/fastcache"
)
var (
slash = []byte("/")
get = []byte("GET")
post = []byte("POST")
put = []byte("PUT")
patch = []byte("PATCH")
delete = []byte("DELETE")
)
// node represents a node in the radix trie
type node struct {
segment []byte
bytecode []byte
scriptPath string
children []*node
isDynamic bool
isWildcard bool
maxParams uint8
}
// Router is a filesystem-based HTTP router for Lua files with bytecode caching
type Router struct {
routesDir string
get, post, put, patch, delete *node
bytecodeCache *fastcache.Cache
compileState *luajit.State
compileMu sync.Mutex
paramsBuffer []string
middlewareFiles map[string][]string
}
// Params holds URL parameters
type Params struct {
Keys []string
Values []string
}
// Get returns a parameter value by name
func (p *Params) Get(name string) string {
for i, key := range p.Keys {
if key == name && i < len(p.Values) {
return p.Values[i]
}
}
return ""
}
// New creates a new Router instance
func New(routesDir string) (*Router, error) {
info, err := os.Stat(routesDir)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, errors.New("routes path is not a directory")
}
compileState := luajit.New()
if compileState == nil {
return nil, errors.New("failed to create Lua compile state")
}
r := &Router{
routesDir: routesDir,
get: &node{},
post: &node{},
put: &node{},
patch: &node{},
delete: &node{},
bytecodeCache: fastcache.New(32 * 1024 * 1024),
compileState: compileState,
paramsBuffer: make([]string, 64),
middlewareFiles: make(map[string][]string),
}
return r, r.buildRoutes()
}
// methodNode returns the root node for a method
func (r *Router) methodNode(method []byte) *node {
switch {
case bytes.Equal(method, get):
return r.get
case bytes.Equal(method, post):
return r.post
case bytes.Equal(method, put):
return r.put
case bytes.Equal(method, patch):
return r.patch
case bytes.Equal(method, delete):
return r.delete
default:
return nil
}
}
// buildRoutes scans the routes directory and builds the routing tree
func (r *Router) buildRoutes() error {
r.middlewareFiles = make(map[string][]string)
// First pass: collect all middleware files
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
fileName := strings.TrimSuffix(info.Name(), ".lua")
if fileName == "middleware" {
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
if err != nil {
return err
}
fsPath := "/"
if relDir != "." {
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
}
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
}
return nil
})
if err != nil {
return err
}
// Second pass: build routes
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
fileName := strings.TrimSuffix(info.Name(), ".lua")
if fileName == "middleware" {
return nil
}
relPath, err := filepath.Rel(r.routesDir, path)
if err != nil {
return err
}
fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/")
if fsPath == "/." {
fsPath = "/"
}
urlPath := r.parseURLPath(fsPath)
urlPathBytes := []byte(urlPath)
// Handle method files
methodBytes := []byte(strings.ToUpper(fileName))
root := r.methodNode(methodBytes)
if root != nil {
return r.addRoute(root, urlPathBytes, fsPath, path)
}
// Handle index files
if fileName == "index" {
methods := [][]byte{get, post, put, patch, delete}
for _, method := range methods {
if root := r.methodNode(method); root != nil {
if err := r.addRoute(root, urlPathBytes, fsPath, path); err != nil {
return err
}
}
}
return nil
}
// Handle named route files
var namedPath []byte
if urlPath == "/" {
namedPath = append(slash, fileName...)
} else {
namedPath = append(urlPathBytes, '/')
namedPath = append(namedPath, fileName...)
}
return r.addRoute(r.get, namedPath, fsPath, path)
})
}
// parseURLPath strips group segments from filesystem path
func (r *Router) parseURLPath(fsPath string) string {
segments := strings.Split(strings.Trim(fsPath, "/"), "/")
var urlSegments []string
for _, segment := range segments {
if segment == "" {
continue
}
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
continue
}
urlSegments = append(urlSegments, segment)
}
if len(urlSegments) == 0 {
return "/"
}
return "/" + strings.Join(urlSegments, "/")
}
// getMiddlewareChain returns middleware files that apply to the given filesystem path
func (r *Router) getMiddlewareChain(fsPath string) []string {
var chain []string
pathParts := strings.Split(strings.Trim(fsPath, "/"), "/")
if pathParts[0] == "" {
pathParts = []string{}
}
if mw, exists := r.middlewareFiles["/"]; exists {
chain = append(chain, mw...)
}
currentPath := ""
for _, part := range pathParts {
currentPath += "/" + part
if mw, exists := r.middlewareFiles[currentPath]; exists {
chain = append(chain, mw...)
}
}
return chain
}
// buildCombinedSource combines middleware and handler source
func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) {
var combined strings.Builder
middlewareChain := r.getMiddlewareChain(fsPath)
for _, mwPath := range middlewareChain {
content, err := os.ReadFile(mwPath)
if err != nil {
return "", err
}
combined.WriteString("-- Middleware: ")
combined.WriteString(mwPath)
combined.WriteString("\n")
combined.Write(content)
combined.WriteString("\n")
}
content, err := os.ReadFile(scriptPath)
if err != nil {
return "", err
}
combined.WriteString("-- Handler: ")
combined.WriteString(scriptPath)
combined.WriteString("\n")
combined.Write(content)
return combined.String(), nil
}
// addRoute adds a new route to the trie with bytecode compilation
func (r *Router) addRoute(root *node, urlPath []byte, fsPath, scriptPath string) error {
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
if err != nil {
return err
}
r.compileMu.Lock()
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
r.compileMu.Unlock()
if err != nil {
return err
}
cacheKey := hashString(scriptPath)
r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode)
if len(urlPath) == 1 && urlPath[0] == '/' {
root.bytecode = bytecode
root.scriptPath = scriptPath
return nil
}
current := root
pos := 0
paramCount := uint8(0)
for {
seg, newPos, more := readSegmentBytes(urlPath, pos)
if len(seg) == 0 {
break
}
isDyn := len(seg) > 2 && seg[0] == '[' && seg[len(seg)-1] == ']'
isWC := len(seg) > 0 && seg[0] == '*'
if isWC && more {
return errors.New("wildcard must be the last segment")
}
if isDyn || isWC {
paramCount++
}
var child *node
for _, c := range current.children {
if bytes.Equal(c.segment, seg) {
child = c
break
}
}
if child == nil {
child = &node{
segment: append([]byte(nil), seg...),
isDynamic: isDyn,
isWildcard: isWC,
}
current.children = append(current.children, child)
}
if child.maxParams < paramCount {
child.maxParams = paramCount
}
current = child
pos = newPos
}
current.bytecode = bytecode
current.scriptPath = scriptPath
return nil
}
// readSegmentBytes extracts the next path segment from byte slice
func readSegmentBytes(path []byte, start int) (segment []byte, end int, hasMore bool) {
if start >= len(path) {
return nil, start, false
}
if path[start] == '/' {
start++
}
if start >= len(path) {
return nil, start, false
}
end = start
for end < len(path) && path[end] != '/' {
end++
}
return path[start:end], end, end < len(path)
}
// Lookup finds bytecode and parameters for a method and path
func (r *Router) Lookup(method, path []byte) ([]byte, *Params, bool) {
root := r.methodNode(method)
if root == nil {
return nil, nil, false
}
if len(path) == 1 && path[0] == '/' {
if root.bytecode != nil {
return root.bytecode, &Params{}, true
}
return nil, nil, false
}
buffer := r.paramsBuffer
if cap(buffer) < int(root.maxParams) {
buffer = make([]string, root.maxParams)
r.paramsBuffer = buffer
}
buffer = buffer[:0]
var keys []string
bytecode, paramCount, found := r.match(root, path, 0, &buffer, &keys)
if !found {
return nil, nil, false
}
params := &Params{
Keys: keys[:paramCount],
Values: buffer[:paramCount],
}
return bytecode, params, true
}
// match traverses the trie to find bytecode
func (r *Router) match(current *node, path []byte, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
paramCount := 0
// Check wildcard first
for _, c := range current.children {
if c.isWildcard {
rem := path[start:]
if len(rem) > 0 && rem[0] == '/' {
rem = rem[1:]
}
*params = append(*params, string(rem))
paramName := string(c.segment[1:]) // Remove *
*keys = append(*keys, paramName)
return c.bytecode, 1, c.bytecode != nil
}
}
seg, pos, more := readSegmentBytes(path, start)
if len(seg) == 0 {
return current.bytecode, 0, current.bytecode != nil
}
for _, c := range current.children {
if bytes.Equal(c.segment, seg) || c.isDynamic {
if c.isDynamic {
*params = append(*params, string(seg))
paramName := string(c.segment[1 : len(c.segment)-1]) // Remove [ ]
*keys = append(*keys, paramName)
paramCount++
}
if !more {
return c.bytecode, paramCount, c.bytecode != nil
}
bytecode, nestedCount, ok := r.match(c, path, pos, params, keys)
if ok {
return bytecode, paramCount + nestedCount, true
}
// Backtrack on failure
if c.isDynamic {
*params = (*params)[:len(*params)-1]
*keys = (*keys)[:len(*keys)-1]
}
}
}
return nil, 0, false
}
// GetBytecode gets cached bytecode by script path
func (r *Router) GetBytecode(scriptPath string) []byte {
cacheKey := hashString(scriptPath)
return r.bytecodeCache.Get(nil, uint64ToBytes(cacheKey))
}
// Refresh rebuilds the router
func (r *Router) Refresh() error {
r.get = &node{}
r.post = &node{}
r.put = &node{}
r.patch = &node{}
r.delete = &node{}
r.middlewareFiles = make(map[string][]string)
r.bytecodeCache.Reset()
return r.buildRoutes()
}
// Close cleans up resources
func (r *Router) Close() {
r.compileMu.Lock()
if r.compileState != nil {
r.compileState.Close()
r.compileState = nil
}
r.compileMu.Unlock()
}
func hashString(s string) uint64 {
h := uint64(5381)
for i := 0; i < len(s); i++ {
h = ((h << 5) + h) + uint64(s[i])
}
return h
}
func uint64ToBytes(n uint64) []byte {
b := make([]byte, 8)
b[0] = byte(n)
b[1] = byte(n >> 8)
b[2] = byte(n >> 16)
b[3] = byte(n >> 24)
b[4] = byte(n >> 32)
b[5] = byte(n >> 40)
b[6] = byte(n >> 48)
b[7] = byte(n >> 56)
return b
}