runner cleanup 2
This commit is contained in:
parent
ec7dcce788
commit
8e511c5dc9
|
@ -3,8 +3,6 @@ package runner
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"maps"
|
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
@ -14,9 +12,6 @@ type Context struct {
|
||||||
// Values stores any context values (route params, HTTP request info, etc.)
|
// Values stores any context values (route params, HTTP request info, etc.)
|
||||||
Values map[string]any
|
Values map[string]any
|
||||||
|
|
||||||
// internal mutex for concurrent access
|
|
||||||
mu sync.RWMutex
|
|
||||||
|
|
||||||
// FastHTTP context if this was created from an HTTP request
|
// FastHTTP context if this was created from an HTTP request
|
||||||
RequestCtx *fasthttp.RequestCtx
|
RequestCtx *fasthttp.RequestCtx
|
||||||
|
|
||||||
|
@ -28,7 +23,7 @@ type Context struct {
|
||||||
var contextPool = sync.Pool{
|
var contextPool = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
return &Context{
|
return &Context{
|
||||||
Values: make(map[string]any, 16), // Pre-allocate with reasonable capacity
|
Values: make(map[string]any, 16),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -47,9 +42,6 @@ func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
|
||||||
|
|
||||||
// Release returns the context to the pool after clearing its values
|
// Release returns the context to the pool after clearing its values
|
||||||
func (c *Context) Release() {
|
func (c *Context) Release() {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
// Clear all values to prevent data leakage
|
// Clear all values to prevent data leakage
|
||||||
for k := range c.Values {
|
for k := range c.Values {
|
||||||
delete(c.Values, k)
|
delete(c.Values, k)
|
||||||
|
@ -69,9 +61,6 @@ func (c *Context) Release() {
|
||||||
|
|
||||||
// GetBuffer returns a byte buffer for efficient string operations
|
// GetBuffer returns a byte buffer for efficient string operations
|
||||||
func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer {
|
func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
if c.buffer == nil {
|
if c.buffer == nil {
|
||||||
c.buffer = bytebufferpool.Get()
|
c.buffer = bytebufferpool.Get()
|
||||||
}
|
}
|
||||||
|
@ -80,49 +69,21 @@ func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer {
|
||||||
|
|
||||||
// Set adds a value to the context
|
// Set adds a value to the context
|
||||||
func (c *Context) Set(key string, value any) {
|
func (c *Context) Set(key string, value any) {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
c.Values[key] = value
|
c.Values[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a value from the context
|
// Get retrieves a value from the context
|
||||||
func (c *Context) Get(key string) any {
|
func (c *Context) Get(key string) any {
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
return c.Values[key]
|
return c.Values[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains checks if a key exists in the context
|
// Contains checks if a key exists in the context
|
||||||
func (c *Context) Contains(key string) bool {
|
func (c *Context) Contains(key string) bool {
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
_, exists := c.Values[key]
|
_, exists := c.Values[key]
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a value from the context
|
// Delete removes a value from the context
|
||||||
func (c *Context) Delete(key string) {
|
func (c *Context) Delete(key string) {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
delete(c.Values, key)
|
delete(c.Values, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// All returns a copy of all values in the context
|
|
||||||
func (c *Context) All() map[string]any {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
result := make(map[string]any, len(c.Values))
|
|
||||||
maps.Copy(result, c.Values)
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsHTTPRequest returns true if this context contains a fasthttp RequestCtx
|
|
||||||
func (c *Context) IsHTTPRequest() bool {
|
|
||||||
return c.RequestCtx != nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,77 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// extractCookie grabs cookies from the Lua state
|
|
||||||
func extractCookie(state *luajit.State) *fasthttp.Cookie {
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
|
|
||||||
// Get name
|
|
||||||
state.GetField(-1, "name")
|
|
||||||
if !state.IsString(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
fasthttp.ReleaseCookie(cookie)
|
|
||||||
return nil // Name is required
|
|
||||||
}
|
|
||||||
cookie.SetKey(state.ToString(-1))
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get value
|
|
||||||
state.GetField(-1, "value")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetValue(state.ToString(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get path
|
|
||||||
state.GetField(-1, "path")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetPath(state.ToString(-1))
|
|
||||||
} else {
|
|
||||||
cookie.SetPath("/") // Default path
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get domain
|
|
||||||
state.GetField(-1, "domain")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetDomain(state.ToString(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get expires
|
|
||||||
state.GetField(-1, "expires")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
expiry := int64(state.ToNumber(-1))
|
|
||||||
cookie.SetExpire(time.Unix(expiry, 0))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get max age
|
|
||||||
state.GetField(-1, "max_age")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get secure
|
|
||||||
state.GetField(-1, "secure")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetSecure(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get http only
|
|
||||||
state.GetField(-1, "http_only")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
return cookie
|
|
||||||
}
|
|
|
@ -10,12 +10,9 @@ import (
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StateInitFunc is a function that initializes a module in a Lua state
|
|
||||||
type StateInitFunc func(*luajit.State) error
|
|
||||||
|
|
||||||
// CoreModuleRegistry manages the initialization and reloading of core modules
|
// CoreModuleRegistry manages the initialization and reloading of core modules
|
||||||
type CoreModuleRegistry struct {
|
type CoreModuleRegistry struct {
|
||||||
modules map[string]StateInitFunc // Module initializers
|
modules map[string]sandbox.StateInitFunc // Module initializers
|
||||||
initOrder []string // Explicit initialization order
|
initOrder []string // Explicit initialization order
|
||||||
dependencies map[string][]string // Module dependencies
|
dependencies map[string][]string // Module dependencies
|
||||||
initializedFlag map[string]bool // Track which modules are initialized
|
initializedFlag map[string]bool // Track which modules are initialized
|
||||||
|
@ -26,7 +23,7 @@ type CoreModuleRegistry struct {
|
||||||
// NewCoreModuleRegistry creates a new core module registry
|
// NewCoreModuleRegistry creates a new core module registry
|
||||||
func NewCoreModuleRegistry() *CoreModuleRegistry {
|
func NewCoreModuleRegistry() *CoreModuleRegistry {
|
||||||
return &CoreModuleRegistry{
|
return &CoreModuleRegistry{
|
||||||
modules: make(map[string]StateInitFunc),
|
modules: make(map[string]sandbox.StateInitFunc),
|
||||||
initOrder: []string{},
|
initOrder: []string{},
|
||||||
dependencies: make(map[string][]string),
|
dependencies: make(map[string][]string),
|
||||||
initializedFlag: make(map[string]bool),
|
initializedFlag: make(map[string]bool),
|
||||||
|
@ -46,15 +43,8 @@ func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// debugLogCont prints continuation debug messages if enabled
|
|
||||||
func (r *CoreModuleRegistry) debugLogCont(format string, args ...interface{}) {
|
|
||||||
if r.debug {
|
|
||||||
logger.DebugCont(format, args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register adds a module to the registry
|
// Register adds a module to the registry
|
||||||
func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) {
|
func (r *CoreModuleRegistry) Register(name string, initFunc sandbox.StateInitFunc) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
@ -63,7 +53,7 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) {
|
||||||
// Add to initialization order if not already there
|
// Add to initialization order if not already there
|
||||||
for _, n := range r.initOrder {
|
for _, n := range r.initOrder {
|
||||||
if n == name {
|
if n == name {
|
||||||
return // Already registered, silently continue
|
return // Already registered
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +62,7 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWithDependencies registers a module with explicit dependencies
|
// RegisterWithDependencies registers a module with explicit dependencies
|
||||||
func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc StateInitFunc, dependencies []string) {
|
func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
@ -82,15 +72,12 @@ func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc Stat
|
||||||
// Add to initialization order if not already there
|
// Add to initialization order if not already there
|
||||||
for _, n := range r.initOrder {
|
for _, n := range r.initOrder {
|
||||||
if n == name {
|
if n == name {
|
||||||
return // Already registered, silently continue
|
return // Already registered
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.initOrder = append(r.initOrder, name)
|
r.initOrder = append(r.initOrder, name)
|
||||||
r.debugLog("registered module %s", name)
|
r.debugLog("registered module %s with dependencies: %v", name, dependencies)
|
||||||
if len(dependencies) > 0 {
|
|
||||||
r.debugLogCont("Dependencies: %v", dependencies)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInitOrder sets explicit initialization order
|
// SetInitOrder sets explicit initialization order
|
||||||
|
@ -103,34 +90,14 @@ func (r *CoreModuleRegistry) SetInitOrder(order []string) {
|
||||||
|
|
||||||
// First add all known modules that are in the specified order
|
// First add all known modules that are in the specified order
|
||||||
for _, name := range order {
|
for _, name := range order {
|
||||||
if _, exists := r.modules[name]; exists {
|
if _, exists := r.modules[name]; exists && !contains(newOrder, name) {
|
||||||
// Check for duplicates
|
|
||||||
isDuplicate := false
|
|
||||||
for _, existing := range newOrder {
|
|
||||||
if existing == name {
|
|
||||||
isDuplicate = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isDuplicate {
|
|
||||||
newOrder = append(newOrder, name)
|
newOrder = append(newOrder, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Then add any modules not in the specified order
|
// Then add any modules not in the specified order
|
||||||
for name := range r.modules {
|
for name := range r.modules {
|
||||||
// Check if module already in the new order
|
if !contains(newOrder, name) {
|
||||||
found := false
|
|
||||||
for _, n := range newOrder {
|
|
||||||
if n == name {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
newOrder = append(newOrder, name)
|
newOrder = append(newOrder, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,7 +127,7 @@ func (r *CoreModuleRegistry) Initialize(state *luajit.State, stateIndex int) err
|
||||||
}
|
}
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
r.debugLogCont("All modules initialized successfully")
|
r.debugLog("All modules initialized successfully")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -201,14 +168,14 @@ func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string,
|
||||||
err := initFunc(state)
|
err := initFunc(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Always log failures regardless of verbose setting
|
// Always log failures regardless of verbose setting
|
||||||
r.debugLogCont("Initializing module %s... failure: %v", name, err)
|
r.debugLog("Initializing module %s... failure: %v", name, err)
|
||||||
return fmt.Errorf("failed to initialize module %s: %w", name, err)
|
return fmt.Errorf("failed to initialize module %s: %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.initializedFlag[name] = true
|
r.initializedFlag[name] = true
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
r.debugLogCont("Initializing module %s... success", name)
|
r.debugLog("Initializing module %s... success", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -226,18 +193,6 @@ func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string)
|
||||||
return r.initializeModule(state, name, []string{}, true)
|
return r.initializeModule(state, name, []string{}, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModuleNames returns a list of all registered module names
|
|
||||||
func (r *CoreModuleRegistry) ModuleNames() []string {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
names := make([]string, 0, len(r.modules))
|
|
||||||
for name := range r.modules {
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// MatchModuleName checks if a file path corresponds to a registered module
|
// MatchModuleName checks if a file path corresponds to a registered module
|
||||||
func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) {
|
func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
|
@ -266,7 +221,7 @@ func init() {
|
||||||
GlobalRegistry.EnableDebug() // Enable debugging by default
|
GlobalRegistry.EnableDebug() // Enable debugging by default
|
||||||
logger.Debug("[ModuleRegistry] Registering core modules...")
|
logger.Debug("[ModuleRegistry] Registering core modules...")
|
||||||
|
|
||||||
// Register core modules - these now point to the sandbox implementations
|
// Register core modules
|
||||||
GlobalRegistry.Register("util", func(state *luajit.State) error {
|
GlobalRegistry.Register("util", func(state *luajit.State) error {
|
||||||
return sandbox.UtilModuleInitFunc()(state)
|
return sandbox.UtilModuleInitFunc()(state)
|
||||||
})
|
})
|
||||||
|
@ -283,15 +238,25 @@ func init() {
|
||||||
"csrf", // Fourth: CSRF protection
|
"csrf", // Fourth: CSRF protection
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.DebugCont("Core modules registered successfully")
|
logger.Debug("Core modules registered successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterCoreModule is a helper to register a core module with the global registry
|
// RegisterCoreModule registers a core module with the global registry
|
||||||
func RegisterCoreModule(name string, initFunc StateInitFunc) {
|
func RegisterCoreModule(name string, initFunc sandbox.StateInitFunc) {
|
||||||
GlobalRegistry.Register(name, initFunc)
|
GlobalRegistry.Register(name, initFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterCoreModuleWithDependencies registers a module with dependencies
|
// RegisterCoreModuleWithDependencies registers a module with dependencies
|
||||||
func RegisterCoreModuleWithDependencies(name string, initFunc StateInitFunc, dependencies []string) {
|
func RegisterCoreModuleWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
|
||||||
GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies)
|
GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
func contains(slice []string, item string) bool {
|
||||||
|
for _, s := range slice {
|
||||||
|
if s == item {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -1,86 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ModuleFunc is a function that returns a map of module functions
|
|
||||||
type ModuleFunc func() map[string]luajit.GoFunction
|
|
||||||
|
|
||||||
// RegisterModule registers a map of functions as a Lua module
|
|
||||||
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
|
|
||||||
// Create a new table for the module
|
|
||||||
state.NewTable()
|
|
||||||
|
|
||||||
// Add each function to the module table
|
|
||||||
for fname, f := range funcs {
|
|
||||||
// Push function name
|
|
||||||
state.PushString(fname)
|
|
||||||
|
|
||||||
// Push function
|
|
||||||
if err := state.PushGoFunction(f); err != nil {
|
|
||||||
state.Pop(1) // Pop table
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set table[fname] = f
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register the module globally
|
|
||||||
state.SetGlobal(name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModuleInitFunc creates a state initializer that registers multiple modules
|
|
||||||
func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
for name, moduleFunc := range modules {
|
|
||||||
if err := RegisterModule(state, name, moduleFunc()); err != nil {
|
|
||||||
logger.Error("Failed to register module %s: %v", name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CombineInitFuncs combines multiple state initializer functions into one
|
|
||||||
func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
for _, f := range funcs {
|
|
||||||
if f != nil {
|
|
||||||
if err := f(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterLuaCode registers a Lua code snippet as a module
|
|
||||||
func RegisterLuaCode(state *luajit.State, code string) error {
|
|
||||||
return state.DoString(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
|
|
||||||
func RegisterLuaCodeInitFunc(code string) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
return RegisterLuaCode(state, code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
|
|
||||||
func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
// Create name = {} global
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal(name)
|
|
||||||
|
|
||||||
// Then run the module code which will populate it
|
|
||||||
return state.DoString(code)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,15 +3,13 @@ package runner
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/panjf2000/ants/v2"
|
|
||||||
"github.com/valyala/bytebufferpool"
|
|
||||||
|
|
||||||
"Moonshark/core/runner/sandbox"
|
"Moonshark/core/runner/sandbox"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
@ -35,7 +33,6 @@ type State struct {
|
||||||
sandbox *sandbox.Sandbox // Associated sandbox
|
sandbox *sandbox.Sandbox // Associated sandbox
|
||||||
index int // Index for debugging
|
index int // Index for debugging
|
||||||
inUse bool // Whether the state is currently in use
|
inUse bool // Whether the state is currently in use
|
||||||
initTime time.Time // When this state was initialized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitHook runs before executing a script
|
// InitHook runs before executing a script
|
||||||
|
@ -44,20 +41,6 @@ type InitHook func(*luajit.State, *Context) error
|
||||||
// FinalizeHook runs after executing a script
|
// FinalizeHook runs after executing a script
|
||||||
type FinalizeHook func(*luajit.State, *Context, any) error
|
type FinalizeHook func(*luajit.State, *Context, any) error
|
||||||
|
|
||||||
// ExecuteTask represents a task in the execution goroutine pool
|
|
||||||
type ExecuteTask struct {
|
|
||||||
bytecode []byte
|
|
||||||
context *Context
|
|
||||||
scriptPath string
|
|
||||||
result chan<- taskResult
|
|
||||||
}
|
|
||||||
|
|
||||||
// taskResult holds the result of an execution task
|
|
||||||
type taskResult struct {
|
|
||||||
value any
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Runner runs Lua scripts using a pool of Lua states
|
// Runner runs Lua scripts using a pool of Lua states
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
states []*State // All states managed by this runner
|
states []*State // All states managed by this runner
|
||||||
|
@ -70,7 +53,6 @@ type Runner struct {
|
||||||
initHooks []InitHook // Hooks run before script execution
|
initHooks []InitHook // Hooks run before script execution
|
||||||
finalizeHooks []FinalizeHook // Hooks run after script execution
|
finalizeHooks []FinalizeHook // Hooks run after script execution
|
||||||
scriptDir string // Current script directory
|
scriptDir string // Current script directory
|
||||||
pool *ants.Pool // Goroutine pool for task execution
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPoolSize sets the state pool size
|
// WithPoolSize sets the state pool size
|
||||||
|
@ -144,13 +126,6 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
|
||||||
runner.states = make([]*State, runner.poolSize)
|
runner.states = make([]*State, runner.poolSize)
|
||||||
runner.statePool = make(chan int, runner.poolSize)
|
runner.statePool = make(chan int, runner.poolSize)
|
||||||
|
|
||||||
// Create ants goroutine pool
|
|
||||||
var err error
|
|
||||||
runner.pool, err = ants.NewPool(runner.poolSize * 2)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create and initialize all states
|
// Create and initialize all states
|
||||||
if err := runner.initializeStates(); err != nil {
|
if err := runner.initializeStates(); err != nil {
|
||||||
runner.Close() // Clean up already created states
|
runner.Close() // Clean up already created states
|
||||||
|
@ -168,28 +143,12 @@ func (r *Runner) debugLog(format string, args ...interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) debugLogCont(format string, args ...interface{}) {
|
|
||||||
if r.debug {
|
|
||||||
logger.DebugCont(format, args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initializeStates creates and initializes all states in the pool
|
// initializeStates creates and initializes all states in the pool
|
||||||
func (r *Runner) initializeStates() error {
|
func (r *Runner) initializeStates() error {
|
||||||
r.debugLog("is initializing %d states", r.poolSize)
|
r.debugLog("is initializing %d states", r.poolSize)
|
||||||
|
|
||||||
// Create main template state first with full logging
|
// Create all states
|
||||||
templateState, err := r.createState(0)
|
for i := 0; i < r.poolSize; i++ {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r.states[0] = templateState
|
|
||||||
r.statePool <- 0 // Add index to the pool
|
|
||||||
|
|
||||||
// Create remaining states with minimal logging
|
|
||||||
successCount := 1
|
|
||||||
for i := 1; i < r.poolSize; i++ {
|
|
||||||
state, err := r.createState(i)
|
state, err := r.createState(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -197,10 +156,8 @@ func (r *Runner) initializeStates() error {
|
||||||
|
|
||||||
r.states[i] = state
|
r.states[i] = state
|
||||||
r.statePool <- i // Add index to the pool
|
r.statePool <- i // Add index to the pool
|
||||||
successCount++
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.debugLog("has built %d/%d states successfully", successCount, r.poolSize)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,16 +175,13 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create sandbox
|
// Create sandbox
|
||||||
sandbox := sandbox.NewSandbox()
|
sb := sandbox.NewSandbox()
|
||||||
if r.debug && verbose {
|
if r.debug && verbose {
|
||||||
sandbox.EnableDebug()
|
sb.EnableDebug()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up require system
|
// Set up require system
|
||||||
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
||||||
if verbose {
|
|
||||||
r.debugLogCont("Failed to set up require for state %d: %v", index, err)
|
|
||||||
}
|
|
||||||
L.Cleanup()
|
L.Cleanup()
|
||||||
L.Close()
|
L.Close()
|
||||||
return nil, ErrInitFailed
|
return nil, ErrInitFailed
|
||||||
|
@ -235,19 +189,13 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||||
|
|
||||||
// Initialize all core modules from the registry
|
// Initialize all core modules from the registry
|
||||||
if err := GlobalRegistry.Initialize(L, index); err != nil {
|
if err := GlobalRegistry.Initialize(L, index); err != nil {
|
||||||
if verbose {
|
|
||||||
r.debugLogCont("Failed to initialize core modules for state %d: %v", index, err)
|
|
||||||
}
|
|
||||||
L.Cleanup()
|
L.Cleanup()
|
||||||
L.Close()
|
L.Close()
|
||||||
return nil, ErrInitFailed
|
return nil, ErrInitFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up sandbox after core modules are initialized
|
// Set up sandbox after core modules are initialized
|
||||||
if err := sandbox.Setup(L, index); err != nil {
|
if err := sb.Setup(L, index); err != nil {
|
||||||
if verbose {
|
|
||||||
r.debugLogCont("Failed to set up sandbox for state %d: %v", index, err)
|
|
||||||
}
|
|
||||||
L.Cleanup()
|
L.Cleanup()
|
||||||
L.Close()
|
L.Close()
|
||||||
return nil, ErrInitFailed
|
return nil, ErrInitFailed
|
||||||
|
@ -255,63 +203,49 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||||
|
|
||||||
// Preload all modules
|
// Preload all modules
|
||||||
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
||||||
if verbose {
|
|
||||||
r.debugLogCont("Failed to preload modules for state %d: %v", index, err)
|
|
||||||
}
|
|
||||||
L.Cleanup()
|
L.Cleanup()
|
||||||
L.Close()
|
L.Close()
|
||||||
return nil, errors.New("failed to preload modules")
|
return nil, errors.New("failed to preload modules")
|
||||||
}
|
}
|
||||||
|
|
||||||
state := &State{
|
return &State{
|
||||||
L: L,
|
L: L,
|
||||||
sandbox: sandbox,
|
sandbox: sb,
|
||||||
index: index,
|
index: index,
|
||||||
inUse: false,
|
inUse: false,
|
||||||
initTime: time.Now(),
|
}, nil
|
||||||
}
|
|
||||||
|
|
||||||
if verbose {
|
|
||||||
r.debugLog("State %d created successfully", index)
|
|
||||||
}
|
|
||||||
return state, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeTask is the worker function for the ants pool
|
// Execute runs a script with context
|
||||||
func (r *Runner) executeTask(i interface{}) {
|
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||||
task, ok := i.(*ExecuteTask)
|
if !r.isRunning.Load() {
|
||||||
if !ok {
|
return nil, ErrRunnerClosed
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set script directory if provided
|
// Set script directory if provided
|
||||||
if task.scriptPath != "" {
|
if scriptPath != "" {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
r.scriptDir = filepath.Dir(task.scriptPath)
|
r.scriptDir = filepath.Dir(scriptPath)
|
||||||
r.moduleLoader.SetScriptDir(r.scriptDir)
|
r.moduleLoader.SetScriptDir(r.scriptDir)
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a state index from the pool
|
// Get a state from the pool
|
||||||
var stateIndex int
|
var stateIndex int
|
||||||
select {
|
select {
|
||||||
case stateIndex = <-r.statePool:
|
case stateIndex = <-r.statePool:
|
||||||
// Got a state
|
// Got a state
|
||||||
case <-time.After(5 * time.Second): // 5-second timeout
|
case <-ctx.Done():
|
||||||
// Timed out waiting for a state
|
return nil, ctx.Err()
|
||||||
task.result <- taskResult{nil, errors.New("server busy - timed out waiting for a Lua state")}
|
case <-time.After(5 * time.Second):
|
||||||
return
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual state
|
// Get the actual state
|
||||||
r.mu.RLock()
|
|
||||||
state := r.states[stateIndex]
|
state := r.states[stateIndex]
|
||||||
r.mu.RUnlock()
|
|
||||||
|
|
||||||
if state == nil {
|
if state == nil {
|
||||||
r.statePool <- stateIndex
|
r.statePool <- stateIndex
|
||||||
task.result <- taskResult{nil, ErrStateNotReady}
|
return nil, ErrStateNotReady
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark state as in use
|
// Mark state as in use
|
||||||
|
@ -325,75 +259,69 @@ func (r *Runner) executeTask(i interface{}) {
|
||||||
case r.statePool <- stateIndex:
|
case r.statePool <- stateIndex:
|
||||||
// State returned to pool
|
// State returned to pool
|
||||||
default:
|
default:
|
||||||
// Pool is full or closed (shouldn't happen)
|
// Pool is full or closed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Copy hooks to avoid holding lock during execution
|
|
||||||
r.mu.RLock()
|
|
||||||
initHooks := make([]InitHook, len(r.initHooks))
|
|
||||||
copy(initHooks, r.initHooks)
|
|
||||||
finalizeHooks := make([]FinalizeHook, len(r.finalizeHooks))
|
|
||||||
copy(finalizeHooks, r.finalizeHooks)
|
|
||||||
r.mu.RUnlock()
|
|
||||||
|
|
||||||
// Run init hooks
|
// Run init hooks
|
||||||
for _, hook := range initHooks {
|
for _, hook := range r.initHooks {
|
||||||
if err := hook(state.L, task.context); err != nil {
|
if err := hook(state.L, execCtx); err != nil {
|
||||||
task.result <- taskResult{nil, err}
|
return nil, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare context values
|
// Get context values
|
||||||
var ctxValues map[string]any
|
var ctxValues map[string]any
|
||||||
if task.context != nil {
|
if execCtx != nil {
|
||||||
ctxValues = task.context.Values
|
ctxValues = execCtx.Values
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
result, err := state.sandbox.Execute(state.L, task.bytecode, ctxValues)
|
result, err := state.sandbox.Execute(state.L, bytecode, ctxValues)
|
||||||
|
if err != nil {
|
||||||
// Run finalize hooks
|
|
||||||
for _, hook := range finalizeHooks {
|
|
||||||
if hookErr := hook(state.L, task.context, result); hookErr != nil && err == nil {
|
|
||||||
err = hookErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
task.result <- taskResult{result, err}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute runs a script with context
|
|
||||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
|
||||||
if !r.isRunning.Load() {
|
|
||||||
return nil, ErrRunnerClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create result channel
|
|
||||||
resultChan := make(chan taskResult, 1)
|
|
||||||
|
|
||||||
// Create task
|
|
||||||
task := &ExecuteTask{
|
|
||||||
bytecode: bytecode,
|
|
||||||
context: execCtx,
|
|
||||||
scriptPath: scriptPath,
|
|
||||||
result: resultChan,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Submit task to pool
|
|
||||||
if err := r.pool.Submit(func() { r.executeTask(task) }); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for result with context timeout
|
// Run finalize hooks
|
||||||
select {
|
for _, hook := range r.finalizeHooks {
|
||||||
case result := <-resultChan:
|
if hookErr := hook(state.L, execCtx, result); hookErr != nil {
|
||||||
return result.value, result.err
|
return nil, hookErr
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for HTTP response
|
||||||
|
httpResp, hasResponse := sandbox.GetHTTPResponse(state.L)
|
||||||
|
if hasResponse {
|
||||||
|
// Set result as body if not already set
|
||||||
|
if httpResp.Body == nil {
|
||||||
|
httpResp.Body = result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply directly to request context if available
|
||||||
|
if execCtx != nil && execCtx.RequestCtx != nil {
|
||||||
|
sandbox.ApplyHTTPResponse(httpResp, execCtx.RequestCtx)
|
||||||
|
sandbox.ReleaseResponse(httpResp)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return httpResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle direct result if FastHTTP context is available
|
||||||
|
if execCtx != nil && execCtx.RequestCtx != nil && result != nil {
|
||||||
|
switch r := result.(type) {
|
||||||
|
case string:
|
||||||
|
execCtx.RequestCtx.SetBodyString(r)
|
||||||
|
case []byte:
|
||||||
|
execCtx.RequestCtx.SetBody(r)
|
||||||
|
default:
|
||||||
|
execCtx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run executes a Lua script (convenience wrapper)
|
// Run executes a Lua script (convenience wrapper)
|
||||||
|
@ -411,14 +339,19 @@ func (r *Runner) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.isRunning.Store(false)
|
r.isRunning.Store(false)
|
||||||
r.debugLog("Closing Runner and destroying all states")
|
|
||||||
|
|
||||||
// Shut down goroutine pool
|
|
||||||
r.pool.Release()
|
|
||||||
|
|
||||||
// Drain the state pool
|
// Drain the state pool
|
||||||
r.drainStatePool()
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.statePool:
|
||||||
|
// Drain one state
|
||||||
|
default:
|
||||||
|
// Pool is empty
|
||||||
|
goto cleanup
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
// Clean up all states
|
// Clean up all states
|
||||||
for i, state := range r.states {
|
for i, state := range r.states {
|
||||||
if state != nil {
|
if state != nil {
|
||||||
|
@ -431,19 +364,6 @@ func (r *Runner) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// drainStatePool removes all states from the pool
|
|
||||||
func (r *Runner) drainStatePool() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-r.statePool:
|
|
||||||
// Drain one state
|
|
||||||
default:
|
|
||||||
// Pool is empty
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshStates rebuilds all states in the pool
|
// RefreshStates rebuilds all states in the pool
|
||||||
func (r *Runner) RefreshStates() error {
|
func (r *Runner) RefreshStates() error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
|
@ -453,11 +373,18 @@ func (r *Runner) RefreshStates() error {
|
||||||
return ErrRunnerClosed
|
return ErrRunnerClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
r.debugLog("Refreshing all Lua states")
|
|
||||||
|
|
||||||
// Drain all states from the pool
|
// Drain all states from the pool
|
||||||
r.drainStatePool()
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.statePool:
|
||||||
|
// Drain one state
|
||||||
|
default:
|
||||||
|
// Pool is empty
|
||||||
|
goto cleanup
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
// Destroy all existing states
|
// Destroy all existing states
|
||||||
for i, state := range r.states {
|
for i, state := range r.states {
|
||||||
if state != nil {
|
if state != nil {
|
||||||
|
@ -479,7 +406,82 @@ func (r *Runner) RefreshStates() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotifyFileChanged handles file change notifications
|
// AddInitHook adds a hook to be called before script execution
|
||||||
|
func (r *Runner) AddInitHook(hook InitHook) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.initHooks = append(r.initHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFinalizeHook adds a hook to be called after script execution
|
||||||
|
func (r *Runner) AddFinalizeHook(hook FinalizeHook) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.finalizeHooks = append(r.finalizeHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStateCount returns the number of initialized states
|
||||||
|
func (r *Runner) GetStateCount() int {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for _, state := range r.states {
|
||||||
|
if state != nil {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveStateCount returns the number of states currently in use
|
||||||
|
func (r *Runner) GetActiveStateCount() int {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for _, state := range r.states {
|
||||||
|
if state != nil && state.inUse {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModuleCount returns the number of loaded modules in the first available state
|
||||||
|
func (r *Runner) GetModuleCount() int {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
if !r.isRunning.Load() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find first available state
|
||||||
|
for _, state := range r.states {
|
||||||
|
if state != nil && !state.inUse {
|
||||||
|
// Execute a Lua snippet to count modules
|
||||||
|
if res, err := state.L.ExecuteWithResult(`
|
||||||
|
local count = 0
|
||||||
|
for _ in pairs(package.loaded) do
|
||||||
|
count = count + 1
|
||||||
|
end
|
||||||
|
return count
|
||||||
|
`); err == nil {
|
||||||
|
if num, ok := res.(float64); ok {
|
||||||
|
return int(num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyFileChanged alerts the runner about file changes
|
||||||
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
||||||
r.debugLog("File change detected: %s", filePath)
|
r.debugLog("File change detected: %s", filePath)
|
||||||
|
|
||||||
|
@ -515,21 +517,12 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
||||||
|
|
||||||
success := true
|
success := true
|
||||||
for _, state := range r.states {
|
for _, state := range r.states {
|
||||||
if state == nil {
|
if state == nil || state.inUse {
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip states that are in use
|
|
||||||
if state.inUse {
|
|
||||||
r.debugLog("Skipping refresh for state %d (in use)", state.index)
|
|
||||||
success = false
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalidate module in Lua
|
// Invalidate module in Lua
|
||||||
if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil {
|
if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil {
|
||||||
r.debugLog("Failed to invalidate module %s in state %d: %v",
|
|
||||||
moduleName, state.index, err)
|
|
||||||
success = false
|
success = false
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -537,139 +530,10 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
||||||
// For core modules, reinitialize them
|
// For core modules, reinitialize them
|
||||||
if isCore {
|
if isCore {
|
||||||
if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil {
|
if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil {
|
||||||
r.debugLog("Failed to reinitialize core module %s in state %d: %v",
|
|
||||||
coreName, state.index, err)
|
|
||||||
success = false
|
success = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if success {
|
|
||||||
r.debugLog("Module %s refreshed successfully in all states", moduleName)
|
|
||||||
} else {
|
|
||||||
r.debugLog("Module %s refresh had some failures", moduleName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return success
|
return success
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddModule adds a module to all sandbox environments
|
|
||||||
func (r *Runner) AddModule(name string, module any) {
|
|
||||||
r.debugLog("Adding module %s to all sandboxes", name)
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil && state.sandbox != nil && !state.inUse {
|
|
||||||
state.sandbox.AddModule(name, module)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddInitHook adds a hook to be called before script execution
|
|
||||||
func (r *Runner) AddInitHook(hook InitHook) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.initHooks = append(r.initHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFinalizeHook adds a hook to be called after script execution
|
|
||||||
func (r *Runner) AddFinalizeHook(hook FinalizeHook) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.finalizeHooks = append(r.finalizeHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetModuleCache clears the module cache in all states
|
|
||||||
func (r *Runner) ResetModuleCache() {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
if !r.isRunning.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.debugLog("Resetting module cache in all states")
|
|
||||||
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil && !state.inUse {
|
|
||||||
r.moduleLoader.ResetModules(state.L)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStateCount returns the number of initialized states
|
|
||||||
func (r *Runner) GetStateCount() int {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetActiveStateCount returns the number of states currently in use
|
|
||||||
func (r *Runner) GetActiveStateCount() int {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil && state.inUse {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetWorkerPoolStats returns statistics about the worker pool
|
|
||||||
func (r *Runner) GetWorkerPoolStats() (running, capacity int) {
|
|
||||||
return r.pool.Running(), r.pool.Cap()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModuleCount returns the number of loaded modules in the first available state
|
|
||||||
func (r *Runner) GetModuleCount() int {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
if !r.isRunning.Load() {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find first available state
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil && !state.inUse {
|
|
||||||
// Execute a Lua snippet to count modules
|
|
||||||
if res, err := state.L.ExecuteWithResult(`
|
|
||||||
local count = 0
|
|
||||||
for _ in pairs(package.loaded) do
|
|
||||||
count = count + 1
|
|
||||||
end
|
|
||||||
return count
|
|
||||||
`); err == nil {
|
|
||||||
if num, ok := res.(float64); ok {
|
|
||||||
return int(num)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBufferPool returns a buffer from the bytebufferpool
|
|
||||||
func GetBufferPool() *bytebufferpool.ByteBuffer {
|
|
||||||
return bytebufferpool.Get()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReleaseBufferPool returns a buffer to the bytebufferpool
|
|
||||||
func ReleaseBufferPool(buf *bytebufferpool.ByteBuffer) {
|
|
||||||
bytebufferpool.Put(buf)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"Moonshark/core/runner/sandbox"
|
"Moonshark/core/runner/sandbox"
|
||||||
"Moonshark/core/sessions"
|
"Moonshark/core/sessions"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionHandler handles session management for Lua scripts
|
// SessionHandler handles session management for Lua scripts
|
||||||
|
@ -29,30 +28,17 @@ func (h *SessionHandler) EnableDebug() {
|
||||||
h.debugLog = true
|
h.debugLog = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// debug logs a message if debug is enabled
|
|
||||||
func (h *SessionHandler) debug(format string, args ...interface{}) {
|
|
||||||
if h.debugLog {
|
|
||||||
logger.Debug("[SessionHandler] "+format, args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithSessionManager creates a RunnerOption to add session support
|
// WithSessionManager creates a RunnerOption to add session support
|
||||||
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
|
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
|
||||||
return func(r *Runner) {
|
return func(r *Runner) {
|
||||||
handler := NewSessionHandler(manager)
|
handler := NewSessionHandler(manager)
|
||||||
|
|
||||||
// Add hooks to the runner
|
|
||||||
r.AddInitHook(handler.preRequestHook)
|
r.AddInitHook(handler.preRequestHook)
|
||||||
r.AddFinalizeHook(handler.postRequestHook)
|
r.AddFinalizeHook(handler.postRequestHook)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// preRequestHook is called before executing a request
|
// preRequestHook initializes session before script execution
|
||||||
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error {
|
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error {
|
||||||
h.debug("Running pre-request session hook")
|
|
||||||
|
|
||||||
// Check if we have cookie information in context
|
|
||||||
// Instead of raw request, we now look for the cookie map
|
|
||||||
if ctx == nil || ctx.Values["_request_cookies"] == nil {
|
if ctx == nil || ctx.Values["_request_cookies"] == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -71,43 +57,31 @@ func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error
|
||||||
if cookieValue, exists := cookies[cookieName]; exists {
|
if cookieValue, exists := cookies[cookieName]; exists {
|
||||||
if strValue, ok := cookieValue.(string); ok && strValue != "" {
|
if strValue, ok := cookieValue.(string); ok && strValue != "" {
|
||||||
sessionID = strValue
|
sessionID = strValue
|
||||||
h.debug("Found existing session ID: %s", sessionID)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no session ID found, create new session
|
// Create new session if needed
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
// Create a new session
|
|
||||||
session := h.manager.CreateSession()
|
session := h.manager.CreateSession()
|
||||||
sessionID = session.ID
|
sessionID = session.ID
|
||||||
h.debug("Created new session with ID: %s", sessionID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the session ID in the context for later use
|
// Store the session ID in the context
|
||||||
ctx.Set("_session_id", sessionID)
|
ctx.Set("_session_id", sessionID)
|
||||||
|
|
||||||
// Get the session data
|
// Get session data
|
||||||
session := h.manager.GetSession(sessionID)
|
session := h.manager.GetSession(sessionID)
|
||||||
sessionData := session.GetAll()
|
sessionData := session.GetAll()
|
||||||
|
|
||||||
// Set session data in Lua state
|
// Set session data in Lua state
|
||||||
if err := SetSessionData(state, sessionID, sessionData); err != nil {
|
return SetSessionData(state, sessionID, sessionData)
|
||||||
h.debug("Failed to set session data: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.debug("Session data initialized successfully")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// postRequestHook is called after executing a request
|
// postRequestHook handles session after script execution
|
||||||
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error {
|
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error {
|
||||||
h.debug("Running post-request session hook")
|
|
||||||
|
|
||||||
// Check if session was modified
|
// Check if session was modified
|
||||||
modifiedID, modifiedData, modified := GetSessionData(state)
|
modifiedID, modifiedData, modified := GetSessionData(state)
|
||||||
if !modified {
|
if !modified {
|
||||||
h.debug("Session not modified, skipping")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,12 +99,9 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
|
||||||
}
|
}
|
||||||
|
|
||||||
if modifiedID == "" {
|
if modifiedID == "" {
|
||||||
h.debug("No session ID found, cannot persist session data")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
h.debug("Persisting modified session data for ID: %s", modifiedID)
|
|
||||||
|
|
||||||
// Update session in manager
|
// Update session in manager
|
||||||
session := h.manager.GetSession(modifiedID)
|
session := h.manager.GetSession(modifiedID)
|
||||||
session.Clear() // clear to sync deleted values
|
session.Clear() // clear to sync deleted values
|
||||||
|
@ -145,7 +116,6 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
|
||||||
h.addSessionCookie(httpResp, modifiedID)
|
h.addSessionCookie(httpResp, modifiedID)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.debug("Session data persisted successfully")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,13 +128,10 @@ func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID
|
||||||
cookieName := opts["name"].(string)
|
cookieName := opts["name"].(string)
|
||||||
for _, cookie := range resp.Cookies {
|
for _, cookie := range resp.Cookies {
|
||||||
if string(cookie.Key()) == cookieName {
|
if string(cookie.Key()) == cookieName {
|
||||||
h.debug("Session cookie already set in response")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.debug("Adding session cookie to response")
|
|
||||||
|
|
||||||
// Create and add cookie
|
// Create and add cookie
|
||||||
cookie := fasthttp.AcquireCookie()
|
cookie := fasthttp.AcquireCookie()
|
||||||
cookie.SetKey(cookieName)
|
cookie.SetKey(cookieName)
|
||||||
|
|
|
@ -31,8 +31,8 @@ var responsePool = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
return &HTTPResponse{
|
return &HTTPResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Headers: make(map[string]string, 8), // Pre-allocate with reasonable capacity
|
Headers: make(map[string]string, 8),
|
||||||
Cookies: make([]*fasthttp.Cookie, 0, 4), // Pre-allocate with reasonable capacity
|
Cookies: make([]*fasthttp.Cookie, 0, 4),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -48,14 +48,10 @@ var defaultFastClient fasthttp.Client = fasthttp.Client{
|
||||||
|
|
||||||
// HTTPClientConfig contains client settings
|
// HTTPClientConfig contains client settings
|
||||||
type HTTPClientConfig struct {
|
type HTTPClientConfig struct {
|
||||||
// Maximum timeout for requests (0 = no limit)
|
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
|
||||||
MaxTimeout time.Duration
|
DefaultTimeout time.Duration // Default request timeout
|
||||||
// Default request timeout
|
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
|
||||||
DefaultTimeout time.Duration
|
AllowRemote bool // Whether to allow remote connections
|
||||||
// Maximum response size in bytes (0 = no limit)
|
|
||||||
MaxResponseSize int64
|
|
||||||
// Whether to allow remote connections
|
|
||||||
AllowRemote bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultHTTPClientConfig provides sensible defaults
|
// DefaultHTTPClientConfig provides sensible defaults
|
||||||
|
@ -66,12 +62,12 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
|
||||||
AllowRemote: true,
|
AllowRemote: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool
|
// NewHTTPResponse creates a default HTTP response from pool
|
||||||
func NewHTTPResponse() *HTTPResponse {
|
func NewHTTPResponse() *HTTPResponse {
|
||||||
return responsePool.Get().(*HTTPResponse)
|
return responsePool.Get().(*HTTPResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseResponse returns the response to the pool after clearing its values
|
// ReleaseResponse returns the response to the pool
|
||||||
func ReleaseResponse(resp *HTTPResponse) {
|
func ReleaseResponse(resp *HTTPResponse) {
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return
|
return
|
||||||
|
@ -99,8 +95,7 @@ func HTTPModuleInitFunc() func(*luajit.State) error {
|
||||||
return func(state *luajit.State) error {
|
return func(state *luajit.State) error {
|
||||||
// Register the native Go function first
|
// Register the native Go function first
|
||||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||||
logger.Error("[HTTP Module] Failed to register __http_request function")
|
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
|
||||||
logger.ErrorCont("%v", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +106,7 @@ func HTTPModuleInitFunc() func(*luajit.State) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to set up HTTP client config
|
// setupHTTPClientConfig configures HTTP client in Lua
|
||||||
func setupHTTPClientConfig(state *luajit.State) {
|
func setupHTTPClientConfig(state *luajit.State) {
|
||||||
state.NewTable()
|
state.NewTable()
|
||||||
|
|
||||||
|
@ -138,7 +133,7 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
||||||
state.GetGlobal("__http_responses")
|
state.GetGlobal("__http_responses")
|
||||||
if state.IsNil(-1) {
|
if state.IsNil(-1) {
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
ReleaseResponse(response) // Return unused response to pool
|
ReleaseResponse(response)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +142,7 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
||||||
state.GetTable(-2)
|
state.GetTable(-2)
|
||||||
if state.IsNil(-1) {
|
if state.IsNil(-1) {
|
||||||
state.Pop(2)
|
state.Pop(2)
|
||||||
ReleaseResponse(response) // Return unused response to pool
|
ReleaseResponse(response)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -340,11 +335,10 @@ func httpRequest(state *luajit.State) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get client configuration from registry (if available)
|
// Get client configuration
|
||||||
var config HTTPClientConfig = DefaultHTTPClientConfig
|
var config HTTPClientConfig = DefaultHTTPClientConfig
|
||||||
state.GetGlobal("__http_client_config")
|
state.GetGlobal("__http_client_config")
|
||||||
if !state.IsNil(-1) {
|
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Extract max timeout
|
// Extract max timeout
|
||||||
state.GetField(-1, "max_timeout")
|
state.GetField(-1, "max_timeout")
|
||||||
if state.IsNumber(-1) {
|
if state.IsNumber(-1) {
|
||||||
|
@ -373,7 +367,6 @@ func httpRequest(state *luajit.State) int {
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
// Check if remote connections are allowed
|
// Check if remote connections are allowed
|
||||||
|
|
|
@ -6,9 +6,12 @@ import (
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ModuleFunc is a function that returns a map of module functions
|
// ModuleFunc returns a map of module functions
|
||||||
type ModuleFunc func() map[string]luajit.GoFunction
|
type ModuleFunc func() map[string]luajit.GoFunction
|
||||||
|
|
||||||
|
// StateInitFunc initializes a module in a Lua state
|
||||||
|
type StateInitFunc func(*luajit.State) error
|
||||||
|
|
||||||
// RegisterModule registers a map of functions as a Lua module
|
// RegisterModule registers a map of functions as a Lua module
|
||||||
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
|
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
|
||||||
// Create a new table for the module
|
// Create a new table for the module
|
||||||
|
@ -16,16 +19,11 @@ func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.Go
|
||||||
|
|
||||||
// Add each function to the module table
|
// Add each function to the module table
|
||||||
for fname, f := range funcs {
|
for fname, f := range funcs {
|
||||||
// Push function name
|
|
||||||
state.PushString(fname)
|
state.PushString(fname)
|
||||||
|
|
||||||
// Push function
|
|
||||||
if err := state.PushGoFunction(f); err != nil {
|
if err := state.PushGoFunction(f); err != nil {
|
||||||
state.Pop(1) // Pop table
|
state.Pop(1) // Pop table
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set table[fname] = f
|
|
||||||
state.SetTable(-3)
|
state.SetTable(-3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,8 +32,21 @@ func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.Go
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModuleInitFunc creates a state initializer that registers multiple modules
|
||||||
|
func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
for name, moduleFunc := range modules {
|
||||||
|
if err := RegisterModule(state, name, moduleFunc()); err != nil {
|
||||||
|
logger.Error("Failed to register module %s: %v", name, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CombineInitFuncs combines multiple state initializer functions into one
|
// CombineInitFuncs combines multiple state initializer functions into one
|
||||||
func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) error {
|
func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc {
|
||||||
return func(state *luajit.State) error {
|
return func(state *luajit.State) error {
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
if f != nil {
|
if f != nil {
|
||||||
|
@ -48,33 +59,20 @@ func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) er
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModuleInitFunc creates a state initializer that registers multiple modules
|
// RegisterLuaCode registers a Lua code snippet in a state
|
||||||
func ModuleInitFunc(modules map[string]ModuleFunc) func(*luajit.State) error {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
for name, moduleFunc := range modules {
|
|
||||||
if err := RegisterModule(state, name, moduleFunc()); err != nil {
|
|
||||||
logger.Error("Failed to register module %s: %v", name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterLuaCode registers a Lua code snippet as a module
|
|
||||||
func RegisterLuaCode(state *luajit.State, code string) error {
|
func RegisterLuaCode(state *luajit.State, code string) error {
|
||||||
return state.DoString(code)
|
return state.DoString(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
|
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
|
||||||
func RegisterLuaCodeInitFunc(code string) func(*luajit.State) error {
|
func RegisterLuaCodeInitFunc(code string) StateInitFunc {
|
||||||
return func(state *luajit.State) error {
|
return func(state *luajit.State) error {
|
||||||
return RegisterLuaCode(state, code)
|
return RegisterLuaCode(state, code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
|
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
|
||||||
func RegisterLuaModuleInitFunc(name string, code string) func(*luajit.State) error {
|
func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc {
|
||||||
return func(state *luajit.State) error {
|
return func(state *luajit.State) error {
|
||||||
// Create name = {} global
|
// Create name = {} global
|
||||||
state.NewTable()
|
state.NewTable()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user