Moonshark/runner/runner.go

656 lines
14 KiB
Go

package runner
import (
"Moonshark/runner/lualibs"
"Moonshark/utils/color"
"Moonshark/utils/logger"
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
// Common errors
var (
ErrRunnerClosed = errors.New("lua runner is closed")
ErrInitFailed = errors.New("initialization failed")
ErrStateNotReady = errors.New("lua state not ready")
ErrTimeout = errors.New("operation timed out")
)
// RunnerOption defines a functional option for configuring the Runner
type RunnerOption func(*Runner)
// State wraps a Lua state with its sandbox
type State struct {
L *luajit.State // The Lua state
sandbox *Sandbox // Associated sandbox
index int // Index for debugging
inUse atomic.Bool // Whether the state is currently in use
}
// Runner runs Lua scripts using a pool of Lua states
type Runner struct {
states []*State // All states managed by this runner
statePool chan int // Pool of available state indexes
poolSize int // Size of the state pool
moduleLoader *ModuleLoader // Module loader
dataDir string // Data directory for SQLite databases
fsDir string // Virtual filesystem directory
isRunning atomic.Bool // Whether the runner is active
mu sync.RWMutex // Mutex for thread safety
scriptDir string // Current script directory
}
// WithPoolSize sets the state pool size
func WithPoolSize(size int) RunnerOption {
return func(r *Runner) {
if size > 0 {
r.poolSize = size
}
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *Runner) {
if r.moduleLoader == nil {
r.moduleLoader = NewModuleLoader(&ModuleConfig{
LibDirs: dirs,
})
} else {
r.moduleLoader.config.LibDirs = dirs
}
}
}
// WithDataDir sets the data directory for SQLite databases
func WithDataDir(dataDir string) RunnerOption {
return func(r *Runner) {
if dataDir != "" {
r.dataDir = dataDir
}
}
}
// WithFsDir sets the virtual filesystem directory
func WithFsDir(fsDir string) RunnerOption {
return func(r *Runner) {
if fsDir != "" {
r.fsDir = fsDir
}
}
}
// NewRunner creates a new Runner with a pool of states
func NewRunner(options ...RunnerOption) (*Runner, error) {
// Default configuration
runner := &Runner{
poolSize: runtime.GOMAXPROCS(0),
dataDir: "data",
fsDir: "fs",
}
// Apply options
for _, opt := range options {
opt(runner)
}
// Set up module loader if not already initialized
if runner.moduleLoader == nil {
config := &ModuleConfig{
ScriptDir: "",
LibDirs: []string{},
}
runner.moduleLoader = NewModuleLoader(config)
}
lualibs.InitSQLite(runner.dataDir)
lualibs.InitFS(runner.fsDir)
lualibs.SetSQLitePoolSize(runner.poolSize)
// Initialize states and pool
runner.states = make([]*State, runner.poolSize)
runner.statePool = make(chan int, runner.poolSize)
// Create and initialize all states
if err := runner.initializeStates(); err != nil {
lualibs.CleanupSQLite()
runner.Close()
return nil, err
}
runner.isRunning.Store(true)
return runner, nil
}
// initializeStates creates and initializes all states in the pool
func (r *Runner) initializeStates() error {
logger.Infof("[LuaRunner] Creating %s states...", color.Yellow(strconv.Itoa(r.poolSize)))
for i := range r.poolSize {
state, err := r.createState(i)
if err != nil {
return err
}
r.states[i] = state
r.statePool <- i // Add index to the pool
}
return nil
}
// createState initializes a new Lua state
func (r *Runner) createState(index int) (*State, error) {
verbose := index == 0
if verbose {
logger.Debugf("Creating Lua state %d", index)
}
L := luajit.New(true) // Explicitly open standard libraries
if L == nil {
return nil, errors.New("failed to create Lua state")
}
sb := NewSandbox()
// Set up sandbox
if err := sb.Setup(L, verbose); err != nil {
L.Cleanup()
L.Close()
return nil, ErrInitFailed
}
// Set up module loader
if err := r.moduleLoader.SetupRequire(L); err != nil {
L.Cleanup()
L.Close()
return nil, ErrInitFailed
}
// Preload modules
if err := r.moduleLoader.PreloadModules(L); err != nil {
L.Cleanup()
L.Close()
return nil, errors.New("failed to preload modules")
}
if verbose {
logger.Debugf("Lua state %d initialized successfully", index)
}
return &State{
L: L,
sandbox: sb,
index: index,
}, nil
}
// Execute runs a script in a sandbox with context
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
if !r.isRunning.Load() {
return nil, ErrRunnerClosed
}
// Set script directory if provided
if scriptPath != "" {
r.mu.Lock()
r.scriptDir = filepath.Dir(scriptPath)
r.moduleLoader.SetScriptDir(r.scriptDir)
r.mu.Unlock()
}
// Get a state from the pool
var stateIndex int
select {
case stateIndex = <-r.statePool:
// Got a state
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(1 * time.Second):
return nil, ErrTimeout
}
state := r.states[stateIndex]
if state == nil {
r.statePool <- stateIndex
return nil, ErrStateNotReady
}
// Use atomic operations
state.inUse.Store(true)
defer func() {
state.inUse.Store(false)
if r.isRunning.Load() {
select {
case r.statePool <- stateIndex:
default:
// Pool is full or closed, state will be cleaned up by Close()
}
}
}()
// Execute in sandbox
response, err := state.sandbox.Execute(state.L, bytecode, execCtx)
if err != nil {
return nil, err
}
return response, nil
}
// Run executes a Lua script with immediate context
func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
return r.Execute(context.Background(), bytecode, execCtx, scriptPath)
}
// Close gracefully shuts down the Runner
func (r *Runner) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isRunning.Load() {
return ErrRunnerClosed
}
r.isRunning.Store(false)
// Drain all states from the pool
for {
select {
case <-r.statePool:
default:
goto waitForInUse
}
}
waitForInUse:
// Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warnf("Timeout waiting for states to finish during shutdown, forcing close")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely close all states
for i, state := range r.states {
if state != nil {
if state.inUse.Load() {
logger.Warnf("Force closing state %d that is still in use", i)
}
state.L.Cleanup()
state.L.Close()
r.states[i] = nil
}
}
lualibs.CleanupFS()
lualibs.CleanupSQLite()
logger.Debugf("Runner closed")
return nil
}
// RefreshStates rebuilds all states in the pool
func (r *Runner) RefreshStates() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isRunning.Load() {
return ErrRunnerClosed
}
logger.Infof("Runner is refreshing all states...")
// Drain all states from the pool
for {
select {
case <-r.statePool:
default:
goto waitForInUse
}
}
waitForInUse:
// Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warnf("Timeout waiting for states to finish, forcing refresh")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely destroy all states
for i, state := range r.states {
if state != nil {
if state.inUse.Load() {
logger.Warnf("Force closing state %d that is still in use", i)
}
state.L.Cleanup()
state.L.Close()
r.states[i] = nil
}
}
// Reinitialize all states
if err := r.initializeStates(); err != nil {
return err
}
logger.Debugf("All states refreshed successfully")
return nil
}
// NotifyFileChanged alerts the runner about file changes
func (r *Runner) NotifyFileChanged(filePath string) bool {
logger.Debugf("Runner notified of file change: %s", filePath)
module, isModule := r.moduleLoader.GetModuleByPath(filePath)
if isModule {
logger.Debugf("Refreshing module: %s", module)
return r.RefreshModule(module)
}
logger.Debugf("File change noted but no refresh needed: %s", filePath)
return true
}
// RefreshModule refreshes a specific module across all states
func (r *Runner) RefreshModule(moduleName string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if !r.isRunning.Load() {
return false
}
logger.Debugf("Refreshing module: %s", moduleName)
success := true
for _, state := range r.states {
if state == nil || state.inUse.Load() {
continue
}
// Use the enhanced module loader refresh
if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil {
success = false
logger.Debugf("Failed to refresh module %s in state %d: %v", moduleName, state.index, err)
}
}
if success {
logger.Debugf("Successfully refreshed module: %s", moduleName)
}
return success
}
// RefreshModuleByPath refreshes a module by its file path
func (r *Runner) RefreshModuleByPath(filePath string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if !r.isRunning.Load() {
return false
}
logger.Debugf("Refreshing module by path: %s", filePath)
success := true
for _, state := range r.states {
if state == nil || state.inUse.Load() {
continue
}
// Use the enhanced module loader refresh by path
if err := r.moduleLoader.RefreshModuleByPath(state.L, filePath); err != nil {
success = false
logger.Debugf("Failed to refresh module at %s in state %d: %v", filePath, state.index, err)
}
}
return success
}
// GetStateCount returns the number of initialized states
func (r *Runner) GetStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil {
count++
}
}
return count
}
// GetActiveStateCount returns the number of states currently in use
func (r *Runner) GetActiveStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil && state.inUse.Load() {
count++
}
}
return count
}
// RunScriptFile loads, compiles and executes a Lua script file
func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
if !r.isRunning.Load() {
return nil, ErrRunnerClosed
}
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil, fmt.Errorf("script file not found: %s", filePath)
}
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
absPath, err := filepath.Abs(filePath)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path: %w", err)
}
scriptDir := filepath.Dir(absPath)
r.mu.Lock()
prevScriptDir := r.scriptDir
r.scriptDir = scriptDir
r.moduleLoader.SetScriptDir(scriptDir)
r.mu.Unlock()
defer func() {
r.mu.Lock()
r.scriptDir = prevScriptDir
r.moduleLoader.SetScriptDir(prevScriptDir)
r.mu.Unlock()
}()
var stateIndex int
select {
case stateIndex = <-r.statePool:
// Got a state
case <-time.After(5 * time.Second):
return nil, ErrTimeout
}
state := r.states[stateIndex]
if state == nil {
r.statePool <- stateIndex
return nil, ErrStateNotReady
}
state.inUse.Store(true)
defer func() {
state.inUse.Store(false)
if r.isRunning.Load() {
select {
case r.statePool <- stateIndex:
// State returned to pool
default:
// Pool is full or closed
}
}
}()
bytecode, err := state.L.CompileBytecode(string(content), filepath.Base(absPath))
if err != nil {
return nil, fmt.Errorf("compilation error: %w", err)
}
ctx := NewContext()
defer ctx.Release()
ctx.Set("_script_path", absPath)
ctx.Set("_script_dir", scriptDir)
response, err := state.sandbox.Execute(state.L, bytecode, ctx)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
}
return response, nil
}
// ApplyResponse applies a Response to a fasthttp.RequestCtx
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
// Set status code
ctx.SetStatusCode(resp.Status)
// Set headers
for name, value := range resp.Headers {
ctx.Response.Header.Set(name, value)
}
// Set cookies
for _, cookie := range resp.Cookies {
ctx.Response.Header.SetCookie(cookie)
}
// Process the body based on its type
if resp.Body == nil {
return
}
// Check if Content-Type was manually set
contentTypeSet := false
for name := range resp.Headers {
if strings.ToLower(name) == "content-type" {
contentTypeSet = true
break
}
}
// Get a buffer from the pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Set body based on type
switch body := resp.Body.(type) {
case string:
if !contentTypeSet {
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
}
ctx.SetBodyString(body)
case []byte:
if !contentTypeSet {
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
}
ctx.SetBody(body)
case map[string]any, map[any]any, []any, []float64, []string, []int, []map[string]any:
// Marshal JSON
if err := json.NewEncoder(buf).Encode(body); err == nil {
if !contentTypeSet {
ctx.Response.Header.SetContentType("application/json")
}
ctx.SetBody(buf.Bytes())
} else {
// Fallback to string representation
if !contentTypeSet {
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
}
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
default:
// Check if it's any other map or slice type
typeStr := fmt.Sprintf("%T", body)
if typeStr[0] == '[' || (len(typeStr) > 3 && typeStr[:3] == "map") {
if err := json.NewEncoder(buf).Encode(body); err == nil {
if !contentTypeSet {
ctx.Response.Header.SetContentType("application/json")
}
ctx.SetBody(buf.Bytes())
} else {
if !contentTypeSet {
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
}
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
} else {
// Default to string representation
if !contentTypeSet {
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
}
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
}
}