optimized module loader

This commit is contained in:
Sky Johnson 2025-03-22 16:39:13 -05:00
parent 522a5770ed
commit 7bc5194b10
4 changed files with 565 additions and 622 deletions

View File

@ -3,9 +3,7 @@ package runner
import (
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"sync"
"sync/atomic"
@ -21,6 +19,9 @@ var (
// StateInitFunc is a function that initializes a Lua state
type StateInitFunc func(*luajit.State) error
// RunnerOption defines a functional option for configuring the LuaRunner
type RunnerOption func(*LuaRunner)
// LuaRunner runs Lua scripts using a single Lua state
type LuaRunner struct {
state *luajit.State // The Lua state
@ -30,21 +31,44 @@ type LuaRunner struct {
wg sync.WaitGroup // WaitGroup for clean shutdown
initFunc StateInitFunc // Optional function to initialize Lua state
bufferSize int // Size of the job queue buffer
requireCache *RequireCache // Cache for required modules
requireCfg *RequireConfig // Configuration for require paths
moduleLoader luajit.GoFunction // Keep reference to prevent GC
moduleLoader *NativeModuleLoader // Native module loader for require
sandbox *Sandbox // The sandbox environment
}
// WithBufferSize sets the job queue buffer size
func WithBufferSize(size int) RunnerOption {
return func(r *LuaRunner) {
if size > 0 {
r.bufferSize = size
}
}
}
// WithInitFunc sets the init function for the Lua state
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
return func(r *LuaRunner) {
r.initFunc = initFunc
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) {
if r.moduleLoader == nil || r.moduleLoader.config == nil {
r.moduleLoader = NewNativeModuleLoader(&RequireConfig{
LibDirs: dirs,
})
} else {
r.moduleLoader.config.LibDirs = dirs
}
}
}
// NewRunner creates a new LuaRunner
func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Default configuration
runner := &LuaRunner{
bufferSize: 10, // Default buffer size
requireCache: NewRequireCache(),
requireCfg: &RequireConfig{
LibDirs: []string{},
},
sandbox: NewSandbox(),
}
@ -64,55 +88,25 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
runner.jobQueue = make(chan job, runner.bufferSize)
runner.isRunning.Store(true)
// Create a shared config pointer that will be updated per request
runner.requireCfg = &RequireConfig{
ScriptDir: runner.scriptDir(),
LibDirs: runner.libDirs(),
// Set up module loader if not already initialized
if runner.moduleLoader == nil {
requireConfig := &RequireConfig{
ScriptDir: "",
LibDirs: []string{},
}
runner.moduleLoader = NewNativeModuleLoader(requireConfig)
}
// Set up require functionality
moduleLoader := func(s *luajit.State) int {
// Get module name
modName := s.ToString(1)
if modName == "" {
s.PushString("module name required")
return -1
}
// Find and compile module
bytecode, err := findAndCompileModule(s, runner.requireCache, *runner.requireCfg, modName)
if err != nil {
if err == ErrModuleNotFound {
s.PushString("module '" + modName + "' not found")
} else {
s.PushString("error loading module: " + err.Error())
}
return -1 // Return error
}
// Load the bytecode
if err := s.LoadBytecode(bytecode, modName); err != nil {
s.PushString("error loading bytecode: " + err.Error())
return -1 // Return error
}
// Return the loaded function
return 1
}
// Store reference to prevent garbage collection
runner.moduleLoader = moduleLoader
// Register with Lua state
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
// Set up require paths and mechanism
if err := runner.moduleLoader.SetupRequire(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up the require mechanism
if err := setupRequireFunction(state); err != nil {
// Preload all modules into package.loaded
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
state.Close()
return nil, ErrInitFailed
return nil, errors.New("failed to preload modules")
}
// Set up sandbox
@ -136,93 +130,10 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
return runner, nil
}
// setupRequireFunction adds the secure require implementation
func setupRequireFunction(state *luajit.State) error {
return state.DoString(`
function __setup_secure_require(env)
-- Replace env.require with our secure version
env.require = function(modname)
-- Check if already loaded in this environment's package.loaded
if env.package.loaded[modname] then
return env.package.loaded[modname]
end
-- Try to load the module using our Go loader
local loader = __go_load_module
-- Load the module
local f, err = loader(modname)
if not f then
error(err or "failed to load module: " .. modname)
end
-- Set the environment for the module
setfenv(f, env)
-- Execute the module
local result = f()
-- If module didn't return a value, use true
if result == nil then
result = true
end
-- Cache the result in this environment only
env.package.loaded[modname] = result
return result
end
return env
end
`)
}
// RunnerOption defines a functional option for configuring the LuaRunner
type RunnerOption func(*LuaRunner)
// WithBufferSize sets the job queue buffer size
func WithBufferSize(size int) RunnerOption {
return func(r *LuaRunner) {
if size > 0 {
r.bufferSize = size
}
}
}
// WithInitFunc sets the init function for the Lua state
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
return func(r *LuaRunner) {
r.initFunc = initFunc
}
}
// WithScriptDir sets the base directory for scripts
func WithScriptDir(dir string) RunnerOption {
return func(r *LuaRunner) {
r.requireCfg.ScriptDir = dir
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) {
r.requireCfg.LibDirs = dirs
}
}
// scriptDir returns the current script directory
func (r *LuaRunner) scriptDir() string {
if r.requireCfg != nil {
return r.requireCfg.ScriptDir
}
return ""
}
// libDirs returns the current library directories
func (r *LuaRunner) libDirs() []string {
if r.requireCfg != nil {
return r.requireCfg.LibDirs
if r.moduleLoader != nil && r.moduleLoader.config != nil {
return r.moduleLoader.config.LibDirs
}
return nil
}
@ -246,20 +157,13 @@ func (r *LuaRunner) processJobs() {
// executeJob runs a script in the sandbox environment
func (r *LuaRunner) executeJob(j job) JobResult {
// If the job has a script path, update paths without re-registering
// If the job has a script path, update script dir for module resolution
if j.ScriptPath != "" {
r.mu.Lock()
UpdateRequirePaths(r.requireCfg, j.ScriptPath)
r.moduleLoader.config.ScriptDir = filepath.Dir(j.ScriptPath)
r.mu.Unlock()
}
// Re-run init function if needed
if r.initFunc != nil {
if err := r.initFunc(r.state); err != nil {
return JobResult{nil, err}
}
}
// Convert context for sandbox
var ctx map[string]any
if j.Context != nil {
@ -328,87 +232,41 @@ func (r *LuaRunner) Close() error {
return nil
}
// RequireCache returns the require cache for external access
func (r *LuaRunner) RequireCache() *RequireCache {
return r.requireCache
// NotifyFileChanged handles file change notifications from watchers
func (r *LuaRunner) NotifyFileChanged(filePath string) bool {
if r.moduleLoader != nil {
return r.moduleLoader.NotifyFileChanged(r.state, filePath)
}
return false
}
// ClearRequireCache clears the cache of loaded modules
func (r *LuaRunner) ClearRequireCache() {
r.requireCache.Clear()
// ResetModuleCache clears non-core modules from package.loaded
func (r *LuaRunner) ResetModuleCache() {
if r.moduleLoader != nil {
r.moduleLoader.ResetModules(r.state)
}
}
// ReloadAllModules reloads all modules into package.loaded
func (r *LuaRunner) ReloadAllModules() error {
if r.moduleLoader != nil {
return r.moduleLoader.PreloadAllModules(r.state)
}
return nil
}
// RefreshModuleByName invalidates a specific module in package.loaded
func (r *LuaRunner) RefreshModuleByName(modName string) bool {
if r.state != nil {
if err := r.state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
return false
}
return true
}
return false
}
// AddModule adds a module to the sandbox environment
func (r *LuaRunner) AddModule(name string, module any) {
r.sandbox.AddModule(name, module)
}
// RefreshRequireCache refreshes the module cache if needed
func (r *LuaRunner) RefreshRequireCache() int {
count := r.requireCache.RefreshAll()
return count
}
// ResetPackageLoaded resets the Lua package.loaded table
func (r *LuaRunner) ResetPackageLoaded() error {
return r.state.DoString(`
-- Create list of modules to unload (excluding core modules)
local to_unload = {}
for name, _ in pairs(package.loaded) do
-- Skip core modules
if name ~= "string" and
name ~= "table" and
name ~= "math" and
name ~= "os" and
name ~= "package" and
name ~= "io" and
name ~= "coroutine" and
name ~= "debug" and
name ~= "_G" then
table.insert(to_unload, name)
end
end
-- Unload each module
for _, name in ipairs(to_unload) do
package.loaded[name] = nil
end
`)
}
// RefreshModuleByName clears a specific module from Lua's package.loaded table
func (r *LuaRunner) RefreshModuleByName(modName string) error {
if r.state == nil {
return nil
}
return r.state.DoString(fmt.Sprintf(`
package.loaded["%s"] = nil
`, modName))
}
// ProcessModuleChange handles a file change notification from a watcher
func (r *LuaRunner) ProcessModuleChange(filePath string) {
// Mark cache as needing refresh
r.requireCache.MarkNeedsRefresh()
// Extract module name from file path for package.loaded clearing
ext := filepath.Ext(filePath)
if ext == ".lua" {
// Get relative path from lib directories
var modName string
for _, libDir := range r.requireCfg.LibDirs {
if rel, err := filepath.Rel(libDir, filePath); err == nil && !strings.HasPrefix(rel, "..") {
// Convert path to module name format
modName = strings.TrimSuffix(rel, ext)
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
break
}
}
if modName != "" {
// Clear from Lua's package.loaded (non-blocking)
go r.RefreshModuleByName(modName)
}
}
}

View File

@ -1,361 +1,429 @@
package runner
import (
"errors"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrModuleNotFound = errors.New("module not found")
ErrPathTraversal = errors.New("path traversal not allowed")
)
// ModuleEntry represents a cached module with timestamp
type ModuleEntry struct {
Bytecode []byte
LastUsed time.Time
}
// RequireConfig holds configuration for Lua's require function
// RequireConfig holds configuration for Lua's require mechanism
type RequireConfig struct {
ScriptDir string // Base directory for script being executed
LibDirs []string // Additional library directories
}
// RequireCache is a thread-safe cache for loaded Lua modules
type RequireCache struct {
modules sync.Map // Maps full file paths to ModuleEntry
mu sync.Mutex
maxItems int // Maximum number of modules to cache
lastRefresh time.Time // When we last did a full refresh check
needsRefresh atomic.Bool // Flag for watchers to signal refresh needed
// NativeModuleLoader uses Lua's native package.loaded as the cache
type NativeModuleLoader struct {
registry *ModuleRegistry
config *RequireConfig
mu sync.RWMutex
}
// NewRequireCache creates a new, empty require cache
func NewRequireCache() *RequireCache {
cache := &RequireCache{
modules: sync.Map{},
maxItems: 100, // Default cache size
lastRefresh: time.Now(),
}
return cache
// ModuleRegistry keeps track of Lua modules for file watching
type ModuleRegistry struct {
// Maps file paths to module names
pathToModule sync.Map
// Maps module names to file paths (for direct access)
moduleToPath sync.Map
}
// SetCacheSize adjusts the maximum cache size
func (c *RequireCache) SetCacheSize(size int) {
if size > 0 {
c.mu.Lock()
c.maxItems = size
c.mu.Unlock()
// NewModuleRegistry creates a new module registry
func NewModuleRegistry() *ModuleRegistry {
return &ModuleRegistry{
pathToModule: sync.Map{},
moduleToPath: sync.Map{},
}
}
// Size returns the approximate number of items in the cache
func (c *RequireCache) Size() int {
size := 0
c.modules.Range(func(_, _ any) bool {
size++
return true
})
return size
// Register adds a module path to the registry
func (r *ModuleRegistry) Register(path string, name string) {
r.pathToModule.Store(path, name)
r.moduleToPath.Store(name, path)
}
// MarkNeedsRefresh signals that modules have changed and need refresh
func (c *RequireCache) MarkNeedsRefresh() {
c.needsRefresh.Store(true)
}
// Get retrieves a module from the cache, updating its last used time
func (c *RequireCache) Get(path string) ([]byte, bool) {
value, ok := c.modules.Load(path)
// GetModuleName retrieves a module name by path
func (r *ModuleRegistry) GetModuleName(path string) (string, bool) {
value, ok := r.pathToModule.Load(path)
if !ok {
return nil, false
return "", false
}
return value.(string), true
}
entry, ok := value.(ModuleEntry)
// GetModulePath retrieves a path by module name
func (r *ModuleRegistry) GetModulePath(name string) (string, bool) {
value, ok := r.moduleToPath.Load(name)
if !ok {
// Handle legacy entries (plain bytecode)
bytecode, ok := value.([]byte)
if !ok {
return nil, false
return "", false
}
// Convert to ModuleEntry and update
entry = ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
}
c.modules.Store(path, entry)
return bytecode, true
}
// Update last used time
entry.LastUsed = time.Now()
c.modules.Store(path, entry)
return entry.Bytecode, true
return value.(string), true
}
// Store adds a module to the cache with LRU eviction
func (c *RequireCache) Store(path string, bytecode []byte) {
c.mu.Lock()
defer c.mu.Unlock()
// Check if we need to evict
if c.Size() >= c.maxItems {
c.evictOldest()
}
// Store the new entry
c.modules.Store(path, ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
})
}
// evictOldest removes the least recently used item from the cache
func (c *RequireCache) evictOldest() {
var oldestTime time.Time
var oldestKey string
first := true
// Find oldest entry
c.modules.Range(func(key, value any) bool {
// Handle different value types
var lastUsed time.Time
switch v := value.(type) {
case ModuleEntry:
lastUsed = v.LastUsed
default:
// For non-ModuleEntry values, treat as oldest
if first {
oldestKey = key.(string)
first = false
return true
}
return true
}
if first || lastUsed.Before(oldestTime) {
oldestTime = lastUsed
oldestKey = key.(string)
first = false
}
return true
})
// Remove oldest entry
if oldestKey != "" {
c.modules.Delete(oldestKey)
// NewNativeModuleLoader creates a new native module loader
func NewNativeModuleLoader(config *RequireConfig) *NativeModuleLoader {
return &NativeModuleLoader{
registry: NewModuleRegistry(),
config: config,
}
}
// Clear empties the entire cache
func (c *RequireCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
// Create a new sync.Map to replace the existing one
c.modules = sync.Map{}
// escapeLuaString escapes special characters in a string for Lua
func escapeLuaString(s string) string {
replacer := strings.NewReplacer(
"\\", "\\\\",
"\"", "\\\"",
"\n", "\\n",
"\r", "\\r",
"\t", "\\t",
)
return replacer.Replace(s)
}
// RefreshModule checks if a specific module needs to be refreshed
func (c *RequireCache) RefreshModule(path string) bool {
// Get the cached module
val, ok := c.modules.Load(path)
if !ok {
// Not in cache, nothing to refresh
return false
}
// SetupRequire configures the require system
func (l *NativeModuleLoader) SetupRequire(state *luajit.State) error {
// Initialize our module registry in Lua
return state.DoString(`
-- Initialize global module registry
__module_paths = {}
// Get file info
fileInfo, err := os.Stat(path)
if err != nil {
// File no longer exists or can't be accessed, remove from cache
c.modules.Delete(path)
return true
}
-- Setup fast module loading system
__module_results = {}
// Check if the cached module is up-to-date
entry, ok := val.(ModuleEntry)
if !ok {
// Invalid entry, remove it
c.modules.Delete(path)
return true
}
-- Create module preload table
package.preload = package.preload or {}
// Check if the file has been modified since it was cached
if fileInfo.ModTime().After(entry.LastUsed) {
// File is newer than the cached version, remove from cache
c.modules.Delete(path)
return true
}
return false
-- Setup module loader registry
__ready_modules = {}
`)
}
// RefreshAll checks all cached modules and refreshes those that have changed
func (c *RequireCache) RefreshAll() int {
refreshed := 0
// PreloadAllModules fully preloads modules for maximum performance
func (l *NativeModuleLoader) PreloadAllModules(state *luajit.State) error {
l.mu.Lock()
defer l.mu.Unlock()
// No need to refresh if flag isn't set
if !c.needsRefresh.Load() {
return 0
// Reset registry
l.registry = NewModuleRegistry()
// Reset preloaded modules in Lua
if err := state.DoString(`
-- Reset module registry
__module_paths = {}
__module_results = {}
-- Clear non-core modules from package.loaded
local core_modules = {
string = true, table = true, math = true, os = true,
package = true, io = true, coroutine = true, debug = true, _G = true
}
// For maximum performance, just clear everything
c.Clear()
for name in pairs(package.loaded) do
if not core_modules[name] then
package.loaded[name] = nil
end
end
// Reset the needsRefresh flag
c.needsRefresh.Store(false)
c.lastRefresh = time.Now()
-- Reset preload table
package.preload = package.preload or {}
for name in pairs(package.preload) do
package.preload[name] = nil
end
return refreshed
}
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
func UpdateRequirePaths(config *RequireConfig, scriptPath string) {
if scriptPath != "" {
config.ScriptDir = filepath.Dir(scriptPath)
}
}
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
func findAndCompileModule(
state *luajit.State,
cache *RequireCache,
config RequireConfig,
modName string,
) ([]byte, error) {
// Convert module name to relative path
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
// List of paths to check
paths := []string{}
// 1. Check adjacent to script directory first
if config.ScriptDir != "" {
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
-- Reset ready modules
__ready_modules = {}
`); err != nil {
return err
}
// 2. Check in lib directories
for _, libDir := range config.LibDirs {
if libDir != "" {
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
// Set up paths for require
absPaths := []string{}
pathsMap := map[string]bool{}
// Add script directory (absolute path)
if l.config.ScriptDir != "" {
absPath, err := filepath.Abs(l.config.ScriptDir)
if err == nil && !pathsMap[absPath] {
absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
pathsMap[absPath] = true
}
}
// If the cache needs refresh, handle it immediately
if cache.needsRefresh.Load() {
cache.Clear() // Complete reset for max performance
cache.needsRefresh.Store(false)
cache.lastRefresh = time.Now()
}
// Try each path
for _, path := range paths {
// Clean the path to handle .. and such (security)
cleanPath := filepath.Clean(path)
// Check for path traversal (extra safety)
if !isSubPath(config.ScriptDir, cleanPath) {
isValidLib := false
for _, libDir := range config.LibDirs {
if isSubPath(libDir, cleanPath) {
isValidLib = true
break
}
}
if !isValidLib {
continue // Skip paths outside allowed directories
}
}
// Check if already in cache
if value, ok := cache.modules.Load(cleanPath); ok {
entry, ok := value.(ModuleEntry)
if !ok {
// Legacy format, use it anyway
return value.([]byte), nil
}
// Check file modification time if cache is marked for refresh
if cache.needsRefresh.Load() {
fileInfo, err := os.Stat(cleanPath)
// Remove from cache if file changed or doesn't exist
if err != nil || (entry.LastUsed.Before(fileInfo.ModTime())) {
cache.modules.Delete(cleanPath)
// Continue to recompile
} else {
// Update last used time and return cached bytecode
entry.LastUsed = time.Now()
cache.modules.Store(cleanPath, entry)
return entry.Bytecode, nil
}
} else {
// Update last used time and return cached bytecode
entry.LastUsed = time.Now()
cache.modules.Store(cleanPath, entry)
return entry.Bytecode, nil
}
}
// Check if file exists
_, err := os.Stat(cleanPath)
if os.IsNotExist(err) {
// Add lib directories (absolute paths)
for _, dir := range l.config.LibDirs {
if dir == "" {
continue
}
// Read and compile the file
content, err := os.ReadFile(cleanPath)
if err != nil {
return nil, err
absPath, err := filepath.Abs(dir)
if err == nil && !pathsMap[absPath] {
absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
pathsMap[absPath] = true
}
}
// Compile to bytecode
bytecode, err := state.CompileBytecode(string(content), cleanPath)
if err != nil {
return nil, err
// Set package.path
escapedPathStr := escapeLuaString(strings.Join(absPaths, ";"))
if err := state.DoString(`package.path = "` + escapedPathStr + `"`); err != nil {
return err
}
// Store in cache with current time
cache.modules.Store(cleanPath, ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
// Process and preload all modules from lib directories
for _, dir := range l.config.LibDirs {
if dir == "" {
continue
}
absDir, err := filepath.Abs(dir)
if err != nil {
continue
}
// Find all Lua files
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
return nil
}
// Get module name
relPath, err := filepath.Rel(absDir, path)
if err != nil || strings.HasPrefix(relPath, "..") {
return nil
}
modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
// Register module path
l.registry.Register(path, modName)
// Register path in Lua
escapedPath := escapeLuaString(path)
escapedName := escapeLuaString(modName)
if err := state.DoString(`__module_paths["` + escapedName + `"] = "` + escapedPath + `"`); err != nil {
return nil
}
// Compile the module
content, err := os.ReadFile(path)
if err != nil {
return nil
}
// Precompile bytecode
bytecode, err := state.CompileBytecode(string(content), path)
if err != nil {
return nil
}
// Load bytecode
if err := state.LoadBytecode(bytecode, path); err != nil {
return nil
}
// Store in package.preload for fast loading
// We use string concat for efficiency (no string.format overhead)
luaCode := `
local modname = "` + escapedName + `"
local chunk = ...
package.preload[modname] = chunk
__ready_modules[modname] = true
`
if err := state.DoString(luaCode); err != nil {
state.Pop(1) // Remove chunk from stack
return nil
}
state.Pop(1) // Remove chunk from stack
return nil
})
return bytecode, nil
if err != nil {
return err
}
}
return nil, ErrModuleNotFound
// Install optimized require implementation
return state.DoString(`
-- Ultra-fast module loader
function __fast_require(env, modname)
-- 1. Check already loaded modules
if package.loaded[modname] then
return package.loaded[modname]
end
-- 2. Check preloaded chunks
if __ready_modules[modname] then
local loader = package.preload[modname]
if loader then
-- Set environment
setfenv(loader, env)
-- Execute and store result
local result = loader()
if result == nil then
result = true
end
-- Cache in shared registry
package.loaded[modname] = result
return result
end
end
-- 3. Direct file load as fallback
if __module_paths[modname] then
local path = __module_paths[modname]
local chunk, err = loadfile(path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- 4. Full path search as last resort
local err_msgs = {}
for path in package.path:gmatch("[^;]+") do
local file_path = path:gsub("?", modname:gsub("%.", "/"))
local chunk, err = loadfile(file_path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
table.insert(err_msgs, "no file '" .. file_path .. "'")
end
error("module '" .. modname .. "' not found:\n" .. table.concat(err_msgs, "\n"), 2)
end
-- Install require factory
function __setup_require(env)
-- Create highly optimized require with closure
env.require = function(modname)
return __fast_require(env, modname)
end
return env
end
`)
}
// isSubPath checks if path is contained within base directory
func isSubPath(baseDir, path string) bool {
if baseDir == "" {
return false
}
// Clean and normalize paths
baseDir = filepath.Clean(baseDir)
// NotifyFileChanged invalidates modules when files change
func (l *NativeModuleLoader) NotifyFileChanged(state *luajit.State, path string) bool {
path = filepath.Clean(path)
// Get relative path
rel, err := filepath.Rel(baseDir, path)
// Get module name from registry
modName, found := l.registry.GetModuleName(path)
if !found {
// Try to find by path for lib dirs
for _, libDir := range l.config.LibDirs {
absDir, err := filepath.Abs(libDir)
if err != nil {
continue
}
relPath, err := filepath.Rel(absDir, path)
if err != nil || strings.HasPrefix(relPath, "..") {
continue
}
if strings.HasSuffix(relPath, ".lua") {
modName = strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
found = true
break
}
}
}
if !found {
return false
}
// Check if path goes outside baseDir
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
// Update bytecode and invalidate caches
content, err := os.ReadFile(path)
if err != nil {
// File might have been deleted - just invalidate
escapedName := escapeLuaString(modName)
state.DoString(`
package.loaded["` + escapedName + `"] = nil
__ready_modules["` + escapedName + `"] = nil
if package.preload then
package.preload["` + escapedName + `"] = nil
end
`)
return true
}
// Recompile module
bytecode, err := state.CompileBytecode(string(content), path)
if err != nil {
// Invalid Lua - just invalidate
escapedName := escapeLuaString(modName)
state.DoString(`
package.loaded["` + escapedName + `"] = nil
__ready_modules["` + escapedName + `"] = nil
if package.preload then
package.preload["` + escapedName + `"] = nil
end
`)
return true
}
// Load bytecode
if err := state.LoadBytecode(bytecode, path); err != nil {
// Unable to load - just invalidate
escapedName := escapeLuaString(modName)
state.DoString(`
package.loaded["` + escapedName + `"] = nil
__ready_modules["` + escapedName + `"] = nil
if package.preload then
package.preload["` + escapedName + `"] = nil
end
`)
return true
}
// Update preload with new chunk
escapedName := escapeLuaString(modName)
luaCode := `
-- Update module in package.preload and clear loaded
package.loaded["` + escapedName + `"] = nil
package.preload["` + escapedName + `"] = ...
__ready_modules["` + escapedName + `"] = true
`
if err := state.DoString(luaCode); err != nil {
state.Pop(1) // Remove chunk from stack
return false
}
state.Pop(1) // Remove chunk from stack
return true
}
// ResetModules clears all non-core modules
func (l *NativeModuleLoader) ResetModules(state *luajit.State) error {
return state.DoString(`
local core_modules = {
string = true, table = true, math = true, os = true,
package = true, io = true, coroutine = true, debug = true, _G = true
}
for name in pairs(package.loaded) do
if not core_modules[name] then
package.loaded[name] = nil
end
end
`)
}

View File

@ -30,120 +30,134 @@ func (s *Sandbox) Setup(state *luajit.State) error {
return err
}
// Setup the sandbox creation logic with base environment reuse
// Create high-performance persistent environment
return state.DoString(`
-- Create the base environment once (static parts)
local __base_env = nil
-- Global shared environment (created once)
__env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Create function to initialize base environment
function __init_base_env()
if __base_env then return end
-- Initialize base environment once
if not __env_system.initialized then
-- Create base environment with all standard libraries
local base = {}
local env = {}
-- Add standard library modules (restricted)
env.string = string
env.table = table
env.math = math
env.os = {
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
env.tonumber = tonumber
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Add module loader
env.__go_load_module = __go_load_module
-- Basic functions
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.unpack = unpack
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
-- Add custom modules from sandbox registry
if __sandbox_modules then
for name, module in pairs(__sandbox_modules) do
env[name] = module
end
end
-- Copy custom global functions
for k, v in pairs(_G) do
if (type(v) == "function" or type(v) == "table") and
k ~= "__sandbox_modules" and
k ~= "__base_env" and
k ~= "__init_base_env" and
k ~= "__create_sandbox_env" and
k ~= "__run_sandboxed" and
k ~= "__setup_secure_require" and
k ~= "__go_load_module" and
k ~= "string" and k ~= "table" and k ~= "math" and
k ~= "os" and k ~= "io" and k ~= "debug" and
k ~= "package" and k ~= "bit" and k ~= "jit" and
k ~= "coroutine" and k ~= "_G" and k ~= "_VERSION" then
env[k] = v
end
end
__base_env = env
end
-- Create function that builds sandbox from base env
function __create_sandbox_env(ctx)
-- Initialize base env if needed
__init_base_env()
-- Create new environment using base as prototype
local env = {}
-- Copy from base environment
for k, v in pairs(__base_env) do
env[k] = v
end
-- Add isolated package.loaded table
env.package = {
loaded = {}
-- Package system is shared for performance
base.package = {
loaded = package.loaded,
path = package.path,
preload = package.preload
}
-- Add context if provided
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Fast environment creation with pre-allocation
function __get_sandbox_env(ctx)
local env
-- Try to reuse from pool
if __env_system.pool_size > 0 then
env = table.remove(__env_system.env_pool)
__env_system.pool_size = __env_system.pool_size - 1
-- Clear any previous context
env.ctx = ctx or nil
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
__index = __env_system.base_env
})
-- Set context if provided
if ctx then
env.ctx = ctx
end
-- Setup require function
env = __setup_secure_require(env)
-- Create metatable for isolation
local mt = {
__index = function(t, k)
return rawget(env, k)
end,
__newindex = function(t, k, v)
rawset(env, k, v)
-- Install the fast require implementation
env.require = function(modname)
return __fast_require(env, modname)
end
end
}
setmetatable(env, mt)
return env
end
-- Function to run code in sandbox
function __run_sandboxed(bytecode, ctx)
-- Create fresh environment for this request
local env = __create_sandbox_env(ctx)
-- Return environment to pool for reuse
function __recycle_env(env)
-- Only recycle if pool isn't full
if __env_system.pool_size < __env_system.max_pool_size then
-- Clear context reference to avoid memory leaks
env.ctx = nil
-- Set environment and execute
setfenv(bytecode, env)
return bytecode()
-- Add to pool
table.insert(__env_system.env_pool, env)
__env_system.pool_size = __env_system.pool_size + 1
end
end
-- Hyper-optimized sandbox executor
function __execute_sandbox(bytecode, ctx)
-- Get environment (from pool if available)
local env = __get_sandbox_env(ctx)
-- Set environment for bytecode
setfenv(bytecode, env)
-- Execute with protected call
local success, result = pcall(bytecode)
-- Recycle environment for future use
__recycle_env(env)
-- Process result
if not success then
error(result, 0)
end
return result
end
-- Run minimal GC for overall health
collectgarbage("step", 10)
`)
}
@ -191,11 +205,14 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
// Create context table if provided
if len(ctx) > 0 {
state.NewTable()
// Preallocate table with appropriate size
state.CreateTable(0, len(ctx))
// Add context entries
for k, v := range ctx {
state.PushString(k)
if err := state.PushValue(v); err != nil {
state.Pop(3)
state.Pop(2)
return nil, err
}
state.SetTable(-3)
@ -204,8 +221,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
state.PushNil() // No context
}
// Get sandbox function
state.GetGlobal("__run_sandboxed")
// Get optimized sandbox executor
state.GetGlobal("__execute_sandbox")
// Setup call with correct argument order
state.PushCopy(-3) // Copy bytecode function
@ -215,7 +232,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
state.Remove(-5) // Remove original bytecode
state.Remove(-4) // Remove original context
// Call sandbox function
// Call optimized sandbox executor
if err := state.Call(2, 1); err != nil {
return nil, err
}

View File

@ -10,17 +10,17 @@ func WatchLuaModules(luaRunner *runner.LuaRunner, libDirs []string, log *logger.
watchers := make([]*Watcher, 0, len(libDirs))
for _, dir := range libDirs {
// Create a directory-specific callback that only does minimal work
// Create a directory-specific callback that identifies changed files
dirCopy := dir // Capture for closure
callback := func() error {
log.Debug("Detected changes in Lua module directory: %s", dirCopy)
// Completely reset the cache to match fresh-start performance
luaRunner.RequireCache().Clear()
// Force reset of Lua's module registry
luaRunner.ResetPackageLoaded()
// Instead of clearing everything, use directory-level smart refresh
// This will scan lib directory and refresh all modified Lua modules
if err := luaRunner.ReloadAllModules(); err != nil {
log.Warning("Error reloading modules: %v", err)
}
return nil
}