remove sandbox
This commit is contained in:
parent
5774808064
commit
656ac1a703
|
@ -1,133 +0,0 @@
|
|||
package luajit_bench
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
|
||||
func BenchmarkSandboxLuaExecution(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Simple Lua code that does some computation
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 100 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.Run(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
if result != float64(5050) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
|
||||
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Same code as above, but precompiled
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 100 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`
|
||||
|
||||
// Compile the bytecode once
|
||||
bytecode, err := sandbox.Compile(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to compile bytecode: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.RunBytecode(bytecode)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run bytecode: %v", err)
|
||||
}
|
||||
if result != float64(5050) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSandboxComplexComputation measures performance with more complex computation
|
||||
func BenchmarkSandboxComplexComputation(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// More complex Lua code that calculates Fibonacci numbers
|
||||
code := `
|
||||
function fibonacci(n)
|
||||
if n <= 1 then
|
||||
return n
|
||||
end
|
||||
return fibonacci(n-1) + fibonacci(n-2)
|
||||
end
|
||||
|
||||
return fibonacci(15) -- Not too high to avoid excessive runtime
|
||||
`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.Run(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
if result != float64(610) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
|
||||
func BenchmarkSandboxFunctionCall(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Register a simple addition function
|
||||
add := func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1
|
||||
}
|
||||
|
||||
err := sandbox.RegisterFunction("add", add)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to register function: %v", err)
|
||||
}
|
||||
|
||||
// Lua code that calls the Go function in a loop
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 100 do
|
||||
sum = add(sum, i)
|
||||
end
|
||||
return sum
|
||||
`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.Run(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
if result != float64(5050) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
609
sandbox.go
609
sandbox.go
|
@ -1,609 +0,0 @@
|
|||
package luajit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// LUA_MULTRET is the constant for multiple return values
|
||||
const LUA_MULTRET = -1
|
||||
|
||||
// Sandbox provides a persistent Lua environment for executing scripts
|
||||
type Sandbox struct {
|
||||
state *State
|
||||
mutex sync.Mutex
|
||||
initialized bool
|
||||
modules map[string]any
|
||||
functions map[string]GoFunction
|
||||
}
|
||||
|
||||
// NewSandbox creates a new sandbox with standard libraries loaded
|
||||
func NewSandbox() *Sandbox {
|
||||
return &Sandbox{
|
||||
state: New(),
|
||||
initialized: false,
|
||||
modules: make(map[string]any),
|
||||
functions: make(map[string]GoFunction),
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources used by the sandbox
|
||||
func (s *Sandbox) Close() {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state != nil {
|
||||
s.state.Close()
|
||||
s.state = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize sets up the environment system
|
||||
func (s *Sandbox) Initialize() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
return s.initializeUnlocked()
|
||||
}
|
||||
|
||||
// initializeUnlocked sets up the environment system without locking
|
||||
func (s *Sandbox) initializeUnlocked() error {
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
if s.initialized {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register modules
|
||||
s.state.GetGlobal("__sandbox_modules")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
s.state.NewTable()
|
||||
s.state.SetGlobal("__sandbox_modules")
|
||||
s.state.GetGlobal("__sandbox_modules")
|
||||
}
|
||||
|
||||
// Add modules
|
||||
for name, module := range s.modules {
|
||||
s.state.PushString(name)
|
||||
if err := s.state.PushValue(module); err != nil {
|
||||
s.state.Pop(2)
|
||||
return err
|
||||
}
|
||||
s.state.SetTable(-3)
|
||||
}
|
||||
s.state.Pop(1)
|
||||
|
||||
// Create simplified environment system
|
||||
err := s.state.DoString(`
|
||||
-- Global shared environment
|
||||
__env_system = {
|
||||
base_env = {}, -- Template environment
|
||||
env_pool = {}, -- Pre-allocated environment pool
|
||||
pool_size = 0, -- Current pool size
|
||||
max_pool_size = 8 -- Maximum pool size
|
||||
}
|
||||
|
||||
-- Create base environment with standard libraries
|
||||
local base = __env_system.base_env
|
||||
|
||||
-- Safe standard libraries
|
||||
base.string = string
|
||||
base.table = table
|
||||
base.math = math
|
||||
base.os = {
|
||||
time = os.time,
|
||||
date = os.date,
|
||||
difftime = os.difftime,
|
||||
clock = os.clock
|
||||
}
|
||||
|
||||
-- Basic functions
|
||||
base.print = print
|
||||
base.tonumber = tonumber
|
||||
base.tostring = tostring
|
||||
base.type = type
|
||||
base.pairs = pairs
|
||||
base.ipairs = ipairs
|
||||
base.next = next
|
||||
base.select = select
|
||||
base.pcall = pcall
|
||||
base.xpcall = xpcall
|
||||
base.error = error
|
||||
base.assert = assert
|
||||
base.collectgarbage = collectgarbage
|
||||
base.unpack = unpack or table.unpack
|
||||
|
||||
-- Package system
|
||||
base.package = {
|
||||
loaded = {},
|
||||
path = package.path,
|
||||
preload = {}
|
||||
}
|
||||
|
||||
base.require = function(modname)
|
||||
if base.package.loaded[modname] then
|
||||
return base.package.loaded[modname]
|
||||
end
|
||||
|
||||
local loader = base.package.preload[modname]
|
||||
if type(loader) == "function" then
|
||||
local result = loader(modname)
|
||||
base.package.loaded[modname] = result or true
|
||||
return result
|
||||
end
|
||||
|
||||
error("module '" .. modname .. "' not found", 2)
|
||||
end
|
||||
|
||||
-- Add registered custom modules
|
||||
if __sandbox_modules then
|
||||
for name, mod in pairs(__sandbox_modules) do
|
||||
base[name] = mod
|
||||
end
|
||||
end
|
||||
|
||||
-- Get an environment for execution
|
||||
function __get_sandbox_env()
|
||||
local env
|
||||
|
||||
-- Try to reuse from pool
|
||||
if __env_system.pool_size > 0 then
|
||||
env = table.remove(__env_system.env_pool)
|
||||
__env_system.pool_size = __env_system.pool_size - 1
|
||||
else
|
||||
-- Create new environment with metatable inheritance
|
||||
env = setmetatable({}, {
|
||||
__index = __env_system.base_env
|
||||
})
|
||||
end
|
||||
|
||||
return env
|
||||
end
|
||||
|
||||
-- Return environment to pool for reuse
|
||||
function __recycle_env(env)
|
||||
if __env_system.pool_size < __env_system.max_pool_size then
|
||||
-- Clear all fields except metatable
|
||||
for k in pairs(env) do
|
||||
env[k] = nil
|
||||
end
|
||||
|
||||
-- Add to pool
|
||||
table.insert(__env_system.env_pool, env)
|
||||
__env_system.pool_size = __env_system.pool_size + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Execute code in sandbox
|
||||
function __execute_sandbox(f)
|
||||
-- Get environment
|
||||
local env = __get_sandbox_env()
|
||||
|
||||
-- Set environment for function
|
||||
setfenv(f, env)
|
||||
|
||||
-- Execute with protected call
|
||||
local success, result = pcall(f)
|
||||
|
||||
-- Update base environment with new globals
|
||||
for k, v in pairs(env) do
|
||||
if k ~= "_G" and type(k) == "string" then
|
||||
__env_system.base_env[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
-- Recycle environment
|
||||
__recycle_env(env)
|
||||
|
||||
-- Process result
|
||||
if not success then
|
||||
error(result, 0)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.initialized = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterFunction registers a Go function in the sandbox
|
||||
func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Register function globally
|
||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store function for re-registration
|
||||
s.functions[name] = fn
|
||||
|
||||
// Add to base environment
|
||||
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
||||
}
|
||||
|
||||
// SetGlobal sets a global variable in the sandbox base environment
|
||||
func (s *Sandbox) SetGlobal(name string, value any) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Push the value onto the stack
|
||||
if err := s.state.PushValue(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the global with the pushed value
|
||||
s.state.SetGlobal(name)
|
||||
|
||||
// Add to base environment
|
||||
return s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name)
|
||||
}
|
||||
|
||||
// GetGlobal retrieves a global variable from the sandbox base environment
|
||||
func (s *Sandbox) GetGlobal(name string) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the global from the base environment
|
||||
return s.state.ExecuteWithResult(`return __env_system.base_env["` + name + `"]`)
|
||||
}
|
||||
|
||||
// Run executes Lua code in the sandbox
|
||||
func (s *Sandbox) Run(code string) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Simplified wrapper for multiple return values
|
||||
wrappedCode := `
|
||||
local function _execfunc()
|
||||
` + code + `
|
||||
end
|
||||
|
||||
-- Process results to match expected format
|
||||
local function _wrapresults(...)
|
||||
local n = select('#', ...)
|
||||
if n == 0 then
|
||||
return nil
|
||||
elseif n == 1 then
|
||||
return select(1, ...)
|
||||
else
|
||||
local results = {}
|
||||
for i = 1, n do
|
||||
results[i] = select(i, ...)
|
||||
end
|
||||
return results
|
||||
end
|
||||
end
|
||||
|
||||
return _wrapresults(_execfunc())
|
||||
`
|
||||
|
||||
// Compile the code
|
||||
if err := s.state.LoadString(wrappedCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Push the function as argument
|
||||
s.state.PushCopy(-2)
|
||||
s.state.Remove(-3)
|
||||
|
||||
// Execute in sandbox
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.processResult(result), nil
|
||||
}
|
||||
|
||||
// RunFile executes a Lua file in the sandbox
|
||||
func (s *Sandbox) RunFile(filename string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.DoFile(filename)
|
||||
}
|
||||
|
||||
// Compile compiles Lua code to bytecode
|
||||
func (s *Sandbox) Compile(code string) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.CompileBytecode(code, "sandbox")
|
||||
}
|
||||
|
||||
// RunBytecode executes precompiled Lua bytecode
|
||||
func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return nil, fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Load the bytecode
|
||||
if err := s.state.LoadBytecode(bytecode, "sandbox"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Push bytecode function
|
||||
s.state.PushCopy(-2)
|
||||
s.state.Remove(-3)
|
||||
|
||||
// Execute in sandbox
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.processResult(result), nil
|
||||
}
|
||||
|
||||
// LoadModule loads a Lua module
|
||||
func (s *Sandbox) LoadModule(name string) error {
|
||||
code := fmt.Sprintf("require('%s')", name)
|
||||
_, err := s.Run(code)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetPackagePath sets the sandbox package.path
|
||||
func (s *Sandbox) SetPackagePath(path string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Update global package.path
|
||||
if err := s.state.SetPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
||||
}
|
||||
|
||||
// AddPackagePath adds a path to the sandbox package.path
|
||||
func (s *Sandbox) AddPackagePath(path string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Update global package.path
|
||||
if err := s.state.AddPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`__env_system.base_env.package.path = package.path`)
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (s *Sandbox) AddModule(name string, module any) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
s.modules[name] = module
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPermanentLua adds Lua code to the environment permanently
|
||||
func (s *Sandbox) AddPermanentLua(code string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Initialize if needed
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Simplified approach to add code to base environment
|
||||
return s.state.DoString(`
|
||||
local f, err = loadstring([=[` + code + `]=], "permanent")
|
||||
if not f then error(err, 0) end
|
||||
|
||||
local env = setmetatable({}, {__index = __env_system.base_env})
|
||||
setfenv(f, env)
|
||||
|
||||
local ok, err = pcall(f)
|
||||
if not ok then error(err, 0) end
|
||||
|
||||
for k, v in pairs(env) do
|
||||
__env_system.base_env[k] = v
|
||||
end
|
||||
`)
|
||||
}
|
||||
|
||||
// ResetEnvironment resets the sandbox to its initial state
|
||||
func (s *Sandbox) ResetEnvironment() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.state == nil {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Clear the environment system
|
||||
s.state.DoString(`__env_system = nil`)
|
||||
|
||||
// Reinitialize
|
||||
s.initialized = false
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Re-register all functions
|
||||
for name, fn := range s.functions {
|
||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.state.DoString(`__env_system.base_env["` + name + `"] = ` + name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// unwrapResult processes results from Lua executions
|
||||
func (s *Sandbox) processResult(result any) any {
|
||||
// Handle []float64 (common LuaJIT return type)
|
||||
if floats, ok := result.([]float64); ok {
|
||||
if len(floats) == 1 {
|
||||
// Single number - return as float64
|
||||
return floats[0]
|
||||
}
|
||||
// Multiple numbers - MUST convert to []any for tests to pass
|
||||
anySlice := make([]any, len(floats))
|
||||
for i, v := range floats {
|
||||
anySlice[i] = v
|
||||
}
|
||||
return anySlice
|
||||
}
|
||||
|
||||
// Handle maps with numeric keys (Lua tables)
|
||||
if m, ok := result.(map[string]any); ok {
|
||||
// Handle return tables with special structure
|
||||
if vals, ok := m[""]; ok {
|
||||
// This is a special case used by some Lua returns
|
||||
if arr, ok := vals.([]float64); ok {
|
||||
// Convert to []any for consistency
|
||||
anySlice := make([]any, len(arr))
|
||||
for i, v := range arr {
|
||||
anySlice[i] = v
|
||||
}
|
||||
return anySlice
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
if len(m) == 1 {
|
||||
// Check for single value map
|
||||
for k, v := range m {
|
||||
if k == "1" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Other array types should be preserved
|
||||
return result
|
||||
}
|
|
@ -1,650 +0,0 @@
|
|||
package luajit_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// TestSandboxLifecycle tests sandbox creation and closing
|
||||
func TestSandboxLifecycle(t *testing.T) {
|
||||
// Create a new sandbox
|
||||
sandbox := luajit.NewSandbox()
|
||||
if sandbox == nil {
|
||||
t.Fatal("Failed to create sandbox")
|
||||
}
|
||||
|
||||
// Close the sandbox
|
||||
sandbox.Close()
|
||||
|
||||
// Test idempotent close (should not panic)
|
||||
sandbox.Close()
|
||||
}
|
||||
|
||||
// TestSandboxFunctionRegistration tests registering Go functions in the sandbox
|
||||
func TestSandboxFunctionRegistration(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Register a simple addition function
|
||||
add := func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1
|
||||
}
|
||||
|
||||
err := sandbox.RegisterFunction("add", add)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register function: %v", err)
|
||||
}
|
||||
|
||||
// Test the function
|
||||
result, err := sandbox.Run("return add(3, 4)")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute function: %v", err)
|
||||
}
|
||||
|
||||
if result != float64(7) {
|
||||
t.Fatalf("Expected 7, got %v", result)
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
err = sandbox.RegisterFunction("test", add)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when registering function on closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSandboxGlobalVariables tests setting and getting global variables
|
||||
func TestSandboxGlobalVariables(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Set a global variable
|
||||
err := sandbox.SetGlobal("answer", 42)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set global: %v", err)
|
||||
}
|
||||
|
||||
// Get the global variable
|
||||
value, err := sandbox.GetGlobal("answer")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get global: %v", err)
|
||||
}
|
||||
|
||||
if value != float64(42) {
|
||||
t.Fatalf("Expected 42, got %v", value)
|
||||
}
|
||||
|
||||
// Test different types
|
||||
testCases := []struct {
|
||||
name string
|
||||
value any
|
||||
}{
|
||||
{"nil_value", nil},
|
||||
{"bool_value", true},
|
||||
{"string_value", "hello"},
|
||||
{"table_value", map[string]any{"key": "value"}},
|
||||
{"array_value", []float64{1, 2, 3}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
err := sandbox.SetGlobal(tc.name, tc.value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set global %s: %v", tc.name, err)
|
||||
}
|
||||
|
||||
value, err := sandbox.GetGlobal(tc.name)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get global %s: %v", tc.name, err)
|
||||
}
|
||||
|
||||
// For tables/arrays, just check they're not nil
|
||||
switch tc.value.(type) {
|
||||
case map[string]any, []float64:
|
||||
if value == nil {
|
||||
t.Fatalf("Expected non-nil for %s, got nil", tc.name)
|
||||
}
|
||||
default:
|
||||
if value != tc.value && !(tc.value == nil && value == nil) {
|
||||
t.Fatalf("For %s: expected %v, got %v", tc.name, tc.value, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
err = sandbox.SetGlobal("test", 123)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when setting global on closed sandbox")
|
||||
}
|
||||
|
||||
_, err = sandbox.GetGlobal("test")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when getting global from closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSandboxCodeExecution tests running Lua code in the sandbox
|
||||
func TestSandboxCodeExecution(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Run simple code
|
||||
result, err := sandbox.Run("return 'hello'")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
|
||||
if result != "hello" {
|
||||
t.Fatalf("Expected 'hello', got %v", result)
|
||||
}
|
||||
|
||||
// Run code with multiple return values
|
||||
result, err = sandbox.Run("return 1, 2, 3")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
|
||||
// Should return array for multiple values
|
||||
results, ok := result.([]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected array for multiple returns, got %T", result)
|
||||
}
|
||||
|
||||
if len(results) != 3 || results[0] != float64(1) || results[1] != float64(2) || results[2] != float64(3) {
|
||||
t.Fatalf("Expected [1, 2, 3], got %v", results)
|
||||
}
|
||||
|
||||
// Run code that sets a global
|
||||
_, err = sandbox.Run("global_var = 'set from Lua'")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
|
||||
value, err := sandbox.GetGlobal("global_var")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get global: %v", err)
|
||||
}
|
||||
|
||||
if value != "set from Lua" {
|
||||
t.Fatalf("Expected 'set from Lua', got %v", value)
|
||||
}
|
||||
|
||||
// Run invalid code
|
||||
_, err = sandbox.Run("this is not valid Lua")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid code")
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
_, err = sandbox.Run("return true")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when running code on closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSandboxBytecodeExecution tests bytecode compilation and execution
|
||||
func TestSandboxBytecodeExecution(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Compile code to bytecode
|
||||
code := `
|
||||
local function greet(name)
|
||||
return "Hello, " .. name
|
||||
end
|
||||
return greet("World")
|
||||
`
|
||||
bytecode, err := sandbox.Compile(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile bytecode: %v", err)
|
||||
}
|
||||
|
||||
if len(bytecode) == 0 {
|
||||
t.Fatal("Expected non-empty bytecode")
|
||||
}
|
||||
|
||||
// Run the bytecode
|
||||
result, err := sandbox.RunBytecode(bytecode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run bytecode: %v", err)
|
||||
}
|
||||
|
||||
if result != "Hello, World" {
|
||||
t.Fatalf("Expected 'Hello, World', got %v", result)
|
||||
}
|
||||
|
||||
// Test bytecode that sets a global
|
||||
bytecode, err = sandbox.Compile("bytecode_var = 42")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile bytecode: %v", err)
|
||||
}
|
||||
|
||||
_, err = sandbox.RunBytecode(bytecode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run bytecode: %v", err)
|
||||
}
|
||||
|
||||
value, err := sandbox.GetGlobal("bytecode_var")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get global: %v", err)
|
||||
}
|
||||
|
||||
if value != float64(42) {
|
||||
t.Fatalf("Expected 42, got %v", value)
|
||||
}
|
||||
|
||||
// Test invalid bytecode
|
||||
_, err = sandbox.RunBytecode([]byte("not valid bytecode"))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid bytecode")
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
_, err = sandbox.Compile("return true")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when compiling on closed sandbox")
|
||||
}
|
||||
|
||||
_, err = sandbox.RunBytecode(bytecode)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when running bytecode on closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSandboxPersistence tests state persistence across executions
|
||||
func TestSandboxPersistence(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Set up initial state
|
||||
_, err := sandbox.Run(`
|
||||
counter = 0
|
||||
function increment()
|
||||
counter = counter + 1
|
||||
return counter
|
||||
end
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set up state: %v", err)
|
||||
}
|
||||
|
||||
// Run multiple executions
|
||||
for i := 1; i <= 3; i++ {
|
||||
result, err := sandbox.Run("return increment()")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
|
||||
if result != float64(i) {
|
||||
t.Fatalf("Expected %d, got %v", i, result)
|
||||
}
|
||||
}
|
||||
|
||||
// Check final counter value
|
||||
value, err := sandbox.GetGlobal("counter")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get global: %v", err)
|
||||
}
|
||||
|
||||
if value != float64(3) {
|
||||
t.Fatalf("Expected final counter to be 3, got %v", value)
|
||||
}
|
||||
|
||||
// Test persistence with bytecode
|
||||
bytecode, err := sandbox.Compile("return counter + 1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile bytecode: %v", err)
|
||||
}
|
||||
|
||||
result, err := sandbox.RunBytecode(bytecode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run bytecode: %v", err)
|
||||
}
|
||||
|
||||
if result != float64(4) {
|
||||
t.Fatalf("Expected 4, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSandboxConcurrency tests concurrent access to the sandbox
|
||||
func TestSandboxConcurrency(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Set up a counter
|
||||
_, err := sandbox.Run("counter = 0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set up counter: %v", err)
|
||||
}
|
||||
|
||||
// Run concurrent increments
|
||||
const numGoroutines = 10
|
||||
const incrementsPerGoroutine = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < incrementsPerGoroutine; j++ {
|
||||
_, err := sandbox.Run("counter = counter + 1")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to increment counter: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check the final counter value
|
||||
value, err := sandbox.GetGlobal("counter")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get counter: %v", err)
|
||||
}
|
||||
|
||||
expected := float64(numGoroutines * incrementsPerGoroutine)
|
||||
if value != expected {
|
||||
t.Fatalf("Expected counter to be %v, got %v", expected, value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPermanentLua tests the AddPermanentLua method
|
||||
func TestPermanentLua(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Add permanent Lua environment
|
||||
err := sandbox.AddPermanentLua(`
|
||||
-- Create utility functions
|
||||
function double(x)
|
||||
return x * 2
|
||||
end
|
||||
|
||||
function square(x)
|
||||
return x * x
|
||||
end
|
||||
|
||||
-- Create a protected environment
|
||||
env = {
|
||||
add = function(a, b) return a + b end,
|
||||
sub = function(a, b) return a - b end
|
||||
}
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add permanent Lua: %v", err)
|
||||
}
|
||||
|
||||
// Test using the permanent functions
|
||||
testCases := []struct {
|
||||
code string
|
||||
expected float64
|
||||
}{
|
||||
{"return double(5)", 10},
|
||||
{"return square(4)", 16},
|
||||
{"return env.add(10, 20)", 30},
|
||||
{"return env.sub(50, 30)", 20},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result, err := sandbox.Run(tc.code)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run code '%s': %v", tc.code, err)
|
||||
}
|
||||
|
||||
if result != tc.expected {
|
||||
t.Fatalf("For '%s': expected %v, got %v", tc.code, tc.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// Test persistence of permanent code across executions
|
||||
_, err = sandbox.Run("counter = 0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set counter: %v", err)
|
||||
}
|
||||
|
||||
result, err := sandbox.Run("counter = counter + 1; return double(counter)")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
|
||||
if result != float64(2) {
|
||||
t.Fatalf("Expected 2, got %v", result)
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
err = sandbox.AddPermanentLua("function test() end")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when adding permanent Lua to closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResetEnvironment tests the ResetEnvironment method
|
||||
func TestResetEnvironment(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Set up some Go functions and Lua code
|
||||
sandbox.RegisterFunction("timeNow", func(s *luajit.State) int {
|
||||
s.PushString("test")
|
||||
return 1
|
||||
})
|
||||
|
||||
sandbox.AddPermanentLua(`
|
||||
function permanent()
|
||||
return "permanent function"
|
||||
end
|
||||
`)
|
||||
|
||||
_, err := sandbox.Run(`
|
||||
temp_var = "will be reset"
|
||||
function temp_func()
|
||||
return "temp function"
|
||||
end
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run setup code: %v", err)
|
||||
}
|
||||
|
||||
// Verify everything is set up correctly
|
||||
result, err := sandbox.Run("return timeNow()")
|
||||
if err != nil || result != "test" {
|
||||
t.Fatalf("Go function not working: %v, %v", result, err)
|
||||
}
|
||||
|
||||
result, err = sandbox.Run("return permanent()")
|
||||
if err != nil || result != "permanent function" {
|
||||
t.Fatalf("Permanent function not working: %v, %v", result, err)
|
||||
}
|
||||
|
||||
result, err = sandbox.Run("return temp_func()")
|
||||
if err != nil || result != "temp function" {
|
||||
t.Fatalf("Temp function not working: %v, %v", result, err)
|
||||
}
|
||||
|
||||
value, err := sandbox.GetGlobal("temp_var")
|
||||
if err != nil || value != "will be reset" {
|
||||
t.Fatalf("Temp var not set correctly: %v, %v", value, err)
|
||||
}
|
||||
|
||||
// Reset the environment
|
||||
err = sandbox.ResetEnvironment()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to reset environment: %v", err)
|
||||
}
|
||||
|
||||
// Check Go function survives reset
|
||||
result, err = sandbox.Run("return timeNow()")
|
||||
if err != nil || result != "test" {
|
||||
t.Fatalf("Go function should survive reset: %v, %v", result, err)
|
||||
}
|
||||
|
||||
// Check permanent function is gone (it was added with AddPermanentLua but reset removes it)
|
||||
_, err = sandbox.Run("return permanent()")
|
||||
if err == nil {
|
||||
t.Fatal("Permanent function should be gone after reset")
|
||||
}
|
||||
|
||||
// Check temp function is gone
|
||||
_, err = sandbox.Run("return temp_func()")
|
||||
if err == nil {
|
||||
t.Fatal("Temp function should be gone after reset")
|
||||
}
|
||||
|
||||
// Check temp var is gone
|
||||
value, err = sandbox.GetGlobal("temp_var")
|
||||
if err != nil || value != nil {
|
||||
t.Fatalf("Temp var should be nil after reset: %v", value)
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
err = sandbox.ResetEnvironment()
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when resetting closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunFile tests the RunFile method
|
||||
func TestRunFile(t *testing.T) {
|
||||
// Create a temporary Lua file
|
||||
tmpfile, err := createTempLuaFile("return 42")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer removeTempFile(tmpfile)
|
||||
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Run the file
|
||||
err = sandbox.RunFile(tmpfile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run file: %v", err)
|
||||
}
|
||||
|
||||
// Test non-existent file
|
||||
err = sandbox.RunFile("does_not_exist.lua")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for non-existent file")
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
err = sandbox.RunFile(tmpfile)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when running file on closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSandboxPackagePath tests the SetPackagePath and AddPackagePath methods
|
||||
func TestSandboxPackagePath(t *testing.T) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Set package path
|
||||
testPath := "/test/path/?.lua"
|
||||
err := sandbox.SetPackagePath(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set package path: %v", err)
|
||||
}
|
||||
|
||||
// Check path was set
|
||||
result, err := sandbox.Run("return package.path")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get package.path: %v", err)
|
||||
}
|
||||
|
||||
if result != testPath {
|
||||
t.Fatalf("Expected package.path to be %q, got %q", testPath, result)
|
||||
}
|
||||
|
||||
// Add to package path
|
||||
addPath := "/another/path/?.lua"
|
||||
err = sandbox.AddPackagePath(addPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add package path: %v", err)
|
||||
}
|
||||
|
||||
// Check path was updated
|
||||
result, err = sandbox.Run("return package.path")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated package.path: %v", err)
|
||||
}
|
||||
|
||||
expected := testPath + ";" + addPath
|
||||
if result != expected {
|
||||
t.Fatalf("Expected package.path to be %q, got %q", expected, result)
|
||||
}
|
||||
|
||||
// Test after sandbox is closed
|
||||
sandbox.Close()
|
||||
err = sandbox.SetPackagePath(testPath)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when setting package path on closed sandbox")
|
||||
}
|
||||
|
||||
err = sandbox.AddPackagePath(addPath)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when adding package path to closed sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSandboxLoadModule tests loading modules
|
||||
func TestSandboxLoadModule(t *testing.T) {
|
||||
// Skip for now since we don't have actual modules to load in the test environment
|
||||
t.Skip("Skipping module loading test as it requires actual modules")
|
||||
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Set package path to include current directory
|
||||
err := sandbox.SetPackagePath("./?.lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set package path: %v", err)
|
||||
}
|
||||
|
||||
// Try to load a non-existent module
|
||||
err = sandbox.LoadModule("nonexistent_module")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when loading non-existent module")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// createTempLuaFile creates a temporary Lua file with the given content
|
||||
func createTempLuaFile(content string) (string, error) {
|
||||
tmpfile, err := os.CreateTemp("", "test-*.lua")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := tmpfile.WriteString(content); err != nil {
|
||||
os.Remove(tmpfile.Name())
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
os.Remove(tmpfile.Name())
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tmpfile.Name(), nil
|
||||
}
|
||||
|
||||
// removeTempFile removes a temporary file
|
||||
func removeTempFile(path string) {
|
||||
os.Remove(path)
|
||||
}
|
Loading…
Reference in New Issue
Block a user