diff --git a/core/runner/luarunner.go b/core/runner/luarunner.go index d28308d..2abdf96 100644 --- a/core/runner/luarunner.go +++ b/core/runner/luarunner.go @@ -20,17 +20,18 @@ type StateInitFunc func(*luajit.State) error // LuaRunner runs Lua scripts using a single Lua state type LuaRunner struct { - state *luajit.State // The Lua state - jobQueue chan job // Channel for incoming jobs - isRunning atomic.Bool // Flag indicating if the runner is active - mu sync.RWMutex // Mutex for thread safety - wg sync.WaitGroup // WaitGroup for clean shutdown - initFunc StateInitFunc // Optional function to initialize Lua state - bufferSize int // Size of the job queue buffer - requireCache *RequireCache // Cache for required modules - requireCfg *RequireConfig // Configuration for require paths - scriptDir string // Base directory for scripts - libDirs []string // Additional library directories + state *luajit.State // The Lua state + jobQueue chan job // Channel for incoming jobs + isRunning atomic.Bool // Flag indicating if the runner is active + mu sync.RWMutex // Mutex for thread safety + wg sync.WaitGroup // WaitGroup for clean shutdown + initFunc StateInitFunc // Optional function to initialize Lua state + bufferSize int // Size of the job queue buffer + requireCache *RequireCache // Cache for required modules + requireCfg *RequireConfig // Configuration for require paths + scriptDir string // Base directory for scripts + libDirs []string // Additional library directories + loaderFunc luajit.GoFunction // Keep reference to prevent GC } // NewRunner creates a new LuaRunner @@ -67,7 +68,87 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { } // Set up require functionality ONCE - if err := SetupRequire(state, runner.requireCache, runner.requireCfg); err != nil { + // Create and register the module loader function + moduleLoader := func(s *luajit.State) int { + // Get module name + modName := s.ToString(1) + if modName == "" { + s.PushString("module name required") + return -1 + } + + // Find and compile module + bytecode, err := findAndCompileModule(s, runner.requireCache, *runner.requireCfg, modName) + if err != nil { + if err == ErrModuleNotFound { + s.PushString("module '" + modName + "' not found") + } else { + s.PushString("error loading module: " + err.Error()) + } + return -1 // Return error + } + + // Load the bytecode + if err := s.LoadBytecode(bytecode, modName); err != nil { + s.PushString("error loading bytecode: " + err.Error()) + return -1 // Return error + } + + // Return the loaded function + return 1 + } + + // Store reference to prevent garbage collection + runner.loaderFunc = moduleLoader + + // Register with Lua state + if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil { + state.Close() + return nil, ErrInitFailed + } + + // Set up the require mechanism + setupRequireScript := ` + -- Create a secure require function for sandboxed environments + function __setup_secure_require(env) + -- Replace env.require with our secure version + env.require = function(modname) + -- Check if already loaded in package.loaded + if package.loaded[modname] then + return package.loaded[modname] + end + + -- Try to load the module using our Go loader + local loader = __go_load_module + + -- Load the module + local f, err = loader(modname) + if not f then + error(err or "failed to load module: " .. modname) + end + + -- Set the environment for the module + setfenv(f, env) + + -- Execute the module + local result = f() + + -- If module didn't return a value, use true + if result == nil then + result = true + end + + -- Cache the result + package.loaded[modname] = result + + return result + end + + return env + end + ` + + if err := state.DoString(setupRequireScript); err != nil { state.Close() return nil, ErrInitFailed } @@ -165,6 +246,9 @@ func (r *LuaRunner) setupSandbox() error { loaded = {} -- Table to store loaded modules } + -- Explicitly expose the module loader function + env.__go_load_module = __go_load_module + -- Set up secure require function env = __setup_secure_require(env) diff --git a/core/runner/require.go b/core/runner/require.go index 9a9da55..f9599d3 100644 --- a/core/runner/require.go +++ b/core/runner/require.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "sync" + "time" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) @@ -16,6 +17,12 @@ var ( ErrPathTraversal = errors.New("path traversal not allowed") ) +// ModuleEntry represents a cached module with timestamp +type ModuleEntry struct { + Bytecode []byte + LastUsed time.Time +} + // RequireConfig holds configuration for Lua's require function type RequireConfig struct { ScriptDir string // Base directory for script being executed @@ -24,94 +31,122 @@ type RequireConfig struct { // RequireCache is a thread-safe cache for loaded Lua modules type RequireCache struct { - modules sync.Map // Maps full file paths to compiled bytecode + modules sync.Map // Maps full file paths to ModuleEntry + mu sync.Mutex + maxItems int // Maximum number of modules to cache } // NewRequireCache creates a new, empty require cache func NewRequireCache() *RequireCache { return &RequireCache{ - modules: sync.Map{}, + modules: sync.Map{}, + maxItems: 100, // Default cache size - can be adjusted based on expected module load } } -// SetupRequire configures the Lua state with a secure require function -func SetupRequire(state *luajit.State, cache *RequireCache, config *RequireConfig) error { - // Register the loader function - err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int { - // Get module name - modName := s.ToString(1) - if modName == "" { - s.PushString("module name required") - return -1 - } +// SetCacheSize adjusts the maximum cache size +func (c *RequireCache) SetCacheSize(size int) { + if size > 0 { + c.mu.Lock() + c.maxItems = size + c.mu.Unlock() + } +} - // Use the pointer to the shared config - bytecode, err := findAndCompileModule(s, cache, *config, modName) - if err != nil { - if err == ErrModuleNotFound { - s.PushString("module '" + modName + "' not found") - } else { - s.PushString("error loading module: " + err.Error()) - } - return -1 // Return error - } - - // Load the bytecode - if err := s.LoadBytecode(bytecode, modName); err != nil { - s.PushString("error loading bytecode: " + err.Error()) - return -1 // Return error - } - - // Return the loaded function - return 1 +// Size returns the approximate number of items in the cache +func (c *RequireCache) Size() int { + size := 0 + c.modules.Range(func(_, _ interface{}) bool { + size++ + return true }) + return size +} - if err != nil { - return err +// Get retrieves a module from the cache, updating its last used time +func (c *RequireCache) Get(path string) ([]byte, bool) { + value, ok := c.modules.Load(path) + if !ok { + return nil, false } - // Set up the secure require implementation - setupScript := ` - -- Create a secure require function for sandboxed environments - function __setup_secure_require(env) - -- Replace env.require with our secure version - env.require = function(modname) - -- Check if already loaded in package.loaded - if package.loaded[modname] then - return package.loaded[modname] - end + entry, ok := value.(ModuleEntry) + if !ok { + // Handle legacy entries (plain bytecode) + bytecode, ok := value.([]byte) + if !ok { + return nil, false + } - -- Try to load the module using our Go loader - local loader = __go_load_module - - -- Load the module - local f, err = loader(modname) - if not f then - error(err or "failed to load module: " .. modname) - end - - -- Set the environment for the module - setfenv(f, env) - - -- Execute the module - local result = f() - - -- If module didn't return a value, use true - if result == nil then - result = true - end - - -- Cache the result - package.loaded[modname] = result - - return result - end - - return env - end - ` + // Convert to ModuleEntry and update + entry = ModuleEntry{ + Bytecode: bytecode, + LastUsed: time.Now(), + } + c.modules.Store(path, entry) + return bytecode, true + } - return state.DoString(setupScript) + // Update last used time + entry.LastUsed = time.Now() + c.modules.Store(path, entry) + + return entry.Bytecode, true +} + +// Store adds a module to the cache with LRU eviction +func (c *RequireCache) Store(path string, bytecode []byte) { + c.mu.Lock() + defer c.mu.Unlock() + + // Check if we need to evict + if c.Size() >= c.maxItems { + c.evictOldest() + } + + // Store the new entry + c.modules.Store(path, ModuleEntry{ + Bytecode: bytecode, + LastUsed: time.Now(), + }) +} + +// evictOldest removes the least recently used item from the cache +func (c *RequireCache) evictOldest() { + var oldestTime time.Time + var oldestKey string + first := true + + // Find oldest entry + c.modules.Range(func(key, value interface{}) bool { + // Handle different value types + var lastUsed time.Time + + switch v := value.(type) { + case ModuleEntry: + lastUsed = v.LastUsed + default: + // For non-ModuleEntry values, treat as oldest + if first { + oldestKey = key.(string) + first = false + return true + } + return true + } + + if first || lastUsed.Before(oldestTime) { + oldestTime = lastUsed + oldestKey = key.(string) + first = false + } + return true + }) + + // Remove oldest entry + if oldestKey != "" { + c.modules.Delete(oldestKey) + } } // UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader. @@ -166,9 +201,9 @@ func findAndCompileModule( } } - // Check if already in cache - if bytecode, ok := cache.modules.Load(cleanPath); ok { - return bytecode.([]byte), nil + // Check if already in cache - using our Get method to update LRU info + if bytecode, ok := cache.Get(cleanPath); ok { + return bytecode, nil } // Check if file exists @@ -189,8 +224,8 @@ func findAndCompileModule( return nil, err } - // Store in cache - cache.modules.Store(cleanPath, bytecode) + // Store in cache - using our Store method with LRU eviction + cache.Store(cleanPath, bytecode) return bytecode, nil } diff --git a/core/runner/require_test.go b/core/runner/require_test.go index 0663c65..3cf9fdd 100644 --- a/core/runner/require_test.go +++ b/core/runner/require_test.go @@ -1,7 +1,6 @@ package runner_test import ( - "fmt" "os" "path/filepath" "testing" @@ -152,13 +151,16 @@ func TestRequireSecurityBoundaries(t *testing.T) { libDir := filepath.Join(tempDir, "libs") secretDir := filepath.Join(tempDir, "secret") - if err := os.Mkdir(scriptDir, 0755); err != nil { + err = os.MkdirAll(scriptDir, 0755) + if err != nil { t.Fatalf("Failed to create script directory: %v", err) } - if err := os.Mkdir(libDir, 0755); err != nil { + err = os.MkdirAll(libDir, 0755) + if err != nil { t.Fatalf("Failed to create lib directory: %v", err) } - if err := os.Mkdir(secretDir, 0755); err != nil { + err = os.MkdirAll(secretDir, 0755) + if err != nil { t.Fatalf("Failed to create secret directory: %v", err) } @@ -167,17 +169,67 @@ func TestRequireSecurityBoundaries(t *testing.T) { local secret = "TOP SECRET" return secret ` - if err := os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644); err != nil { + err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644) + if err != nil { t.Fatalf("Failed to write secret module: %v", err) } // Create a normal module in lib normalModule := `return "normal module"` - if err := os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644); err != nil { + err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644) + if err != nil { t.Fatalf("Failed to write normal module: %v", err) } - // Test attempting to access file outside allowed paths + // Create a compile-and-run function that takes care of both compilation and execution + compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) { + // Compile + state := luajit.New() + if state == nil { + return nil, nil + } + defer state.Close() + + bytecode, err := state.CompileBytecode(scriptText, scriptName) + if err != nil { + return nil, err + } + + // Create and configure a new runner each time + r, err := runner.NewRunner( + runner.WithScriptDir(scriptDir), + runner.WithLibDirs(libDir), + ) + if err != nil { + return nil, err + } + defer r.Close() + + // Run + return r.Run(bytecode, nil, scriptPath) + } + + // Test that normal require works + normalScript := ` + local normal = require("normal") + return normal + ` + normalPath := filepath.Join(scriptDir, "normal_test.lua") + err = os.WriteFile(normalPath, []byte(normalScript), 0644) + if err != nil { + t.Fatalf("Failed to write normal script: %v", err) + } + + result, err := compileAndRun(normalScript, "normal_test.lua", normalPath) + if err != nil { + t.Fatalf("Failed to run normal script: %v", err) + } + + if result != "normal module" { + t.Errorf("Expected 'normal module', got %v", result) + } + + // Test path traversal attempts pathTraversalTests := []struct { name string script string @@ -187,94 +239,43 @@ func TestRequireSecurityBoundaries(t *testing.T) { script: ` -- Try path traversal local secret = require("../secret/secret") - return secret + return secret ~= nil `, }, { name: "Double dot traversal", script: ` local secret = require("..secret.secret") - return secret + return secret ~= nil `, }, { - name: "Absolute path", - script: fmt.Sprintf(` - local secret = require("%s") - return secret - `, filepath.Join(secretDir, "secret")), + name: "Absolute path traversal", + script: ` + local secret = require("` + filepath.Join(secretDir, "secret") + `") + return secret ~= nil + `, }, } - // Create and configure runner - luaRunner, err := runner.NewRunner( - runner.WithScriptDir(scriptDir), - runner.WithLibDirs(libDir), - ) - if err != nil { - t.Fatalf("Failed to create LuaRunner: %v", err) - } - defer luaRunner.Close() - - // Test each attempt at path traversal for _, tt := range pathTraversalTests { t.Run(tt.name, func(t *testing.T) { - // Write the script scriptPath := filepath.Join(scriptDir, tt.name+".lua") - if err := os.WriteFile(scriptPath, []byte(tt.script), 0644); err != nil { + err := os.WriteFile(scriptPath, []byte(tt.script), 0644) + if err != nil { t.Fatalf("Failed to write test script: %v", err) } - // Compile - state := luajit.New() - if state == nil { - t.Fatal("Failed to create Lua state") - } - defer state.Close() - - bytecode, err := state.CompileBytecode(tt.script, tt.name+".lua") + result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath) + // If there's an error, that's expected and good if err != nil { - t.Fatalf("Failed to compile script: %v", err) + return } - // Run and expect error - _, err = luaRunner.Run(bytecode, nil, scriptPath) - if err == nil { - t.Error("Expected error for path traversal, got nil") + // If no error, then the script should have returned false (couldn't get the module) + if result == true { + t.Errorf("Security breach! Script was able to access restricted module") } }) } - - // Test that we can still require valid modules - normalScript := ` - local normal = require("normal") - return normal - ` - scriptPath := filepath.Join(scriptDir, "normal_test.lua") - if err := os.WriteFile(scriptPath, []byte(normalScript), 0644); err != nil { - t.Fatalf("Failed to write normal test script: %v", err) - } - - // Compile - state := luajit.New() - if state == nil { - t.Fatal("Failed to create Lua state") - } - defer state.Close() - - bytecode, err := state.CompileBytecode(normalScript, "normal_test.lua") - if err != nil { - t.Fatalf("Failed to compile script: %v", err) - } - - // Run and expect success - result, err := luaRunner.Run(bytecode, nil, scriptPath) - if err != nil { - t.Fatalf("Failed to run normal script: %v", err) - } - - // Check result - if result != "normal module" { - t.Errorf("Expected 'normal module', got %v", result) - } }