fix tests, LRU cache require

This commit is contained in:
Sky Johnson 2025-03-19 22:41:04 -05:00
parent 03a03af96c
commit fc57a03a8e
3 changed files with 283 additions and 163 deletions

View File

@ -31,6 +31,7 @@ type LuaRunner struct {
requireCfg *RequireConfig // Configuration for require paths requireCfg *RequireConfig // Configuration for require paths
scriptDir string // Base directory for scripts scriptDir string // Base directory for scripts
libDirs []string // Additional library directories libDirs []string // Additional library directories
loaderFunc luajit.GoFunction // Keep reference to prevent GC
} }
// NewRunner creates a new LuaRunner // NewRunner creates a new LuaRunner
@ -67,7 +68,87 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
} }
// Set up require functionality ONCE // Set up require functionality ONCE
if err := SetupRequire(state, runner.requireCache, runner.requireCfg); err != nil { // Create and register the module loader function
moduleLoader := func(s *luajit.State) int {
// Get module name
modName := s.ToString(1)
if modName == "" {
s.PushString("module name required")
return -1
}
// Find and compile module
bytecode, err := findAndCompileModule(s, runner.requireCache, *runner.requireCfg, modName)
if err != nil {
if err == ErrModuleNotFound {
s.PushString("module '" + modName + "' not found")
} else {
s.PushString("error loading module: " + err.Error())
}
return -1 // Return error
}
// Load the bytecode
if err := s.LoadBytecode(bytecode, modName); err != nil {
s.PushString("error loading bytecode: " + err.Error())
return -1 // Return error
}
// Return the loaded function
return 1
}
// Store reference to prevent garbage collection
runner.loaderFunc = moduleLoader
// Register with Lua state
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up the require mechanism
setupRequireScript := `
-- Create a secure require function for sandboxed environments
function __setup_secure_require(env)
-- Replace env.require with our secure version
env.require = function(modname)
-- Check if already loaded in package.loaded
if package.loaded[modname] then
return package.loaded[modname]
end
-- Try to load the module using our Go loader
local loader = __go_load_module
-- Load the module
local f, err = loader(modname)
if not f then
error(err or "failed to load module: " .. modname)
end
-- Set the environment for the module
setfenv(f, env)
-- Execute the module
local result = f()
-- If module didn't return a value, use true
if result == nil then
result = true
end
-- Cache the result
package.loaded[modname] = result
return result
end
return env
end
`
if err := state.DoString(setupRequireScript); err != nil {
state.Close() state.Close()
return nil, ErrInitFailed return nil, ErrInitFailed
} }
@ -165,6 +246,9 @@ func (r *LuaRunner) setupSandbox() error {
loaded = {} -- Table to store loaded modules loaded = {} -- Table to store loaded modules
} }
-- Explicitly expose the module loader function
env.__go_load_module = __go_load_module
-- Set up secure require function -- Set up secure require function
env = __setup_secure_require(env) env = __setup_secure_require(env)

View File

@ -6,6 +6,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
@ -16,6 +17,12 @@ var (
ErrPathTraversal = errors.New("path traversal not allowed") ErrPathTraversal = errors.New("path traversal not allowed")
) )
// ModuleEntry represents a cached module with timestamp
type ModuleEntry struct {
Bytecode []byte
LastUsed time.Time
}
// RequireConfig holds configuration for Lua's require function // RequireConfig holds configuration for Lua's require function
type RequireConfig struct { type RequireConfig struct {
ScriptDir string // Base directory for script being executed ScriptDir string // Base directory for script being executed
@ -24,94 +31,122 @@ type RequireConfig struct {
// RequireCache is a thread-safe cache for loaded Lua modules // RequireCache is a thread-safe cache for loaded Lua modules
type RequireCache struct { type RequireCache struct {
modules sync.Map // Maps full file paths to compiled bytecode modules sync.Map // Maps full file paths to ModuleEntry
mu sync.Mutex
maxItems int // Maximum number of modules to cache
} }
// NewRequireCache creates a new, empty require cache // NewRequireCache creates a new, empty require cache
func NewRequireCache() *RequireCache { func NewRequireCache() *RequireCache {
return &RequireCache{ return &RequireCache{
modules: sync.Map{}, modules: sync.Map{},
maxItems: 100, // Default cache size - can be adjusted based on expected module load
} }
} }
// SetupRequire configures the Lua state with a secure require function // SetCacheSize adjusts the maximum cache size
func SetupRequire(state *luajit.State, cache *RequireCache, config *RequireConfig) error { func (c *RequireCache) SetCacheSize(size int) {
// Register the loader function if size > 0 {
err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int { c.mu.Lock()
// Get module name c.maxItems = size
modName := s.ToString(1) c.mu.Unlock()
if modName == "" { }
s.PushString("module name required") }
return -1
// Size returns the approximate number of items in the cache
func (c *RequireCache) Size() int {
size := 0
c.modules.Range(func(_, _ interface{}) bool {
size++
return true
})
return size
}
// Get retrieves a module from the cache, updating its last used time
func (c *RequireCache) Get(path string) ([]byte, bool) {
value, ok := c.modules.Load(path)
if !ok {
return nil, false
} }
// Use the pointer to the shared config entry, ok := value.(ModuleEntry)
bytecode, err := findAndCompileModule(s, cache, *config, modName) if !ok {
if err != nil { // Handle legacy entries (plain bytecode)
if err == ErrModuleNotFound { bytecode, ok := value.([]byte)
s.PushString("module '" + modName + "' not found") if !ok {
} else { return nil, false
s.PushString("error loading module: " + err.Error())
}
return -1 // Return error
} }
// Load the bytecode // Convert to ModuleEntry and update
if err := s.LoadBytecode(bytecode, modName); err != nil { entry = ModuleEntry{
s.PushString("error loading bytecode: " + err.Error()) Bytecode: bytecode,
return -1 // Return error LastUsed: time.Now(),
}
c.modules.Store(path, entry)
return bytecode, true
} }
// Return the loaded function // Update last used time
return 1 entry.LastUsed = time.Now()
c.modules.Store(path, entry)
return entry.Bytecode, true
}
// Store adds a module to the cache with LRU eviction
func (c *RequireCache) Store(path string, bytecode []byte) {
c.mu.Lock()
defer c.mu.Unlock()
// Check if we need to evict
if c.Size() >= c.maxItems {
c.evictOldest()
}
// Store the new entry
c.modules.Store(path, ModuleEntry{
Bytecode: bytecode,
LastUsed: time.Now(),
})
}
// evictOldest removes the least recently used item from the cache
func (c *RequireCache) evictOldest() {
var oldestTime time.Time
var oldestKey string
first := true
// Find oldest entry
c.modules.Range(func(key, value interface{}) bool {
// Handle different value types
var lastUsed time.Time
switch v := value.(type) {
case ModuleEntry:
lastUsed = v.LastUsed
default:
// For non-ModuleEntry values, treat as oldest
if first {
oldestKey = key.(string)
first = false
return true
}
return true
}
if first || lastUsed.Before(oldestTime) {
oldestTime = lastUsed
oldestKey = key.(string)
first = false
}
return true
}) })
if err != nil { // Remove oldest entry
return err if oldestKey != "" {
c.modules.Delete(oldestKey)
} }
// Set up the secure require implementation
setupScript := `
-- Create a secure require function for sandboxed environments
function __setup_secure_require(env)
-- Replace env.require with our secure version
env.require = function(modname)
-- Check if already loaded in package.loaded
if package.loaded[modname] then
return package.loaded[modname]
end
-- Try to load the module using our Go loader
local loader = __go_load_module
-- Load the module
local f, err = loader(modname)
if not f then
error(err or "failed to load module: " .. modname)
end
-- Set the environment for the module
setfenv(f, env)
-- Execute the module
local result = f()
-- If module didn't return a value, use true
if result == nil then
result = true
end
-- Cache the result
package.loaded[modname] = result
return result
end
return env
end
`
return state.DoString(setupScript)
} }
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader. // UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
@ -166,9 +201,9 @@ func findAndCompileModule(
} }
} }
// Check if already in cache // Check if already in cache - using our Get method to update LRU info
if bytecode, ok := cache.modules.Load(cleanPath); ok { if bytecode, ok := cache.Get(cleanPath); ok {
return bytecode.([]byte), nil return bytecode, nil
} }
// Check if file exists // Check if file exists
@ -189,8 +224,8 @@ func findAndCompileModule(
return nil, err return nil, err
} }
// Store in cache // Store in cache - using our Store method with LRU eviction
cache.modules.Store(cleanPath, bytecode) cache.Store(cleanPath, bytecode)
return bytecode, nil return bytecode, nil
} }

View File

@ -1,7 +1,6 @@
package runner_test package runner_test
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -152,13 +151,16 @@ func TestRequireSecurityBoundaries(t *testing.T) {
libDir := filepath.Join(tempDir, "libs") libDir := filepath.Join(tempDir, "libs")
secretDir := filepath.Join(tempDir, "secret") secretDir := filepath.Join(tempDir, "secret")
if err := os.Mkdir(scriptDir, 0755); err != nil { err = os.MkdirAll(scriptDir, 0755)
if err != nil {
t.Fatalf("Failed to create script directory: %v", err) t.Fatalf("Failed to create script directory: %v", err)
} }
if err := os.Mkdir(libDir, 0755); err != nil { err = os.MkdirAll(libDir, 0755)
if err != nil {
t.Fatalf("Failed to create lib directory: %v", err) t.Fatalf("Failed to create lib directory: %v", err)
} }
if err := os.Mkdir(secretDir, 0755); err != nil { err = os.MkdirAll(secretDir, 0755)
if err != nil {
t.Fatalf("Failed to create secret directory: %v", err) t.Fatalf("Failed to create secret directory: %v", err)
} }
@ -167,17 +169,67 @@ func TestRequireSecurityBoundaries(t *testing.T) {
local secret = "TOP SECRET" local secret = "TOP SECRET"
return secret return secret
` `
if err := os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644); err != nil { err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644)
if err != nil {
t.Fatalf("Failed to write secret module: %v", err) t.Fatalf("Failed to write secret module: %v", err)
} }
// Create a normal module in lib // Create a normal module in lib
normalModule := `return "normal module"` normalModule := `return "normal module"`
if err := os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644); err != nil { err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644)
if err != nil {
t.Fatalf("Failed to write normal module: %v", err) t.Fatalf("Failed to write normal module: %v", err)
} }
// Test attempting to access file outside allowed paths // Create a compile-and-run function that takes care of both compilation and execution
compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) {
// Compile
state := luajit.New()
if state == nil {
return nil, nil
}
defer state.Close()
bytecode, err := state.CompileBytecode(scriptText, scriptName)
if err != nil {
return nil, err
}
// Create and configure a new runner each time
r, err := runner.NewRunner(
runner.WithScriptDir(scriptDir),
runner.WithLibDirs(libDir),
)
if err != nil {
return nil, err
}
defer r.Close()
// Run
return r.Run(bytecode, nil, scriptPath)
}
// Test that normal require works
normalScript := `
local normal = require("normal")
return normal
`
normalPath := filepath.Join(scriptDir, "normal_test.lua")
err = os.WriteFile(normalPath, []byte(normalScript), 0644)
if err != nil {
t.Fatalf("Failed to write normal script: %v", err)
}
result, err := compileAndRun(normalScript, "normal_test.lua", normalPath)
if err != nil {
t.Fatalf("Failed to run normal script: %v", err)
}
if result != "normal module" {
t.Errorf("Expected 'normal module', got %v", result)
}
// Test path traversal attempts
pathTraversalTests := []struct { pathTraversalTests := []struct {
name string name string
script string script string
@ -187,94 +239,43 @@ func TestRequireSecurityBoundaries(t *testing.T) {
script: ` script: `
-- Try path traversal -- Try path traversal
local secret = require("../secret/secret") local secret = require("../secret/secret")
return secret return secret ~= nil
`, `,
}, },
{ {
name: "Double dot traversal", name: "Double dot traversal",
script: ` script: `
local secret = require("..secret.secret") local secret = require("..secret.secret")
return secret return secret ~= nil
`, `,
}, },
{ {
name: "Absolute path", name: "Absolute path traversal",
script: fmt.Sprintf(` script: `
local secret = require("%s") local secret = require("` + filepath.Join(secretDir, "secret") + `")
return secret return secret ~= nil
`, filepath.Join(secretDir, "secret")), `,
}, },
} }
// Create and configure runner
luaRunner, err := runner.NewRunner(
runner.WithScriptDir(scriptDir),
runner.WithLibDirs(libDir),
)
if err != nil {
t.Fatalf("Failed to create LuaRunner: %v", err)
}
defer luaRunner.Close()
// Test each attempt at path traversal
for _, tt := range pathTraversalTests { for _, tt := range pathTraversalTests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Write the script
scriptPath := filepath.Join(scriptDir, tt.name+".lua") scriptPath := filepath.Join(scriptDir, tt.name+".lua")
if err := os.WriteFile(scriptPath, []byte(tt.script), 0644); err != nil { err := os.WriteFile(scriptPath, []byte(tt.script), 0644)
if err != nil {
t.Fatalf("Failed to write test script: %v", err) t.Fatalf("Failed to write test script: %v", err)
} }
// Compile result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath)
state := luajit.New() // If there's an error, that's expected and good
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
bytecode, err := state.CompileBytecode(tt.script, tt.name+".lua")
if err != nil { if err != nil {
t.Fatalf("Failed to compile script: %v", err) return
} }
// Run and expect error // If no error, then the script should have returned false (couldn't get the module)
_, err = luaRunner.Run(bytecode, nil, scriptPath) if result == true {
if err == nil { t.Errorf("Security breach! Script was able to access restricted module")
t.Error("Expected error for path traversal, got nil")
} }
}) })
} }
// Test that we can still require valid modules
normalScript := `
local normal = require("normal")
return normal
`
scriptPath := filepath.Join(scriptDir, "normal_test.lua")
if err := os.WriteFile(scriptPath, []byte(normalScript), 0644); err != nil {
t.Fatalf("Failed to write normal test script: %v", err)
}
// Compile
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
bytecode, err := state.CompileBytecode(normalScript, "normal_test.lua")
if err != nil {
t.Fatalf("Failed to compile script: %v", err)
}
// Run and expect success
result, err := luaRunner.Run(bytecode, nil, scriptPath)
if err != nil {
t.Fatalf("Failed to run normal script: %v", err)
}
// Check result
if result != "normal module" {
t.Errorf("Expected 'normal module', got %v", result)
}
} }