diff --git a/core/runner/luarunner.go b/core/runner/luarunner.go index 2abdf96..e7305cc 100644 --- a/core/runner/luarunner.go +++ b/core/runner/luarunner.go @@ -29,9 +29,8 @@ type LuaRunner struct { bufferSize int // Size of the job queue buffer requireCache *RequireCache // Cache for required modules requireCfg *RequireConfig // Configuration for require paths - scriptDir string // Base directory for scripts - libDirs []string // Additional library directories - loaderFunc luajit.GoFunction // Keep reference to prevent GC + moduleLoader luajit.GoFunction // Keep reference to prevent GC + sandbox *Sandbox // The sandbox environment } // NewRunner creates a new LuaRunner @@ -43,6 +42,7 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { requireCfg: &RequireConfig{ LibDirs: []string{}, }, + sandbox: NewSandbox(), } // Apply options @@ -63,12 +63,11 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { // Create a shared config pointer that will be updated per request runner.requireCfg = &RequireConfig{ - ScriptDir: runner.scriptDir, - LibDirs: runner.libDirs, + ScriptDir: runner.scriptDir(), + LibDirs: runner.libDirs(), } - // Set up require functionality ONCE - // Create and register the module loader function + // Set up require functionality moduleLoader := func(s *luajit.State) int { // Get module name modName := s.ToString(1) @@ -99,7 +98,7 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { } // Store reference to prevent garbage collection - runner.loaderFunc = moduleLoader + runner.moduleLoader = moduleLoader // Register with Lua state if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil { @@ -108,53 +107,13 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { } // Set up the require mechanism - setupRequireScript := ` - -- Create a secure require function for sandboxed environments - function __setup_secure_require(env) - -- Replace env.require with our secure version - env.require = function(modname) - -- Check if already loaded in package.loaded - if package.loaded[modname] then - return package.loaded[modname] - end - - -- Try to load the module using our Go loader - local loader = __go_load_module - - -- Load the module - local f, err = loader(modname) - if not f then - error(err or "failed to load module: " .. modname) - end - - -- Set the environment for the module - setfenv(f, env) - - -- Execute the module - local result = f() - - -- If module didn't return a value, use true - if result == nil then - result = true - end - - -- Cache the result - package.loaded[modname] = result - - return result - end - - return env - end - ` - - if err := state.DoString(setupRequireScript); err != nil { + if err := setupRequireFunction(state); err != nil { state.Close() return nil, ErrInitFailed } // Set up sandbox - if err := runner.setupSandbox(); err != nil { + if err := runner.sandbox.Setup(state); err != nil { state.Close() return nil, ErrInitFailed } @@ -169,11 +128,53 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { // Start the event loop runner.wg.Add(1) - go runner.eventLoop() + go runner.processJobs() return runner, nil } +// setupRequireFunction adds the secure require implementation +func setupRequireFunction(state *luajit.State) error { + return state.DoString(` + function __setup_secure_require(env) + -- Replace env.require with our secure version + env.require = function(modname) + -- Check if already loaded in package.loaded + if package.loaded[modname] then + return package.loaded[modname] + end + + -- Try to load the module using our Go loader + local loader = __go_load_module + + -- Load the module + local f, err = loader(modname) + if not f then + error(err or "failed to load module: " .. modname) + end + + -- Set the environment for the module + setfenv(f, env) + + -- Execute the module + local result = f() + + -- If module didn't return a value, use true + if result == nil then + result = true + end + + -- Cache the result + package.loaded[modname] = result + + return result + end + + return env + end + `) +} + // RunnerOption defines a functional option for configuring the LuaRunner type RunnerOption func(*LuaRunner) @@ -196,7 +197,6 @@ func WithInitFunc(initFunc StateInitFunc) RunnerOption { // WithScriptDir sets the base directory for scripts func WithScriptDir(dir string) RunnerOption { return func(r *LuaRunner) { - r.scriptDir = dir r.requireCfg.ScriptDir = dir } } @@ -204,103 +204,31 @@ func WithScriptDir(dir string) RunnerOption { // WithLibDirs sets additional library directories func WithLibDirs(dirs ...string) RunnerOption { return func(r *LuaRunner) { - r.libDirs = dirs r.requireCfg.LibDirs = dirs } } -// setupSandbox initializes the sandbox environment -func (r *LuaRunner) setupSandbox() error { - // This is the Lua script that creates our sandbox function - setupScript := ` - -- Create a function to run code in a sandbox environment - function __create_sandbox() - -- Create new environment table - local env = {} - - -- Add standard library modules (can be restricted as needed) - env.string = string - env.table = table - env.math = math - env.os = { - time = os.time, - date = os.date, - difftime = os.difftime, - clock = os.clock - } - env.tonumber = tonumber - env.tostring = tostring - env.type = type - env.pairs = pairs - env.ipairs = ipairs - env.next = next - env.select = select - env.unpack = unpack - env.pcall = pcall - env.xpcall = xpcall - env.error = error - env.assert = assert - - -- Set up the standard library package table - env.package = { - loaded = {} -- Table to store loaded modules - } - - -- Explicitly expose the module loader function - env.__go_load_module = __go_load_module - - -- Set up secure require function - env = __setup_secure_require(env) - - -- Create metatable to restrict access to _G - local mt = { - __index = function(t, k) - -- First check in env table - local v = rawget(env, k) - if v ~= nil then return v end - - -- If not found, check for registered modules/functions - local moduleValue = _G[k] - if type(moduleValue) == "table" or - type(moduleValue) == "function" then - return moduleValue - end - - return nil - end, - __newindex = function(t, k, v) - rawset(env, k, v) - end - } - - setmetatable(env, mt) - return env - end - - -- Create function to execute code with a sandbox - function __run_sandboxed(f, ctx) - local env = __create_sandbox() - - -- Add context to the environment if provided - if ctx then - env.ctx = ctx - end - - -- Set the environment and run the function - setfenv(f, env) - return f() - end - ` - - return r.state.DoString(setupScript) +// scriptDir returns the current script directory +func (r *LuaRunner) scriptDir() string { + if r.requireCfg != nil { + return r.requireCfg.ScriptDir + } + return "" } -// eventLoop processes jobs from the queue -func (r *LuaRunner) eventLoop() { +// libDirs returns the current library directories +func (r *LuaRunner) libDirs() []string { + if r.requireCfg != nil { + return r.requireCfg.LibDirs + } + return nil +} + +// processJobs handles the job queue +func (r *LuaRunner) processJobs() { defer r.wg.Done() defer r.state.Close() - // Process jobs until closure for job := range r.jobQueue { // Execute the job and send result result := r.executeJob(job) @@ -329,55 +257,14 @@ func (r *LuaRunner) executeJob(j job) JobResult { } } - // Set up context if provided + // Convert context for sandbox + var ctx map[string]any if j.Context != nil { - // Push context table - r.state.NewTable() - - // Add values to context table - for key, value := range j.Context.Values { - // Push key - r.state.PushString(key) - - // Push value - if err := r.state.PushValue(value); err != nil { - return JobResult{nil, err} - } - - // Set table[key] = value - r.state.SetTable(-3) - } - } else { - // Push nil if no context - r.state.PushNil() + ctx = j.Context.Values } - // Load bytecode - if err := r.state.LoadBytecode(j.Bytecode, j.ScriptPath); err != nil { - r.state.Pop(1) // Pop context - return JobResult{nil, err} - } - - // Get the sandbox runner function - r.state.GetGlobal("__run_sandboxed") - - // Push loaded function and context as arguments - r.state.PushCopy(-2) // Copy the loaded function - r.state.PushCopy(-4) // Copy the context table or nil - - // Remove the original function and context - r.state.Remove(-5) // Remove original context - r.state.Remove(-4) // Remove original function - - // Call the sandbox runner with 2 args (function and context), expecting 1 result - if err := r.state.Call(2, 1); err != nil { - return JobResult{nil, err} - } - - // Get result - value, err := r.state.ToValue(-1) - r.state.Pop(1) // Pop result - + // Execute in sandbox + value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx) return JobResult{value, err} } @@ -442,3 +329,8 @@ func (r *LuaRunner) Close() error { func (r *LuaRunner) ClearRequireCache() { r.requireCache = NewRequireCache() } + +// AddModule adds a module to the sandbox environment +func (r *LuaRunner) AddModule(name string, module any) { + r.sandbox.AddModule(name, module) +} diff --git a/core/runner/sandbox.go b/core/runner/sandbox.go new file mode 100644 index 0000000..931776e --- /dev/null +++ b/core/runner/sandbox.go @@ -0,0 +1,228 @@ +package runner + +import ( + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Sandbox manages a sandboxed Lua environment +type Sandbox struct { + modules map[string]any // Custom modules for environment + initialized bool // Whether base environment is initialized +} + +// NewSandbox creates a new sandbox +func NewSandbox() *Sandbox { + return &Sandbox{ + modules: make(map[string]any), + initialized: false, + } +} + +// AddModule adds a module to the sandbox environment +func (s *Sandbox) AddModule(name string, module any) { + s.modules[name] = module +} + +// Setup initializes the sandbox in a Lua state +func (s *Sandbox) Setup(state *luajit.State) error { + // Register modules + if err := s.registerModules(state); err != nil { + return err + } + + // Setup the sandbox creation logic with base environment reuse + return state.DoString(` + -- Create the base environment once (static parts) + local __base_env = nil + + -- Create function to initialize base environment + function __init_base_env() + if __base_env then return end + + local env = {} + + -- Add standard library modules (restricted) + env.string = string + env.table = table + env.math = math + env.os = { + time = os.time, + date = os.date, + difftime = os.difftime, + clock = os.clock + } + env.tonumber = tonumber + env.tostring = tostring + env.type = type + env.pairs = pairs + env.ipairs = ipairs + env.next = next + env.select = select + env.unpack = unpack + env.pcall = pcall + env.xpcall = xpcall + env.error = error + env.assert = assert + + -- Add module loader + env.__go_load_module = __go_load_module + + -- Add custom modules from sandbox registry + if __sandbox_modules then + for name, module in pairs(__sandbox_modules) do + env[name] = module + end + end + + -- Copy custom global functions + for k, v in pairs(_G) do + if (type(v) == "function" or type(v) == "table") and + k ~= "__sandbox_modules" and + k ~= "__base_env" and + k ~= "__init_base_env" and + k ~= "__create_sandbox_env" and + k ~= "__run_sandboxed" and + k ~= "__setup_secure_require" and + k ~= "__go_load_module" and + k ~= "string" and k ~= "table" and k ~= "math" and + k ~= "os" and k ~= "io" and k ~= "debug" and + k ~= "package" and k ~= "bit" and k ~= "jit" and + k ~= "coroutine" and k ~= "_G" and k ~= "_VERSION" then + env[k] = v + end + end + + __base_env = env + end + + -- Create function that builds sandbox from base env + function __create_sandbox_env(ctx) + -- Initialize base env if needed + __init_base_env() + + -- Create new environment using base as prototype + local env = {} + + -- Copy from base environment + for k, v in pairs(__base_env) do + env[k] = v + end + + -- Add isolated package.loaded table + env.package = { + loaded = {} + } + + -- Add context if provided + if ctx then + env.ctx = ctx + end + + -- Setup require function + env = __setup_secure_require(env) + + -- Create metatable for isolation + local mt = { + __index = function(t, k) + return rawget(env, k) + end, + __newindex = function(t, k, v) + rawset(env, k, v) + end + } + + setmetatable(env, mt) + return env + end + + -- Function to run code in sandbox + function __run_sandboxed(bytecode, ctx) + -- Create environment for this request + local env = __create_sandbox_env(ctx) + + -- Set environment and execute + setfenv(bytecode, env) + return bytecode() + end + `) +} + +// registerModules registers custom modules in the Lua state +func (s *Sandbox) registerModules(state *luajit.State) error { + // Create or get module registry table + state.GetGlobal("__sandbox_modules") + if state.IsNil(-1) { + // Table doesn't exist, create it + state.Pop(1) + state.NewTable() + state.SetGlobal("__sandbox_modules") + state.GetGlobal("__sandbox_modules") + } + + // Add modules to registry + for name, module := range s.modules { + state.PushString(name) + if err := state.PushValue(module); err != nil { + state.Pop(2) + return err + } + state.SetTable(-3) + } + + // Pop module table + state.Pop(1) + return nil +} + +// Execute runs bytecode in the sandbox +func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) { + // Update modules if needed + if !s.initialized { + if err := s.registerModules(state); err != nil { + return nil, err + } + s.initialized = true + } + + // Load bytecode + if err := state.LoadBytecode(bytecode, "script"); err != nil { + return nil, err + } + + // Create context table if provided + if len(ctx) > 0 { + state.NewTable() + for k, v := range ctx { + state.PushString(k) + if err := state.PushValue(v); err != nil { + state.Pop(3) + return nil, err + } + state.SetTable(-3) + } + } else { + state.PushNil() // No context + } + + // Get sandbox function + state.GetGlobal("__run_sandboxed") + + // Setup call with correct argument order + state.PushCopy(-3) // Copy bytecode function + state.PushCopy(-3) // Copy context + + // Clean up stack + state.Remove(-5) // Remove original bytecode + state.Remove(-4) // Remove original context + + // Call sandbox function + if err := state.Call(2, 1); err != nil { + return nil, err + } + + // Get result + result, err := state.ToValue(-1) + state.Pop(1) + + return result, err +}