sandbox 2
This commit is contained in:
parent
f106dfd9ea
commit
faab0a2d08
181
sandbox.go
181
sandbox.go
@ -14,6 +14,7 @@ type Sandbox struct {
|
||||
mutex sync.Mutex
|
||||
initialized bool
|
||||
modules map[string]any
|
||||
functions map[string]GoFunction
|
||||
}
|
||||
|
||||
// NewSandbox creates a new sandbox with standard libraries loaded
|
||||
@ -22,6 +23,7 @@ func NewSandbox() *Sandbox {
|
||||
state: New(),
|
||||
initialized: false,
|
||||
modules: make(map[string]any),
|
||||
functions: make(map[string]GoFunction),
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,6 +59,9 @@ func (s *Sandbox) RegisterFunction(name string, fn GoFunction) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store function for re-registration
|
||||
s.functions[name] = fn
|
||||
|
||||
// Add to base environment
|
||||
return s.state.DoString(`
|
||||
-- Add the function to base environment
|
||||
@ -172,12 +177,7 @@ func (s *Sandbox) Run(code string) (any, error) {
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
// Handle multiple return values
|
||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
||||
return results[0], err
|
||||
}
|
||||
|
||||
return result, err
|
||||
return s.unwrapResult(result, err)
|
||||
}
|
||||
|
||||
// RunFile executes a Lua file in the sandbox
|
||||
@ -225,46 +225,13 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add wrapper for multiple return values
|
||||
if err := s.state.DoString(`
|
||||
__wrap_bytecode = function(f)
|
||||
local function _wrapresults(...)
|
||||
local results = {n = select('#', ...)}
|
||||
for i = 1, results.n do
|
||||
results[i] = select(i, ...)
|
||||
end
|
||||
return results
|
||||
end
|
||||
|
||||
return function()
|
||||
return _wrapresults(f())
|
||||
end
|
||||
end
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get wrapper function
|
||||
s.state.GetGlobal("__wrap_bytecode")
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Push bytecode function
|
||||
s.state.PushCopy(-2)
|
||||
|
||||
// Call wrapper to create wrapped function
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove original bytecode function
|
||||
s.state.Remove(-2)
|
||||
|
||||
// Get the sandbox executor
|
||||
s.state.GetGlobal("__execute_sandbox")
|
||||
|
||||
// Push wrapped function
|
||||
s.state.PushCopy(-2)
|
||||
|
||||
// Remove the wrapped function
|
||||
s.state.Remove(-3)
|
||||
|
||||
// Execute in sandbox
|
||||
@ -276,12 +243,7 @@ func (s *Sandbox) RunBytecode(bytecode []byte) (any, error) {
|
||||
result, err := s.state.ToValue(-1)
|
||||
s.state.Pop(1)
|
||||
|
||||
// Handle multiple return values
|
||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
||||
return results[0], err
|
||||
}
|
||||
|
||||
return result, err
|
||||
return s.unwrapResult(result, err)
|
||||
}
|
||||
|
||||
// getResults collects results from the stack (must be called with mutex locked)
|
||||
@ -326,7 +288,22 @@ func (s *Sandbox) SetPackagePath(path string) error {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.SetPackagePath(path)
|
||||
// Update global package.path
|
||||
if err := s.state.SetPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`
|
||||
__env_system.base_env.package.path = package.path
|
||||
`)
|
||||
}
|
||||
|
||||
// AddPackagePath adds a path to the sandbox package.path
|
||||
@ -338,7 +315,22 @@ func (s *Sandbox) AddPackagePath(path string) error {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
return s.state.AddPackagePath(path)
|
||||
// Update global package.path
|
||||
if err := s.state.AddPackagePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure sandbox is initialized
|
||||
if !s.initialized {
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update base environment's package.path
|
||||
return s.state.DoString(`
|
||||
__env_system.base_env.package.path = package.path
|
||||
`)
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
@ -484,7 +476,7 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||
else
|
||||
-- Create new environment with metatable inheritance
|
||||
env = setmetatable({}, {
|
||||
__index = _G -- Use global environment as fallback
|
||||
__index = __env_system.base_env -- Use base env instead of _G
|
||||
})
|
||||
end
|
||||
|
||||
@ -520,6 +512,13 @@ func (s *Sandbox) initializeUnlocked() error {
|
||||
-- Execute with protected call
|
||||
local success, result = pcall(f)
|
||||
|
||||
-- Copy all globals to base environment
|
||||
for k, v in pairs(env) do
|
||||
if k ~= "_G" and type(k) == "string" then
|
||||
__env_system.base_env[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
-- Recycle environment
|
||||
__recycle_env(env)
|
||||
|
||||
@ -600,7 +599,87 @@ func (s *Sandbox) ResetEnvironment() error {
|
||||
return fmt.Errorf("sandbox is closed")
|
||||
}
|
||||
|
||||
// Clear the environment system completely
|
||||
err := s.state.DoString(`
|
||||
-- Reset environment system
|
||||
__env_system = nil
|
||||
__wrap_bytecode = nil
|
||||
__last_env = nil
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reinitialize the environment system
|
||||
s.initialized = false
|
||||
return s.Initialize()
|
||||
if err := s.initializeUnlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Re-register all functions
|
||||
for name, fn := range s.functions {
|
||||
if err := s.state.RegisterGoFunction(name, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.state.DoString(`
|
||||
-- Add the function to base environment
|
||||
__env_system.base_env["` + name + `"] = ` + name + `
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// unwrapResult processes the raw result value from Lua
|
||||
// and unwraps single values from special map structures
|
||||
func (s *Sandbox) unwrapResult(result any, err error) (any, error) {
|
||||
// Unwrap array stored in map with empty key
|
||||
if m, ok := result.(map[string]any); ok {
|
||||
// Check for special array format
|
||||
if arr, ok := m[""]; ok {
|
||||
// If the array has only one element, return that element
|
||||
if slice, ok := arr.([]float64); ok {
|
||||
if len(slice) == 1 {
|
||||
return slice[0], err
|
||||
}
|
||||
// Convert []float64 to []any for consistency
|
||||
anySlice := make([]any, len(slice))
|
||||
for i, v := range slice {
|
||||
anySlice[i] = v
|
||||
}
|
||||
return anySlice, err
|
||||
}
|
||||
if slice, ok := arr.([]any); ok && len(slice) == 1 {
|
||||
return slice[0], err
|
||||
}
|
||||
result = arr
|
||||
} else if len(m) == 1 {
|
||||
// When there's exactly one item, return its value directly
|
||||
for _, v := range m {
|
||||
return v, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert []float64 to []any for consistency with multiple returns
|
||||
if slice, ok := result.([]float64); ok {
|
||||
if len(slice) == 1 {
|
||||
return slice[0], err
|
||||
}
|
||||
anySlice := make([]any, len(slice))
|
||||
for i, v := range slice {
|
||||
anySlice[i] = v
|
||||
}
|
||||
return anySlice, err
|
||||
}
|
||||
|
||||
// Handle multiple return values
|
||||
if results, ok := result.([]any); ok && len(results) == 1 {
|
||||
return results[0], err
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
@ -648,3 +648,129 @@ func createTempLuaFile(content string) (string, error) {
|
||||
func removeTempFile(path string) {
|
||||
os.Remove(path)
|
||||
}
|
||||
|
||||
// BenchmarkSandboxLuaExecution measures the performance of executing raw Lua code
|
||||
func BenchmarkSandboxLuaExecution(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Simple Lua code that does some computation
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 100 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.Run(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
if result != float64(5050) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSandboxBytecodeExecution measures the performance of executing precompiled bytecode
|
||||
func BenchmarkSandboxBytecodeExecution(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Same code as above, but precompiled
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 100 do
|
||||
sum = sum + i
|
||||
end
|
||||
return sum
|
||||
`
|
||||
|
||||
// Compile the bytecode once
|
||||
bytecode, err := sandbox.Compile(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to compile bytecode: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.RunBytecode(bytecode)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run bytecode: %v", err)
|
||||
}
|
||||
if result != float64(5050) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSandboxComplexComputation measures performance with more complex computation
|
||||
func BenchmarkSandboxComplexComputation(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// More complex Lua code that calculates Fibonacci numbers
|
||||
code := `
|
||||
function fibonacci(n)
|
||||
if n <= 1 then
|
||||
return n
|
||||
end
|
||||
return fibonacci(n-1) + fibonacci(n-2)
|
||||
end
|
||||
|
||||
return fibonacci(15) -- Not too high to avoid excessive runtime
|
||||
`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.Run(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
if result != float64(610) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSandboxFunctionCall measures performance of calling a registered Go function
|
||||
func BenchmarkSandboxFunctionCall(b *testing.B) {
|
||||
sandbox := luajit.NewSandbox()
|
||||
defer sandbox.Close()
|
||||
|
||||
// Register a simple addition function
|
||||
add := func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1
|
||||
}
|
||||
|
||||
err := sandbox.RegisterFunction("add", add)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to register function: %v", err)
|
||||
}
|
||||
|
||||
// Lua code that calls the Go function in a loop
|
||||
code := `
|
||||
local sum = 0
|
||||
for i = 1, 100 do
|
||||
sum = add(sum, i)
|
||||
end
|
||||
return sum
|
||||
`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result, err := sandbox.Run(code)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to run code: %v", err)
|
||||
}
|
||||
if result != float64(5050) {
|
||||
b.Fatalf("Incorrect result: %v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user