remove sandbox
This commit is contained in:
parent
5774808064
commit
656ac1a703
|
@ -1,133 +0,0 @@
|
||||||
package luajit_bench
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
|
|
||||||
func BenchmarkSandboxLuaExecution(b *testing.B) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Simple Lua code that does some computation
|
|
||||||
code := `
|
|
||||||
local sum = 0
|
|
||||||
for i = 1, 100 do
|
|
||||||
sum = sum + i
|
|
||||||
end
|
|
||||||
return sum
|
|
||||||
`
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
result, err := sandbox.Run(code)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
if result != float64(5050) {
|
|
||||||
b.Fatalf("Incorrect result: %v", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
|
|
||||||
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Same code as above, but precompiled
|
|
||||||
code := `
|
|
||||||
local sum = 0
|
|
||||||
for i = 1, 100 do
|
|
||||||
sum = sum + i
|
|
||||||
end
|
|
||||||
return sum
|
|
||||||
`
|
|
||||||
|
|
||||||
// Compile the bytecode once
|
|
||||||
bytecode, err := sandbox.Compile(code)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to compile bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
result, err := sandbox.RunBytecode(bytecode)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to run bytecode: %v", err)
|
|
||||||
}
|
|
||||||
if result != float64(5050) {
|
|
||||||
b.Fatalf("Incorrect result: %v", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkSandboxComplexComputation measures performance with more complex computation
|
|
||||||
func BenchmarkSandboxComplexComputation(b *testing.B) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// More complex Lua code that calculates Fibonacci numbers
|
|
||||||
code := `
|
|
||||||
function fibonacci(n)
|
|
||||||
if n <= 1 then
|
|
||||||
return n
|
|
||||||
end
|
|
||||||
return fibonacci(n-1) + fibonacci(n-2)
|
|
||||||
end
|
|
||||||
|
|
||||||
return fibonacci(15) -- Not too high to avoid excessive runtime
|
|
||||||
`
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
result, err := sandbox.Run(code)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
if result != float64(610) {
|
|
||||||
b.Fatalf("Incorrect result: %v", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
|
|
||||||
func BenchmarkSandboxFunctionCall(b *testing.B) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Register a simple addition function
|
|
||||||
add := func(s *luajit.State) int {
|
|
||||||
a := s.ToNumber(1)
|
|
||||||
b := s.ToNumber(2)
|
|
||||||
s.PushNumber(a + b)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
err := sandbox.RegisterFunction("add", add)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to register function: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lua code that calls the Go function in a loop
|
|
||||||
code := `
|
|
||||||
local sum = 0
|
|
||||||
for i = 1, 100 do
|
|
||||||
sum = add(sum, i)
|
|
||||||
end
|
|
||||||
return sum
|
|
||||||
`
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
result, err := sandbox.Run(code)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
if result != float64(5050) {
|
|
||||||
b.Fatalf("Incorrect result: %v", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
609
sandbox.go
609
sandbox.go
|
@ -1,609 +0,0 @@
|
||||||
package luajit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LUA_MULTRET is the constant for multiple return values
|
|
||||||
const LUA_MULTRET = -1
|
|
||||||
|
|
||||||
// Sandbox provides a persistent Lua environment for executing scripts
|
|
||||||
type Sandbox struct {
|
|
||||||
state *State
|
|
||||||
mutex sync.Mutex
|
|
||||||
initialized bool
|
|
||||||
modules map[string]any
|
|
||||||
functions map[string]GoFunction
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox with standard libraries loaded
|
|
||||||
func NewSandbox() *Sandbox {
|
|
||||||
return &Sandbox{
|
|
||||||
state: New(),
|
|
||||||
initialized: false,
|
|
||||||
modules: make(map[string]any),
|
|
||||||
functions: make(map[string]GoFunction),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close releases all resources used by the sandbox
|
|
||||||
func (s *Sandbox) Close() {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state != nil {
|
|
||||||
s.state.Close()
|
|
||||||
s.state = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize sets up the environment system
|
|
||||||
func (s *Sandbox) Initialize() error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
return s.initializeUnlocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
// initializeUnlocked sets up the environment system without locking
|
|
||||||
func (s *Sandbox) initializeUnlocked() error {
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.initialized {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register modules
|
|
||||||
s.state.GetGlobal("__sandbox_modules")
|
|
||||||
if s.state.IsNil(-1) {
|
|
||||||
s.state.Pop(1)
|
|
||||||
s.state.NewTable()
|
|
||||||
s.state.SetGlobal("__sandbox_modules")
|
|
||||||
s.state.GetGlobal("__sandbox_modules")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add modules
|
|
||||||
for name, module := range s.modules {
|
|
||||||
s.state.PushString(name)
|
|
||||||
if err := s.state.PushValue(module); err != nil {
|
|
||||||
s.state.Pop(2)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.state.SetTable(-3)
|
|
||||||
}
|
|
||||||
s.state.Pop(1)
|
|
||||||
|
|
||||||
// Create simplified environment system
|
|
||||||
err := s.state.DoString(`
|
|
||||||
-- Global shared environment
|
|
||||||
__env_system = {
|
|
||||||
base_env = {}, -- Template environment
|
|
||||||
env_pool = {}, -- Pre-allocated environment pool
|
|
||||||
pool_size = 0, -- Current pool size
|
|
||||||
max_pool_size = 8 -- Maximum pool size
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Create base environment with standard libraries
|
|
||||||
local base = __env_system.base_env
|
|
||||||
|
|
||||||
-- Safe standard libraries
|
|
||||||
base.string = string
|
|
||||||
base.table = table
|
|
||||||
base.math = math
|
|
||||||
base.os = {
|
|
||||||
time = os.time,
|
|
||||||
date = os.date,
|
|
||||||
difftime = os.difftime,
|
|
||||||
clock = os.clock
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Basic functions
|
|
||||||
base.print = print
|
|
||||||
base.tonumber = tonumber
|
|
||||||
base.tostring = tostring
|
|
||||||
base.type = type
|
|
||||||
base.pairs = pairs
|
|
||||||
base.ipairs = ipairs
|
|
||||||
base.next = next
|
|
||||||
base.select = select
|
|
||||||
base.pcall = pcall
|
|
||||||
base.xpcall = xpcall
|
|
||||||
base.error = error
|
|
||||||
base.assert = assert
|
|
||||||
base.collectgarbage = collectgarbage
|
|
||||||
base.unpack = unpack or table.unpack
|
|
||||||
|
|
||||||
-- Package system
|
|
||||||
base.package = {
|
|
||||||
loaded = {},
|
|
||||||
path = package.path,
|
|
||||||
preload = {}
|
|
||||||
}
|
|
||||||
|
|
||||||
base.require = function(modname)
|
|
||||||
if base.package.loaded[modname] then
|
|
||||||
return base.package.loaded[modname]
|
|
||||||
end
|
|
||||||
|
|
||||||
local loader = base.package.preload[modname]
|
|
||||||
if type(loader) == "function" then
|
|
||||||
local result = loader(modname)
|
|
||||||
base.package.loaded[modname] = result or true
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
|
|
||||||
error("module '" .. modname .. "' not found", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Add registered custom modules
|
|
||||||
if __sandbox_modules then
|
|
||||||
for name, mod in pairs(__sandbox_modules) do
|
|
||||||
base[name] = mod
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Get an environment for execution
|
|
||||||
function __get_sandbox_env()
|
|
||||||
local env
|
|
||||||
|
|
||||||
-- Try to reuse from pool
|
|
||||||
if __env_system.pool_size > 0 then
|
|
||||||
env = table.remove(__env_system.env_pool)
|
|
||||||
__env_system.pool_size = __env_system.pool_size - 1
|
|
||||||
else
|
|
||||||
-- Create new environment with metatable inheritance
|
|
||||||
env = setmetatable({}, {
|
|
||||||
__index = __env_system.base_env
|
|
||||||
})
|
|
||||||
end
|
|
||||||
|
|
||||||
return env
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Return environment to pool for reuse
|
|
||||||
function __recycle_env(env)
|
|
||||||
if __env_system.pool_size < __env_system.max_pool_size then
|
|
||||||
-- Clear all fields except metatable
|
|
||||||
for k in pairs(env) do
|
|
||||||
env[k] = nil
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Add to pool
|
|
||||||
table.insert(__env_system.env_pool, env)
|
|
||||||
__env_system.pool_size = __env_system.pool_size + 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Execute code in sandbox
|
|
||||||
function __execute_sandbox(f)
|
|
||||||
-- Get environment
|
|
||||||
local env = __get_sandbox_env()
|
|
||||||
|
|
||||||
-- Set environment for function
|
|
||||||
setfenv(f, env)
|
|
||||||
|
|
||||||
-- Execute with protected call
|
|
||||||
local success, result = pcall(f)
|
|
||||||
|
|
||||||
-- Update base environment with new globals
|
|
||||||
for k, v in pairs(env) do
|
|
||||||
if k ~= "_G" and type(k) == "string" then
|
|
||||||
__env_system.base_env[k] = v
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Recycle environment
|
|
||||||
__recycle_env(env)
|
|
||||||
|
|
||||||
-- Process result
|
|
||||||
if not success then
|
|
||||||
error(result, 0)
|
|
||||||
end
|
|
||||||
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.initialized = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterFunction registers a Go function in the sandbox
|
|
||||||
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register function globally
|
|
||||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store function for re-registration
|
|
||||||
s.functions[name] = fn
|
|
||||||
|
|
||||||
// Add to base environment
|
|
||||||
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetGlobal sets a global variable in the sandbox base environment
|
|
||||||
func (s *Sandbox) SetGlobal(name string, value any) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Push the value onto the stack
|
|
||||||
if err := s.state.PushValue(value); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the global with the pushed value
|
|
||||||
s.state.SetGlobal(name)
|
|
||||||
|
|
||||||
// Add to base environment
|
|
||||||
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGlobal retrieves a global variable from the sandbox base environment
|
|
||||||
func (s *Sandbox) GetGlobal(name string) (any, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the global from the base environment
|
|
||||||
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run executes Lua code in the sandbox
|
|
||||||
func (s *Sandbox) Run(code string) (any, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simplified wrapper for multiple return values
|
|
||||||
wrappedCode := `
|
|
||||||
local function _execfunc()
|
|
||||||
` + code + `
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Process results to match expected format
|
|
||||||
local function _wrapresults(...)
|
|
||||||
local n = select('#', ...)
|
|
||||||
if n == 0 then
|
|
||||||
return nil
|
|
||||||
elseif n == 1 then
|
|
||||||
return select(1, ...)
|
|
||||||
else
|
|
||||||
local results = {}
|
|
||||||
for i = 1, n do
|
|
||||||
results[i] = select(i, ...)
|
|
||||||
end
|
|
||||||
return results
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return _wrapresults(_execfunc())
|
|
||||||
`
|
|
||||||
|
|
||||||
// Compile the code
|
|
||||||
if err := s.state.LoadString(wrappedCode); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the sandbox executor
|
|
||||||
s.state.GetGlobal("__execute_sandbox")
|
|
||||||
|
|
||||||
// Push the function as argument
|
|
||||||
s.state.PushCopy(-2)
|
|
||||||
s.state.Remove(-3)
|
|
||||||
|
|
||||||
// Execute in sandbox
|
|
||||||
if err := s.state.Call(1, 1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get result
|
|
||||||
result, err := s.state.ToValue(-1)
|
|
||||||
s.state.Pop(1)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.processResult(result), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunFile executes a Lua file in the sandbox
|
|
||||||
func (s *Sandbox) RunFile(filename string) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.state.DoFile(filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile compiles Lua code to bytecode
|
|
||||||
func (s *Sandbox) Compile(code string) ([]byte, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.state.CompileBytecode(code, "sandbox")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunBytecode executes precompiled Lua bytecode
|
|
||||||
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return nil, fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the bytecode
|
|
||||||
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the sandbox executor
|
|
||||||
s.state.GetGlobal("__execute_sandbox")
|
|
||||||
|
|
||||||
// Push bytecode function
|
|
||||||
s.state.PushCopy(-2)
|
|
||||||
s.state.Remove(-3)
|
|
||||||
|
|
||||||
// Execute in sandbox
|
|
||||||
if err := s.state.Call(1, 1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get result
|
|
||||||
result, err := s.state.ToValue(-1)
|
|
||||||
s.state.Pop(1)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.processResult(result), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadModule loads a Lua module
|
|
||||||
func (s *Sandbox) LoadModule(name string) error {
|
|
||||||
code := fmt.Sprintf("require('%s')", name)
|
|
||||||
_, err := s.Run(code)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPackagePath sets the sandbox package.path
|
|
||||||
func (s *Sandbox) SetPackagePath(path string) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update global package.path
|
|
||||||
if err := s.state.SetPackagePath(path); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update base environment's package.path
|
|
||||||
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPackagePath adds a path to the sandbox package.path
|
|
||||||
func (s *Sandbox) AddPackagePath(path string) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update global package.path
|
|
||||||
if err := s.state.AddPackagePath(path); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update base environment's package.path
|
|
||||||
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
|
||||||
func (s *Sandbox) AddModule(name string, module any) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.modules[name] = module
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPermanentLua adds Lua code to the environment permanently
|
|
||||||
func (s *Sandbox) AddPermanentLua(code string) error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simplified approach to add code to base environment
|
|
||||||
return s.state.DoString(`
|
|
||||||
local f, err = loadstring([=[` + code + `]=], "permanent")
|
|
||||||
if not f then error(err, 0) end
|
|
||||||
|
|
||||||
local env = setmetatable({}, {__index = __env_system.base_env})
|
|
||||||
setfenv(f, env)
|
|
||||||
|
|
||||||
local ok, err = pcall(f)
|
|
||||||
if not ok then error(err, 0) end
|
|
||||||
|
|
||||||
for k, v in pairs(env) do
|
|
||||||
__env_system.base_env[k] = v
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetEnvironment resets the sandbox to its initial state
|
|
||||||
func (s *Sandbox) ResetEnvironment() error {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.state == nil {
|
|
||||||
return fmt.Errorf("sandbox is closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the environment system
|
|
||||||
s.state.DoString(`__env_system = nil`)
|
|
||||||
|
|
||||||
// Reinitialize
|
|
||||||
s.initialized = false
|
|
||||||
if err := s.initializeUnlocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-register all functions
|
|
||||||
for name, fn := range s.functions {
|
|
||||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// unwrapResult processes results from Lua executions
|
|
||||||
func (s *Sandbox) processResult(result any) any {
|
|
||||||
// Handle []float64 (common LuaJIT return type)
|
|
||||||
if floats, ok := result.([]float64); ok {
|
|
||||||
if len(floats) == 1 {
|
|
||||||
// Single number - return as float64
|
|
||||||
return floats[0]
|
|
||||||
}
|
|
||||||
// Multiple numbers - MUST convert to []any for tests to pass
|
|
||||||
anySlice := make([]any, len(floats))
|
|
||||||
for i, v := range floats {
|
|
||||||
anySlice[i] = v
|
|
||||||
}
|
|
||||||
return anySlice
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle maps with numeric keys (Lua tables)
|
|
||||||
if m, ok := result.(map[string]any); ok {
|
|
||||||
// Handle return tables with special structure
|
|
||||||
if vals, ok := m[""]; ok {
|
|
||||||
// This is a special case used by some Lua returns
|
|
||||||
if arr, ok := vals.([]float64); ok {
|
|
||||||
// Convert to []any for consistency
|
|
||||||
anySlice := make([]any, len(arr))
|
|
||||||
for i, v := range arr {
|
|
||||||
anySlice[i] = v
|
|
||||||
}
|
|
||||||
return anySlice
|
|
||||||
}
|
|
||||||
return vals
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m) == 1 {
|
|
||||||
// Check for single value map
|
|
||||||
for k, v := range m {
|
|
||||||
if k == "1" {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Other array types should be preserved
|
|
||||||
return result
|
|
||||||
}
|
|
|
@ -1,650 +0,0 @@
|
||||||
package luajit_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestSandboxLifecycle tests sandbox creation and closing
|
|
||||||
func TestSandboxLifecycle(t *testing.T) {
|
|
||||||
// Create a new sandbox
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
if sandbox == nil {
|
|
||||||
t.Fatal("Failed to create sandbox")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the sandbox
|
|
||||||
sandbox.Close()
|
|
||||||
|
|
||||||
// Test idempotent close (should not panic)
|
|
||||||
sandbox.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxFunctionRegistration tests registering Go functions in the sandbox
|
|
||||||
func TestSandboxFunctionRegistration(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Register a simple addition function
|
|
||||||
add := func(s *luajit.State) int {
|
|
||||||
a := s.ToNumber(1)
|
|
||||||
b := s.ToNumber(2)
|
|
||||||
s.PushNumber(a + b)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
err := sandbox.RegisterFunction("add", add)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to register function: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the function
|
|
||||||
result, err := sandbox.Run("return add(3, 4)")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to execute function: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != float64(7) {
|
|
||||||
t.Fatalf("Expected 7, got %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
err = sandbox.RegisterFunction("test", add)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when registering function on closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxGlobalVariables tests setting and getting global variables
|
|
||||||
func TestSandboxGlobalVariables(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Set a global variable
|
|
||||||
err := sandbox.SetGlobal("answer", 42)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set global: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the global variable
|
|
||||||
value, err := sandbox.GetGlobal("answer")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get global: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value != float64(42) {
|
|
||||||
t.Fatalf("Expected 42, got %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test different types
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
value any
|
|
||||||
}{
|
|
||||||
{"nil_value", nil},
|
|
||||||
{"bool_value", true},
|
|
||||||
{"string_value", "hello"},
|
|
||||||
{"table_value", map[string]any{"key": "value"}},
|
|
||||||
{"array_value", []float64{1, 2, 3}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
err := sandbox.SetGlobal(tc.name, tc.value)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set global %s: %v", tc.name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
value, err := sandbox.GetGlobal(tc.name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get global %s: %v", tc.name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// For tables/arrays, just check they're not nil
|
|
||||||
switch tc.value.(type) {
|
|
||||||
case map[string]any, []float64:
|
|
||||||
if value == nil {
|
|
||||||
t.Fatalf("Expected non-nil for %s, got nil", tc.name)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if value != tc.value && !(tc.value == nil && value == nil) {
|
|
||||||
t.Fatalf("For %s: expected %v, got %v", tc.name, tc.value, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
err = sandbox.SetGlobal("test", 123)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when setting global on closed sandbox")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = sandbox.GetGlobal("test")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when getting global from closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxCodeExecution tests running Lua code in the sandbox
|
|
||||||
func TestSandboxCodeExecution(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Run simple code
|
|
||||||
result, err := sandbox.Run("return 'hello'")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != "hello" {
|
|
||||||
t.Fatalf("Expected 'hello', got %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run code with multiple return values
|
|
||||||
result, err = sandbox.Run("return 1, 2, 3")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should return array for multiple values
|
|
||||||
results, ok := result.([]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Expected array for multiple returns, got %T", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(results) != 3 || results[0] != float64(1) || results[1] != float64(2) || results[2] != float64(3) {
|
|
||||||
t.Fatalf("Expected [1, 2, 3], got %v", results)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run code that sets a global
|
|
||||||
_, err = sandbox.Run("global_var = 'set from Lua'")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
value, err := sandbox.GetGlobal("global_var")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get global: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value != "set from Lua" {
|
|
||||||
t.Fatalf("Expected 'set from Lua', got %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run invalid code
|
|
||||||
_, err = sandbox.Run("this is not valid Lua")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error for invalid code")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
_, err = sandbox.Run("return true")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when running code on closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxBytecodeExecution tests bytecode compilation and execution
|
|
||||||
func TestSandboxBytecodeExecution(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Compile code to bytecode
|
|
||||||
code := `
|
|
||||||
local function greet(name)
|
|
||||||
return "Hello, " .. name
|
|
||||||
end
|
|
||||||
return greet("World")
|
|
||||||
`
|
|
||||||
bytecode, err := sandbox.Compile(code)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to compile bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(bytecode) == 0 {
|
|
||||||
t.Fatal("Expected non-empty bytecode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the bytecode
|
|
||||||
result, err := sandbox.RunBytecode(bytecode)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != "Hello, World" {
|
|
||||||
t.Fatalf("Expected 'Hello, World', got %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test bytecode that sets a global
|
|
||||||
bytecode, err = sandbox.Compile("bytecode_var = 42")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to compile bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = sandbox.RunBytecode(bytecode)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
value, err := sandbox.GetGlobal("bytecode_var")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get global: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value != float64(42) {
|
|
||||||
t.Fatalf("Expected 42, got %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test invalid bytecode
|
|
||||||
_, err = sandbox.RunBytecode([]byte("not valid bytecode"))
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error for invalid bytecode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
_, err = sandbox.Compile("return true")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when compiling on closed sandbox")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = sandbox.RunBytecode(bytecode)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when running bytecode on closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxPersistence tests state persistence across executions
|
|
||||||
func TestSandboxPersistence(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Set up initial state
|
|
||||||
_, err := sandbox.Run(`
|
|
||||||
counter = 0
|
|
||||||
function increment()
|
|
||||||
counter = counter + 1
|
|
||||||
return counter
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set up state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run multiple executions
|
|
||||||
for i := 1; i <= 3; i++ {
|
|
||||||
result, err := sandbox.Run("return increment()")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != float64(i) {
|
|
||||||
t.Fatalf("Expected %d, got %v", i, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check final counter value
|
|
||||||
value, err := sandbox.GetGlobal("counter")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get global: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value != float64(3) {
|
|
||||||
t.Fatalf("Expected final counter to be 3, got %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test persistence with bytecode
|
|
||||||
bytecode, err := sandbox.Compile("return counter + 1")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to compile bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := sandbox.RunBytecode(bytecode)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != float64(4) {
|
|
||||||
t.Fatalf("Expected 4, got %v", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxConcurrency tests concurrent access to the sandbox
|
|
||||||
func TestSandboxConcurrency(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Set up a counter
|
|
||||||
_, err := sandbox.Run("counter = 0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set up counter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run concurrent increments
|
|
||||||
const numGoroutines = 10
|
|
||||||
const incrementsPerGoroutine = 100
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(numGoroutines)
|
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for j := 0; j < incrementsPerGoroutine; j++ {
|
|
||||||
_, err := sandbox.Run("counter = counter + 1")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to increment counter: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// Check the final counter value
|
|
||||||
value, err := sandbox.GetGlobal("counter")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get counter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := float64(numGoroutines * incrementsPerGoroutine)
|
|
||||||
if value != expected {
|
|
||||||
t.Fatalf("Expected counter to be %v, got %v", expected, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPermanentLua tests the AddPermanentLua method
|
|
||||||
func TestPermanentLua(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Add permanent Lua environment
|
|
||||||
err := sandbox.AddPermanentLua(`
|
|
||||||
-- Create utility functions
|
|
||||||
function double(x)
|
|
||||||
return x * 2
|
|
||||||
end
|
|
||||||
|
|
||||||
function square(x)
|
|
||||||
return x * x
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Create a protected environment
|
|
||||||
env = {
|
|
||||||
add = function(a, b) return a + b end,
|
|
||||||
sub = function(a, b) return a - b end
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to add permanent Lua: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test using the permanent functions
|
|
||||||
testCases := []struct {
|
|
||||||
code string
|
|
||||||
expected float64
|
|
||||||
}{
|
|
||||||
{"return double(5)", 10},
|
|
||||||
{"return square(4)", 16},
|
|
||||||
{"return env.add(10, 20)", 30},
|
|
||||||
{"return env.sub(50, 30)", 20},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
result, err := sandbox.Run(tc.code)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run code '%s': %v", tc.code, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != tc.expected {
|
|
||||||
t.Fatalf("For '%s': expected %v, got %v", tc.code, tc.expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test persistence of permanent code across executions
|
|
||||||
_, err = sandbox.Run("counter = 0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set counter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := sandbox.Run("counter = counter + 1; return double(counter)")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run code: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != float64(2) {
|
|
||||||
t.Fatalf("Expected 2, got %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
err = sandbox.AddPermanentLua("function test() end")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when adding permanent Lua to closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestResetEnvironment tests the ResetEnvironment method
|
|
||||||
func TestResetEnvironment(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Set up some Go functions and Lua code
|
|
||||||
sandbox.RegisterFunction("timeNow", func(s *luajit.State) int {
|
|
||||||
s.PushString("test")
|
|
||||||
return 1
|
|
||||||
})
|
|
||||||
|
|
||||||
sandbox.AddPermanentLua(`
|
|
||||||
function permanent()
|
|
||||||
return "permanent function"
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
|
|
||||||
_, err := sandbox.Run(`
|
|
||||||
temp_var = "will be reset"
|
|
||||||
function temp_func()
|
|
||||||
return "temp function"
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run setup code: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify everything is set up correctly
|
|
||||||
result, err := sandbox.Run("return timeNow()")
|
|
||||||
if err != nil || result != "test" {
|
|
||||||
t.Fatalf("Go function not working: %v, %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err = sandbox.Run("return permanent()")
|
|
||||||
if err != nil || result != "permanent function" {
|
|
||||||
t.Fatalf("Permanent function not working: %v, %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err = sandbox.Run("return temp_func()")
|
|
||||||
if err != nil || result != "temp function" {
|
|
||||||
t.Fatalf("Temp function not working: %v, %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
value, err := sandbox.GetGlobal("temp_var")
|
|
||||||
if err != nil || value != "will be reset" {
|
|
||||||
t.Fatalf("Temp var not set correctly: %v, %v", value, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset the environment
|
|
||||||
err = sandbox.ResetEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to reset environment: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check Go function survives reset
|
|
||||||
result, err = sandbox.Run("return timeNow()")
|
|
||||||
if err != nil || result != "test" {
|
|
||||||
t.Fatalf("Go function should survive reset: %v, %v", result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check permanent function is gone (it was added with AddPermanentLua but reset removes it)
|
|
||||||
_, err = sandbox.Run("return permanent()")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Permanent function should be gone after reset")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check temp function is gone
|
|
||||||
_, err = sandbox.Run("return temp_func()")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Temp function should be gone after reset")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check temp var is gone
|
|
||||||
value, err = sandbox.GetGlobal("temp_var")
|
|
||||||
if err != nil || value != nil {
|
|
||||||
t.Fatalf("Temp var should be nil after reset: %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
err = sandbox.ResetEnvironment()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when resetting closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRunFile tests the RunFile method
|
|
||||||
func TestRunFile(t *testing.T) {
|
|
||||||
// Create a temporary Lua file
|
|
||||||
tmpfile, err := createTempLuaFile("return 42")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temp file: %v", err)
|
|
||||||
}
|
|
||||||
defer removeTempFile(tmpfile)
|
|
||||||
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Run the file
|
|
||||||
err = sandbox.RunFile(tmpfile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test non-existent file
|
|
||||||
err = sandbox.RunFile("does_not_exist.lua")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error for non-existent file")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
err = sandbox.RunFile(tmpfile)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when running file on closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxPackagePath tests the SetPackagePath and AddPackagePath methods
|
|
||||||
func TestSandboxPackagePath(t *testing.T) {
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Set package path
|
|
||||||
testPath := "/test/path/?.lua"
|
|
||||||
err := sandbox.SetPackagePath(testPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set package path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check path was set
|
|
||||||
result, err := sandbox.Run("return package.path")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get package.path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != testPath {
|
|
||||||
t.Fatalf("Expected package.path to be %q, got %q", testPath, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to package path
|
|
||||||
addPath := "/another/path/?.lua"
|
|
||||||
err = sandbox.AddPackagePath(addPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to add package path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check path was updated
|
|
||||||
result, err = sandbox.Run("return package.path")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get updated package.path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := testPath + ";" + addPath
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected package.path to be %q, got %q", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test after sandbox is closed
|
|
||||||
sandbox.Close()
|
|
||||||
err = sandbox.SetPackagePath(testPath)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when setting package path on closed sandbox")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = sandbox.AddPackagePath(addPath)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when adding package path to closed sandbox")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSandboxLoadModule tests loading modules
|
|
||||||
func TestSandboxLoadModule(t *testing.T) {
|
|
||||||
// Skip for now since we don't have actual modules to load in the test environment
|
|
||||||
t.Skip("Skipping module loading test as it requires actual modules")
|
|
||||||
|
|
||||||
sandbox := luajit.NewSandbox()
|
|
||||||
defer sandbox.Close()
|
|
||||||
|
|
||||||
// Set package path to include current directory
|
|
||||||
err := sandbox.SetPackagePath("./?.lua")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set package path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to load a non-existent module
|
|
||||||
err = sandbox.LoadModule("nonexistent_module")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when loading non-existent module")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
|
|
||||||
// createTempLuaFile creates a temporary Lua file with the given content
|
|
||||||
func createTempLuaFile(content string) (string, error) {
|
|
||||||
tmpfile, err := os.CreateTemp("", "test-*.lua")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tmpfile.WriteString(content); err != nil {
|
|
||||||
os.Remove(tmpfile.Name())
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tmpfile.Close(); err != nil {
|
|
||||||
os.Remove(tmpfile.Name())
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tmpfile.Name(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeTempFile removes a temporary file
|
|
||||||
func removeTempFile(path string) {
|
|
||||||
os.Remove(path)
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user