536 lines
13 KiB
Go
536 lines
13 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
|
)
|
|
|
|
// Common errors
|
|
var (
|
|
ErrRunnerClosed = errors.New("lua runner is closed")
|
|
ErrInitFailed = errors.New("initialization failed")
|
|
)
|
|
|
|
// StateInitFunc is a function that initializes a Lua state
|
|
type StateInitFunc func(*luajit.State) error
|
|
|
|
// RunnerOption defines a functional option for configuring the LuaRunner
|
|
type RunnerOption func(*LuaRunner)
|
|
|
|
// JobResult represents the result of a Lua script execution
|
|
type JobResult struct {
|
|
Value any // Return value from Lua
|
|
Error error // Error if any
|
|
}
|
|
|
|
// StateWrapper wraps a Lua state with its sandbox
|
|
type StateWrapper struct {
|
|
state *luajit.State // The Lua state
|
|
sandbox *Sandbox // Associated sandbox
|
|
index int // Index for debugging
|
|
}
|
|
|
|
// InitHook is a function that runs before executing a script
|
|
type InitHook func(*luajit.State, *Context) error
|
|
|
|
// FinalizeHook is a function that runs after executing a script
|
|
type FinalizeHook func(*luajit.State, *Context, any) error
|
|
|
|
// LuaRunner runs Lua scripts using a pool of Lua states
|
|
type LuaRunner struct {
|
|
states []*StateWrapper // Pool of Lua states
|
|
stateSem chan int // Semaphore with state indexes
|
|
poolSize int // Size of the state pool
|
|
initFunc StateInitFunc // Optional function to initialize Lua states
|
|
moduleLoader *NativeModuleLoader // Native module loader for require
|
|
isRunning atomic.Bool // Flag indicating if the runner is active
|
|
mu sync.RWMutex // Mutex for thread safety
|
|
debug bool // Enable debug logging
|
|
initHooks []InitHook // Hooks to run before script execution
|
|
finalizeHooks []FinalizeHook // Hooks to run after script execution
|
|
}
|
|
|
|
// WithPoolSize sets the state pool size
|
|
func WithPoolSize(size int) RunnerOption {
|
|
return func(r *LuaRunner) {
|
|
if size > 0 {
|
|
r.poolSize = size
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithInitFunc sets the init function for the Lua state
|
|
func WithInitFunc(initFunc StateInitFunc) RunnerOption {
|
|
return func(r *LuaRunner) {
|
|
r.initFunc = initFunc
|
|
}
|
|
}
|
|
|
|
// WithLibDirs sets additional library directories
|
|
func WithLibDirs(dirs ...string) RunnerOption {
|
|
return func(r *LuaRunner) {
|
|
if r.moduleLoader == nil || r.moduleLoader.config == nil {
|
|
r.moduleLoader = NewNativeModuleLoader(&RequireConfig{
|
|
LibDirs: dirs,
|
|
})
|
|
} else {
|
|
r.moduleLoader.config.LibDirs = dirs
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithDebugEnabled enables debug output
|
|
func WithDebugEnabled() RunnerOption {
|
|
return func(r *LuaRunner) {
|
|
r.debug = true
|
|
}
|
|
}
|
|
|
|
// NewRunner creates a new LuaRunner with a pool of states
|
|
func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
|
// Default configuration
|
|
runner := &LuaRunner{
|
|
poolSize: runtime.GOMAXPROCS(0),
|
|
debug: false,
|
|
initHooks: make([]InitHook, 0),
|
|
finalizeHooks: make([]FinalizeHook, 0),
|
|
}
|
|
|
|
// Apply options
|
|
for _, opt := range options {
|
|
opt(runner)
|
|
}
|
|
|
|
// Set up module loader if not already initialized
|
|
if runner.moduleLoader == nil {
|
|
requireConfig := &RequireConfig{
|
|
ScriptDir: "",
|
|
LibDirs: []string{},
|
|
}
|
|
runner.moduleLoader = NewNativeModuleLoader(requireConfig)
|
|
}
|
|
|
|
// Initialize states and semaphore
|
|
runner.states = make([]*StateWrapper, runner.poolSize)
|
|
runner.stateSem = make(chan int, runner.poolSize)
|
|
|
|
// Create and initialize all states
|
|
for i := 0; i < runner.poolSize; i++ {
|
|
wrapper, err := runner.initState(i)
|
|
if err != nil {
|
|
runner.Close() // Clean up already created states
|
|
return nil, err
|
|
}
|
|
|
|
runner.states[i] = wrapper
|
|
runner.stateSem <- i // Add index to semaphore
|
|
}
|
|
|
|
runner.isRunning.Store(true)
|
|
|
|
return runner, nil
|
|
}
|
|
|
|
// debugLog logs a message if debug mode is enabled
|
|
func (r *LuaRunner) debugLog(format string, args ...interface{}) {
|
|
if r.debug {
|
|
logger.Debug("[LuaRunner] "+format, args...)
|
|
}
|
|
}
|
|
|
|
// initState creates and initializes a new state
|
|
func (r *LuaRunner) initState(index int) (*StateWrapper, error) {
|
|
r.debugLog("Initializing Lua state %d", index)
|
|
|
|
// Create a new state
|
|
state := luajit.New()
|
|
if state == nil {
|
|
return nil, errors.New("failed to create Lua state")
|
|
}
|
|
r.debugLog("Created new Lua state %d", index)
|
|
|
|
// Create sandbox
|
|
sandbox := NewSandbox()
|
|
if r.debug {
|
|
sandbox.EnableDebug()
|
|
}
|
|
|
|
// Set up require paths and mechanism
|
|
if err := r.moduleLoader.SetupRequire(state); err != nil {
|
|
r.debugLog("Failed to set up require for state %d: %v", index, err)
|
|
state.Cleanup()
|
|
state.Close()
|
|
return nil, ErrInitFailed
|
|
}
|
|
r.debugLog("Require system initialized for state %d", index)
|
|
|
|
// Initialize all core modules from the registry
|
|
if err := GlobalRegistry.Initialize(state); err != nil {
|
|
r.debugLog("Failed to initialize core modules for state %d: %v", index, err)
|
|
state.Cleanup()
|
|
state.Close()
|
|
return nil, ErrInitFailed
|
|
}
|
|
r.debugLog("Core modules initialized for state %d", index)
|
|
|
|
// Set up sandbox after core modules are initialized
|
|
if err := sandbox.Setup(state); err != nil {
|
|
r.debugLog("Failed to set up sandbox for state %d: %v", index, err)
|
|
state.Cleanup()
|
|
state.Close()
|
|
return nil, ErrInitFailed
|
|
}
|
|
r.debugLog("Sandbox environment set up for state %d", index)
|
|
|
|
// Preload all modules into package.loaded
|
|
if err := r.moduleLoader.PreloadAllModules(state); err != nil {
|
|
r.debugLog("Failed to preload modules for state %d: %v", index, err)
|
|
state.Cleanup()
|
|
state.Close()
|
|
return nil, errors.New("failed to preload modules")
|
|
}
|
|
r.debugLog("All modules preloaded for state %d", index)
|
|
|
|
// Run init function if provided
|
|
if r.initFunc != nil {
|
|
if err := r.initFunc(state); err != nil {
|
|
r.debugLog("Custom init function failed for state %d: %v", index, err)
|
|
state.Cleanup()
|
|
state.Close()
|
|
return nil, ErrInitFailed
|
|
}
|
|
r.debugLog("Custom init function completed for state %d", index)
|
|
}
|
|
|
|
r.debugLog("State %d initialization complete", index)
|
|
|
|
return &StateWrapper{
|
|
state: state,
|
|
sandbox: sandbox,
|
|
index: index,
|
|
}, nil
|
|
}
|
|
|
|
// AddInitHook adds a hook to be called before script execution
|
|
func (r *LuaRunner) AddInitHook(hook InitHook) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.initHooks = append(r.initHooks, hook)
|
|
}
|
|
|
|
// AddFinalizeHook adds a hook to be called after script execution
|
|
func (r *LuaRunner) AddFinalizeHook(hook FinalizeHook) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.finalizeHooks = append(r.finalizeHooks, hook)
|
|
}
|
|
|
|
// RunWithContext executes a Lua script with context and timeout
|
|
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
|
if !r.isRunning.Load() {
|
|
return nil, ErrRunnerClosed
|
|
}
|
|
|
|
// Create a result channel
|
|
resultChan := make(chan JobResult, 1)
|
|
|
|
// Get a state index with timeout
|
|
var stateIndex int
|
|
select {
|
|
case stateIndex = <-r.stateSem:
|
|
// Got a state
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
// Launch a goroutine to execute the job
|
|
go func() {
|
|
// Make sure to return the state to the pool when done
|
|
defer func() {
|
|
// Only return if runner is still open
|
|
if r.isRunning.Load() {
|
|
select {
|
|
case r.stateSem <- stateIndex:
|
|
// State returned to pool
|
|
default:
|
|
// Pool is full or closed (shouldn't happen)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Execute the job
|
|
var result JobResult
|
|
|
|
r.mu.RLock()
|
|
state := r.states[stateIndex]
|
|
// Copy hooks to ensure we don't hold the lock during execution
|
|
initHooks := make([]InitHook, len(r.initHooks))
|
|
copy(initHooks, r.initHooks)
|
|
finalizeHooks := make([]FinalizeHook, len(r.finalizeHooks))
|
|
copy(finalizeHooks, r.finalizeHooks)
|
|
r.mu.RUnlock()
|
|
|
|
if state == nil {
|
|
result = JobResult{nil, errors.New("state is not initialized")}
|
|
} else {
|
|
// Set script directory for module resolution
|
|
if scriptPath != "" {
|
|
r.mu.Lock()
|
|
r.moduleLoader.config.ScriptDir = filepath.Dir(scriptPath)
|
|
r.mu.Unlock()
|
|
}
|
|
|
|
// Run init hooks
|
|
for _, hook := range initHooks {
|
|
if err := hook(state.state, execCtx); err != nil {
|
|
result = JobResult{nil, err}
|
|
// Send result and return early
|
|
select {
|
|
case resultChan <- result:
|
|
default:
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// Convert context
|
|
var ctxMap map[string]any
|
|
if execCtx != nil {
|
|
ctxMap = execCtx.Values
|
|
}
|
|
|
|
// Execute in sandbox
|
|
value, err := state.sandbox.Execute(state.state, bytecode, ctxMap)
|
|
|
|
// Run finalize hooks
|
|
for _, hook := range finalizeHooks {
|
|
hookErr := hook(state.state, execCtx, value)
|
|
if hookErr != nil && err == nil {
|
|
// Only override nil errors
|
|
err = hookErr
|
|
}
|
|
}
|
|
|
|
result = JobResult{value, err}
|
|
}
|
|
|
|
// Send result
|
|
select {
|
|
case resultChan <- result:
|
|
// Result sent
|
|
default:
|
|
// Result channel closed or full (shouldn't happen with buffered channel)
|
|
}
|
|
}()
|
|
|
|
// Wait for result with context
|
|
select {
|
|
case result := <-resultChan:
|
|
return result.Value, result.Error
|
|
case <-ctx.Done():
|
|
// Note: we can't cancel the Lua execution, but we can stop waiting for it
|
|
// The state will be returned to the pool when the goroutine completes
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
// Run executes a Lua script
|
|
func (r *LuaRunner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
|
return r.RunWithContext(context.Background(), bytecode, execCtx, scriptPath)
|
|
}
|
|
|
|
// Close gracefully shuts down the LuaRunner
|
|
func (r *LuaRunner) Close() error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
if !r.isRunning.Load() {
|
|
return ErrRunnerClosed
|
|
}
|
|
|
|
r.isRunning.Store(false)
|
|
|
|
// Drain the semaphore (non-blocking)
|
|
for {
|
|
select {
|
|
case <-r.stateSem:
|
|
// Drained one slot
|
|
default:
|
|
// Empty
|
|
goto drained
|
|
}
|
|
}
|
|
drained:
|
|
|
|
// Clean up all states
|
|
for i := 0; i < len(r.states); i++ {
|
|
if r.states[i] != nil {
|
|
r.states[i].state.Cleanup()
|
|
r.states[i].state.Close()
|
|
r.states[i] = nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NotifyFileChanged handles file change notifications from watchers
|
|
func (r *LuaRunner) NotifyFileChanged(filePath string) bool {
|
|
r.debugLog("File change detected: %s", filePath)
|
|
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
// Check if runner is closed
|
|
if !r.isRunning.Load() {
|
|
return false
|
|
}
|
|
|
|
// Create a new semaphore
|
|
newSem := make(chan int, cap(r.stateSem))
|
|
|
|
// Drain the current semaphore (non-blocking)
|
|
for {
|
|
select {
|
|
case <-r.stateSem:
|
|
// Drained one slot
|
|
default:
|
|
// Empty
|
|
goto drained
|
|
}
|
|
}
|
|
drained:
|
|
|
|
r.stateSem = newSem
|
|
|
|
// Reinitialize all states
|
|
success := true
|
|
for i := 0; i < len(r.states); i++ {
|
|
// Clean up old state
|
|
if r.states[i] != nil {
|
|
r.states[i].state.Cleanup()
|
|
r.states[i].state.Close()
|
|
}
|
|
|
|
// Initialize new state
|
|
wrapper, err := r.initState(i)
|
|
if err != nil {
|
|
r.debugLog("Failed to reinitialize state %d: %v", i, err)
|
|
success = false
|
|
r.states[i] = nil
|
|
continue
|
|
}
|
|
|
|
r.states[i] = wrapper
|
|
|
|
// Add to semaphore
|
|
select {
|
|
case newSem <- i:
|
|
// Added to semaphore
|
|
default:
|
|
// Semaphore full (shouldn't happen)
|
|
}
|
|
}
|
|
|
|
return success
|
|
}
|
|
|
|
// ResetModuleCache clears non-core modules from package.loaded in all states
|
|
func (r *LuaRunner) ResetModuleCache() {
|
|
if r.moduleLoader != nil {
|
|
r.debugLog("Resetting module cache in all states")
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
for i := 0; i < len(r.states); i++ {
|
|
if r.states[i] != nil && r.states[i].state != nil {
|
|
r.moduleLoader.ResetModules(r.states[i].state)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReloadAllModules reloads all modules into package.loaded in all states
|
|
func (r *LuaRunner) ReloadAllModules() error {
|
|
if r.moduleLoader != nil {
|
|
r.debugLog("Reloading all modules in all states")
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
for i := 0; i < len(r.states); i++ {
|
|
if r.states[i] != nil && r.states[i].state != nil {
|
|
if err := r.moduleLoader.PreloadAllModules(r.states[i].state); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RefreshModuleByName invalidates a specific module in package.loaded in all states
|
|
func (r *LuaRunner) RefreshModuleByName(modName string) bool {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
success := true
|
|
for i := 0; i < len(r.states); i++ {
|
|
if r.states[i] != nil && r.states[i].state != nil {
|
|
r.debugLog("Refreshing module %s in state %d", modName, i)
|
|
if err := r.states[i].state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
|
|
success = false
|
|
}
|
|
}
|
|
}
|
|
return success
|
|
}
|
|
|
|
// AddModule adds a module to all sandbox environments
|
|
func (r *LuaRunner) AddModule(name string, module any) {
|
|
r.debugLog("Adding module %s to all sandboxes", name)
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
for i := 0; i < len(r.states); i++ {
|
|
if r.states[i] != nil && r.states[i].sandbox != nil {
|
|
r.states[i].sandbox.AddModule(name, module)
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetModuleCount returns the number of loaded modules in the first state
|
|
func (r *LuaRunner) GetModuleCount() int {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
count := 0
|
|
|
|
// Get count from the first available state
|
|
for i := 0; i < len(r.states); i++ {
|
|
if r.states[i] != nil && r.states[i].state != nil {
|
|
// Execute a Lua snippet to count modules
|
|
if res, err := r.states[i].state.ExecuteWithResult(`
|
|
local count = 0
|
|
for _ in pairs(package.loaded) do
|
|
count = count + 1
|
|
end
|
|
return count
|
|
`); err == nil {
|
|
if num, ok := res.(float64); ok {
|
|
count = int(num)
|
|
}
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
return count
|
|
}
|