fix tests, LRU cache require
This commit is contained in:
parent
03a03af96c
commit
fc57a03a8e
|
@ -20,17 +20,18 @@ type StateInitFunc func(*luajit.State) error
|
||||||
|
|
||||||
// LuaRunner runs Lua scripts using a single Lua state
|
// LuaRunner runs Lua scripts using a single Lua state
|
||||||
type LuaRunner struct {
|
type LuaRunner struct {
|
||||||
state *luajit.State // The Lua state
|
state *luajit.State // The Lua state
|
||||||
jobQueue chan job // Channel for incoming jobs
|
jobQueue chan job // Channel for incoming jobs
|
||||||
isRunning atomic.Bool // Flag indicating if the runner is active
|
isRunning atomic.Bool // Flag indicating if the runner is active
|
||||||
mu sync.RWMutex // Mutex for thread safety
|
mu sync.RWMutex // Mutex for thread safety
|
||||||
wg sync.WaitGroup // WaitGroup for clean shutdown
|
wg sync.WaitGroup // WaitGroup for clean shutdown
|
||||||
initFunc StateInitFunc // Optional function to initialize Lua state
|
initFunc StateInitFunc // Optional function to initialize Lua state
|
||||||
bufferSize int // Size of the job queue buffer
|
bufferSize int // Size of the job queue buffer
|
||||||
requireCache *RequireCache // Cache for required modules
|
requireCache *RequireCache // Cache for required modules
|
||||||
requireCfg *RequireConfig // Configuration for require paths
|
requireCfg *RequireConfig // Configuration for require paths
|
||||||
scriptDir string // Base directory for scripts
|
scriptDir string // Base directory for scripts
|
||||||
libDirs []string // Additional library directories
|
libDirs []string // Additional library directories
|
||||||
|
loaderFunc luajit.GoFunction // Keep reference to prevent GC
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRunner creates a new LuaRunner
|
// NewRunner creates a new LuaRunner
|
||||||
|
@ -67,7 +68,87 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up require functionality ONCE
|
// Set up require functionality ONCE
|
||||||
if err := SetupRequire(state, runner.requireCache, runner.requireCfg); err != nil {
|
// Create and register the module loader function
|
||||||
|
moduleLoader := func(s *luajit.State) int {
|
||||||
|
// Get module name
|
||||||
|
modName := s.ToString(1)
|
||||||
|
if modName == "" {
|
||||||
|
s.PushString("module name required")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find and compile module
|
||||||
|
bytecode, err := findAndCompileModule(s, runner.requireCache, *runner.requireCfg, modName)
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrModuleNotFound {
|
||||||
|
s.PushString("module '" + modName + "' not found")
|
||||||
|
} else {
|
||||||
|
s.PushString("error loading module: " + err.Error())
|
||||||
|
}
|
||||||
|
return -1 // Return error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the bytecode
|
||||||
|
if err := s.LoadBytecode(bytecode, modName); err != nil {
|
||||||
|
s.PushString("error loading bytecode: " + err.Error())
|
||||||
|
return -1 // Return error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the loaded function
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store reference to prevent garbage collection
|
||||||
|
runner.loaderFunc = moduleLoader
|
||||||
|
|
||||||
|
// Register with Lua state
|
||||||
|
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
|
||||||
|
state.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the require mechanism
|
||||||
|
setupRequireScript := `
|
||||||
|
-- Create a secure require function for sandboxed environments
|
||||||
|
function __setup_secure_require(env)
|
||||||
|
-- Replace env.require with our secure version
|
||||||
|
env.require = function(modname)
|
||||||
|
-- Check if already loaded in package.loaded
|
||||||
|
if package.loaded[modname] then
|
||||||
|
return package.loaded[modname]
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Try to load the module using our Go loader
|
||||||
|
local loader = __go_load_module
|
||||||
|
|
||||||
|
-- Load the module
|
||||||
|
local f, err = loader(modname)
|
||||||
|
if not f then
|
||||||
|
error(err or "failed to load module: " .. modname)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Set the environment for the module
|
||||||
|
setfenv(f, env)
|
||||||
|
|
||||||
|
-- Execute the module
|
||||||
|
local result = f()
|
||||||
|
|
||||||
|
-- If module didn't return a value, use true
|
||||||
|
if result == nil then
|
||||||
|
result = true
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Cache the result
|
||||||
|
package.loaded[modname] = result
|
||||||
|
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
return env
|
||||||
|
end
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := state.DoString(setupRequireScript); err != nil {
|
||||||
state.Close()
|
state.Close()
|
||||||
return nil, ErrInitFailed
|
return nil, ErrInitFailed
|
||||||
}
|
}
|
||||||
|
@ -165,6 +246,9 @@ func (r *LuaRunner) setupSandbox() error {
|
||||||
loaded = {} -- Table to store loaded modules
|
loaded = {} -- Table to store loaded modules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
-- Explicitly expose the module loader function
|
||||||
|
env.__go_load_module = __go_load_module
|
||||||
|
|
||||||
-- Set up secure require function
|
-- Set up secure require function
|
||||||
env = __setup_secure_require(env)
|
env = __setup_secure_require(env)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
@ -16,6 +17,12 @@ var (
|
||||||
ErrPathTraversal = errors.New("path traversal not allowed")
|
ErrPathTraversal = errors.New("path traversal not allowed")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ModuleEntry represents a cached module with timestamp
|
||||||
|
type ModuleEntry struct {
|
||||||
|
Bytecode []byte
|
||||||
|
LastUsed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// RequireConfig holds configuration for Lua's require function
|
// RequireConfig holds configuration for Lua's require function
|
||||||
type RequireConfig struct {
|
type RequireConfig struct {
|
||||||
ScriptDir string // Base directory for script being executed
|
ScriptDir string // Base directory for script being executed
|
||||||
|
@ -24,94 +31,122 @@ type RequireConfig struct {
|
||||||
|
|
||||||
// RequireCache is a thread-safe cache for loaded Lua modules
|
// RequireCache is a thread-safe cache for loaded Lua modules
|
||||||
type RequireCache struct {
|
type RequireCache struct {
|
||||||
modules sync.Map // Maps full file paths to compiled bytecode
|
modules sync.Map // Maps full file paths to ModuleEntry
|
||||||
|
mu sync.Mutex
|
||||||
|
maxItems int // Maximum number of modules to cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequireCache creates a new, empty require cache
|
// NewRequireCache creates a new, empty require cache
|
||||||
func NewRequireCache() *RequireCache {
|
func NewRequireCache() *RequireCache {
|
||||||
return &RequireCache{
|
return &RequireCache{
|
||||||
modules: sync.Map{},
|
modules: sync.Map{},
|
||||||
|
maxItems: 100, // Default cache size - can be adjusted based on expected module load
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupRequire configures the Lua state with a secure require function
|
// SetCacheSize adjusts the maximum cache size
|
||||||
func SetupRequire(state *luajit.State, cache *RequireCache, config *RequireConfig) error {
|
func (c *RequireCache) SetCacheSize(size int) {
|
||||||
// Register the loader function
|
if size > 0 {
|
||||||
err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int {
|
c.mu.Lock()
|
||||||
// Get module name
|
c.maxItems = size
|
||||||
modName := s.ToString(1)
|
c.mu.Unlock()
|
||||||
if modName == "" {
|
}
|
||||||
s.PushString("module name required")
|
}
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the pointer to the shared config
|
// Size returns the approximate number of items in the cache
|
||||||
bytecode, err := findAndCompileModule(s, cache, *config, modName)
|
func (c *RequireCache) Size() int {
|
||||||
if err != nil {
|
size := 0
|
||||||
if err == ErrModuleNotFound {
|
c.modules.Range(func(_, _ interface{}) bool {
|
||||||
s.PushString("module '" + modName + "' not found")
|
size++
|
||||||
} else {
|
return true
|
||||||
s.PushString("error loading module: " + err.Error())
|
|
||||||
}
|
|
||||||
return -1 // Return error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the bytecode
|
|
||||||
if err := s.LoadBytecode(bytecode, modName); err != nil {
|
|
||||||
s.PushString("error loading bytecode: " + err.Error())
|
|
||||||
return -1 // Return error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the loaded function
|
|
||||||
return 1
|
|
||||||
})
|
})
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
// Get retrieves a module from the cache, updating its last used time
|
||||||
return err
|
func (c *RequireCache) Get(path string) ([]byte, bool) {
|
||||||
|
value, ok := c.modules.Load(path)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the secure require implementation
|
entry, ok := value.(ModuleEntry)
|
||||||
setupScript := `
|
if !ok {
|
||||||
-- Create a secure require function for sandboxed environments
|
// Handle legacy entries (plain bytecode)
|
||||||
function __setup_secure_require(env)
|
bytecode, ok := value.([]byte)
|
||||||
-- Replace env.require with our secure version
|
if !ok {
|
||||||
env.require = function(modname)
|
return nil, false
|
||||||
-- Check if already loaded in package.loaded
|
}
|
||||||
if package.loaded[modname] then
|
|
||||||
return package.loaded[modname]
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Try to load the module using our Go loader
|
// Convert to ModuleEntry and update
|
||||||
local loader = __go_load_module
|
entry = ModuleEntry{
|
||||||
|
Bytecode: bytecode,
|
||||||
|
LastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
c.modules.Store(path, entry)
|
||||||
|
return bytecode, true
|
||||||
|
}
|
||||||
|
|
||||||
-- Load the module
|
// Update last used time
|
||||||
local f, err = loader(modname)
|
entry.LastUsed = time.Now()
|
||||||
if not f then
|
c.modules.Store(path, entry)
|
||||||
error(err or "failed to load module: " .. modname)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Set the environment for the module
|
return entry.Bytecode, true
|
||||||
setfenv(f, env)
|
}
|
||||||
|
|
||||||
-- Execute the module
|
// Store adds a module to the cache with LRU eviction
|
||||||
local result = f()
|
func (c *RequireCache) Store(path string, bytecode []byte) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
-- If module didn't return a value, use true
|
// Check if we need to evict
|
||||||
if result == nil then
|
if c.Size() >= c.maxItems {
|
||||||
result = true
|
c.evictOldest()
|
||||||
end
|
}
|
||||||
|
|
||||||
-- Cache the result
|
// Store the new entry
|
||||||
package.loaded[modname] = result
|
c.modules.Store(path, ModuleEntry{
|
||||||
|
Bytecode: bytecode,
|
||||||
|
LastUsed: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
// evictOldest removes the least recently used item from the cache
|
||||||
end
|
func (c *RequireCache) evictOldest() {
|
||||||
|
var oldestTime time.Time
|
||||||
|
var oldestKey string
|
||||||
|
first := true
|
||||||
|
|
||||||
return env
|
// Find oldest entry
|
||||||
end
|
c.modules.Range(func(key, value interface{}) bool {
|
||||||
`
|
// Handle different value types
|
||||||
|
var lastUsed time.Time
|
||||||
|
|
||||||
return state.DoString(setupScript)
|
switch v := value.(type) {
|
||||||
|
case ModuleEntry:
|
||||||
|
lastUsed = v.LastUsed
|
||||||
|
default:
|
||||||
|
// For non-ModuleEntry values, treat as oldest
|
||||||
|
if first {
|
||||||
|
oldestKey = key.(string)
|
||||||
|
first = false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if first || lastUsed.Before(oldestTime) {
|
||||||
|
oldestTime = lastUsed
|
||||||
|
oldestKey = key.(string)
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Remove oldest entry
|
||||||
|
if oldestKey != "" {
|
||||||
|
c.modules.Delete(oldestKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
|
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
|
||||||
|
@ -166,9 +201,9 @@ func findAndCompileModule(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if already in cache
|
// Check if already in cache - using our Get method to update LRU info
|
||||||
if bytecode, ok := cache.modules.Load(cleanPath); ok {
|
if bytecode, ok := cache.Get(cleanPath); ok {
|
||||||
return bytecode.([]byte), nil
|
return bytecode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if file exists
|
// Check if file exists
|
||||||
|
@ -189,8 +224,8 @@ func findAndCompileModule(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store in cache
|
// Store in cache - using our Store method with LRU eviction
|
||||||
cache.modules.Store(cleanPath, bytecode)
|
cache.Store(cleanPath, bytecode)
|
||||||
|
|
||||||
return bytecode, nil
|
return bytecode, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package runner_test
|
package runner_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -152,13 +151,16 @@ func TestRequireSecurityBoundaries(t *testing.T) {
|
||||||
libDir := filepath.Join(tempDir, "libs")
|
libDir := filepath.Join(tempDir, "libs")
|
||||||
secretDir := filepath.Join(tempDir, "secret")
|
secretDir := filepath.Join(tempDir, "secret")
|
||||||
|
|
||||||
if err := os.Mkdir(scriptDir, 0755); err != nil {
|
err = os.MkdirAll(scriptDir, 0755)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to create script directory: %v", err)
|
t.Fatalf("Failed to create script directory: %v", err)
|
||||||
}
|
}
|
||||||
if err := os.Mkdir(libDir, 0755); err != nil {
|
err = os.MkdirAll(libDir, 0755)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to create lib directory: %v", err)
|
t.Fatalf("Failed to create lib directory: %v", err)
|
||||||
}
|
}
|
||||||
if err := os.Mkdir(secretDir, 0755); err != nil {
|
err = os.MkdirAll(secretDir, 0755)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to create secret directory: %v", err)
|
t.Fatalf("Failed to create secret directory: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,17 +169,67 @@ func TestRequireSecurityBoundaries(t *testing.T) {
|
||||||
local secret = "TOP SECRET"
|
local secret = "TOP SECRET"
|
||||||
return secret
|
return secret
|
||||||
`
|
`
|
||||||
if err := os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644); err != nil {
|
err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to write secret module: %v", err)
|
t.Fatalf("Failed to write secret module: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a normal module in lib
|
// Create a normal module in lib
|
||||||
normalModule := `return "normal module"`
|
normalModule := `return "normal module"`
|
||||||
if err := os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644); err != nil {
|
err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to write normal module: %v", err)
|
t.Fatalf("Failed to write normal module: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test attempting to access file outside allowed paths
|
// Create a compile-and-run function that takes care of both compilation and execution
|
||||||
|
compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) {
|
||||||
|
// Compile
|
||||||
|
state := luajit.New()
|
||||||
|
if state == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
defer state.Close()
|
||||||
|
|
||||||
|
bytecode, err := state.CompileBytecode(scriptText, scriptName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and configure a new runner each time
|
||||||
|
r, err := runner.NewRunner(
|
||||||
|
runner.WithScriptDir(scriptDir),
|
||||||
|
runner.WithLibDirs(libDir),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
// Run
|
||||||
|
return r.Run(bytecode, nil, scriptPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that normal require works
|
||||||
|
normalScript := `
|
||||||
|
local normal = require("normal")
|
||||||
|
return normal
|
||||||
|
`
|
||||||
|
normalPath := filepath.Join(scriptDir, "normal_test.lua")
|
||||||
|
err = os.WriteFile(normalPath, []byte(normalScript), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write normal script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := compileAndRun(normalScript, "normal_test.lua", normalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to run normal script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != "normal module" {
|
||||||
|
t.Errorf("Expected 'normal module', got %v", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test path traversal attempts
|
||||||
pathTraversalTests := []struct {
|
pathTraversalTests := []struct {
|
||||||
name string
|
name string
|
||||||
script string
|
script string
|
||||||
|
@ -187,94 +239,43 @@ func TestRequireSecurityBoundaries(t *testing.T) {
|
||||||
script: `
|
script: `
|
||||||
-- Try path traversal
|
-- Try path traversal
|
||||||
local secret = require("../secret/secret")
|
local secret = require("../secret/secret")
|
||||||
return secret
|
return secret ~= nil
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Double dot traversal",
|
name: "Double dot traversal",
|
||||||
script: `
|
script: `
|
||||||
local secret = require("..secret.secret")
|
local secret = require("..secret.secret")
|
||||||
return secret
|
return secret ~= nil
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Absolute path",
|
name: "Absolute path traversal",
|
||||||
script: fmt.Sprintf(`
|
script: `
|
||||||
local secret = require("%s")
|
local secret = require("` + filepath.Join(secretDir, "secret") + `")
|
||||||
return secret
|
return secret ~= nil
|
||||||
`, filepath.Join(secretDir, "secret")),
|
`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and configure runner
|
|
||||||
luaRunner, err := runner.NewRunner(
|
|
||||||
runner.WithScriptDir(scriptDir),
|
|
||||||
runner.WithLibDirs(libDir),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create LuaRunner: %v", err)
|
|
||||||
}
|
|
||||||
defer luaRunner.Close()
|
|
||||||
|
|
||||||
// Test each attempt at path traversal
|
|
||||||
for _, tt := range pathTraversalTests {
|
for _, tt := range pathTraversalTests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Write the script
|
|
||||||
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
|
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
|
||||||
if err := os.WriteFile(scriptPath, []byte(tt.script), 0644); err != nil {
|
err := os.WriteFile(scriptPath, []byte(tt.script), 0644)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to write test script: %v", err)
|
t.Fatalf("Failed to write test script: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile
|
result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath)
|
||||||
state := luajit.New()
|
// If there's an error, that's expected and good
|
||||||
if state == nil {
|
|
||||||
t.Fatal("Failed to create Lua state")
|
|
||||||
}
|
|
||||||
defer state.Close()
|
|
||||||
|
|
||||||
bytecode, err := state.CompileBytecode(tt.script, tt.name+".lua")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to compile script: %v", err)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run and expect error
|
// If no error, then the script should have returned false (couldn't get the module)
|
||||||
_, err = luaRunner.Run(bytecode, nil, scriptPath)
|
if result == true {
|
||||||
if err == nil {
|
t.Errorf("Security breach! Script was able to access restricted module")
|
||||||
t.Error("Expected error for path traversal, got nil")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that we can still require valid modules
|
|
||||||
normalScript := `
|
|
||||||
local normal = require("normal")
|
|
||||||
return normal
|
|
||||||
`
|
|
||||||
scriptPath := filepath.Join(scriptDir, "normal_test.lua")
|
|
||||||
if err := os.WriteFile(scriptPath, []byte(normalScript), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write normal test script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile
|
|
||||||
state := luajit.New()
|
|
||||||
if state == nil {
|
|
||||||
t.Fatal("Failed to create Lua state")
|
|
||||||
}
|
|
||||||
defer state.Close()
|
|
||||||
|
|
||||||
bytecode, err := state.CompileBytecode(normalScript, "normal_test.lua")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to compile script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run and expect success
|
|
||||||
result, err := luaRunner.Run(bytecode, nil, scriptPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run normal script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check result
|
|
||||||
if result != "normal module" {
|
|
||||||
t.Errorf("Expected 'normal module', got %v", result)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user