Moonshark/core/runner/LuaRunner.go

484 lines
12 KiB
Go

package runner
import (
"context"
"errors"
"path/filepath"
"runtime"
"sync"
"sync/atomic"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/logger"
)
// Common errors
var (
ErrRunnerClosed = errors.New("lua runner is closed")
ErrInitFailed = errors.New("initialization failed")
)
// StateInitFunc is a function that initializes a Lua state
type StateInitFunc func(*luajit.State) error
// RunnerOption defines a functional option for configuring the LuaRunner
type RunnerOption func(*LuaRunner)
// JobResult represents the result of a Lua script execution
type JobResult struct {
Value any // Return value from Lua
Error error // Error if any
}
// StateWrapper wraps a Lua state with its sandbox
type StateWrapper struct {
state *luajit.State // The Lua state
sandbox *Sandbox // Associated sandbox
index int // Index for debugging
}
// LuaRunner runs Lua scripts using a pool of Lua states
type LuaRunner struct {
states []*StateWrapper // Pool of Lua states
stateSem chan int // Semaphore with state indexes
poolSize int // Size of the state pool
initFunc StateInitFunc // Optional function to initialize Lua states
moduleLoader *NativeModuleLoader // Native module loader for require
isRunning atomic.Bool // Flag indicating if the runner is active
mu sync.RWMutex // Mutex for thread safety
debug bool // Enable debug logging
}
// WithPoolSize sets the state pool size
func WithPoolSize(size int) RunnerOption {
return func(r *LuaRunner) {
if size > 0 {
r.poolSize = size
}
}
}
// WithInitFunc sets the init function for the Lua state
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
return func(r *LuaRunner) {
r.initFunc = initFunc
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) {
if r.moduleLoader == nil || r.moduleLoader.config == nil {
r.moduleLoader = NewNativeModuleLoader(&RequireConfig{
LibDirs: dirs,
})
} else {
r.moduleLoader.config.LibDirs = dirs
}
}
}
// WithDebugEnabled enables debug output
func WithDebugEnabled() RunnerOption {
return func(r *LuaRunner) {
r.debug = true
}
}
// NewRunner creates a new LuaRunner with a pool of states
func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Default configuration
runner := &LuaRunner{
poolSize: runtime.GOMAXPROCS(0),
debug: false,
}
// Apply options
for _, opt := range options {
opt(runner)
}
// Set up module loader if not already initialized
if runner.moduleLoader == nil {
requireConfig := &RequireConfig{
ScriptDir: "",
LibDirs: []string{},
}
runner.moduleLoader = NewNativeModuleLoader(requireConfig)
}
// Initialize states and semaphore
runner.states = make([]*StateWrapper, runner.poolSize)
runner.stateSem = make(chan int, runner.poolSize)
// Create and initialize all states
for i := 0; i < runner.poolSize; i++ {
wrapper, err := runner.initState(i)
if err != nil {
runner.Close() // Clean up already created states
return nil, err
}
runner.states[i] = wrapper
runner.stateSem <- i // Add index to semaphore
}
runner.isRunning.Store(true)
return runner, nil
}
// debugLog logs a message if debug mode is enabled
func (r *LuaRunner) debugLog(format string, args ...interface{}) {
if r.debug {
logger.Debug("[LuaRunner] "+format, args...)
}
}
// initState creates and initializes a new state
func (r *LuaRunner) initState(index int) (*StateWrapper, error) {
r.debugLog("Initializing Lua state %d", index)
// Create a new state
state := luajit.New()
if state == nil {
return nil, errors.New("failed to create Lua state")
}
r.debugLog("Created new Lua state %d", index)
// Create sandbox
sandbox := NewSandbox()
if r.debug {
sandbox.EnableDebug()
}
// Set up require paths and mechanism
if err := r.moduleLoader.SetupRequire(state); err != nil {
r.debugLog("Failed to set up require for state %d: %v", index, err)
state.Cleanup()
state.Close()
return nil, ErrInitFailed
}
r.debugLog("Require system initialized for state %d", index)
// Initialize all core modules from the registry
if err := GlobalRegistry.Initialize(state); err != nil {
r.debugLog("Failed to initialize core modules for state %d: %v", index, err)
state.Cleanup()
state.Close()
return nil, ErrInitFailed
}
r.debugLog("Core modules initialized for state %d", index)
// Set up sandbox after core modules are initialized
if err := sandbox.Setup(state); err != nil {
r.debugLog("Failed to set up sandbox for state %d: %v", index, err)
state.Cleanup()
state.Close()
return nil, ErrInitFailed
}
r.debugLog("Sandbox environment set up for state %d", index)
// Preload all modules into package.loaded
if err := r.moduleLoader.PreloadAllModules(state); err != nil {
r.debugLog("Failed to preload modules for state %d: %v", index, err)
state.Cleanup()
state.Close()
return nil, errors.New("failed to preload modules")
}
r.debugLog("All modules preloaded for state %d", index)
// Run init function if provided
if r.initFunc != nil {
if err := r.initFunc(state); err != nil {
r.debugLog("Custom init function failed for state %d: %v", index, err)
state.Cleanup()
state.Close()
return nil, ErrInitFailed
}
r.debugLog("Custom init function completed for state %d", index)
}
r.debugLog("State %d initialization complete", index)
return &StateWrapper{
state: state,
sandbox: sandbox,
index: index,
}, nil
}
// RunWithContext executes a Lua script with context and timeout
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
if !r.isRunning.Load() {
return nil, ErrRunnerClosed
}
// Create a result channel
resultChan := make(chan JobResult, 1)
// Get a state index with timeout
var stateIndex int
select {
case stateIndex = <-r.stateSem:
// Got a state
case <-ctx.Done():
return nil, ctx.Err()
}
// Launch a goroutine to execute the job
go func() {
// Make sure to return the state to the pool when done
defer func() {
// Only return if runner is still open
if r.isRunning.Load() {
select {
case r.stateSem <- stateIndex:
// State returned to pool
default:
// Pool is full or closed (shouldn't happen)
}
}
}()
// Execute the job
var result JobResult
r.mu.RLock()
state := r.states[stateIndex]
r.mu.RUnlock()
if state == nil {
result = JobResult{nil, errors.New("state is not initialized")}
} else {
// Set script directory for module resolution
if scriptPath != "" {
r.mu.Lock()
r.moduleLoader.config.ScriptDir = filepath.Dir(scriptPath)
r.mu.Unlock()
}
// Convert context
var ctxMap map[string]any
if execCtx != nil {
ctxMap = execCtx.Values
}
// Execute in sandbox
value, err := state.sandbox.Execute(state.state, bytecode, ctxMap)
result = JobResult{value, err}
}
// Send result
select {
case resultChan <- result:
// Result sent
default:
// Result channel closed or full (shouldn't happen with buffered channel)
}
}()
// Wait for result with context
select {
case result := <-resultChan:
return result.Value, result.Error
case <-ctx.Done():
// Note: we can't cancel the Lua execution, but we can stop waiting for it
// The state will be returned to the pool when the goroutine completes
return nil, ctx.Err()
}
}
// Run executes a Lua script
func (r *LuaRunner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
return r.RunWithContext(context.Background(), bytecode, execCtx, scriptPath)
}
// Close gracefully shuts down the LuaRunner
func (r *LuaRunner) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isRunning.Load() {
return ErrRunnerClosed
}
r.isRunning.Store(false)
// Drain the semaphore (non-blocking)
for {
select {
case <-r.stateSem:
// Drained one slot
default:
// Empty
goto drained
}
}
drained:
// Clean up all states
for i := 0; i < len(r.states); i++ {
if r.states[i] != nil {
r.states[i].state.Cleanup()
r.states[i].state.Close()
r.states[i] = nil
}
}
return nil
}
// NotifyFileChanged handles file change notifications from watchers
func (r *LuaRunner) NotifyFileChanged(filePath string) bool {
r.debugLog("File change detected: %s", filePath)
r.mu.Lock()
defer r.mu.Unlock()
// Check if runner is closed
if !r.isRunning.Load() {
return false
}
// Create a new semaphore
newSem := make(chan int, cap(r.stateSem))
// Drain the current semaphore (non-blocking)
for {
select {
case <-r.stateSem:
// Drained one slot
default:
// Empty
goto drained
}
}
drained:
r.stateSem = newSem
// Reinitialize all states
success := true
for i := 0; i < len(r.states); i++ {
// Clean up old state
if r.states[i] != nil {
r.states[i].state.Cleanup()
r.states[i].state.Close()
}
// Initialize new state
wrapper, err := r.initState(i)
if err != nil {
r.debugLog("Failed to reinitialize state %d: %v", i, err)
success = false
r.states[i] = nil
continue
}
r.states[i] = wrapper
// Add to semaphore
select {
case newSem <- i:
// Added to semaphore
default:
// Semaphore full (shouldn't happen)
}
}
return success
}
// ResetModuleCache clears non-core modules from package.loaded in all states
func (r *LuaRunner) ResetModuleCache() {
if r.moduleLoader != nil {
r.debugLog("Resetting module cache in all states")
r.mu.RLock()
defer r.mu.RUnlock()
for i := 0; i < len(r.states); i++ {
if r.states[i] != nil && r.states[i].state != nil {
r.moduleLoader.ResetModules(r.states[i].state)
}
}
}
}
// ReloadAllModules reloads all modules into package.loaded in all states
func (r *LuaRunner) ReloadAllModules() error {
if r.moduleLoader != nil {
r.debugLog("Reloading all modules in all states")
r.mu.RLock()
defer r.mu.RUnlock()
for i := 0; i < len(r.states); i++ {
if r.states[i] != nil && r.states[i].state != nil {
if err := r.moduleLoader.PreloadAllModules(r.states[i].state); err != nil {
return err
}
}
}
}
return nil
}
// RefreshModuleByName invalidates a specific module in package.loaded in all states
func (r *LuaRunner) RefreshModuleByName(modName string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
success := true
for i := 0; i < len(r.states); i++ {
if r.states[i] != nil && r.states[i].state != nil {
r.debugLog("Refreshing module %s in state %d", modName, i)
if err := r.states[i].state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
success = false
}
}
}
return success
}
// AddModule adds a module to all sandbox environments
func (r *LuaRunner) AddModule(name string, module any) {
r.debugLog("Adding module %s to all sandboxes", name)
r.mu.RLock()
defer r.mu.RUnlock()
for i := 0; i < len(r.states); i++ {
if r.states[i] != nil && r.states[i].sandbox != nil {
r.states[i].sandbox.AddModule(name, module)
}
}
}
// GetModuleCount returns the number of loaded modules in the first state
func (r *LuaRunner) GetModuleCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
// Get count from the first available state
for i := 0; i < len(r.states); i++ {
if r.states[i] != nil && r.states[i].state != nil {
// Execute a Lua snippet to count modules
if res, err := r.states[i].state.ExecuteWithResult(`
local count = 0
for _ in pairs(package.loaded) do
count = count + 1
end
return count
`); err == nil {
if num, ok := res.(float64); ok {
count = int(num)
}
}
break
}
}
return count
}