Moonshark/core/runner/sandbox.go
2025-03-21 22:25:05 -05:00

229 lines
5.1 KiB
Go

package runner
import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Sandbox manages a sandboxed Lua environment
type Sandbox struct {
modules map[string]any // Custom modules for environment
initialized bool // Whether base environment is initialized
}
// NewSandbox creates a new sandbox
func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any),
initialized: false,
}
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) {
s.modules[name] = module
}
// Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State) error {
// Register modules
if err := s.registerModules(state); err != nil {
return err
}
// Setup the sandbox creation logic with base environment reuse
return state.DoString(`
-- Create the base environment once (static parts)
local __base_env = nil
-- Create function to initialize base environment
function __init_base_env()
if __base_env then return end
local env = {}
-- Add standard library modules (restricted)
env.string = string
env.table = table
env.math = math
env.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
env.tonumber = tonumber
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Add module loader
env.__go_load_module = __go_load_module
-- Add custom modules from sandbox registry
if __sandbox_modules then
for name, module in pairs(__sandbox_modules) do
env[name] = module
end
end
-- Copy custom global functions
for k, v in pairs(_G) do
if (type(v) == "function" or type(v) == "table") and
k ~= "__sandbox_modules" and
k ~= "__base_env" and
k ~= "__init_base_env" and
k ~= "__create_sandbox_env" and
k ~= "__run_sandboxed" and
k ~= "__setup_secure_require" and
k ~= "__go_load_module" and
k ~= "string" and k ~= "table" and k ~= "math" and
k ~= "os" and k ~= "io" and k ~= "debug" and
k ~= "package" and k ~= "bit" and k ~= "jit" and
k ~= "coroutine" and k ~= "_G" and k ~= "_VERSION" then
env[k] = v
end
end
__base_env = env
end
-- Create function that builds sandbox from base env
function __create_sandbox_env(ctx)
-- Initialize base env if needed
__init_base_env()
-- Create new environment using base as prototype
local env = {}
-- Copy from base environment
for k, v in pairs(__base_env) do
env[k] = v
end
-- Add isolated package.loaded table
env.package = {
loaded = {}
}
-- Add context if provided
if ctx then
env.ctx = ctx
end
-- Setup require function
env = __setup_secure_require(env)
-- Create metatable for isolation
local mt = {
__index = function(t, k)
return rawget(env, k)
end,
__newindex = function(t, k, v)
rawset(env, k, v)
end
}
setmetatable(env, mt)
return env
end
-- Function to run code in sandbox
function __run_sandboxed(bytecode, ctx)
-- Create fresh environment for this request
local env = __create_sandbox_env(ctx)
-- Set environment and execute
setfenv(bytecode, env)
return bytecode()
end
`)
}
// registerModules registers custom modules in the Lua state
func (s *Sandbox) registerModules(state *luajit.State) error {
// Create or get module registry table
state.GetGlobal("__sandbox_modules")
if state.IsNil(-1) {
// Table doesn't exist, create it
state.Pop(1)
state.NewTable()
state.SetGlobal("__sandbox_modules")
state.GetGlobal("__sandbox_modules")
}
// Add modules to registry
for name, module := range s.modules {
state.PushString(name)
if err := state.PushValue(module); err != nil {
state.Pop(2)
return err
}
state.SetTable(-3)
}
// Pop module table
state.Pop(1)
return nil
}
// Execute runs bytecode in the sandbox
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
// Update custom modules if needed
if !s.initialized {
if err := s.registerModules(state); err != nil {
return nil, err
}
s.initialized = true
}
// Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil {
return nil, err
}
// Create context table if provided
if len(ctx) > 0 {
state.NewTable()
for k, v := range ctx {
state.PushString(k)
if err := state.PushValue(v); err != nil {
state.Pop(3)
return nil, err
}
state.SetTable(-3)
}
} else {
state.PushNil() // No context
}
// Get sandbox function
state.GetGlobal("__run_sandboxed")
// Setup call with correct argument order
state.PushCopy(-3) // Copy bytecode function
state.PushCopy(-3) // Copy context
// Clean up stack
state.Remove(-5) // Remove original bytecode
state.Remove(-4) // Remove original context
// Call sandbox function
if err := state.Call(2, 1); err != nil {
return nil, err
}
// Get result
result, err := state.ToValue(-1)
state.Pop(1)
return result, err
}