package runner import ( "errors" "fmt" "os" "path/filepath" "runtime" "strings" "sync" "sync/atomic" "time" "Moonshark/router" "Moonshark/runner/lualibs" "Moonshark/runner/sqlite" "Moonshark/sessions" "Moonshark/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" "github.com/goccy/go-json" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) var emptyMap = make(map[string]any) var ( ErrRunnerClosed = errors.New("lua runner is closed") ErrTimeout = errors.New("operation timed out") ErrStateNotReady = errors.New("lua state not ready") ) type State struct { L *luajit.State sandbox *Sandbox index int inUse atomic.Bool } type Runner struct { states []*State statePool chan int poolSize int moduleLoader *ModuleLoader isRunning atomic.Bool mu sync.RWMutex scriptDir string // Pre-allocated pools for HTTP processing ctxPool sync.Pool paramsPool sync.Pool } func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner, error) { if poolSize <= 0 { poolSize = runtime.GOMAXPROCS(0) } // Configure module loader with lib directories moduleConfig := &ModuleConfig{ LibDirs: libDirs, } r := &Runner{ poolSize: poolSize, moduleLoader: NewModuleLoader(moduleConfig), ctxPool: sync.Pool{ New: func() any { return make(map[string]any, 8) }, }, paramsPool: sync.Pool{ New: func() any { return make(map[string]any, 4) }, }, } sqlite.InitSQLite(dataDir) lualibs.InitFS(fsDir) sqlite.SetSQLitePoolSize(poolSize) r.states = make([]*State, poolSize) r.statePool = make(chan int, poolSize) if err := r.initStates(); err != nil { sqlite.CleanupSQLite() return nil, err } r.isRunning.Store(true) return r, nil } // Single entry point for HTTP execution func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) (*Response, error) { if !r.isRunning.Load() { return nil, ErrRunnerClosed } // Get state with timeout var stateIndex int select { case stateIndex = <-r.statePool: case <-time.After(time.Second): return nil, ErrTimeout } state := r.states[stateIndex] state.inUse.Store(true) defer func() { state.inUse.Store(false) if r.isRunning.Load() { select { case r.statePool <- stateIndex: default: } } }() // Build Lua context directly from HTTP request luaCtx := r.buildHTTPContext(httpCtx, params, session) defer r.releaseHTTPContext(luaCtx) return state.sandbox.Execute(state.L, bytecode, luaCtx) } // Build Lua context from HTTP request func (r *Runner) buildHTTPContext(ctx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *Context { luaCtx := NewContext() // Basic request info luaCtx.Set("method", string(ctx.Method())) luaCtx.Set("path", string(ctx.Path())) luaCtx.Set("host", string(ctx.Host())) // Headers headers := r.ctxPool.Get().(map[string]any) ctx.Request.Header.VisitAll(func(key, value []byte) { headers[string(key)] = string(value) }) luaCtx.Set("headers", headers) // Cookies cookies := r.ctxPool.Get().(map[string]any) ctx.Request.Header.VisitAllCookie(func(key, value []byte) { cookies[string(key)] = string(value) }) luaCtx.Set("cookies", cookies) // Route parameters if params != nil && len(params.Keys) > 0 { paramMap := r.paramsPool.Get().(map[string]any) for i, key := range params.Keys { if i < len(params.Values) { paramMap[key] = params.Values[i] } } luaCtx.Set("params", paramMap) } else { luaCtx.Set("params", emptyMap) } // Form data for POST/PUT/PATCH method := ctx.Method() if string(method) == "POST" || string(method) == "PUT" || string(method) == "PATCH" { if formData := parseForm(ctx); formData != nil { luaCtx.Set("form", formData) } else { luaCtx.Set("form", emptyMap) } } else { luaCtx.Set("form", emptyMap) } // Session data sessionMap := r.ctxPool.Get().(map[string]any) session.AdvanceFlash() sessionMap["id"] = session.ID if !session.IsEmpty() { sessionMap["data"] = session.GetAll() sessionMap["flash"] = session.GetAllFlash() } else { sessionMap["data"] = emptyMap sessionMap["flash"] = emptyMap } luaCtx.Set("session", sessionMap) // Environment variables if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil { luaCtx.Set("env", envMgr.GetAll()) } return luaCtx } // Releases the HTTP context's maps back to their pool func (r *Runner) releaseHTTPContext(luaCtx *Context) { if headers, ok := luaCtx.Get("headers").(map[string]any); ok { for k := range headers { delete(headers, k) } r.ctxPool.Put(headers) } if cookies, ok := luaCtx.Get("cookies").(map[string]any); ok { for k := range cookies { delete(cookies, k) } r.ctxPool.Put(cookies) } if params, ok := luaCtx.Get("params").(map[string]any); ok && len(params) > 0 { for k := range params { delete(params, k) } r.paramsPool.Put(params) } if sessionMap, ok := luaCtx.Get("session").(map[string]any); ok { for k := range sessionMap { delete(sessionMap, k) } r.ctxPool.Put(sessionMap) } luaCtx.Release() } func (r *Runner) initStates() error { logger.Infof("[LuaRunner] Creating %d states...", r.poolSize) for i := range r.poolSize { state, err := r.createState(i) if err != nil { return err } r.states[i] = state r.statePool <- i } return nil } func (r *Runner) createState(index int) (*State, error) { L := luajit.New(true) if L == nil { return nil, errors.New("failed to create Lua state") } sb := NewSandbox() if err := sb.Setup(L, index, index == 0); err != nil { L.Cleanup() L.Close() return nil, err } if err := r.moduleLoader.SetupRequire(L); err != nil { L.Cleanup() L.Close() return nil, err } if err := r.moduleLoader.PreloadModules(L); err != nil { L.Cleanup() L.Close() return nil, err } return &State{L: L, sandbox: sb, index: index}, nil } func (r *Runner) Close() error { r.mu.Lock() defer r.mu.Unlock() if !r.isRunning.Load() { return ErrRunnerClosed } r.isRunning.Store(false) // Drain pool for { select { case <-r.statePool: default: goto cleanup } } cleanup: // Wait for states to finish timeout := time.Now().Add(10 * time.Second) for time.Now().Before(timeout) { allIdle := true for _, state := range r.states { if state != nil && state.inUse.Load() { allIdle = false break } } if allIdle { break } time.Sleep(10 * time.Millisecond) } // Close states for i, state := range r.states { if state != nil { state.L.Cleanup() state.L.Close() r.states[i] = nil } } lualibs.CleanupFS() sqlite.CleanupSQLite() return nil } // parseForm extracts form data from HTTP request func parseForm(ctx *fasthttp.RequestCtx) map[string]any { form := make(map[string]any) // Parse POST form data ctx.PostArgs().VisitAll(func(key, value []byte) { form[string(key)] = string(value) }) // Parse multipart form if present if multipartForm, err := ctx.MultipartForm(); err == nil { for key, values := range multipartForm.Value { if len(values) == 1 { form[key] = values[0] } else { form[key] = values } } } if len(form) == 0 { return nil } return form } // NotifyFileChanged alerts the runner about file changes func (r *Runner) NotifyFileChanged(filePath string) bool { logger.Debugf("Runner notified of file change: %s", filePath) module, isModule := r.moduleLoader.GetModuleByPath(filePath) if isModule { logger.Debugf("Refreshing module: %s", module) return r.RefreshModule(module) } logger.Debugf("File change noted but no refresh needed: %s", filePath) return true } // RefreshModule refreshes a specific module across all states func (r *Runner) RefreshModule(moduleName string) bool { r.mu.RLock() defer r.mu.RUnlock() if !r.isRunning.Load() { return false } logger.Debugf("Refreshing module: %s", moduleName) success := true for _, state := range r.states { if state == nil || state.inUse.Load() { continue } if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil { success = false logger.Debugf("Failed to refresh module %s in state %d: %v", moduleName, state.index, err) } } if success { logger.Debugf("Successfully refreshed module: %s", moduleName) } return success } // RunScriptFile loads, compiles and executes a Lua script file func (r *Runner) RunScriptFile(filePath string) (*Response, error) { if !r.isRunning.Load() { return nil, ErrRunnerClosed } if _, err := os.Stat(filePath); os.IsNotExist(err) { return nil, fmt.Errorf("script file not found: %s", filePath) } content, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } absPath, err := filepath.Abs(filePath) if err != nil { return nil, fmt.Errorf("failed to get absolute path: %w", err) } scriptDir := filepath.Dir(absPath) r.mu.Lock() prevScriptDir := r.scriptDir r.scriptDir = scriptDir r.moduleLoader.SetScriptDir(scriptDir) r.mu.Unlock() defer func() { r.mu.Lock() r.scriptDir = prevScriptDir r.moduleLoader.SetScriptDir(prevScriptDir) r.mu.Unlock() }() // Get state from pool var stateIndex int select { case stateIndex = <-r.statePool: case <-time.After(5 * time.Second): return nil, ErrTimeout } state := r.states[stateIndex] if state == nil { r.statePool <- stateIndex return nil, ErrStateNotReady } state.inUse.Store(true) defer func() { state.inUse.Store(false) if r.isRunning.Load() { select { case r.statePool <- stateIndex: default: } } }() // Compile script bytecode, err := state.L.CompileBytecode(string(content), filepath.Base(absPath)) if err != nil { return nil, fmt.Errorf("compilation error: %w", err) } // Create simple context for script execution ctx := NewContext() defer ctx.Release() ctx.Set("_script_path", absPath) ctx.Set("_script_dir", scriptDir) // Execute script response, err := state.sandbox.Execute(state.L, bytecode, ctx) if err != nil { return nil, fmt.Errorf("execution error: %w", err) } return response, nil } // ApplyResponse applies a Response to a fasthttp.RequestCtx func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) { // Set status code ctx.SetStatusCode(resp.Status) // Set headers for name, value := range resp.Headers { ctx.Response.Header.Set(name, value) } // Set cookies for _, cookie := range resp.Cookies { ctx.Response.Header.SetCookie(cookie) } // Process the body based on its type if resp.Body == nil { return } // Check if Content-Type was manually set contentTypeSet := false for name := range resp.Headers { if strings.ToLower(name) == "content-type" { contentTypeSet = true break } } // Get a buffer from the pool buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) // Set body based on type switch body := resp.Body.(type) { case string: if !contentTypeSet { ctx.Response.Header.SetContentType("text/plain; charset=utf-8") } ctx.SetBodyString(body) case []byte: if !contentTypeSet { ctx.Response.Header.SetContentType("text/plain; charset=utf-8") } ctx.SetBody(body) case map[string]any, map[any]any, []any, []float64, []string, []int, []map[string]any: // Marshal JSON if err := json.NewEncoder(buf).Encode(body); err == nil { if !contentTypeSet { ctx.Response.Header.SetContentType("application/json") } ctx.SetBody(buf.Bytes()) } else { // Fallback to string representation if !contentTypeSet { ctx.Response.Header.SetContentType("text/plain; charset=utf-8") } ctx.SetBodyString(fmt.Sprintf("%v", body)) } default: // Check if it's any other map or slice type typeStr := fmt.Sprintf("%T", body) if typeStr[0] == '[' || (len(typeStr) > 3 && typeStr[:3] == "map") { if err := json.NewEncoder(buf).Encode(body); err == nil { if !contentTypeSet { ctx.Response.Header.SetContentType("application/json") } ctx.SetBody(buf.Bytes()) } else { if !contentTypeSet { ctx.Response.Header.SetContentType("text/plain; charset=utf-8") } ctx.SetBodyString(fmt.Sprintf("%v", body)) } } else { // Default to string representation if !contentTypeSet { ctx.Response.Header.SetContentType("text/plain; charset=utf-8") } ctx.SetBodyString(fmt.Sprintf("%v", body)) } } }