Moonshark/core/runner/luarunner.go
2025-03-21 22:25:05 -05:00

415 lines
10 KiB
Go

package runner
import (
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"sync"
"sync/atomic"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrRunnerClosed = errors.New("lua runner is closed")
ErrInitFailed = errors.New("initialization failed")
)
// StateInitFunc is a function that initializes a Lua state
type StateInitFunc func(*luajit.State) error
// LuaRunner runs Lua scripts using a single Lua state
type LuaRunner struct {
state *luajit.State // The Lua state
jobQueue chan job // Channel for incoming jobs
isRunning atomic.Bool // Flag indicating if the runner is active
mu sync.RWMutex // Mutex for thread safety
wg sync.WaitGroup // WaitGroup for clean shutdown
initFunc StateInitFunc // Optional function to initialize Lua state
bufferSize int // Size of the job queue buffer
requireCache *RequireCache // Cache for required modules
requireCfg *RequireConfig // Configuration for require paths
moduleLoader luajit.GoFunction // Keep reference to prevent GC
sandbox *Sandbox // The sandbox environment
}
// NewRunner creates a new LuaRunner
func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Default configuration
runner := &LuaRunner{
bufferSize: 10, // Default buffer size
requireCache: NewRequireCache(),
requireCfg: &RequireConfig{
LibDirs: []string{},
},
sandbox: NewSandbox(),
}
// Apply options
for _, opt := range options {
opt(runner)
}
// Initialize Lua state
state := luajit.New()
if state == nil {
return nil, errors.New("failed to create Lua state")
}
runner.state = state
// Create job queue
runner.jobQueue = make(chan job, runner.bufferSize)
runner.isRunning.Store(true)
// Create a shared config pointer that will be updated per request
runner.requireCfg = &RequireConfig{
ScriptDir: runner.scriptDir(),
LibDirs: runner.libDirs(),
}
// Set up require functionality
moduleLoader := func(s *luajit.State) int {
// Get module name
modName := s.ToString(1)
if modName == "" {
s.PushString("module name required")
return -1
}
// Find and compile module
bytecode, err := findAndCompileModule(s, runner.requireCache, *runner.requireCfg, modName)
if err != nil {
if err == ErrModuleNotFound {
s.PushString("module '" + modName + "' not found")
} else {
s.PushString("error loading module: " + err.Error())
}
return -1 // Return error
}
// Load the bytecode
if err := s.LoadBytecode(bytecode, modName); err != nil {
s.PushString("error loading bytecode: " + err.Error())
return -1 // Return error
}
// Return the loaded function
return 1
}
// Store reference to prevent garbage collection
runner.moduleLoader = moduleLoader
// Register with Lua state
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up the require mechanism
if err := setupRequireFunction(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up sandbox
if err := runner.sandbox.Setup(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Run init function if provided
if runner.initFunc != nil {
if err := runner.initFunc(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
}
// Start the event loop
runner.wg.Add(1)
go runner.processJobs()
return runner, nil
}
// setupRequireFunction adds the secure require implementation
func setupRequireFunction(state *luajit.State) error {
return state.DoString(`
function __setup_secure_require(env)
-- Replace env.require with our secure version
env.require = function(modname)
-- Check if already loaded in this environment's package.loaded
if env.package.loaded[modname] then
return env.package.loaded[modname]
end
-- Try to load the module using our Go loader
local loader = __go_load_module
-- Load the module
local f, err = loader(modname)
if not f then
error(err or "failed to load module: " .. modname)
end
-- Set the environment for the module
setfenv(f, env)
-- Execute the module
local result = f()
-- If module didn't return a value, use true
if result == nil then
result = true
end
-- Cache the result in this environment only
env.package.loaded[modname] = result
return result
end
return env
end
`)
}
// RunnerOption defines a functional option for configuring the LuaRunner
type RunnerOption func(*LuaRunner)
// WithBufferSize sets the job queue buffer size
func WithBufferSize(size int) RunnerOption {
return func(r *LuaRunner) {
if size > 0 {
r.bufferSize = size
}
}
}
// WithInitFunc sets the init function for the Lua state
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
return func(r *LuaRunner) {
r.initFunc = initFunc
}
}
// WithScriptDir sets the base directory for scripts
func WithScriptDir(dir string) RunnerOption {
return func(r *LuaRunner) {
r.requireCfg.ScriptDir = dir
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) {
r.requireCfg.LibDirs = dirs
}
}
// scriptDir returns the current script directory
func (r *LuaRunner) scriptDir() string {
if r.requireCfg != nil {
return r.requireCfg.ScriptDir
}
return ""
}
// libDirs returns the current library directories
func (r *LuaRunner) libDirs() []string {
if r.requireCfg != nil {
return r.requireCfg.LibDirs
}
return nil
}
// processJobs handles the job queue
func (r *LuaRunner) processJobs() {
defer r.wg.Done()
defer r.state.Close()
for job := range r.jobQueue {
// Execute the job and send result
result := r.executeJob(job)
select {
case job.Result <- result:
// Result sent successfully
default:
// Result channel closed or full, discard the result
}
}
}
// executeJob runs a script in the sandbox environment
func (r *LuaRunner) executeJob(j job) JobResult {
// If the job has a script path, update paths without re-registering
if j.ScriptPath != "" {
r.mu.Lock()
UpdateRequirePaths(r.requireCfg, j.ScriptPath)
r.mu.Unlock()
}
// Re-run init function if needed
if r.initFunc != nil {
if err := r.initFunc(r.state); err != nil {
return JobResult{nil, err}
}
}
// Convert context for sandbox
var ctx map[string]any
if j.Context != nil {
ctx = j.Context.Values
}
// Execute in sandbox
value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx)
return JobResult{value, err}
}
// RunWithContext executes a Lua script with context and timeout
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
r.mu.RLock()
if !r.isRunning.Load() {
r.mu.RUnlock()
return nil, ErrRunnerClosed
}
r.mu.RUnlock()
resultChan := make(chan JobResult, 1)
j := job{
Bytecode: bytecode,
Context: execCtx,
ScriptPath: scriptPath,
Result: resultChan,
}
// Submit job with context
select {
case r.jobQueue <- j:
// Job submitted
case <-ctx.Done():
return nil, ctx.Err()
}
// Wait for result with context
select {
case result := <-resultChan:
return result.Value, result.Error
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Run executes a Lua script
func (r *LuaRunner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
return r.RunWithContext(context.Background(), bytecode, execCtx, scriptPath)
}
// Close gracefully shuts down the LuaRunner
func (r *LuaRunner) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isRunning.Load() {
return ErrRunnerClosed
}
r.isRunning.Store(false)
close(r.jobQueue)
// Wait for event loop to finish
r.wg.Wait()
return nil
}
// RequireCache returns the require cache for external access
func (r *LuaRunner) RequireCache() *RequireCache {
return r.requireCache
}
// ClearRequireCache clears the cache of loaded modules
func (r *LuaRunner) ClearRequireCache() {
r.requireCache.Clear()
}
// AddModule adds a module to the sandbox environment
func (r *LuaRunner) AddModule(name string, module any) {
r.sandbox.AddModule(name, module)
}
// RefreshRequireCache refreshes the module cache if needed
func (r *LuaRunner) RefreshRequireCache() int {
count := r.requireCache.RefreshAll()
return count
}
// ResetPackageLoaded resets the Lua package.loaded table
func (r *LuaRunner) ResetPackageLoaded() error {
return r.state.DoString(`
-- Create list of modules to unload (excluding core modules)
local to_unload = {}
for name, _ in pairs(package.loaded) do
-- Skip core modules
if name ~= "string" and
name ~= "table" and
name ~= "math" and
name ~= "os" and
name ~= "package" and
name ~= "io" and
name ~= "coroutine" and
name ~= "debug" and
name ~= "_G" then
table.insert(to_unload, name)
end
end
-- Unload each module
for _, name in ipairs(to_unload) do
package.loaded[name] = nil
end
`)
}
// RefreshModuleByName clears a specific module from Lua's package.loaded table
func (r *LuaRunner) RefreshModuleByName(modName string) error {
if r.state == nil {
return nil
}
return r.state.DoString(fmt.Sprintf(`
package.loaded["%s"] = nil
`, modName))
}
// ProcessModuleChange handles a file change notification from a watcher
func (r *LuaRunner) ProcessModuleChange(filePath string) {
// Mark cache as needing refresh
r.requireCache.MarkNeedsRefresh()
// Extract module name from file path for package.loaded clearing
ext := filepath.Ext(filePath)
if ext == ".lua" {
// Get relative path from lib directories
var modName string
for _, libDir := range r.requireCfg.LibDirs {
if rel, err := filepath.Rel(libDir, filePath); err == nil && !strings.HasPrefix(rel, "..") {
// Convert path to module name format
modName = strings.TrimSuffix(rel, ext)
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
break
}
}
if modName != "" {
// Clear from Lua's package.loaded (non-blocking)
go r.RefreshModuleByName(modName)
}
}
}