363 lines
8.0 KiB
Go
363 lines
8.0 KiB
Go
package runner
|
|
|
|
import (
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// Common errors
|
|
var (
|
|
ErrModuleNotFound = errors.New("module not found")
|
|
ErrPathTraversal = errors.New("path traversal not allowed")
|
|
)
|
|
|
|
// ModuleEntry represents a cached module with timestamp
|
|
type ModuleEntry struct {
|
|
Bytecode []byte
|
|
LastUsed time.Time
|
|
}
|
|
|
|
// RequireConfig holds configuration for Lua's require function
|
|
type RequireConfig struct {
|
|
ScriptDir string // Base directory for script being executed
|
|
LibDirs []string // Additional library directories
|
|
}
|
|
|
|
// RequireCache is a thread-safe cache for loaded Lua modules
|
|
type RequireCache struct {
|
|
modules sync.Map // Maps full file paths to ModuleEntry
|
|
mu sync.Mutex
|
|
maxItems int // Maximum number of modules to cache
|
|
lastRefresh time.Time // When we last did a full refresh check
|
|
needsRefresh atomic.Bool // Flag for watchers to signal refresh needed
|
|
}
|
|
|
|
// NewRequireCache creates a new, empty require cache
|
|
func NewRequireCache() *RequireCache {
|
|
cache := &RequireCache{
|
|
modules: sync.Map{},
|
|
maxItems: 100, // Default cache size
|
|
lastRefresh: time.Now(),
|
|
}
|
|
return cache
|
|
}
|
|
|
|
// SetCacheSize adjusts the maximum cache size
|
|
func (c *RequireCache) SetCacheSize(size int) {
|
|
if size > 0 {
|
|
c.mu.Lock()
|
|
c.maxItems = size
|
|
c.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Size returns the approximate number of items in the cache
|
|
func (c *RequireCache) Size() int {
|
|
size := 0
|
|
c.modules.Range(func(_, _ any) bool {
|
|
size++
|
|
return true
|
|
})
|
|
return size
|
|
}
|
|
|
|
// MarkNeedsRefresh signals that modules have changed and need refresh
|
|
func (c *RequireCache) MarkNeedsRefresh() {
|
|
c.needsRefresh.Store(true)
|
|
}
|
|
|
|
// Get retrieves a module from the cache, updating its last used time
|
|
func (c *RequireCache) Get(path string) ([]byte, bool) {
|
|
value, ok := c.modules.Load(path)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
entry, ok := value.(ModuleEntry)
|
|
if !ok {
|
|
// Handle legacy entries (plain bytecode)
|
|
bytecode, ok := value.([]byte)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
// Convert to ModuleEntry and update
|
|
entry = ModuleEntry{
|
|
Bytecode: bytecode,
|
|
LastUsed: time.Now(),
|
|
}
|
|
c.modules.Store(path, entry)
|
|
return bytecode, true
|
|
}
|
|
|
|
// Update last used time
|
|
entry.LastUsed = time.Now()
|
|
c.modules.Store(path, entry)
|
|
|
|
return entry.Bytecode, true
|
|
}
|
|
|
|
// Store adds a module to the cache with LRU eviction
|
|
func (c *RequireCache) Store(path string, bytecode []byte) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
// Check if we need to evict
|
|
if c.Size() >= c.maxItems {
|
|
c.evictOldest()
|
|
}
|
|
|
|
// Store the new entry
|
|
c.modules.Store(path, ModuleEntry{
|
|
Bytecode: bytecode,
|
|
LastUsed: time.Now(),
|
|
})
|
|
}
|
|
|
|
// evictOldest removes the least recently used item from the cache
|
|
func (c *RequireCache) evictOldest() {
|
|
var oldestTime time.Time
|
|
var oldestKey string
|
|
first := true
|
|
|
|
// Find oldest entry
|
|
c.modules.Range(func(key, value any) bool {
|
|
// Handle different value types
|
|
var lastUsed time.Time
|
|
|
|
switch v := value.(type) {
|
|
case ModuleEntry:
|
|
lastUsed = v.LastUsed
|
|
default:
|
|
// For non-ModuleEntry values, treat as oldest
|
|
if first {
|
|
oldestKey = key.(string)
|
|
first = false
|
|
return true
|
|
}
|
|
return true
|
|
}
|
|
|
|
if first || lastUsed.Before(oldestTime) {
|
|
oldestTime = lastUsed
|
|
oldestKey = key.(string)
|
|
first = false
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Remove oldest entry
|
|
if oldestKey != "" {
|
|
c.modules.Delete(oldestKey)
|
|
}
|
|
}
|
|
|
|
// Clear empties the entire cache
|
|
func (c *RequireCache) Clear() {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
// Create a new sync.Map to replace the existing one
|
|
c.modules = sync.Map{}
|
|
}
|
|
|
|
// RefreshModule checks if a specific module needs to be refreshed
|
|
func (c *RequireCache) RefreshModule(path string) bool {
|
|
// Get the cached module
|
|
val, ok := c.modules.Load(path)
|
|
if !ok {
|
|
// Not in cache, nothing to refresh
|
|
return false
|
|
}
|
|
|
|
// Get file info
|
|
fileInfo, err := os.Stat(path)
|
|
if err != nil {
|
|
// File no longer exists or can't be accessed, remove from cache
|
|
c.modules.Delete(path)
|
|
return true
|
|
}
|
|
|
|
// Check if the cached module is up-to-date
|
|
entry, ok := val.(ModuleEntry)
|
|
if !ok {
|
|
// Invalid entry, remove it
|
|
c.modules.Delete(path)
|
|
return true
|
|
}
|
|
|
|
// Check if the file has been modified since it was cached
|
|
if fileInfo.ModTime().After(entry.LastUsed) {
|
|
// File is newer than the cached version, remove from cache
|
|
c.modules.Delete(path)
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// RefreshAll checks all cached modules and refreshes those that have changed
|
|
func (c *RequireCache) RefreshAll() int {
|
|
refreshed := 0
|
|
|
|
// No need to refresh if flag isn't set
|
|
if !c.needsRefresh.Load() {
|
|
return 0
|
|
}
|
|
|
|
// Collect paths to check
|
|
var paths []string
|
|
c.modules.Range(func(key, _ any) bool {
|
|
if path, ok := key.(string); ok {
|
|
paths = append(paths, path)
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Check each path
|
|
for _, path := range paths {
|
|
if c.RefreshModule(path) {
|
|
refreshed++
|
|
}
|
|
}
|
|
|
|
// Reset the needsRefresh flag
|
|
c.needsRefresh.Store(false)
|
|
c.lastRefresh = time.Now()
|
|
|
|
return refreshed
|
|
}
|
|
|
|
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
|
|
func UpdateRequirePaths(config *RequireConfig, scriptPath string) {
|
|
if scriptPath != "" {
|
|
config.ScriptDir = filepath.Dir(scriptPath)
|
|
}
|
|
}
|
|
|
|
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
|
|
func findAndCompileModule(
|
|
state *luajit.State,
|
|
cache *RequireCache,
|
|
config RequireConfig,
|
|
modName string,
|
|
) ([]byte, error) {
|
|
// Convert module name to relative path
|
|
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
|
|
|
|
// List of paths to check
|
|
paths := []string{}
|
|
|
|
// 1. Check adjacent to script directory first
|
|
if config.ScriptDir != "" {
|
|
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
|
|
}
|
|
|
|
// 2. Check in lib directories
|
|
for _, libDir := range config.LibDirs {
|
|
if libDir != "" {
|
|
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
|
|
}
|
|
}
|
|
|
|
// Try each path
|
|
for _, path := range paths {
|
|
// Clean the path to handle .. and such (security)
|
|
cleanPath := filepath.Clean(path)
|
|
|
|
// Check for path traversal (extra safety)
|
|
if !isSubPath(config.ScriptDir, cleanPath) {
|
|
isValidLib := false
|
|
for _, libDir := range config.LibDirs {
|
|
if isSubPath(libDir, cleanPath) {
|
|
isValidLib = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !isValidLib {
|
|
continue // Skip paths outside allowed directories
|
|
}
|
|
}
|
|
|
|
// Check if already in cache
|
|
if value, ok := cache.modules.Load(cleanPath); ok {
|
|
entry, ok := value.(ModuleEntry)
|
|
if !ok {
|
|
// Legacy format, use it anyway
|
|
return value.([]byte), nil
|
|
}
|
|
|
|
// Only do refresh check if marked as needed (by watcher)
|
|
if cache.needsRefresh.Load() {
|
|
fileInfo, err := os.Stat(cleanPath)
|
|
// Remove from cache if file changed or doesn't exist
|
|
if err != nil || (entry.LastUsed.Before(fileInfo.ModTime())) {
|
|
cache.modules.Delete(cleanPath)
|
|
// Continue to recompile
|
|
} else {
|
|
return entry.Bytecode, nil
|
|
}
|
|
} else {
|
|
// No refresh needed, use cached bytecode
|
|
return entry.Bytecode, nil
|
|
}
|
|
}
|
|
|
|
// Check if file exists
|
|
_, err := os.Stat(cleanPath)
|
|
if os.IsNotExist(err) {
|
|
continue
|
|
}
|
|
|
|
// Read and compile the file
|
|
content, err := os.ReadFile(cleanPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Compile to bytecode
|
|
bytecode, err := state.CompileBytecode(string(content), cleanPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Store in cache with current time
|
|
cache.modules.Store(cleanPath, ModuleEntry{
|
|
Bytecode: bytecode,
|
|
LastUsed: time.Now(),
|
|
})
|
|
|
|
return bytecode, nil
|
|
}
|
|
|
|
return nil, ErrModuleNotFound
|
|
}
|
|
|
|
// isSubPath checks if path is contained within base directory
|
|
func isSubPath(baseDir, path string) bool {
|
|
if baseDir == "" {
|
|
return false
|
|
}
|
|
|
|
// Clean and normalize paths
|
|
baseDir = filepath.Clean(baseDir)
|
|
path = filepath.Clean(path)
|
|
|
|
// Get relative path
|
|
rel, err := filepath.Rel(baseDir, path)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if path goes outside baseDir
|
|
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
|
|
}
|