optimized module loader

This commit is contained in:
Sky Johnson 2025-03-22 16:39:13 -05:00
parent 522a5770ed
commit 7bc5194b10
4 changed files with 565 additions and 622 deletions

View File

@ -3,9 +3,7 @@ package runner
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -21,6 +19,9 @@ var (
// StateInitFunc is a function that initializes a Lua state // StateInitFunc is a function that initializes a Lua state
type StateInitFunc func(*luajit.State) error type StateInitFunc func(*luajit.State) error
// RunnerOption defines a functional option for configuring the LuaRunner
type RunnerOption func(*LuaRunner)
// LuaRunner runs Lua scripts using a single Lua state // LuaRunner runs Lua scripts using a single Lua state
type LuaRunner struct { type LuaRunner struct {
state *luajit.State // The Lua state state *luajit.State // The Lua state
@ -30,21 +31,44 @@ type LuaRunner struct {
wg sync.WaitGroup // WaitGroup for clean shutdown wg sync.WaitGroup // WaitGroup for clean shutdown
initFunc StateInitFunc // Optional function to initialize Lua state initFunc StateInitFunc // Optional function to initialize Lua state
bufferSize int // Size of the job queue buffer bufferSize int // Size of the job queue buffer
requireCache *RequireCache // Cache for required modules moduleLoader *NativeModuleLoader // Native module loader for require
requireCfg *RequireConfig // Configuration for require paths
moduleLoader luajit.GoFunction // Keep reference to prevent GC
sandbox *Sandbox // The sandbox environment sandbox *Sandbox // The sandbox environment
} }
// WithBufferSize sets the job queue buffer size
func WithBufferSize(size int) RunnerOption {
return func(r *LuaRunner) {
if size > 0 {
r.bufferSize = size
}
}
}
// WithInitFunc sets the init function for the Lua state
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
return func(r *LuaRunner) {
r.initFunc = initFunc
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) {
if r.moduleLoader == nil || r.moduleLoader.config == nil {
r.moduleLoader = NewNativeModuleLoader(&RequireConfig{
LibDirs: dirs,
})
} else {
r.moduleLoader.config.LibDirs = dirs
}
}
}
// NewRunner creates a new LuaRunner // NewRunner creates a new LuaRunner
func NewRunner(options ...RunnerOption) (*LuaRunner, error) { func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Default configuration // Default configuration
runner := &LuaRunner{ runner := &LuaRunner{
bufferSize: 10, // Default buffer size bufferSize: 10, // Default buffer size
requireCache: NewRequireCache(),
requireCfg: &RequireConfig{
LibDirs: []string{},
},
sandbox: NewSandbox(), sandbox: NewSandbox(),
} }
@ -64,55 +88,25 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
runner.jobQueue = make(chan job, runner.bufferSize) runner.jobQueue = make(chan job, runner.bufferSize)
runner.isRunning.Store(true) runner.isRunning.Store(true)
// Create a shared config pointer that will be updated per request // Set up module loader if not already initialized
runner.requireCfg = &RequireConfig{ if runner.moduleLoader == nil {
ScriptDir: runner.scriptDir(), requireConfig := &RequireConfig{
LibDirs: runner.libDirs(), ScriptDir: "",
LibDirs: []string{},
}
runner.moduleLoader = NewNativeModuleLoader(requireConfig)
} }
// Set up require functionality // Set up require paths and mechanism
moduleLoader := func(s *luajit.State) int { if err := runner.moduleLoader.SetupRequire(state); err != nil {
// Get module name
modName := s.ToString(1)
if modName == "" {
s.PushString("module name required")
return -1
}
// Find and compile module
bytecode, err := findAndCompileModule(s, runner.requireCache, *runner.requireCfg, modName)
if err != nil {
if err == ErrModuleNotFound {
s.PushString("module '" + modName + "' not found")
} else {
s.PushString("error loading module: " + err.Error())
}
return -1 // Return error
}
// Load the bytecode
if err := s.LoadBytecode(bytecode, modName); err != nil {
s.PushString("error loading bytecode: " + err.Error())
return -1 // Return error
}
// Return the loaded function
return 1
}
// Store reference to prevent garbage collection
runner.moduleLoader = moduleLoader
// Register with Lua state
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
state.Close() state.Close()
return nil, ErrInitFailed return nil, ErrInitFailed
} }
// Set up the require mechanism // Preload all modules into package.loaded
if err := setupRequireFunction(state); err != nil { if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
state.Close() state.Close()
return nil, ErrInitFailed return nil, errors.New("failed to preload modules")
} }
// Set up sandbox // Set up sandbox
@ -136,93 +130,10 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
return runner, nil return runner, nil
} }
// setupRequireFunction adds the secure require implementation
func setupRequireFunction(state *luajit.State) error {
return state.DoString(`
function __setup_secure_require(env)
-- Replace env.require with our secure version
env.require = function(modname)
-- Check if already loaded in this environment's package.loaded
if env.package.loaded[modname] then
return env.package.loaded[modname]
end
-- Try to load the module using our Go loader
local loader = __go_load_module
-- Load the module
local f, err = loader(modname)
if not f then
error(err or "failed to load module: " .. modname)
end
-- Set the environment for the module
setfenv(f, env)
-- Execute the module
local result = f()
-- If module didn't return a value, use true
if result == nil then
result = true
end
-- Cache the result in this environment only
env.package.loaded[modname] = result
return result
end
return env
end
`)
}
// RunnerOption defines a functional option for configuring the LuaRunner
type RunnerOption func(*LuaRunner)
// WithBufferSize sets the job queue buffer size
func WithBufferSize(size int) RunnerOption {
return func(r *LuaRunner) {
if size > 0 {
r.bufferSize = size
}
}
}
// WithInitFunc sets the init function for the Lua state
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
return func(r *LuaRunner) {
r.initFunc = initFunc
}
}
// WithScriptDir sets the base directory for scripts
func WithScriptDir(dir string) RunnerOption {
return func(r *LuaRunner) {
r.requireCfg.ScriptDir = dir
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) {
r.requireCfg.LibDirs = dirs
}
}
// scriptDir returns the current script directory
func (r *LuaRunner) scriptDir() string {
if r.requireCfg != nil {
return r.requireCfg.ScriptDir
}
return ""
}
// libDirs returns the current library directories // libDirs returns the current library directories
func (r *LuaRunner) libDirs() []string { func (r *LuaRunner) libDirs() []string {
if r.requireCfg != nil { if r.moduleLoader != nil && r.moduleLoader.config != nil {
return r.requireCfg.LibDirs return r.moduleLoader.config.LibDirs
} }
return nil return nil
} }
@ -246,20 +157,13 @@ func (r *LuaRunner) processJobs() {
// executeJob runs a script in the sandbox environment // executeJob runs a script in the sandbox environment
func (r *LuaRunner) executeJob(j job) JobResult { func (r *LuaRunner) executeJob(j job) JobResult {
// If the job has a script path, update paths without re-registering // If the job has a script path, update script dir for module resolution
if j.ScriptPath != "" { if j.ScriptPath != "" {
r.mu.Lock() r.mu.Lock()
UpdateRequirePaths(r.requireCfg, j.ScriptPath) r.moduleLoader.config.ScriptDir = filepath.Dir(j.ScriptPath)
r.mu.Unlock() r.mu.Unlock()
} }
// Re-run init function if needed
if r.initFunc != nil {
if err := r.initFunc(r.state); err != nil {
return JobResult{nil, err}
}
}
// Convert context for sandbox // Convert context for sandbox
var ctx map[string]any var ctx map[string]any
if j.Context != nil { if j.Context != nil {
@ -328,87 +232,41 @@ func (r *LuaRunner) Close() error {
return nil return nil
} }
// RequireCache returns the require cache for external access // NotifyFileChanged handles file change notifications from watchers
func (r *LuaRunner) RequireCache() *RequireCache { func (r *LuaRunner) NotifyFileChanged(filePath string) bool {
return r.requireCache if r.moduleLoader != nil {
return r.moduleLoader.NotifyFileChanged(r.state, filePath)
}
return false
} }
// ClearRequireCache clears the cache of loaded modules // ResetModuleCache clears non-core modules from package.loaded
func (r *LuaRunner) ClearRequireCache() { func (r *LuaRunner) ResetModuleCache() {
r.requireCache.Clear() if r.moduleLoader != nil {
r.moduleLoader.ResetModules(r.state)
}
}
// ReloadAllModules reloads all modules into package.loaded
func (r *LuaRunner) ReloadAllModules() error {
if r.moduleLoader != nil {
return r.moduleLoader.PreloadAllModules(r.state)
}
return nil
}
// RefreshModuleByName invalidates a specific module in package.loaded
func (r *LuaRunner) RefreshModuleByName(modName string) bool {
if r.state != nil {
if err := r.state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
return false
}
return true
}
return false
} }
// AddModule adds a module to the sandbox environment // AddModule adds a module to the sandbox environment
func (r *LuaRunner) AddModule(name string, module any) { func (r *LuaRunner) AddModule(name string, module any) {
r.sandbox.AddModule(name, module) r.sandbox.AddModule(name, module)
} }
// RefreshRequireCache refreshes the module cache if needed
func (r *LuaRunner) RefreshRequireCache() int {
count := r.requireCache.RefreshAll()
return count
}
// ResetPackageLoaded resets the Lua package.loaded table
func (r *LuaRunner) ResetPackageLoaded() error {
return r.state.DoString(`
-- Create list of modules to unload (excluding core modules)
local to_unload = {}
for name, _ in pairs(package.loaded) do
-- Skip core modules
if name ~= "string" and
name ~= "table" and
name ~= "math" and
name ~= "os" and
name ~= "package" and
name ~= "io" and
name ~= "coroutine" and
name ~= "debug" and
name ~= "_G" then
table.insert(to_unload, name)
end
end
-- Unload each module
for _, name in ipairs(to_unload) do
package.loaded[name] = nil
end
`)
}
// RefreshModuleByName clears a specific module from Lua's package.loaded table
func (r *LuaRunner) RefreshModuleByName(modName string) error {
if r.state == nil {
return nil
}
return r.state.DoString(fmt.Sprintf(`
package.loaded["%s"] = nil
`, modName))
}
// ProcessModuleChange handles a file change notification from a watcher
func (r *LuaRunner) ProcessModuleChange(filePath string) {
// Mark cache as needing refresh
r.requireCache.MarkNeedsRefresh()
// Extract module name from file path for package.loaded clearing
ext := filepath.Ext(filePath)
if ext == ".lua" {
// Get relative path from lib directories
var modName string
for _, libDir := range r.requireCfg.LibDirs {
if rel, err := filepath.Rel(libDir, filePath); err == nil && !strings.HasPrefix(rel, "..") {
// Convert path to module name format
modName = strings.TrimSuffix(rel, ext)
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
break
}
}
if modName != "" {
// Clear from Lua's package.loaded (non-blocking)
go r.RefreshModuleByName(modName)
}
}
}

View File

@ -1,361 +1,429 @@
package runner package runner
import ( import (
"errors"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
// Common errors // RequireConfig holds configuration for Lua's require mechanism
var (
ErrModuleNotFound = errors.New("module not found")
ErrPathTraversal = errors.New("path traversal not allowed")
)
// ModuleEntry represents a cached module with timestamp
type ModuleEntry struct {
Bytecode []byte
LastUsed time.Time
}
// RequireConfig holds configuration for Lua's require function
type RequireConfig struct { type RequireConfig struct {
ScriptDir string // Base directory for script being executed ScriptDir string // Base directory for script being executed
LibDirs []string // Additional library directories LibDirs []string // Additional library directories
} }
// RequireCache is a thread-safe cache for loaded Lua modules // NativeModuleLoader uses Lua's native package.loaded as the cache
type RequireCache struct { type NativeModuleLoader struct {
modules sync.Map // Maps full file paths to ModuleEntry registry *ModuleRegistry
mu sync.Mutex config *RequireConfig
maxItems int // Maximum number of modules to cache mu sync.RWMutex
lastRefresh time.Time // When we last did a full refresh check
needsRefresh atomic.Bool // Flag for watchers to signal refresh needed
} }
// NewRequireCache creates a new, empty require cache // ModuleRegistry keeps track of Lua modules for file watching
func NewRequireCache() *RequireCache { type ModuleRegistry struct {
cache := &RequireCache{ // Maps file paths to module names
modules: sync.Map{}, pathToModule sync.Map
maxItems: 100, // Default cache size // Maps module names to file paths (for direct access)
lastRefresh: time.Now(), moduleToPath sync.Map
}
return cache
} }
// SetCacheSize adjusts the maximum cache size // NewModuleRegistry creates a new module registry
func (c *RequireCache) SetCacheSize(size int) { func NewModuleRegistry() *ModuleRegistry {
if size > 0 { return &ModuleRegistry{
c.mu.Lock() pathToModule: sync.Map{},
c.maxItems = size moduleToPath: sync.Map{},
c.mu.Unlock()
} }
} }
// Size returns the approximate number of items in the cache // Register adds a module path to the registry
func (c *RequireCache) Size() int { func (r *ModuleRegistry) Register(path string, name string) {
size := 0 r.pathToModule.Store(path, name)
c.modules.Range(func(_, _ any) bool { r.moduleToPath.Store(name, path)
size++
return true
})
return size
} }
// MarkNeedsRefresh signals that modules have changed and need refresh // GetModuleName retrieves a module name by path
func (c *RequireCache) MarkNeedsRefresh() { func (r *ModuleRegistry) GetModuleName(path string) (string, bool) {
c.needsRefresh.Store(true) value, ok := r.pathToModule.Load(path)
}
// Get retrieves a module from the cache, updating its last used time
func (c *RequireCache) Get(path string) ([]byte, bool) {
value, ok := c.modules.Load(path)
if !ok { if !ok {
return nil, false return "", false
}
return value.(string), true
} }
entry, ok := value.(ModuleEntry) // GetModulePath retrieves a path by module name
func (r *ModuleRegistry) GetModulePath(name string) (string, bool) {
value, ok := r.moduleToPath.Load(name)
if !ok { if !ok {
// Handle legacy entries (plain bytecode) return "", false
bytecode, ok := value.([]byte) }
if !ok { return value.(string), true
return nil, false
} }
// Convert to ModuleEntry and update // NewNativeModuleLoader creates a new native module loader
entry = ModuleEntry{ func NewNativeModuleLoader(config *RequireConfig) *NativeModuleLoader {
Bytecode: bytecode, return &NativeModuleLoader{
LastUsed: time.Now(), registry: NewModuleRegistry(),
} config: config,
c.modules.Store(path, entry)
return bytecode, true
}
// Update last used time
entry.LastUsed = time.Now()
c.modules.Store(path, entry)
return entry.Bytecode, true
}
// Store adds a module to the cache with LRU eviction
func (c *RequireCache) Store(path string, bytecode []byte) {
c.mu.Lock()
defer c.mu.Unlock()
// Check if we need to evict
if c.Size() >= c.maxItems {
c.evictOldest()
}
// Store the new entry
c.modules.Store(path, ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
})
}
// evictOldest removes the least recently used item from the cache
func (c *RequireCache) evictOldest() {
var oldestTime time.Time
var oldestKey string
first := true
// Find oldest entry
c.modules.Range(func(key, value any) bool {
// Handle different value types
var lastUsed time.Time
switch v := value.(type) {
case ModuleEntry:
lastUsed = v.LastUsed
default:
// For non-ModuleEntry values, treat as oldest
if first {
oldestKey = key.(string)
first = false
return true
}
return true
}
if first || lastUsed.Before(oldestTime) {
oldestTime = lastUsed
oldestKey = key.(string)
first = false
}
return true
})
// Remove oldest entry
if oldestKey != "" {
c.modules.Delete(oldestKey)
} }
} }
// Clear empties the entire cache // escapeLuaString escapes special characters in a string for Lua
func (c *RequireCache) Clear() { func escapeLuaString(s string) string {
c.mu.Lock() replacer := strings.NewReplacer(
defer c.mu.Unlock() "\\", "\\\\",
"\"", "\\\"",
// Create a new sync.Map to replace the existing one "\n", "\\n",
c.modules = sync.Map{} "\r", "\\r",
"\t", "\\t",
)
return replacer.Replace(s)
} }
// RefreshModule checks if a specific module needs to be refreshed // SetupRequire configures the require system
func (c *RequireCache) RefreshModule(path string) bool { func (l *NativeModuleLoader) SetupRequire(state *luajit.State) error {
// Get the cached module // Initialize our module registry in Lua
val, ok := c.modules.Load(path) return state.DoString(`
if !ok { -- Initialize global module registry
// Not in cache, nothing to refresh __module_paths = {}
return false
-- Setup fast module loading system
__module_results = {}
-- Create module preload table
package.preload = package.preload or {}
-- Setup module loader registry
__ready_modules = {}
`)
} }
// Get file info // PreloadAllModules fully preloads modules for maximum performance
fileInfo, err := os.Stat(path) func (l *NativeModuleLoader) PreloadAllModules(state *luajit.State) error {
if err != nil { l.mu.Lock()
// File no longer exists or can't be accessed, remove from cache defer l.mu.Unlock()
c.modules.Delete(path)
return true // Reset registry
l.registry = NewModuleRegistry()
// Reset preloaded modules in Lua
if err := state.DoString(`
-- Reset module registry
__module_paths = {}
__module_results = {}
-- Clear non-core modules from package.loaded
local core_modules = {
string = true, table = true, math = true, os = true,
package = true, io = true, coroutine = true, debug = true, _G = true
} }
// Check if the cached module is up-to-date for name in pairs(package.loaded) do
entry, ok := val.(ModuleEntry) if not core_modules[name] then
if !ok { package.loaded[name] = nil
// Invalid entry, remove it end
c.modules.Delete(path) end
return true
-- Reset preload table
package.preload = package.preload or {}
for name in pairs(package.preload) do
package.preload[name] = nil
end
-- Reset ready modules
__ready_modules = {}
`); err != nil {
return err
} }
// Check if the file has been modified since it was cached // Set up paths for require
if fileInfo.ModTime().After(entry.LastUsed) { absPaths := []string{}
// File is newer than the cached version, remove from cache pathsMap := map[string]bool{}
c.modules.Delete(path)
return true
}
return false // Add script directory (absolute path)
} if l.config.ScriptDir != "" {
absPath, err := filepath.Abs(l.config.ScriptDir)
// RefreshAll checks all cached modules and refreshes those that have changed if err == nil && !pathsMap[absPath] {
func (c *RequireCache) RefreshAll() int { absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
refreshed := 0 pathsMap[absPath] = true
// No need to refresh if flag isn't set
if !c.needsRefresh.Load() {
return 0
}
// For maximum performance, just clear everything
c.Clear()
// Reset the needsRefresh flag
c.needsRefresh.Store(false)
c.lastRefresh = time.Now()
return refreshed
}
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
func UpdateRequirePaths(config *RequireConfig, scriptPath string) {
if scriptPath != "" {
config.ScriptDir = filepath.Dir(scriptPath)
} }
} }
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode // Add lib directories (absolute paths)
func findAndCompileModule( for _, dir := range l.config.LibDirs {
state *luajit.State, if dir == "" {
cache *RequireCache,
config RequireConfig,
modName string,
) ([]byte, error) {
// Convert module name to relative path
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
// List of paths to check
paths := []string{}
// 1. Check adjacent to script directory first
if config.ScriptDir != "" {
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
}
// 2. Check in lib directories
for _, libDir := range config.LibDirs {
if libDir != "" {
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
}
}
// If the cache needs refresh, handle it immediately
if cache.needsRefresh.Load() {
cache.Clear() // Complete reset for max performance
cache.needsRefresh.Store(false)
cache.lastRefresh = time.Now()
}
// Try each path
for _, path := range paths {
// Clean the path to handle .. and such (security)
cleanPath := filepath.Clean(path)
// Check for path traversal (extra safety)
if !isSubPath(config.ScriptDir, cleanPath) {
isValidLib := false
for _, libDir := range config.LibDirs {
if isSubPath(libDir, cleanPath) {
isValidLib = true
break
}
}
if !isValidLib {
continue // Skip paths outside allowed directories
}
}
// Check if already in cache
if value, ok := cache.modules.Load(cleanPath); ok {
entry, ok := value.(ModuleEntry)
if !ok {
// Legacy format, use it anyway
return value.([]byte), nil
}
// Check file modification time if cache is marked for refresh
if cache.needsRefresh.Load() {
fileInfo, err := os.Stat(cleanPath)
// Remove from cache if file changed or doesn't exist
if err != nil || (entry.LastUsed.Before(fileInfo.ModTime())) {
cache.modules.Delete(cleanPath)
// Continue to recompile
} else {
// Update last used time and return cached bytecode
entry.LastUsed = time.Now()
cache.modules.Store(cleanPath, entry)
return entry.Bytecode, nil
}
} else {
// Update last used time and return cached bytecode
entry.LastUsed = time.Now()
cache.modules.Store(cleanPath, entry)
return entry.Bytecode, nil
}
}
// Check if file exists
_, err := os.Stat(cleanPath)
if os.IsNotExist(err) {
continue continue
} }
// Read and compile the file absPath, err := filepath.Abs(dir)
content, err := os.ReadFile(cleanPath) if err == nil && !pathsMap[absPath] {
if err != nil { absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
return nil, err pathsMap[absPath] = true
}
} }
// Compile to bytecode // Set package.path
bytecode, err := state.CompileBytecode(string(content), cleanPath) escapedPathStr := escapeLuaString(strings.Join(absPaths, ";"))
if err != nil { if err := state.DoString(`package.path = "` + escapedPathStr + `"`); err != nil {
return nil, err return err
} }
// Store in cache with current time // Process and preload all modules from lib directories
cache.modules.Store(cleanPath, ModuleEntry{ for _, dir := range l.config.LibDirs {
Bytecode: bytecode, if dir == "" {
LastUsed: time.Now(), continue
}
absDir, err := filepath.Abs(dir)
if err != nil {
continue
}
// Find all Lua files
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
return nil
}
// Get module name
relPath, err := filepath.Rel(absDir, path)
if err != nil || strings.HasPrefix(relPath, "..") {
return nil
}
modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
// Register module path
l.registry.Register(path, modName)
// Register path in Lua
escapedPath := escapeLuaString(path)
escapedName := escapeLuaString(modName)
if err := state.DoString(`__module_paths["` + escapedName + `"] = "` + escapedPath + `"`); err != nil {
return nil
}
// Compile the module
content, err := os.ReadFile(path)
if err != nil {
return nil
}
// Precompile bytecode
bytecode, err := state.CompileBytecode(string(content), path)
if err != nil {
return nil
}
// Load bytecode
if err := state.LoadBytecode(bytecode, path); err != nil {
return nil
}
// Store in package.preload for fast loading
// We use string concat for efficiency (no string.format overhead)
luaCode := `
local modname = "` + escapedName + `"
local chunk = ...
package.preload[modname] = chunk
__ready_modules[modname] = true
`
if err := state.DoString(luaCode); err != nil {
state.Pop(1) // Remove chunk from stack
return nil
}
state.Pop(1) // Remove chunk from stack
return nil
}) })
return bytecode, nil if err != nil {
return err
}
} }
return nil, ErrModuleNotFound // Install optimized require implementation
return state.DoString(`
-- Ultra-fast module loader
function __fast_require(env, modname)
-- 1. Check already loaded modules
if package.loaded[modname] then
return package.loaded[modname]
end
-- 2. Check preloaded chunks
if __ready_modules[modname] then
local loader = package.preload[modname]
if loader then
-- Set environment
setfenv(loader, env)
-- Execute and store result
local result = loader()
if result == nil then
result = true
end
-- Cache in shared registry
package.loaded[modname] = result
return result
end
end
-- 3. Direct file load as fallback
if __module_paths[modname] then
local path = __module_paths[modname]
local chunk, err = loadfile(path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- 4. Full path search as last resort
local err_msgs = {}
for path in package.path:gmatch("[^;]+") do
local file_path = path:gsub("?", modname:gsub("%.", "/"))
local chunk, err = loadfile(file_path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
table.insert(err_msgs, "no file '" .. file_path .. "'")
end
error("module '" .. modname .. "' not found:\n" .. table.concat(err_msgs, "\n"), 2)
end
-- Install require factory
function __setup_require(env)
-- Create highly optimized require with closure
env.require = function(modname)
return __fast_require(env, modname)
end
return env
end
`)
} }
// isSubPath checks if path is contained within base directory // NotifyFileChanged invalidates modules when files change
func isSubPath(baseDir, path string) bool { func (l *NativeModuleLoader) NotifyFileChanged(state *luajit.State, path string) bool {
if baseDir == "" {
return false
}
// Clean and normalize paths
baseDir = filepath.Clean(baseDir)
path = filepath.Clean(path) path = filepath.Clean(path)
// Get relative path // Get module name from registry
rel, err := filepath.Rel(baseDir, path) modName, found := l.registry.GetModuleName(path)
if !found {
// Try to find by path for lib dirs
for _, libDir := range l.config.LibDirs {
absDir, err := filepath.Abs(libDir)
if err != nil { if err != nil {
continue
}
relPath, err := filepath.Rel(absDir, path)
if err != nil || strings.HasPrefix(relPath, "..") {
continue
}
if strings.HasSuffix(relPath, ".lua") {
modName = strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
found = true
break
}
}
}
if !found {
return false return false
} }
// Check if path goes outside baseDir // Update bytecode and invalidate caches
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".." content, err := os.ReadFile(path)
if err != nil {
// File might have been deleted - just invalidate
escapedName := escapeLuaString(modName)
state.DoString(`
package.loaded["` + escapedName + `"] = nil
__ready_modules["` + escapedName + `"] = nil
if package.preload then
package.preload["` + escapedName + `"] = nil
end
`)
return true
}
// Recompile module
bytecode, err := state.CompileBytecode(string(content), path)
if err != nil {
// Invalid Lua - just invalidate
escapedName := escapeLuaString(modName)
state.DoString(`
package.loaded["` + escapedName + `"] = nil
__ready_modules["` + escapedName + `"] = nil
if package.preload then
package.preload["` + escapedName + `"] = nil
end
`)
return true
}
// Load bytecode
if err := state.LoadBytecode(bytecode, path); err != nil {
// Unable to load - just invalidate
escapedName := escapeLuaString(modName)
state.DoString(`
package.loaded["` + escapedName + `"] = nil
__ready_modules["` + escapedName + `"] = nil
if package.preload then
package.preload["` + escapedName + `"] = nil
end
`)
return true
}
// Update preload with new chunk
escapedName := escapeLuaString(modName)
luaCode := `
-- Update module in package.preload and clear loaded
package.loaded["` + escapedName + `"] = nil
package.preload["` + escapedName + `"] = ...
__ready_modules["` + escapedName + `"] = true
`
if err := state.DoString(luaCode); err != nil {
state.Pop(1) // Remove chunk from stack
return false
}
state.Pop(1) // Remove chunk from stack
return true
}
// ResetModules clears all non-core modules
func (l *NativeModuleLoader) ResetModules(state *luajit.State) error {
return state.DoString(`
local core_modules = {
string = true, table = true, math = true, os = true,
package = true, io = true, coroutine = true, debug = true, _G = true
}
for name in pairs(package.loaded) do
if not core_modules[name] then
package.loaded[name] = nil
end
end
`)
} }

View File

@ -30,120 +30,134 @@ func (s *Sandbox) Setup(state *luajit.State) error {
return err return err
} }
// Setup the sandbox creation logic with base environment reuse // Create high-performance persistent environment
return state.DoString(` return state.DoString(`
-- Create the base environment once (static parts) -- Global shared environment (created once)
local __base_env = nil __env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Create function to initialize base environment -- Initialize base environment once
function __init_base_env() if not __env_system.initialized then
if __base_env then return end -- Create base environment with all standard libraries
local base = {}
local env = {} -- Safe standard libraries
base.string = string
-- Add standard library modules (restricted) base.table = table
env.string = string base.math = math
env.table = table base.os = {
env.math = math
env.os = {
time = os.time, time = os.time,
date = os.date, date = os.date,
difftime = os.difftime, difftime = os.difftime,
clock = os.clock clock = os.clock
} }
env.tonumber = tonumber
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Add module loader -- Basic functions
env.__go_load_module = __go_load_module base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.unpack = unpack
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
-- Add custom modules from sandbox registry -- Package system is shared for performance
if __sandbox_modules then base.package = {
for name, module in pairs(__sandbox_modules) do loaded = package.loaded,
env[name] = module path = package.path,
end preload = package.preload
end
-- Copy custom global functions
for k, v in pairs(_G) do
if (type(v) == "function" or type(v) == "table") and
k ~= "__sandbox_modules" and
k ~= "__base_env" and
k ~= "__init_base_env" and
k ~= "__create_sandbox_env" and
k ~= "__run_sandboxed" and
k ~= "__setup_secure_require" and
k ~= "__go_load_module" and
k ~= "string" and k ~= "table" and k ~= "math" and
k ~= "os" and k ~= "io" and k ~= "debug" and
k ~= "package" and k ~= "bit" and k ~= "jit" and
k ~= "coroutine" and k ~= "_G" and k ~= "_VERSION" then
env[k] = v
end
end
__base_env = env
end
-- Create function that builds sandbox from base env
function __create_sandbox_env(ctx)
-- Initialize base env if needed
__init_base_env()
-- Create new environment using base as prototype
local env = {}
-- Copy from base environment
for k, v in pairs(__base_env) do
env[k] = v
end
-- Add isolated package.loaded table
env.package = {
loaded = {}
} }
-- Add context if provided -- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Fast environment creation with pre-allocation
function __get_sandbox_env(ctx)
local env
-- Try to reuse from pool
if __env_system.pool_size > 0 then
env = table.remove(__env_system.env_pool)
__env_system.pool_size = __env_system.pool_size - 1
-- Clear any previous context
env.ctx = ctx or nil
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env
})
-- Set context if provided
if ctx then if ctx then
env.ctx = ctx env.ctx = ctx
end end
-- Setup require function -- Install the fast require implementation
env = __setup_secure_require(env) env.require = function(modname)
return __fast_require(env, modname)
-- Create metatable for isolation end
local mt = {
__index = function(t, k)
return rawget(env, k)
end,
__newindex = function(t, k, v)
rawset(env, k, v)
end end
}
setmetatable(env, mt)
return env return env
end end
-- Function to run code in sandbox -- Return environment to pool for reuse
function __run_sandboxed(bytecode, ctx) function __recycle_env(env)
-- Create fresh environment for this request -- Only recycle if pool isn't full
local env = __create_sandbox_env(ctx) if __env_system.pool_size < __env_system.max_pool_size then
-- Clear context reference to avoid memory leaks
env.ctx = nil
-- Set environment and execute -- Add to pool
setfenv(bytecode, env) table.insert(__env_system.env_pool, env)
return bytecode() __env_system.pool_size = __env_system.pool_size + 1
end end
end
-- Hyper-optimized sandbox executor
function __execute_sandbox(bytecode, ctx)
-- Get environment (from pool if available)
local env = __get_sandbox_env(ctx)
-- Set environment for bytecode
setfenv(bytecode, env)
-- Execute with protected call
local success, result = pcall(bytecode)
-- Recycle environment for future use
__recycle_env(env)
-- Process result
if not success then
error(result, 0)
end
return result
end
-- Run minimal GC for overall health
collectgarbage("step", 10)
`) `)
} }
@ -191,11 +205,14 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
// Create context table if provided // Create context table if provided
if len(ctx) > 0 { if len(ctx) > 0 {
state.NewTable() // Preallocate table with appropriate size
state.CreateTable(0, len(ctx))
// Add context entries
for k, v := range ctx { for k, v := range ctx {
state.PushString(k) state.PushString(k)
if err := state.PushValue(v); err != nil { if err := state.PushValue(v); err != nil {
state.Pop(3) state.Pop(2)
return nil, err return nil, err
} }
state.SetTable(-3) state.SetTable(-3)
@ -204,8 +221,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
state.PushNil() // No context state.PushNil() // No context
} }
// Get sandbox function // Get optimized sandbox executor
state.GetGlobal("__run_sandboxed") state.GetGlobal("__execute_sandbox")
// Setup call with correct argument order // Setup call with correct argument order
state.PushCopy(-3) // Copy bytecode function state.PushCopy(-3) // Copy bytecode function
@ -215,7 +232,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
state.Remove(-5) // Remove original bytecode state.Remove(-5) // Remove original bytecode
state.Remove(-4) // Remove original context state.Remove(-4) // Remove original context
// Call sandbox function // Call optimized sandbox executor
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
return nil, err return nil, err
} }

View File

@ -10,17 +10,17 @@ func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string, log *logger.
watchers := make([]*Watcher, 0, len(libDirs)) watchers := make([]*Watcher, 0, len(libDirs))
for _, dir := range libDirs { for _, dir := range libDirs {
// Create a directory-specific callback that only does minimal work // Create a directory-specific callback that identifies changed files
dirCopy := dir // Capture for closure dirCopy := dir // Capture for closure
callback := func() error { callback := func() error {
log.Debug("Detected changes in Lua module directory: %s", dirCopy) log.Debug("Detected changes in Lua module directory: %s", dirCopy)
// Completely reset the cache to match fresh-start performance // Instead of clearing everything, use directory-level smart refresh
luaRunner.RequireCache().Clear() // This will scan lib directory and refresh all modified Lua modules
if err := luaRunner.ReloadAllModules(); err != nil {
// Force reset of Lua's module registry log.Warning("Error reloading modules: %v", err)
luaRunner.ResetPackageLoaded() }
return nil return nil
} }