Moonshark/core/runner/require.go

255 lines
5.6 KiB
Go

package runner
import (
"errors"
"os"
"path/filepath"
"strings"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrModuleNotFound = errors.New("module not found")
ErrPathTraversal = errors.New("path traversal not allowed")
)
// ModuleEntry represents a cached module with timestamp
type ModuleEntry struct {
Bytecode []byte
LastUsed time.Time
}
// RequireConfig holds configuration for Lua's require function
type RequireConfig struct {
ScriptDir string // Base directory for script being executed
LibDirs []string // Additional library directories
}
// RequireCache is a thread-safe cache for loaded Lua modules
type RequireCache struct {
modules sync.Map // Maps full file paths to ModuleEntry
mu sync.Mutex
maxItems int // Maximum number of modules to cache
}
// NewRequireCache creates a new, empty require cache
func NewRequireCache() *RequireCache {
return &RequireCache{
modules: sync.Map{},
maxItems: 100, // Default cache size - can be adjusted based on expected module load
}
}
// SetCacheSize adjusts the maximum cache size
func (c *RequireCache) SetCacheSize(size int) {
if size > 0 {
c.mu.Lock()
c.maxItems = size
c.mu.Unlock()
}
}
// Size returns the approximate number of items in the cache
func (c *RequireCache) Size() int {
size := 0
c.modules.Range(func(_, _ interface{}) bool {
size++
return true
})
return size
}
// Get retrieves a module from the cache, updating its last used time
func (c *RequireCache) Get(path string) ([]byte, bool) {
value, ok := c.modules.Load(path)
if !ok {
return nil, false
}
entry, ok := value.(ModuleEntry)
if !ok {
// Handle legacy entries (plain bytecode)
bytecode, ok := value.([]byte)
if !ok {
return nil, false
}
// Convert to ModuleEntry and update
entry = ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
}
c.modules.Store(path, entry)
return bytecode, true
}
// Update last used time
entry.LastUsed = time.Now()
c.modules.Store(path, entry)
return entry.Bytecode, true
}
// Store adds a module to the cache with LRU eviction
func (c *RequireCache) Store(path string, bytecode []byte) {
c.mu.Lock()
defer c.mu.Unlock()
// Check if we need to evict
if c.Size() >= c.maxItems {
c.evictOldest()
}
// Store the new entry
c.modules.Store(path, ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
})
}
// evictOldest removes the least recently used item from the cache
func (c *RequireCache) evictOldest() {
var oldestTime time.Time
var oldestKey string
first := true
// Find oldest entry
c.modules.Range(func(key, value interface{}) bool {
// Handle different value types
var lastUsed time.Time
switch v := value.(type) {
case ModuleEntry:
lastUsed = v.LastUsed
default:
// For non-ModuleEntry values, treat as oldest
if first {
oldestKey = key.(string)
first = false
return true
}
return true
}
if first || lastUsed.Before(oldestTime) {
oldestTime = lastUsed
oldestKey = key.(string)
first = false
}
return true
})
// Remove oldest entry
if oldestKey != "" {
c.modules.Delete(oldestKey)
}
}
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
func UpdateRequirePaths(config *RequireConfig, scriptPath string) {
if scriptPath != "" {
config.ScriptDir = filepath.Dir(scriptPath)
}
}
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
func findAndCompileModule(
state *luajit.State,
cache *RequireCache,
config RequireConfig,
modName string,
) ([]byte, error) {
// Convert module name to relative path
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
// List of paths to check
paths := []string{}
// 1. Check adjacent to script directory first
if config.ScriptDir != "" {
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
}
// 2. Check in lib directories
for _, libDir := range config.LibDirs {
if libDir != "" {
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
}
}
// Try each path
for _, path := range paths {
// Clean the path to handle .. and such (security)
cleanPath := filepath.Clean(path)
// Check for path traversal (extra safety)
if !isSubPath(config.ScriptDir, cleanPath) {
isValidLib := false
for _, libDir := range config.LibDirs {
if isSubPath(libDir, cleanPath) {
isValidLib = true
break
}
}
if !isValidLib {
continue // Skip paths outside allowed directories
}
}
// Check if already in cache - using our Get method to update LRU info
if bytecode, ok := cache.Get(cleanPath); ok {
return bytecode, nil
}
// Check if file exists
_, err := os.Stat(cleanPath)
if os.IsNotExist(err) {
continue
}
// Read and compile the file
content, err := os.ReadFile(cleanPath)
if err != nil {
return nil, err
}
// Compile to bytecode
bytecode, err := state.CompileBytecode(string(content), cleanPath)
if err != nil {
return nil, err
}
// Store in cache - using our Store method with LRU eviction
cache.Store(cleanPath, bytecode)
return bytecode, nil
}
return nil, ErrModuleNotFound
}
// isSubPath checks if path is contained within base directory
func isSubPath(baseDir, path string) bool {
if baseDir == "" {
return false
}
// Clean and normalize paths
baseDir = filepath.Clean(baseDir)
path = filepath.Clean(path)
// Get relative path
rel, err := filepath.Rel(baseDir, path)
if err != nil {
return false
}
// Check if path goes outside baseDir
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
}