Moonshark/core/runner/luarunner.go

445 lines
10 KiB
Go

package runner
import (
"context"
"errors"
"sync"
"sync/atomic"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrRunnerClosed = errors.New("lua runner is closed")
ErrInitFailed = errors.New("initialization failed")
)
// StateInitFunc is a function that initializes a Lua state
type StateInitFunc func(*luajit.State) error
// LuaRunner runs Lua scripts using a single Lua state
type LuaRunner struct {
state *luajit.State // The Lua state
jobQueue chan job // Channel for incoming jobs
isRunning atomic.Bool // Flag indicating if the runner is active
mu sync.RWMutex // Mutex for thread safety
wg sync.WaitGroup // WaitGroup for clean shutdown
initFunc StateInitFunc // Optional function to initialize Lua state
bufferSize int // Size of the job queue buffer
requireCache *RequireCache // Cache for required modules
requireCfg *RequireConfig // Configuration for require paths
scriptDir string // Base directory for scripts
libDirs []string // Additional library directories
loaderFunc luajit.GoFunction // Keep reference to prevent GC
}
// NewRunner creates a new LuaRunner
func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Default configuration
runner := &LuaRunner{
bufferSize: 10, // Default buffer size
requireCache: NewRequireCache(),
requireCfg: &RequireConfig{
LibDirs: []string{},
},
}
// Apply options
for _, opt := range options {
opt(runner)
}
// Initialize Lua state
state := luajit.New()
if state == nil {
return nil, errors.New("failed to create Lua state")
}
runner.state = state
// Create job queue
runner.jobQueue = make(chan job, runner.bufferSize)
runner.isRunning.Store(true)
// Create a shared config pointer that will be updated per request
runner.requireCfg = &RequireConfig{
ScriptDir: runner.scriptDir,
LibDirs: runner.libDirs,
}
// Set up require functionality ONCE
// Create and register the module loader function
moduleLoader := func(s *luajit.State) int {
// Get module name
modName := s.ToString(1)
if modName == "" {
s.PushString("module name required")
return -1
}
// Find and compile module
bytecode, err := findAndCompileModule(s, runner.requireCache, *runner.requireCfg, modName)
if err != nil {
if err == ErrModuleNotFound {
s.PushString("module '" + modName + "' not found")
} else {
s.PushString("error loading module: " + err.Error())
}
return -1 // Return error
}
// Load the bytecode
if err := s.LoadBytecode(bytecode, modName); err != nil {
s.PushString("error loading bytecode: " + err.Error())
return -1 // Return error
}
// Return the loaded function
return 1
}
// Store reference to prevent garbage collection
runner.loaderFunc = moduleLoader
// Register with Lua state
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up the require mechanism
setupRequireScript := `
-- Create a secure require function for sandboxed environments
function __setup_secure_require(env)
-- Replace env.require with our secure version
env.require = function(modname)
-- Check if already loaded in package.loaded
if package.loaded[modname] then
return package.loaded[modname]
end
-- Try to load the module using our Go loader
local loader = __go_load_module
-- Load the module
local f, err = loader(modname)
if not f then
error(err or "failed to load module: " .. modname)
end
-- Set the environment for the module
setfenv(f, env)
-- Execute the module
local result = f()
-- If module didn't return a value, use true
if result == nil then
result = true
end
-- Cache the result
package.loaded[modname] = result
return result
end
return env
end
`
if err := state.DoString(setupRequireScript); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up sandbox
if err := runner.setupSandbox(); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Run init function if provided
if runner.initFunc != nil {
if err := runner.initFunc(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
}
// Start the event loop
runner.wg.Add(1)
go runner.eventLoop()
return runner, nil
}
// RunnerOption defines a functional option for configuring the LuaRunner
type RunnerOption func(*LuaRunner)
// WithBufferSize sets the job queue buffer size
func WithBufferSize(size int) RunnerOption {
return func(r *LuaRunner) {
if size > 0 {
r.bufferSize = size
}
}
}
// WithInitFunc sets the init function for the Lua state
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
return func(r *LuaRunner) {
r.initFunc = initFunc
}
}
// WithScriptDir sets the base directory for scripts
func WithScriptDir(dir string) RunnerOption {
return func(r *LuaRunner) {
r.scriptDir = dir
r.requireCfg.ScriptDir = dir
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) {
r.libDirs = dirs
r.requireCfg.LibDirs = dirs
}
}
// setupSandbox initializes the sandbox environment
func (r *LuaRunner) setupSandbox() error {
// This is the Lua script that creates our sandbox function
setupScript := `
-- Create a function to run code in a sandbox environment
function __create_sandbox()
-- Create new environment table
local env = {}
-- Add standard library modules (can be restricted as needed)
env.string = string
env.table = table
env.math = math
env.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
env.tonumber = tonumber
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Set up the standard library package table
env.package = {
loaded = {} -- Table to store loaded modules
}
-- Explicitly expose the module loader function
env.__go_load_module = __go_load_module
-- Set up secure require function
env = __setup_secure_require(env)
-- Create metatable to restrict access to _G
local mt = {
__index = function(t, k)
-- First check in env table
local v = rawget(env, k)
if v ~= nil then return v end
-- If not found, check for registered modules/functions
local moduleValue = _G[k]
if type(moduleValue) == "table" or
type(moduleValue) == "function" then
return moduleValue
end
return nil
end,
__newindex = function(t, k, v)
rawset(env, k, v)
end
}
setmetatable(env, mt)
return env
end
-- Create function to execute code with a sandbox
function __run_sandboxed(f, ctx)
local env = __create_sandbox()
-- Add context to the environment if provided
if ctx then
env.ctx = ctx
end
-- Set the environment and run the function
setfenv(f, env)
return f()
end
`
return r.state.DoString(setupScript)
}
// eventLoop processes jobs from the queue
func (r *LuaRunner) eventLoop() {
defer r.wg.Done()
defer r.state.Close()
// Process jobs until closure
for job := range r.jobQueue {
// Execute the job and send result
result := r.executeJob(job)
select {
case job.Result <- result:
// Result sent successfully
default:
// Result channel closed or full, discard the result
}
}
}
// executeJob runs a script in the sandbox environment
func (r *LuaRunner) executeJob(j job) JobResult {
// If the job has a script path, update paths without re-registering
if j.ScriptPath != "" {
r.mu.Lock()
UpdateRequirePaths(r.requireCfg, j.ScriptPath)
r.mu.Unlock()
}
// Re-run init function if needed
if r.initFunc != nil {
if err := r.initFunc(r.state); err != nil {
return JobResult{nil, err}
}
}
// Set up context if provided
if j.Context != nil {
// Push context table
r.state.NewTable()
// Add values to context table
for key, value := range j.Context.Values {
// Push key
r.state.PushString(key)
// Push value
if err := r.state.PushValue(value); err != nil {
return JobResult{nil, err}
}
// Set table[key] = value
r.state.SetTable(-3)
}
} else {
// Push nil if no context
r.state.PushNil()
}
// Load bytecode
if err := r.state.LoadBytecode(j.Bytecode, j.ScriptPath); err != nil {
r.state.Pop(1) // Pop context
return JobResult{nil, err}
}
// Get the sandbox runner function
r.state.GetGlobal("__run_sandboxed")
// Push loaded function and context as arguments
r.state.PushCopy(-2) // Copy the loaded function
r.state.PushCopy(-4) // Copy the context table or nil
// Remove the original function and context
r.state.Remove(-5) // Remove original context
r.state.Remove(-4) // Remove original function
// Call the sandbox runner with 2 args (function and context), expecting 1 result
if err := r.state.Call(2, 1); err != nil {
return JobResult{nil, err}
}
// Get result
value, err := r.state.ToValue(-1)
r.state.Pop(1) // Pop result
return JobResult{value, err}
}
// RunWithContext executes a Lua script with context and timeout
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
r.mu.RLock()
if !r.isRunning.Load() {
r.mu.RUnlock()
return nil, ErrRunnerClosed
}
r.mu.RUnlock()
resultChan := make(chan JobResult, 1)
j := job{
Bytecode: bytecode,
Context: execCtx,
ScriptPath: scriptPath,
Result: resultChan,
}
// Submit job with context
select {
case r.jobQueue <- j:
// Job submitted
case <-ctx.Done():
return nil, ctx.Err()
}
// Wait for result with context
select {
case result := <-resultChan:
return result.Value, result.Error
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Run executes a Lua script
func (r *LuaRunner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
return r.RunWithContext(context.Background(), bytecode, execCtx, scriptPath)
}
// Close gracefully shuts down the LuaRunner
func (r *LuaRunner) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isRunning.Load() {
return ErrRunnerClosed
}
r.isRunning.Store(false)
close(r.jobQueue)
// Wait for event loop to finish
r.wg.Wait()
return nil
}
// ClearRequireCache clears the cache of loaded modules
func (r *LuaRunner) ClearRequireCache() {
r.requireCache = NewRequireCache()
}