sandbox 2
This commit is contained in:
parent
f106dfd9ea
commit
faab0a2d08
181
sandbox.go
181
sandbox.go
@ -14,6 +14,7 @@ type Sandbox struct {
|
|||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
initialized bool
|
initialized bool
|
||||||
modules map[string]any
|
modules map[string]any
|
||||||
|
functions map[string]GoFunction
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox with standard libraries loaded
|
// NewSandbox creates a new sandbox with standard libraries loaded
|
||||||
@ -22,6 +23,7 @@ func NewSandbox() *Sandbox {
|
|||||||
state: New(),
|
state: New(),
|
||||||
initialized: false,
|
initialized: false,
|
||||||
modules: make(map[string]any),
|
modules: make(map[string]any),
|
||||||
|
functions: make(map[string]GoFunction),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,6 +59,9 @@ func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store function for re-registration
|
||||||
|
s.functions[name] = fn
|
||||||
|
|
||||||
// Add to base environment
|
// Add to base environment
|
||||||
return s.state.DoString(`
|
return s.state.DoString(`
|
||||||
-- Add the function to base environment
|
-- Add the function to base environment
|
||||||
@ -172,12 +177,7 @@ func (s *Sandbox) Run(code string) (any, error) {
|
|||||||
result, err := s.state.ToValue(-1)
|
result, err := s.state.ToValue(-1)
|
||||||
s.state.Pop(1)
|
s.state.Pop(1)
|
||||||
|
|
||||||
// Handle multiple return values
|
return s.unwrapResult(result, err)
|
||||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
|
||||||
return results[0], err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunFile executes a Lua file in the sandbox
|
// RunFile executes a Lua file in the sandbox
|
||||||
@ -225,46 +225,13 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add wrapper for multiple return values
|
// Get the sandbox executor
|
||||||
if err := s.state.DoString(`
|
s.state.GetGlobal("__execute_sandbox")
|
||||||
__wrap_bytecode = function(f)
|
|
||||||
local function _wrapresults(...)
|
|
||||||
local results = {n = select('#', ...)}
|
|
||||||
for i = 1, results.n do
|
|
||||||
results[i] = select(i, ...)
|
|
||||||
end
|
|
||||||
return results
|
|
||||||
end
|
|
||||||
|
|
||||||
return function()
|
|
||||||
return _wrapresults(f())
|
|
||||||
end
|
|
||||||
end
|
|
||||||
`); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get wrapper function
|
|
||||||
s.state.GetGlobal("__wrap_bytecode")
|
|
||||||
|
|
||||||
// Push bytecode function
|
// Push bytecode function
|
||||||
s.state.PushCopy(-2)
|
s.state.PushCopy(-2)
|
||||||
|
|
||||||
// Call wrapper to create wrapped function
|
|
||||||
if err := s.state.Call(1, 1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove original bytecode function
|
// Remove original bytecode function
|
||||||
s.state.Remove(-2)
|
|
||||||
|
|
||||||
// Get the sandbox executor
|
|
||||||
s.state.GetGlobal("__execute_sandbox")
|
|
||||||
|
|
||||||
// Push wrapped function
|
|
||||||
s.state.PushCopy(-2)
|
|
||||||
|
|
||||||
// Remove the wrapped function
|
|
||||||
s.state.Remove(-3)
|
s.state.Remove(-3)
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
@ -276,12 +243,7 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
|||||||
result, err := s.state.ToValue(-1)
|
result, err := s.state.ToValue(-1)
|
||||||
s.state.Pop(1)
|
s.state.Pop(1)
|
||||||
|
|
||||||
// Handle multiple return values
|
return s.unwrapResult(result, err)
|
||||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
|
||||||
return results[0], err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getResults collects results from the stack (must be called with mutex locked)
|
// getResults collects results from the stack (must be called with mutex locked)
|
||||||
@ -326,7 +288,22 @@ func (s *Sandbox) SetPackagePath(path string) error {
|
|||||||
return fmt.Errorf("sandbox is closed")
|
return fmt.Errorf("sandbox is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.state.SetPackagePath(path)
|
// Update global package.path
|
||||||
|
if err := s.state.SetPackagePath(path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure sandbox is initialized
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update base environment's package.path
|
||||||
|
return s.state.DoString(`
|
||||||
|
__env_system.base_env.package.path = package.path
|
||||||
|
`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPackagePath adds a path to the sandbox package.path
|
// AddPackagePath adds a path to the sandbox package.path
|
||||||
@ -338,7 +315,22 @@ func (s *Sandbox) AddPackagePath(path string) error {
|
|||||||
return fmt.Errorf("sandbox is closed")
|
return fmt.Errorf("sandbox is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.state.AddPackagePath(path)
|
// Update global package.path
|
||||||
|
if err := s.state.AddPackagePath(path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure sandbox is initialized
|
||||||
|
if !s.initialized {
|
||||||
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update base environment's package.path
|
||||||
|
return s.state.DoString(`
|
||||||
|
__env_system.base_env.package.path = package.path
|
||||||
|
`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
// AddModule adds a module to the sandbox environment
|
||||||
@ -484,7 +476,7 @@ func (s *Sandbox) initializeUnlocked() error {
|
|||||||
else
|
else
|
||||||
-- Create new environment with metatable inheritance
|
-- Create new environment with metatable inheritance
|
||||||
env = setmetatable({}, {
|
env = setmetatable({}, {
|
||||||
__index = _G -- Use global environment as fallback
|
__index = __env_system.base_env -- Use base env instead of _G
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -520,6 +512,13 @@ func (s *Sandbox) initializeUnlocked() error {
|
|||||||
-- Execute with protected call
|
-- Execute with protected call
|
||||||
local success, result = pcall(f)
|
local success, result = pcall(f)
|
||||||
|
|
||||||
|
-- Copy all globals to base environment
|
||||||
|
for k, v in pairs(env) do
|
||||||
|
if k ~= "_G" and type(k) == "string" then
|
||||||
|
__env_system.base_env[k] = v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
-- Recycle environment
|
-- Recycle environment
|
||||||
__recycle_env(env)
|
__recycle_env(env)
|
||||||
|
|
||||||
@ -600,7 +599,87 @@ func (s *Sandbox) ResetEnvironment() error {
|
|||||||
return fmt.Errorf("sandbox is closed")
|
return fmt.Errorf("sandbox is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear the environment system completely
|
||||||
|
err := s.state.DoString(`
|
||||||
|
-- Reset environment system
|
||||||
|
__env_system = nil
|
||||||
|
__wrap_bytecode = nil
|
||||||
|
__last_env = nil
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Reinitialize the environment system
|
// Reinitialize the environment system
|
||||||
s.initialized = false
|
s.initialized = false
|
||||||
return s.Initialize()
|
if err := s.initializeUnlocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-register all functions
|
||||||
|
for name, fn := range s.functions {
|
||||||
|
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.state.DoString(`
|
||||||
|
-- Add the function to base environment
|
||||||
|
__env_system.base_env["` + name + `"] = ` + name + `
|
||||||
|
`); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unwrapResult processes the raw result value from Lua
|
||||||
|
// and unwraps single values from special map structures
|
||||||
|
func (s *Sandbox) unwrapResult(result any, err error) (any, error) {
|
||||||
|
// Unwrap array stored in map with empty key
|
||||||
|
if m, ok := result.(map[string]any); ok {
|
||||||
|
// Check for special array format
|
||||||
|
if arr, ok := m[""]; ok {
|
||||||
|
// If the array has only one element, return that element
|
||||||
|
if slice, ok := arr.([]float64); ok {
|
||||||
|
if len(slice) == 1 {
|
||||||
|
return slice[0], err
|
||||||
|
}
|
||||||
|
// Convert []float64 to []any for consistency
|
||||||
|
anySlice := make([]any, len(slice))
|
||||||
|
for i, v := range slice {
|
||||||
|
anySlice[i] = v
|
||||||
|
}
|
||||||
|
return anySlice, err
|
||||||
|
}
|
||||||
|
if slice, ok := arr.([]any); ok && len(slice) == 1 {
|
||||||
|
return slice[0], err
|
||||||
|
}
|
||||||
|
result = arr
|
||||||
|
} else if len(m) == 1 {
|
||||||
|
// When there's exactly one item, return its value directly
|
||||||
|
for _, v := range m {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert []float64 to []any for consistency with multiple returns
|
||||||
|
if slice, ok := result.([]float64); ok {
|
||||||
|
if len(slice) == 1 {
|
||||||
|
return slice[0], err
|
||||||
|
}
|
||||||
|
anySlice := make([]any, len(slice))
|
||||||
|
for i, v := range slice {
|
||||||
|
anySlice[i] = v
|
||||||
|
}
|
||||||
|
return anySlice, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple return values
|
||||||
|
if results, ok := result.([]any); ok && len(results) == 1 {
|
||||||
|
return results[0], err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
@ -648,3 +648,129 @@ func createTempLuaFile(content string) (string, error) {
|
|||||||
func removeTempFile(path string) {
|
func removeTempFile(path string) {
|
||||||
os.Remove(path)
|
os.Remove(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
|
||||||
|
func BenchmarkSandboxLuaExecution(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// Simple Lua code that does some computation
|
||||||
|
code := `
|
||||||
|
local sum = 0
|
||||||
|
for i = 1, 100 do
|
||||||
|
sum = sum + i
|
||||||
|
end
|
||||||
|
return sum
|
||||||
|
`
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.Run(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run code: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(5050) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
|
||||||
|
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// Same code as above, but precompiled
|
||||||
|
code := `
|
||||||
|
local sum = 0
|
||||||
|
for i = 1, 100 do
|
||||||
|
sum = sum + i
|
||||||
|
end
|
||||||
|
return sum
|
||||||
|
`
|
||||||
|
|
||||||
|
// Compile the bytecode once
|
||||||
|
bytecode, err := sandbox.Compile(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to compile bytecode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.RunBytecode(bytecode)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run bytecode: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(5050) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSandboxComplexComputation measures performance with more complex computation
|
||||||
|
func BenchmarkSandboxComplexComputation(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// More complex Lua code that calculates Fibonacci numbers
|
||||||
|
code := `
|
||||||
|
function fibonacci(n)
|
||||||
|
if n <= 1 then
|
||||||
|
return n
|
||||||
|
end
|
||||||
|
return fibonacci(n-1) + fibonacci(n-2)
|
||||||
|
end
|
||||||
|
|
||||||
|
return fibonacci(15) -- Not too high to avoid excessive runtime
|
||||||
|
`
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.Run(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run code: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(610) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
|
||||||
|
func BenchmarkSandboxFunctionCall(b *testing.B) {
|
||||||
|
sandbox := luajit.NewSandbox()
|
||||||
|
defer sandbox.Close()
|
||||||
|
|
||||||
|
// Register a simple addition function
|
||||||
|
add := func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a + b)
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sandbox.RegisterFunction("add", add)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to register function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lua code that calls the Go function in a loop
|
||||||
|
code := `
|
||||||
|
local sum = 0
|
||||||
|
for i = 1, 100 do
|
||||||
|
sum = add(sum, i)
|
||||||
|
end
|
||||||
|
return sum
|
||||||
|
`
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result, err := sandbox.Run(code)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to run code: %v", err)
|
||||||
|
}
|
||||||
|
if result != float64(5050) {
|
||||||
|
b.Fatalf("Incorrect result: %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user