Moonshark/core/runner/require.go
2025-03-21 22:25:05 -05:00

363 lines
8.0 KiB
Go

package runner
import (
"errors"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrModuleNotFound = errors.New("module not found")
ErrPathTraversal = errors.New("path traversal not allowed")
)
// ModuleEntry represents a cached module with timestamp
type ModuleEntry struct {
Bytecode []byte
LastUsed time.Time
}
// RequireConfig holds configuration for Lua's require function
type RequireConfig struct {
ScriptDir string // Base directory for script being executed
LibDirs []string // Additional library directories
}
// RequireCache is a thread-safe cache for loaded Lua modules
type RequireCache struct {
modules sync.Map // Maps full file paths to ModuleEntry
mu sync.Mutex
maxItems int // Maximum number of modules to cache
lastRefresh time.Time // When we last did a full refresh check
needsRefresh atomic.Bool // Flag for watchers to signal refresh needed
}
// NewRequireCache creates a new, empty require cache
func NewRequireCache() *RequireCache {
cache := &RequireCache{
modules: sync.Map{},
maxItems: 100, // Default cache size
lastRefresh: time.Now(),
}
return cache
}
// SetCacheSize adjusts the maximum cache size
func (c *RequireCache) SetCacheSize(size int) {
if size > 0 {
c.mu.Lock()
c.maxItems = size
c.mu.Unlock()
}
}
// Size returns the approximate number of items in the cache
func (c *RequireCache) Size() int {
size := 0
c.modules.Range(func(_, _ any) bool {
size++
return true
})
return size
}
// MarkNeedsRefresh signals that modules have changed and need refresh
func (c *RequireCache) MarkNeedsRefresh() {
c.needsRefresh.Store(true)
}
// Get retrieves a module from the cache, updating its last used time
func (c *RequireCache) Get(path string) ([]byte, bool) {
value, ok := c.modules.Load(path)
if !ok {
return nil, false
}
entry, ok := value.(ModuleEntry)
if !ok {
// Handle legacy entries (plain bytecode)
bytecode, ok := value.([]byte)
if !ok {
return nil, false
}
// Convert to ModuleEntry and update
entry = ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
}
c.modules.Store(path, entry)
return bytecode, true
}
// Update last used time
entry.LastUsed = time.Now()
c.modules.Store(path, entry)
return entry.Bytecode, true
}
// Store adds a module to the cache with LRU eviction
func (c *RequireCache) Store(path string, bytecode []byte) {
c.mu.Lock()
defer c.mu.Unlock()
// Check if we need to evict
if c.Size() >= c.maxItems {
c.evictOldest()
}
// Store the new entry
c.modules.Store(path, ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
})
}
// evictOldest removes the least recently used item from the cache
func (c *RequireCache) evictOldest() {
var oldestTime time.Time
var oldestKey string
first := true
// Find oldest entry
c.modules.Range(func(key, value any) bool {
// Handle different value types
var lastUsed time.Time
switch v := value.(type) {
case ModuleEntry:
lastUsed = v.LastUsed
default:
// For non-ModuleEntry values, treat as oldest
if first {
oldestKey = key.(string)
first = false
return true
}
return true
}
if first || lastUsed.Before(oldestTime) {
oldestTime = lastUsed
oldestKey = key.(string)
first = false
}
return true
})
// Remove oldest entry
if oldestKey != "" {
c.modules.Delete(oldestKey)
}
}
// Clear empties the entire cache
func (c *RequireCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
// Create a new sync.Map to replace the existing one
c.modules = sync.Map{}
}
// RefreshModule checks if a specific module needs to be refreshed
func (c *RequireCache) RefreshModule(path string) bool {
// Get the cached module
val, ok := c.modules.Load(path)
if !ok {
// Not in cache, nothing to refresh
return false
}
// Get file info
fileInfo, err := os.Stat(path)
if err != nil {
// File no longer exists or can't be accessed, remove from cache
c.modules.Delete(path)
return true
}
// Check if the cached module is up-to-date
entry, ok := val.(ModuleEntry)
if !ok {
// Invalid entry, remove it
c.modules.Delete(path)
return true
}
// Check if the file has been modified since it was cached
if fileInfo.ModTime().After(entry.LastUsed) {
// File is newer than the cached version, remove from cache
c.modules.Delete(path)
return true
}
return false
}
// RefreshAll checks all cached modules and refreshes those that have changed
func (c *RequireCache) RefreshAll() int {
refreshed := 0
// No need to refresh if flag isn't set
if !c.needsRefresh.Load() {
return 0
}
// Collect paths to check
var paths []string
c.modules.Range(func(key, _ any) bool {
if path, ok := key.(string); ok {
paths = append(paths, path)
}
return true
})
// Check each path
for _, path := range paths {
if c.RefreshModule(path) {
refreshed++
}
}
// Reset the needsRefresh flag
c.needsRefresh.Store(false)
c.lastRefresh = time.Now()
return refreshed
}
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
func UpdateRequirePaths(config *RequireConfig, scriptPath string) {
if scriptPath != "" {
config.ScriptDir = filepath.Dir(scriptPath)
}
}
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
func findAndCompileModule(
state *luajit.State,
cache *RequireCache,
config RequireConfig,
modName string,
) ([]byte, error) {
// Convert module name to relative path
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
// List of paths to check
paths := []string{}
// 1. Check adjacent to script directory first
if config.ScriptDir != "" {
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
}
// 2. Check in lib directories
for _, libDir := range config.LibDirs {
if libDir != "" {
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
}
}
// Try each path
for _, path := range paths {
// Clean the path to handle .. and such (security)
cleanPath := filepath.Clean(path)
// Check for path traversal (extra safety)
if !isSubPath(config.ScriptDir, cleanPath) {
isValidLib := false
for _, libDir := range config.LibDirs {
if isSubPath(libDir, cleanPath) {
isValidLib = true
break
}
}
if !isValidLib {
continue // Skip paths outside allowed directories
}
}
// Check if already in cache
if value, ok := cache.modules.Load(cleanPath); ok {
entry, ok := value.(ModuleEntry)
if !ok {
// Legacy format, use it anyway
return value.([]byte), nil
}
// Only do refresh check if marked as needed (by watcher)
if cache.needsRefresh.Load() {
fileInfo, err := os.Stat(cleanPath)
// Remove from cache if file changed or doesn't exist
if err != nil || (entry.LastUsed.Before(fileInfo.ModTime())) {
cache.modules.Delete(cleanPath)
// Continue to recompile
} else {
return entry.Bytecode, nil
}
} else {
// No refresh needed, use cached bytecode
return entry.Bytecode, nil
}
}
// Check if file exists
_, err := os.Stat(cleanPath)
if os.IsNotExist(err) {
continue
}
// Read and compile the file
content, err := os.ReadFile(cleanPath)
if err != nil {
return nil, err
}
// Compile to bytecode
bytecode, err := state.CompileBytecode(string(content), cleanPath)
if err != nil {
return nil, err
}
// Store in cache with current time
cache.modules.Store(cleanPath, ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
})
return bytecode, nil
}
return nil, ErrModuleNotFound
}
// isSubPath checks if path is contained within base directory
func isSubPath(baseDir, path string) bool {
if baseDir == "" {
return false
}
// Clean and normalize paths
baseDir = filepath.Clean(baseDir)
path = filepath.Clean(path)
// Get relative path
rel, err := filepath.Rel(baseDir, path)
if err != nil {
return false
}
// Check if path goes outside baseDir
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
}