542 lines
12 KiB
Go
542 lines
12 KiB
Go
package runner
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"Moonshark/router"
|
|
"Moonshark/runner/lualibs"
|
|
"Moonshark/sessions"
|
|
"Moonshark/utils/logger"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
"github.com/goccy/go-json"
|
|
"github.com/valyala/bytebufferpool"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
var emptyMap = make(map[string]any)
|
|
|
|
var (
|
|
ErrRunnerClosed = errors.New("lua runner is closed")
|
|
ErrTimeout = errors.New("operation timed out")
|
|
ErrStateNotReady = errors.New("lua state not ready")
|
|
)
|
|
|
|
type State struct {
|
|
L *luajit.State
|
|
sandbox *Sandbox
|
|
index int
|
|
inUse atomic.Bool
|
|
}
|
|
|
|
type Runner struct {
|
|
states []*State
|
|
statePool chan int
|
|
poolSize int
|
|
moduleLoader *ModuleLoader
|
|
isRunning atomic.Bool
|
|
mu sync.RWMutex
|
|
scriptDir string
|
|
|
|
// Pre-allocated pools for HTTP processing
|
|
ctxPool sync.Pool
|
|
paramsPool sync.Pool
|
|
}
|
|
|
|
func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner, error) {
|
|
if poolSize <= 0 {
|
|
poolSize = runtime.GOMAXPROCS(0)
|
|
}
|
|
|
|
// Configure module loader with lib directories
|
|
moduleConfig := &ModuleConfig{
|
|
LibDirs: libDirs,
|
|
}
|
|
|
|
r := &Runner{
|
|
poolSize: poolSize,
|
|
moduleLoader: NewModuleLoader(moduleConfig),
|
|
ctxPool: sync.Pool{
|
|
New: func() any { return make(map[string]any, 8) },
|
|
},
|
|
paramsPool: sync.Pool{
|
|
New: func() any { return make(map[string]any, 4) },
|
|
},
|
|
}
|
|
|
|
lualibs.InitSQLite(dataDir)
|
|
lualibs.InitFS(fsDir)
|
|
lualibs.SetSQLitePoolSize(poolSize)
|
|
|
|
r.states = make([]*State, poolSize)
|
|
r.statePool = make(chan int, poolSize)
|
|
|
|
if err := r.initStates(); err != nil {
|
|
lualibs.CleanupSQLite()
|
|
return nil, err
|
|
}
|
|
|
|
r.isRunning.Store(true)
|
|
return r, nil
|
|
}
|
|
|
|
// Single entry point for HTTP execution
|
|
func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx,
|
|
params *router.Params, session *sessions.Session) (*Response, error) {
|
|
|
|
if !r.isRunning.Load() {
|
|
return nil, ErrRunnerClosed
|
|
}
|
|
|
|
// Get state with timeout
|
|
var stateIndex int
|
|
select {
|
|
case stateIndex = <-r.statePool:
|
|
case <-time.After(time.Second):
|
|
return nil, ErrTimeout
|
|
}
|
|
|
|
state := r.states[stateIndex]
|
|
state.inUse.Store(true)
|
|
|
|
defer func() {
|
|
state.inUse.Store(false)
|
|
if r.isRunning.Load() {
|
|
select {
|
|
case r.statePool <- stateIndex:
|
|
default:
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Build Lua context directly from HTTP request
|
|
luaCtx := r.buildHTTPContext(httpCtx, params, session)
|
|
defer r.releaseHTTPContext(luaCtx)
|
|
|
|
return state.sandbox.Execute(state.L, bytecode, luaCtx)
|
|
}
|
|
|
|
// Build Lua context from HTTP request
|
|
func (r *Runner) buildHTTPContext(ctx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *Context {
|
|
luaCtx := NewContext()
|
|
|
|
// Basic request info
|
|
luaCtx.Set("method", string(ctx.Method()))
|
|
luaCtx.Set("path", string(ctx.Path()))
|
|
luaCtx.Set("host", string(ctx.Host()))
|
|
|
|
// Headers
|
|
headers := r.ctxPool.Get().(map[string]any)
|
|
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
|
headers[string(key)] = string(value)
|
|
})
|
|
luaCtx.Set("headers", headers)
|
|
|
|
// Route parameters
|
|
if params != nil && len(params.Keys) > 0 {
|
|
paramMap := r.paramsPool.Get().(map[string]any)
|
|
for i, key := range params.Keys {
|
|
if i < len(params.Values) {
|
|
paramMap[key] = params.Values[i]
|
|
}
|
|
}
|
|
luaCtx.Set("params", paramMap)
|
|
} else {
|
|
luaCtx.Set("params", emptyMap)
|
|
}
|
|
|
|
// Form data for POST/PUT/PATCH
|
|
method := ctx.Method()
|
|
if string(method) == "POST" || string(method) == "PUT" || string(method) == "PATCH" {
|
|
if formData := parseForm(ctx); formData != nil {
|
|
luaCtx.Set("form", formData)
|
|
} else {
|
|
luaCtx.Set("form", emptyMap)
|
|
}
|
|
} else {
|
|
luaCtx.Set("form", emptyMap)
|
|
}
|
|
|
|
// Session data
|
|
sessionMap := r.ctxPool.Get().(map[string]any)
|
|
session.AdvanceFlash()
|
|
sessionMap["id"] = session.ID
|
|
|
|
if !session.IsEmpty() {
|
|
sessionMap["data"] = session.GetAll()
|
|
sessionMap["flash"] = session.GetAllFlash()
|
|
} else {
|
|
sessionMap["data"] = emptyMap
|
|
sessionMap["flash"] = emptyMap
|
|
}
|
|
luaCtx.Set("session", sessionMap)
|
|
|
|
// Environment variables
|
|
if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil {
|
|
luaCtx.Set("env", envMgr.GetAll())
|
|
}
|
|
|
|
return luaCtx
|
|
}
|
|
|
|
func (r *Runner) releaseHTTPContext(luaCtx *Context) {
|
|
// Return pooled maps
|
|
if headers, ok := luaCtx.Get("headers").(map[string]any); ok {
|
|
for k := range headers {
|
|
delete(headers, k)
|
|
}
|
|
r.ctxPool.Put(headers)
|
|
}
|
|
|
|
if params, ok := luaCtx.Get("params").(map[string]any); ok && len(params) > 0 {
|
|
for k := range params {
|
|
delete(params, k)
|
|
}
|
|
r.paramsPool.Put(params)
|
|
}
|
|
|
|
if sessionMap, ok := luaCtx.Get("session").(map[string]any); ok {
|
|
for k := range sessionMap {
|
|
delete(sessionMap, k)
|
|
}
|
|
r.ctxPool.Put(sessionMap)
|
|
}
|
|
|
|
luaCtx.Release()
|
|
}
|
|
|
|
func (r *Runner) initStates() error {
|
|
logger.Infof("[LuaRunner] Creating %d states...", r.poolSize)
|
|
|
|
for i := range r.poolSize {
|
|
state, err := r.createState(i)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.states[i] = state
|
|
r.statePool <- i
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *Runner) createState(index int) (*State, error) {
|
|
L := luajit.New(true)
|
|
if L == nil {
|
|
return nil, errors.New("failed to create Lua state")
|
|
}
|
|
|
|
sb := NewSandbox()
|
|
if err := sb.Setup(L, index == 0); err != nil {
|
|
L.Cleanup()
|
|
L.Close()
|
|
return nil, err
|
|
}
|
|
|
|
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
|
L.Cleanup()
|
|
L.Close()
|
|
return nil, err
|
|
}
|
|
|
|
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
|
L.Cleanup()
|
|
L.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &State{L: L, sandbox: sb, index: index}, nil
|
|
}
|
|
|
|
func (r *Runner) Close() error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
if !r.isRunning.Load() {
|
|
return ErrRunnerClosed
|
|
}
|
|
r.isRunning.Store(false)
|
|
|
|
// Drain pool
|
|
for {
|
|
select {
|
|
case <-r.statePool:
|
|
default:
|
|
goto cleanup
|
|
}
|
|
}
|
|
|
|
cleanup:
|
|
// Wait for states to finish
|
|
timeout := time.Now().Add(10 * time.Second)
|
|
for time.Now().Before(timeout) {
|
|
allIdle := true
|
|
for _, state := range r.states {
|
|
if state != nil && state.inUse.Load() {
|
|
allIdle = false
|
|
break
|
|
}
|
|
}
|
|
if allIdle {
|
|
break
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
// Close states
|
|
for i, state := range r.states {
|
|
if state != nil {
|
|
state.L.Cleanup()
|
|
state.L.Close()
|
|
r.states[i] = nil
|
|
}
|
|
}
|
|
|
|
lualibs.CleanupFS()
|
|
lualibs.CleanupSQLite()
|
|
return nil
|
|
}
|
|
|
|
// parseForm extracts form data from HTTP request
|
|
func parseForm(ctx *fasthttp.RequestCtx) map[string]any {
|
|
form := make(map[string]any)
|
|
|
|
// Parse POST form data
|
|
ctx.PostArgs().VisitAll(func(key, value []byte) {
|
|
form[string(key)] = string(value)
|
|
})
|
|
|
|
// Parse multipart form if present
|
|
if multipartForm, err := ctx.MultipartForm(); err == nil {
|
|
for key, values := range multipartForm.Value {
|
|
if len(values) == 1 {
|
|
form[key] = values[0]
|
|
} else {
|
|
form[key] = values
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(form) == 0 {
|
|
return nil
|
|
}
|
|
return form
|
|
}
|
|
|
|
// NotifyFileChanged alerts the runner about file changes
|
|
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
|
logger.Debugf("Runner notified of file change: %s", filePath)
|
|
|
|
module, isModule := r.moduleLoader.GetModuleByPath(filePath)
|
|
if isModule {
|
|
logger.Debugf("Refreshing module: %s", module)
|
|
return r.RefreshModule(module)
|
|
}
|
|
|
|
logger.Debugf("File change noted but no refresh needed: %s", filePath)
|
|
return true
|
|
}
|
|
|
|
// RefreshModule refreshes a specific module across all states
|
|
func (r *Runner) RefreshModule(moduleName string) bool {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
if !r.isRunning.Load() {
|
|
return false
|
|
}
|
|
|
|
logger.Debugf("Refreshing module: %s", moduleName)
|
|
|
|
success := true
|
|
for _, state := range r.states {
|
|
if state == nil || state.inUse.Load() {
|
|
continue
|
|
}
|
|
|
|
if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil {
|
|
success = false
|
|
logger.Debugf("Failed to refresh module %s in state %d: %v", moduleName, state.index, err)
|
|
}
|
|
}
|
|
|
|
if success {
|
|
logger.Debugf("Successfully refreshed module: %s", moduleName)
|
|
}
|
|
|
|
return success
|
|
}
|
|
|
|
// RunScriptFile loads, compiles and executes a Lua script file
|
|
func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
|
|
if !r.isRunning.Load() {
|
|
return nil, ErrRunnerClosed
|
|
}
|
|
|
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
|
return nil, fmt.Errorf("script file not found: %s", filePath)
|
|
}
|
|
|
|
content, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read file: %w", err)
|
|
}
|
|
|
|
absPath, err := filepath.Abs(filePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get absolute path: %w", err)
|
|
}
|
|
scriptDir := filepath.Dir(absPath)
|
|
|
|
r.mu.Lock()
|
|
prevScriptDir := r.scriptDir
|
|
r.scriptDir = scriptDir
|
|
r.moduleLoader.SetScriptDir(scriptDir)
|
|
r.mu.Unlock()
|
|
|
|
defer func() {
|
|
r.mu.Lock()
|
|
r.scriptDir = prevScriptDir
|
|
r.moduleLoader.SetScriptDir(prevScriptDir)
|
|
r.mu.Unlock()
|
|
}()
|
|
|
|
// Get state from pool
|
|
var stateIndex int
|
|
select {
|
|
case stateIndex = <-r.statePool:
|
|
case <-time.After(5 * time.Second):
|
|
return nil, ErrTimeout
|
|
}
|
|
|
|
state := r.states[stateIndex]
|
|
if state == nil {
|
|
r.statePool <- stateIndex
|
|
return nil, ErrStateNotReady
|
|
}
|
|
|
|
state.inUse.Store(true)
|
|
|
|
defer func() {
|
|
state.inUse.Store(false)
|
|
if r.isRunning.Load() {
|
|
select {
|
|
case r.statePool <- stateIndex:
|
|
default:
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Compile script
|
|
bytecode, err := state.L.CompileBytecode(string(content), filepath.Base(absPath))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("compilation error: %w", err)
|
|
}
|
|
|
|
// Create simple context for script execution
|
|
ctx := NewContext()
|
|
defer ctx.Release()
|
|
|
|
ctx.Set("_script_path", absPath)
|
|
ctx.Set("_script_dir", scriptDir)
|
|
|
|
// Execute script
|
|
response, err := state.sandbox.Execute(state.L, bytecode, ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("execution error: %w", err)
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ApplyResponse applies a Response to a fasthttp.RequestCtx
|
|
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
|
|
// Set status code
|
|
ctx.SetStatusCode(resp.Status)
|
|
|
|
// Set headers
|
|
for name, value := range resp.Headers {
|
|
ctx.Response.Header.Set(name, value)
|
|
}
|
|
|
|
// Set cookies
|
|
for _, cookie := range resp.Cookies {
|
|
ctx.Response.Header.SetCookie(cookie)
|
|
}
|
|
|
|
// Process the body based on its type
|
|
if resp.Body == nil {
|
|
return
|
|
}
|
|
|
|
// Check if Content-Type was manually set
|
|
contentTypeSet := false
|
|
for name := range resp.Headers {
|
|
if strings.ToLower(name) == "content-type" {
|
|
contentTypeSet = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// Get a buffer from the pool
|
|
buf := bytebufferpool.Get()
|
|
defer bytebufferpool.Put(buf)
|
|
|
|
// Set body based on type
|
|
switch body := resp.Body.(type) {
|
|
case string:
|
|
if !contentTypeSet {
|
|
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
|
}
|
|
ctx.SetBodyString(body)
|
|
case []byte:
|
|
if !contentTypeSet {
|
|
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
|
}
|
|
ctx.SetBody(body)
|
|
case map[string]any, map[any]any, []any, []float64, []string, []int, []map[string]any:
|
|
// Marshal JSON
|
|
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
|
if !contentTypeSet {
|
|
ctx.Response.Header.SetContentType("application/json")
|
|
}
|
|
ctx.SetBody(buf.Bytes())
|
|
} else {
|
|
// Fallback to string representation
|
|
if !contentTypeSet {
|
|
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
|
}
|
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
}
|
|
default:
|
|
// Check if it's any other map or slice type
|
|
typeStr := fmt.Sprintf("%T", body)
|
|
if typeStr[0] == '[' || (len(typeStr) > 3 && typeStr[:3] == "map") {
|
|
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
|
if !contentTypeSet {
|
|
ctx.Response.Header.SetContentType("application/json")
|
|
}
|
|
ctx.SetBody(buf.Bytes())
|
|
} else {
|
|
if !contentTypeSet {
|
|
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
|
}
|
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
}
|
|
} else {
|
|
// Default to string representation
|
|
if !contentTypeSet {
|
|
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
|
}
|
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
}
|
|
}
|
|
}
|