255 lines
5.6 KiB
Go
255 lines
5.6 KiB
Go
package runner
|
|
|
|
import (
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// Common errors
|
|
var (
|
|
ErrModuleNotFound = errors.New("module not found")
|
|
ErrPathTraversal = errors.New("path traversal not allowed")
|
|
)
|
|
|
|
// ModuleEntry represents a cached module with timestamp
|
|
type ModuleEntry struct {
|
|
Bytecode []byte
|
|
LastUsed time.Time
|
|
}
|
|
|
|
// RequireConfig holds configuration for Lua's require function
|
|
type RequireConfig struct {
|
|
ScriptDir string // Base directory for script being executed
|
|
LibDirs []string // Additional library directories
|
|
}
|
|
|
|
// RequireCache is a thread-safe cache for loaded Lua modules
|
|
type RequireCache struct {
|
|
modules sync.Map // Maps full file paths to ModuleEntry
|
|
mu sync.Mutex
|
|
maxItems int // Maximum number of modules to cache
|
|
}
|
|
|
|
// NewRequireCache creates a new, empty require cache
|
|
func NewRequireCache() *RequireCache {
|
|
return &RequireCache{
|
|
modules: sync.Map{},
|
|
maxItems: 100, // Default cache size - can be adjusted based on expected module load
|
|
}
|
|
}
|
|
|
|
// SetCacheSize adjusts the maximum cache size
|
|
func (c *RequireCache) SetCacheSize(size int) {
|
|
if size > 0 {
|
|
c.mu.Lock()
|
|
c.maxItems = size
|
|
c.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Size returns the approximate number of items in the cache
|
|
func (c *RequireCache) Size() int {
|
|
size := 0
|
|
c.modules.Range(func(_, _ interface{}) bool {
|
|
size++
|
|
return true
|
|
})
|
|
return size
|
|
}
|
|
|
|
// Get retrieves a module from the cache, updating its last used time
|
|
func (c *RequireCache) Get(path string) ([]byte, bool) {
|
|
value, ok := c.modules.Load(path)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
entry, ok := value.(ModuleEntry)
|
|
if !ok {
|
|
// Handle legacy entries (plain bytecode)
|
|
bytecode, ok := value.([]byte)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
// Convert to ModuleEntry and update
|
|
entry = ModuleEntry{
|
|
Bytecode: bytecode,
|
|
LastUsed: time.Now(),
|
|
}
|
|
c.modules.Store(path, entry)
|
|
return bytecode, true
|
|
}
|
|
|
|
// Update last used time
|
|
entry.LastUsed = time.Now()
|
|
c.modules.Store(path, entry)
|
|
|
|
return entry.Bytecode, true
|
|
}
|
|
|
|
// Store adds a module to the cache with LRU eviction
|
|
func (c *RequireCache) Store(path string, bytecode []byte) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
// Check if we need to evict
|
|
if c.Size() >= c.maxItems {
|
|
c.evictOldest()
|
|
}
|
|
|
|
// Store the new entry
|
|
c.modules.Store(path, ModuleEntry{
|
|
Bytecode: bytecode,
|
|
LastUsed: time.Now(),
|
|
})
|
|
}
|
|
|
|
// evictOldest removes the least recently used item from the cache
|
|
func (c *RequireCache) evictOldest() {
|
|
var oldestTime time.Time
|
|
var oldestKey string
|
|
first := true
|
|
|
|
// Find oldest entry
|
|
c.modules.Range(func(key, value interface{}) bool {
|
|
// Handle different value types
|
|
var lastUsed time.Time
|
|
|
|
switch v := value.(type) {
|
|
case ModuleEntry:
|
|
lastUsed = v.LastUsed
|
|
default:
|
|
// For non-ModuleEntry values, treat as oldest
|
|
if first {
|
|
oldestKey = key.(string)
|
|
first = false
|
|
return true
|
|
}
|
|
return true
|
|
}
|
|
|
|
if first || lastUsed.Before(oldestTime) {
|
|
oldestTime = lastUsed
|
|
oldestKey = key.(string)
|
|
first = false
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Remove oldest entry
|
|
if oldestKey != "" {
|
|
c.modules.Delete(oldestKey)
|
|
}
|
|
}
|
|
|
|
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
|
|
func UpdateRequirePaths(config *RequireConfig, scriptPath string) {
|
|
if scriptPath != "" {
|
|
config.ScriptDir = filepath.Dir(scriptPath)
|
|
}
|
|
}
|
|
|
|
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
|
|
func findAndCompileModule(
|
|
state *luajit.State,
|
|
cache *RequireCache,
|
|
config RequireConfig,
|
|
modName string,
|
|
) ([]byte, error) {
|
|
// Convert module name to relative path
|
|
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
|
|
|
|
// List of paths to check
|
|
paths := []string{}
|
|
|
|
// 1. Check adjacent to script directory first
|
|
if config.ScriptDir != "" {
|
|
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
|
|
}
|
|
|
|
// 2. Check in lib directories
|
|
for _, libDir := range config.LibDirs {
|
|
if libDir != "" {
|
|
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
|
|
}
|
|
}
|
|
|
|
// Try each path
|
|
for _, path := range paths {
|
|
// Clean the path to handle .. and such (security)
|
|
cleanPath := filepath.Clean(path)
|
|
|
|
// Check for path traversal (extra safety)
|
|
if !isSubPath(config.ScriptDir, cleanPath) {
|
|
isValidLib := false
|
|
for _, libDir := range config.LibDirs {
|
|
if isSubPath(libDir, cleanPath) {
|
|
isValidLib = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !isValidLib {
|
|
continue // Skip paths outside allowed directories
|
|
}
|
|
}
|
|
|
|
// Check if already in cache - using our Get method to update LRU info
|
|
if bytecode, ok := cache.Get(cleanPath); ok {
|
|
return bytecode, nil
|
|
}
|
|
|
|
// Check if file exists
|
|
_, err := os.Stat(cleanPath)
|
|
if os.IsNotExist(err) {
|
|
continue
|
|
}
|
|
|
|
// Read and compile the file
|
|
content, err := os.ReadFile(cleanPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Compile to bytecode
|
|
bytecode, err := state.CompileBytecode(string(content), cleanPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Store in cache - using our Store method with LRU eviction
|
|
cache.Store(cleanPath, bytecode)
|
|
|
|
return bytecode, nil
|
|
}
|
|
|
|
return nil, ErrModuleNotFound
|
|
}
|
|
|
|
// isSubPath checks if path is contained within base directory
|
|
func isSubPath(baseDir, path string) bool {
|
|
if baseDir == "" {
|
|
return false
|
|
}
|
|
|
|
// Clean and normalize paths
|
|
baseDir = filepath.Clean(baseDir)
|
|
path = filepath.Clean(path)
|
|
|
|
// Get relative path
|
|
rel, err := filepath.Rel(baseDir, path)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if path goes outside baseDir
|
|
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
|
|
}
|