package runner import ( "errors" "os" "path/filepath" "strings" "sync" "sync/atomic" "time" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // Common errors var ( ErrModuleNotFound = errors.New("module not found") ErrPathTraversal = errors.New("path traversal not allowed") ) // ModuleEntry represents a cached module with timestamp type ModuleEntry struct { Bytecode []byte LastUsed time.Time } // RequireConfig holds configuration for Lua's require function type RequireConfig struct { ScriptDir string // Base directory for script being executed LibDirs []string // Additional library directories } // RequireCache is a thread-safe cache for loaded Lua modules type RequireCache struct { modules sync.Map // Maps full file paths to ModuleEntry mu sync.Mutex maxItems int // Maximum number of modules to cache lastRefresh time.Time // When we last did a full refresh check needsRefresh atomic.Bool // Flag for watchers to signal refresh needed } // NewRequireCache creates a new, empty require cache func NewRequireCache() *RequireCache { cache := &RequireCache{ modules: sync.Map{}, maxItems: 100, // Default cache size lastRefresh: time.Now(), } return cache } // SetCacheSize adjusts the maximum cache size func (c *RequireCache) SetCacheSize(size int) { if size > 0 { c.mu.Lock() c.maxItems = size c.mu.Unlock() } } // Size returns the approximate number of items in the cache func (c *RequireCache) Size() int { size := 0 c.modules.Range(func(_, _ any) bool { size++ return true }) return size } // MarkNeedsRefresh signals that modules have changed and need refresh func (c *RequireCache) MarkNeedsRefresh() { c.needsRefresh.Store(true) } // Get retrieves a module from the cache, updating its last used time func (c *RequireCache) Get(path string) ([]byte, bool) { value, ok := c.modules.Load(path) if !ok { return nil, false } entry, ok := value.(ModuleEntry) if !ok { // Handle legacy entries (plain bytecode) bytecode, ok := value.([]byte) if !ok { return nil, false } // Convert to ModuleEntry and update entry = ModuleEntry{ Bytecode: bytecode, LastUsed: time.Now(), } c.modules.Store(path, entry) return bytecode, true } // Update last used time entry.LastUsed = time.Now() c.modules.Store(path, entry) return entry.Bytecode, true } // Store adds a module to the cache with LRU eviction func (c *RequireCache) Store(path string, bytecode []byte) { c.mu.Lock() defer c.mu.Unlock() // Check if we need to evict if c.Size() >= c.maxItems { c.evictOldest() } // Store the new entry c.modules.Store(path, ModuleEntry{ Bytecode: bytecode, LastUsed: time.Now(), }) } // evictOldest removes the least recently used item from the cache func (c *RequireCache) evictOldest() { var oldestTime time.Time var oldestKey string first := true // Find oldest entry c.modules.Range(func(key, value any) bool { // Handle different value types var lastUsed time.Time switch v := value.(type) { case ModuleEntry: lastUsed = v.LastUsed default: // For non-ModuleEntry values, treat as oldest if first { oldestKey = key.(string) first = false return true } return true } if first || lastUsed.Before(oldestTime) { oldestTime = lastUsed oldestKey = key.(string) first = false } return true }) // Remove oldest entry if oldestKey != "" { c.modules.Delete(oldestKey) } } // Clear empties the entire cache func (c *RequireCache) Clear() { c.mu.Lock() defer c.mu.Unlock() // Create a new sync.Map to replace the existing one c.modules = sync.Map{} } // RefreshModule checks if a specific module needs to be refreshed func (c *RequireCache) RefreshModule(path string) bool { // Get the cached module val, ok := c.modules.Load(path) if !ok { // Not in cache, nothing to refresh return false } // Get file info fileInfo, err := os.Stat(path) if err != nil { // File no longer exists or can't be accessed, remove from cache c.modules.Delete(path) return true } // Check if the cached module is up-to-date entry, ok := val.(ModuleEntry) if !ok { // Invalid entry, remove it c.modules.Delete(path) return true } // Check if the file has been modified since it was cached if fileInfo.ModTime().After(entry.LastUsed) { // File is newer than the cached version, remove from cache c.modules.Delete(path) return true } return false } // RefreshAll checks all cached modules and refreshes those that have changed func (c *RequireCache) RefreshAll() int { refreshed := 0 // No need to refresh if flag isn't set if !c.needsRefresh.Load() { return 0 } // Collect paths to check var paths []string c.modules.Range(func(key, _ any) bool { if path, ok := key.(string); ok { paths = append(paths, path) } return true }) // Check each path for _, path := range paths { if c.RefreshModule(path) { refreshed++ } } // Reset the needsRefresh flag c.needsRefresh.Store(false) c.lastRefresh = time.Now() return refreshed } // UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader. func UpdateRequirePaths(config *RequireConfig, scriptPath string) { if scriptPath != "" { config.ScriptDir = filepath.Dir(scriptPath) } } // findAndCompileModule finds a module in allowed directories and compiles it to bytecode func findAndCompileModule( state *luajit.State, cache *RequireCache, config RequireConfig, modName string, ) ([]byte, error) { // Convert module name to relative path modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator)) // List of paths to check paths := []string{} // 1. Check adjacent to script directory first if config.ScriptDir != "" { paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua")) } // 2. Check in lib directories for _, libDir := range config.LibDirs { if libDir != "" { paths = append(paths, filepath.Join(libDir, modPath+".lua")) } } // Try each path for _, path := range paths { // Clean the path to handle .. and such (security) cleanPath := filepath.Clean(path) // Check for path traversal (extra safety) if !isSubPath(config.ScriptDir, cleanPath) { isValidLib := false for _, libDir := range config.LibDirs { if isSubPath(libDir, cleanPath) { isValidLib = true break } } if !isValidLib { continue // Skip paths outside allowed directories } } // Check if already in cache if value, ok := cache.modules.Load(cleanPath); ok { entry, ok := value.(ModuleEntry) if !ok { // Legacy format, use it anyway return value.([]byte), nil } // Only do refresh check if marked as needed (by watcher) if cache.needsRefresh.Load() { fileInfo, err := os.Stat(cleanPath) // Remove from cache if file changed or doesn't exist if err != nil || (entry.LastUsed.Before(fileInfo.ModTime())) { cache.modules.Delete(cleanPath) // Continue to recompile } else { return entry.Bytecode, nil } } else { // No refresh needed, use cached bytecode return entry.Bytecode, nil } } // Check if file exists _, err := os.Stat(cleanPath) if os.IsNotExist(err) { continue } // Read and compile the file content, err := os.ReadFile(cleanPath) if err != nil { return nil, err } // Compile to bytecode bytecode, err := state.CompileBytecode(string(content), cleanPath) if err != nil { return nil, err } // Store in cache with current time cache.modules.Store(cleanPath, ModuleEntry{ Bytecode: bytecode, LastUsed: time.Now(), }) return bytecode, nil } return nil, ErrModuleNotFound } // isSubPath checks if path is contained within base directory func isSubPath(baseDir, path string) bool { if baseDir == "" { return false } // Clean and normalize paths baseDir = filepath.Clean(baseDir) path = filepath.Clean(path) // Get relative path rel, err := filepath.Rel(baseDir, path) if err != nil { return false } // Check if path goes outside baseDir return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".." }