http client fix and sandbox optimization
This commit is contained in:
parent
08a532f11a
commit
6154b5303c
|
@ -1,22 +1,38 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
||||
// CoreModuleRegistry manages the initialization and reloading of core modules
|
||||
type CoreModuleRegistry struct {
|
||||
modules map[string]StateInitFunc
|
||||
mu sync.RWMutex
|
||||
debug bool
|
||||
}
|
||||
|
||||
// NewCoreModuleRegistry creates a new core module registry
|
||||
func NewCoreModuleRegistry() *CoreModuleRegistry {
|
||||
return &CoreModuleRegistry{
|
||||
modules: make(map[string]StateInitFunc),
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
// EnableDebug turns on debug logging
|
||||
func (r *CoreModuleRegistry) EnableDebug() {
|
||||
r.debug = true
|
||||
}
|
||||
|
||||
// debugLog prints debug messages if enabled
|
||||
func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) {
|
||||
if r.debug {
|
||||
logger.Debug("[CoreModuleRegistry] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,6 +41,7 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) {
|
|||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.modules[name] = initFunc
|
||||
r.debugLog("Registered module: %s", name)
|
||||
}
|
||||
|
||||
// Initialize initializes all registered modules
|
||||
|
@ -32,16 +49,30 @@ func (r *CoreModuleRegistry) Initialize(state *luajit.State) error {
|
|||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Convert to StateInitFunc
|
||||
initFunc := CombineInitFuncs(r.getInitFuncs()...)
|
||||
return initFunc(state)
|
||||
r.debugLog("Initializing all modules...")
|
||||
|
||||
// Get all module init functions
|
||||
initFuncs := r.getInitFuncs()
|
||||
|
||||
// Initialize modules one by one to better track issues
|
||||
for name, initFunc := range initFuncs {
|
||||
r.debugLog("Initializing module: %s", name)
|
||||
if err := initFunc(state); err != nil {
|
||||
r.debugLog("Failed to initialize module %s: %v", name, err)
|
||||
return fmt.Errorf("failed to initialize module %s: %w", name, err)
|
||||
}
|
||||
r.debugLog("Module %s initialized successfully", name)
|
||||
}
|
||||
|
||||
r.debugLog("All modules initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInitFuncs returns all module init functions
|
||||
func (r *CoreModuleRegistry) getInitFuncs() []StateInitFunc {
|
||||
funcs := make([]StateInitFunc, 0, len(r.modules))
|
||||
for _, initFunc := range r.modules {
|
||||
funcs = append(funcs, initFunc)
|
||||
func (r *CoreModuleRegistry) getInitFuncs() map[string]StateInitFunc {
|
||||
funcs := make(map[string]StateInitFunc, len(r.modules))
|
||||
for name, initFunc := range r.modules {
|
||||
funcs[name] = initFunc
|
||||
}
|
||||
return funcs
|
||||
}
|
||||
|
@ -53,9 +84,11 @@ func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string)
|
|||
|
||||
initFunc, ok := r.modules[name]
|
||||
if !ok {
|
||||
r.debugLog("Module not found: %s", name)
|
||||
return nil // Module not found, no error
|
||||
}
|
||||
|
||||
r.debugLog("Reinitializing module: %s", name)
|
||||
return initFunc(state)
|
||||
}
|
||||
|
||||
|
@ -96,8 +129,10 @@ var GlobalRegistry = NewCoreModuleRegistry()
|
|||
|
||||
// Initialize global registry with core modules
|
||||
func init() {
|
||||
GlobalRegistry.EnableDebug() // Enable debugging by default
|
||||
GlobalRegistry.Register("http", HTTPModuleInitFunc())
|
||||
GlobalRegistry.Register("cookie", CookieModuleInitFunc())
|
||||
logger.Debug("[CoreModuleRegistry] Core modules registered in init()")
|
||||
}
|
||||
|
||||
// RegisterCoreModule is a helper to register a core module
|
||||
|
@ -105,6 +140,3 @@ func init() {
|
|||
func RegisterCoreModule(name string, initFunc StateInitFunc) {
|
||||
GlobalRegistry.Register(name, initFunc)
|
||||
}
|
||||
|
||||
// To add a new module, simply call:
|
||||
// RegisterCoreModule("new_module_name", NewModuleInitFunc())
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
||||
// HTTPResponse represents an HTTP response from Lua
|
||||
|
@ -364,190 +365,34 @@ func httpRequest(state *luajit.State) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// LuaHTTPModule is the pure Lua implementation of the HTTP module
|
||||
const LuaHTTPModule = `
|
||||
-- Table to store response data
|
||||
__http_responses = {}
|
||||
|
||||
-- HTTP module implementation
|
||||
local http = {
|
||||
-- Set HTTP status code
|
||||
set_status = function(code)
|
||||
if type(code) ~= "number" then
|
||||
error("http.set_status: status code must be a number", 2)
|
||||
end
|
||||
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.status = code
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- Set HTTP header
|
||||
set_header = function(name, value)
|
||||
if type(name) ~= "string" or type(value) ~= "string" then
|
||||
error("http.set_header: name and value must be strings", 2)
|
||||
end
|
||||
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.headers = resp.headers or {}
|
||||
resp.headers[name] = value
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- Set content type; set_header helper
|
||||
set_content_type = function(content_type)
|
||||
http.set_header("Content-Type", content_type)
|
||||
end,
|
||||
|
||||
-- HTTP client submodule
|
||||
client = {
|
||||
-- Generic request function
|
||||
request = function(method, url, body, options)
|
||||
if type(method) ~= "string" then
|
||||
error("http.client.request: method must be a string", 2)
|
||||
end
|
||||
if type(url) ~= "string" then
|
||||
error("http.client.request: url must be a string", 2)
|
||||
end
|
||||
|
||||
-- Call native implementation
|
||||
return __http_request(method, url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple GET request
|
||||
get = function(url, options)
|
||||
return http.client.request("GET", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Simple POST request with automatic content-type
|
||||
post = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("POST", url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple PUT request with automatic content-type
|
||||
put = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("PUT", url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple DELETE request
|
||||
delete = function(url, options)
|
||||
return http.client.request("DELETE", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Simple PATCH request
|
||||
patch = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("PATCH", url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple HEAD request
|
||||
head = function(url, options)
|
||||
options = options or {}
|
||||
local old_options = options
|
||||
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
|
||||
local response = http.client.request("HEAD", url, nil, options)
|
||||
return response
|
||||
end,
|
||||
|
||||
-- Simple OPTIONS request
|
||||
options = function(url, options)
|
||||
return http.client.request("OPTIONS", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Shorthand function to directly get JSON
|
||||
get_json = function(url, options)
|
||||
options = options or {}
|
||||
local response = http.client.get(url, options)
|
||||
if response.ok and response.json then
|
||||
return response.json
|
||||
end
|
||||
return nil, response
|
||||
end,
|
||||
|
||||
-- Utility to build a URL with query parameters
|
||||
build_url = function(base_url, params)
|
||||
if not params or type(params) ~= "table" then
|
||||
return base_url
|
||||
end
|
||||
|
||||
local query = {}
|
||||
for k, v in pairs(params) do
|
||||
if type(v) == "table" then
|
||||
for _, item in ipairs(v) do
|
||||
table.insert(query, k .. "=" .. tostring(item))
|
||||
end
|
||||
else
|
||||
table.insert(query, k .. "=" .. tostring(v))
|
||||
end
|
||||
end
|
||||
|
||||
if #query > 0 then
|
||||
if base_url:find("?") then
|
||||
return base_url .. "&" .. table.concat(query, "&")
|
||||
else
|
||||
return base_url .. "?" .. table.concat(query, "&")
|
||||
end
|
||||
end
|
||||
|
||||
return base_url
|
||||
end
|
||||
}
|
||||
}
|
||||
|
||||
-- Install HTTP module
|
||||
_G.http = http
|
||||
|
||||
-- Override sandbox executor to clear HTTP responses
|
||||
local old_execute_sandbox = __execute_sandbox
|
||||
__execute_sandbox = function(bytecode, ctx)
|
||||
-- Clear previous response for this thread
|
||||
__http_responses[1] = nil
|
||||
|
||||
-- Execute the original function
|
||||
local result = old_execute_sandbox(bytecode, ctx)
|
||||
|
||||
-- Return the result unchanged
|
||||
return result
|
||||
end
|
||||
|
||||
-- Make sure the HTTP module is accessible in sandbox
|
||||
if __env_system and __env_system.base_env then
|
||||
__env_system.base_env.http = http
|
||||
end
|
||||
`
|
||||
|
||||
// HTTPModuleInitFunc returns an initializer function for the HTTP module
|
||||
func HTTPModuleInitFunc() StateInitFunc {
|
||||
return func(state *luajit.State) error {
|
||||
// The important fix: register the Go function directly to the global environment
|
||||
// CRITICAL: Register the native Go function first
|
||||
// This must be done BEFORE any Lua code that references it
|
||||
if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil {
|
||||
logger.Error("[HTTP Module] Failed to register __http_request function: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize pure Lua HTTP module
|
||||
// Set up default HTTP client configuration
|
||||
setupHTTPClientConfig(state)
|
||||
|
||||
// Initialize Lua HTTP module
|
||||
if err := state.DoString(LuaHTTPModule); err != nil {
|
||||
logger.Error("[HTTP Module] Failed to initialize HTTP module Lua code: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for existing config (in sandbox modules)
|
||||
state.GetGlobal("__sandbox_modules")
|
||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||
state.PushString("__http_client_config")
|
||||
state.GetTable(-2)
|
||||
// Verify HTTP client functions are available
|
||||
verifyHTTPClient(state)
|
||||
|
||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||
// Use the config from sandbox modules
|
||||
state.SetGlobal("__http_client_config")
|
||||
state.Pop(1) // Pop the sandbox modules table
|
||||
return nil
|
||||
}
|
||||
state.Pop(1) // Pop the nil or non-table value
|
||||
}
|
||||
state.Pop(1) // Pop the nil or sandbox modules table
|
||||
}
|
||||
|
||||
// Setup default configuration if no custom config exists
|
||||
// Helper to set up HTTP client config
|
||||
func setupHTTPClientConfig(state *luajit.State) {
|
||||
state.NewTable()
|
||||
|
||||
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
|
||||
|
@ -563,16 +408,6 @@ func HTTPModuleInitFunc() StateInitFunc {
|
|||
state.SetField(-2, "allow_remote")
|
||||
|
||||
state.SetGlobal("__http_client_config")
|
||||
|
||||
// Ensure the Go function is registered with the base environment
|
||||
// This is critical to make it persist across reloads
|
||||
return state.DoString(`
|
||||
-- Make the __http_request function available in the base environment
|
||||
if __env_system and __env_system.base_env then
|
||||
__env_system.base_env.__http_request = __http_request
|
||||
end
|
||||
`)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHTTPResponse extracts the HTTP response from Lua state
|
||||
|
@ -669,3 +504,187 @@ func RestrictHTTPToLocalhost() RunnerOption {
|
|||
AllowRemote: false,
|
||||
})
|
||||
}
|
||||
|
||||
// Verify that HTTP client is properly set up
|
||||
func verifyHTTPClient(state *luajit.State) {
|
||||
// Get the client table
|
||||
state.GetGlobal("http")
|
||||
if !state.IsTable(-1) {
|
||||
logger.Warning("[HTTP Module] 'http' is not a table\n")
|
||||
state.Pop(1)
|
||||
return
|
||||
}
|
||||
|
||||
state.GetField(-1, "client")
|
||||
if !state.IsTable(-1) {
|
||||
logger.Warning("[HTTP Module] 'http.client' is not a table\n")
|
||||
state.Pop(2)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for get function
|
||||
state.GetField(-1, "get")
|
||||
if !state.IsFunction(-1) {
|
||||
logger.Warning("[HTTP Module] 'http.client.get' is not a function\n")
|
||||
} else {
|
||||
logger.Debug("[HTTP Module] 'http.client.get' is properly registered\n")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Check for the request function
|
||||
state.GetField(-1, "request")
|
||||
if !state.IsFunction(-1) {
|
||||
logger.Warning("[HTTP Module] 'http.client.request' is not a function\n")
|
||||
} else {
|
||||
logger.Debug("[HTTP Module] 'http.client.request' is properly registered\n")
|
||||
}
|
||||
state.Pop(3) // Pop request, client, http
|
||||
}
|
||||
|
||||
const LuaHTTPModule = `
|
||||
-- Table to store response data
|
||||
__http_responses = {}
|
||||
|
||||
-- HTTP module implementation
|
||||
local http = {
|
||||
-- Set HTTP status code
|
||||
set_status = function(code)
|
||||
if type(code) ~= "number" then
|
||||
error("http.set_status: status code must be a number", 2)
|
||||
end
|
||||
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.status = code
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- Set HTTP header
|
||||
set_header = function(name, value)
|
||||
if type(name) ~= "string" or type(value) ~= "string" then
|
||||
error("http.set_header: name and value must be strings", 2)
|
||||
end
|
||||
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.headers = resp.headers or {}
|
||||
resp.headers[name] = value
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- Set content type; set_header helper
|
||||
set_content_type = function(content_type)
|
||||
http.set_header("Content-Type", content_type)
|
||||
end,
|
||||
|
||||
-- HTTP client submodule
|
||||
client = {
|
||||
-- Generic request function
|
||||
request = function(method, url, body, options)
|
||||
if type(method) ~= "string" then
|
||||
error("http.client.request: method must be a string", 2)
|
||||
end
|
||||
if type(url) ~= "string" then
|
||||
error("http.client.request: url must be a string", 2)
|
||||
end
|
||||
|
||||
-- Call native implementation (this is the critical part)
|
||||
local result = __http_request(method, url, body, options)
|
||||
return result
|
||||
end,
|
||||
|
||||
-- Simple GET request
|
||||
get = function(url, options)
|
||||
return http.client.request("GET", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Simple POST request with automatic content-type
|
||||
post = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("POST", url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple PUT request with automatic content-type
|
||||
put = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("PUT", url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple DELETE request
|
||||
delete = function(url, options)
|
||||
return http.client.request("DELETE", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Simple PATCH request
|
||||
patch = function(url, body, options)
|
||||
options = options or {}
|
||||
return http.client.request("PATCH", url, body, options)
|
||||
end,
|
||||
|
||||
-- Simple HEAD request
|
||||
head = function(url, options)
|
||||
options = options or {}
|
||||
local old_options = options
|
||||
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
|
||||
local response = http.client.request("HEAD", url, nil, options)
|
||||
return response
|
||||
end,
|
||||
|
||||
-- Simple OPTIONS request
|
||||
options = function(url, options)
|
||||
return http.client.request("OPTIONS", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Shorthand function to directly get JSON
|
||||
get_json = function(url, options)
|
||||
options = options or {}
|
||||
local response = http.client.get(url, options)
|
||||
if response.ok and response.json then
|
||||
return response.json
|
||||
end
|
||||
return nil, response
|
||||
end,
|
||||
|
||||
-- Utility to build a URL with query parameters
|
||||
build_url = function(base_url, params)
|
||||
if not params or type(params) ~= "table" then
|
||||
return base_url
|
||||
end
|
||||
|
||||
local query = {}
|
||||
for k, v in pairs(params) do
|
||||
if type(v) == "table" then
|
||||
for _, item in ipairs(v) do
|
||||
table.insert(query, k .. "=" .. tostring(item))
|
||||
end
|
||||
else
|
||||
table.insert(query, k .. "=" .. tostring(v))
|
||||
end
|
||||
end
|
||||
|
||||
if #query > 0 then
|
||||
if base_url:find("?") then
|
||||
return base_url .. "&" .. table.concat(query, "&")
|
||||
else
|
||||
return base_url .. "?" .. table.concat(query, "&")
|
||||
end
|
||||
end
|
||||
|
||||
return base_url
|
||||
end
|
||||
}
|
||||
}
|
||||
|
||||
-- Install HTTP module
|
||||
_G.http = http
|
||||
|
||||
-- Clear previous responses when executing scripts
|
||||
local old_execute_script = __execute_script
|
||||
if old_execute_script then
|
||||
__execute_script = function(fn, ctx)
|
||||
-- Clear previous response
|
||||
__http_responses[1] = nil
|
||||
|
||||
-- Execute original function
|
||||
return old_execute_script(fn, ctx)
|
||||
end
|
||||
end
|
||||
`
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"sync/atomic"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
|
@ -40,6 +41,7 @@ type LuaRunner struct {
|
|||
bufferSize int // Size of the job queue buffer
|
||||
moduleLoader *NativeModuleLoader // Native module loader for require
|
||||
sandbox *Sandbox // The sandbox environment
|
||||
debug bool // Enable debug logging
|
||||
}
|
||||
|
||||
// WithBufferSize sets the job queue buffer size
|
||||
|
@ -71,12 +73,20 @@ func WithLibDirs(dirs ...string) RunnerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithDebugEnabled enables debug output
|
||||
func WithDebugEnabled() RunnerOption {
|
||||
return func(r *LuaRunner) {
|
||||
r.debug = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewRunner creates a new LuaRunner
|
||||
func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||
// Default configuration
|
||||
runner := &LuaRunner{
|
||||
bufferSize: 10, // Default buffer size
|
||||
sandbox: NewSandbox(),
|
||||
debug: false,
|
||||
}
|
||||
|
||||
// Apply options
|
||||
|
@ -84,13 +94,6 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
|||
opt(runner)
|
||||
}
|
||||
|
||||
// Initialize Lua state
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
return nil, errors.New("failed to create Lua state")
|
||||
}
|
||||
runner.state = state
|
||||
|
||||
// Create job queue
|
||||
runner.jobQueue = make(chan job, runner.bufferSize)
|
||||
runner.isRunning.Store(true)
|
||||
|
@ -104,36 +107,9 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
|||
runner.moduleLoader = NewNativeModuleLoader(requireConfig)
|
||||
}
|
||||
|
||||
// Set up require paths and mechanism
|
||||
if err := runner.moduleLoader.SetupRequire(state); err != nil {
|
||||
state.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Initialize all core modules from the registry
|
||||
if err := GlobalRegistry.Initialize(state); err != nil {
|
||||
state.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Set up sandbox after core modules are initialized
|
||||
if err := runner.sandbox.Setup(state); err != nil {
|
||||
state.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Preload all modules into package.loaded
|
||||
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
|
||||
state.Close()
|
||||
return nil, errors.New("failed to preload modules")
|
||||
}
|
||||
|
||||
// Run init function if provided
|
||||
if runner.initFunc != nil {
|
||||
if err := runner.initFunc(state); err != nil {
|
||||
state.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
// Initialize Lua state
|
||||
if err := runner.initState(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start the event loop
|
||||
|
@ -143,10 +119,130 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
|||
return runner, nil
|
||||
}
|
||||
|
||||
// debugLog logs a message if debug mode is enabled
|
||||
func (r *LuaRunner) debugLog(format string, args ...interface{}) {
|
||||
if r.debug {
|
||||
logger.Debug("[LuaRunner] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// initState initializes or reinitializes the Lua state
|
||||
func (r *LuaRunner) initState(initial bool) error {
|
||||
r.debugLog("Initializing Lua state (initial=%v)", initial)
|
||||
|
||||
// Clean up existing state if there is one
|
||||
if r.state != nil {
|
||||
r.debugLog("Cleaning up existing state")
|
||||
// Always call Cleanup before Close to properly free function pointers
|
||||
r.state.Cleanup()
|
||||
r.state.Close()
|
||||
r.state = nil
|
||||
}
|
||||
|
||||
// Create fresh state
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
return errors.New("failed to create Lua state")
|
||||
}
|
||||
r.debugLog("Created new Lua state")
|
||||
|
||||
// Set up require paths and mechanism
|
||||
if err := r.moduleLoader.SetupRequire(state); err != nil {
|
||||
r.debugLog("Failed to set up require: %v", err)
|
||||
state.Cleanup()
|
||||
state.Close()
|
||||
return ErrInitFailed
|
||||
}
|
||||
r.debugLog("Require system initialized")
|
||||
|
||||
// Initialize all core modules from the registry
|
||||
if err := GlobalRegistry.Initialize(state); err != nil {
|
||||
r.debugLog("Failed to initialize core modules: %v", err)
|
||||
state.Cleanup()
|
||||
state.Close()
|
||||
return ErrInitFailed
|
||||
}
|
||||
r.debugLog("Core modules initialized")
|
||||
|
||||
// Check if http module is properly registered
|
||||
testResult, err := state.ExecuteWithResult(`
|
||||
if type(http) == "table" and type(http.client) == "table" and
|
||||
type(http.client.get) == "function" then
|
||||
return true
|
||||
else
|
||||
return false
|
||||
end
|
||||
`)
|
||||
if err != nil || testResult != true {
|
||||
r.debugLog("HTTP module verification failed: %v, result: %v", err, testResult)
|
||||
} else {
|
||||
r.debugLog("HTTP module verified OK")
|
||||
}
|
||||
|
||||
// Verify __http_request function
|
||||
testResult, _ = state.ExecuteWithResult(`return type(__http_request)`)
|
||||
r.debugLog("__http_request function is of type: %v", testResult)
|
||||
|
||||
// Set up sandbox after core modules are initialized
|
||||
if err := r.sandbox.Setup(state); err != nil {
|
||||
r.debugLog("Failed to set up sandbox: %v", err)
|
||||
state.Cleanup()
|
||||
state.Close()
|
||||
return ErrInitFailed
|
||||
}
|
||||
r.debugLog("Sandbox environment set up")
|
||||
|
||||
// Preload all modules into package.loaded
|
||||
if err := r.moduleLoader.PreloadAllModules(state); err != nil {
|
||||
r.debugLog("Failed to preload modules: %v", err)
|
||||
state.Cleanup()
|
||||
state.Close()
|
||||
return errors.New("failed to preload modules")
|
||||
}
|
||||
r.debugLog("All modules preloaded")
|
||||
|
||||
// Run init function if provided
|
||||
if r.initFunc != nil {
|
||||
if err := r.initFunc(state); err != nil {
|
||||
r.debugLog("Custom init function failed: %v", err)
|
||||
state.Cleanup()
|
||||
state.Close()
|
||||
return ErrInitFailed
|
||||
}
|
||||
r.debugLog("Custom init function completed")
|
||||
}
|
||||
|
||||
// Test for HTTP module again after full initialization
|
||||
testResult, err = state.ExecuteWithResult(`
|
||||
if type(http) == "table" and type(http.client) == "table" and
|
||||
type(http.client.get) == "function" then
|
||||
return true
|
||||
else
|
||||
return false
|
||||
end
|
||||
`)
|
||||
if err != nil || testResult != true {
|
||||
r.debugLog("Final HTTP module verification failed: %v, result: %v", err, testResult)
|
||||
} else {
|
||||
r.debugLog("Final HTTP module verification OK")
|
||||
}
|
||||
|
||||
r.state = state
|
||||
r.debugLog("State initialization complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// processJobs handles the job queue
|
||||
func (r *LuaRunner) processJobs() {
|
||||
defer r.wg.Done()
|
||||
defer r.state.Close()
|
||||
defer func() {
|
||||
if r.state != nil {
|
||||
r.debugLog("Cleaning up Lua state in processJobs")
|
||||
r.state.Cleanup()
|
||||
r.state.Close()
|
||||
r.state = nil
|
||||
}
|
||||
}()
|
||||
|
||||
for job := range r.jobQueue {
|
||||
// Execute the job and send result
|
||||
|
@ -175,6 +271,13 @@ func (r *LuaRunner) executeJob(j job) JobResult {
|
|||
ctx = j.Context.Values
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if r.state == nil {
|
||||
return JobResult{nil, errors.New("lua state is not initialized")}
|
||||
}
|
||||
|
||||
// Execute in sandbox
|
||||
value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx)
|
||||
return JobResult{value, err}
|
||||
|
@ -260,15 +363,26 @@ func (r *LuaRunner) Close() error {
|
|||
|
||||
// NotifyFileChanged handles file change notifications from watchers
|
||||
func (r *LuaRunner) NotifyFileChanged(filePath string) bool {
|
||||
if r.moduleLoader != nil {
|
||||
return r.moduleLoader.NotifyFileChanged(r.state, filePath)
|
||||
}
|
||||
r.debugLog("File change detected: %s", filePath)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Reset the entire state on file changes
|
||||
err := r.initState(false)
|
||||
if err != nil {
|
||||
r.debugLog("Failed to reinitialize state: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
r.debugLog("State successfully reinitialized")
|
||||
return true
|
||||
}
|
||||
|
||||
// ResetModuleCache clears non-core modules from package.loaded
|
||||
func (r *LuaRunner) ResetModuleCache() {
|
||||
if r.moduleLoader != nil {
|
||||
r.debugLog("Resetting module cache")
|
||||
r.moduleLoader.ResetModules(r.state)
|
||||
}
|
||||
}
|
||||
|
@ -276,6 +390,7 @@ func (r *LuaRunner) ResetModuleCache() {
|
|||
// ReloadAllModules reloads all modules into package.loaded
|
||||
func (r *LuaRunner) ReloadAllModules() error {
|
||||
if r.moduleLoader != nil {
|
||||
r.debugLog("Reloading all modules")
|
||||
return r.moduleLoader.PreloadAllModules(r.state)
|
||||
}
|
||||
return nil
|
||||
|
@ -284,6 +399,7 @@ func (r *LuaRunner) ReloadAllModules() error {
|
|||
// RefreshModuleByName invalidates a specific module in package.loaded
|
||||
func (r *LuaRunner) RefreshModuleByName(modName string) bool {
|
||||
if r.state != nil {
|
||||
r.debugLog("Refreshing module: %s", modName)
|
||||
if err := r.state.DoString(`package.loaded["` + modName + `"] = nil`); err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -294,6 +410,7 @@ func (r *LuaRunner) RefreshModuleByName(modName string) bool {
|
|||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (r *LuaRunner) AddModule(name string, module any) {
|
||||
r.debugLog("Adding module: %s", name)
|
||||
r.sandbox.AddModule(name, module)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,23 +1,29 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
||||
// Sandbox manages a sandboxed Lua environment
|
||||
// Sandbox manages a simplified Lua environment
|
||||
type Sandbox struct {
|
||||
modules map[string]any // Custom modules for environment
|
||||
initialized bool // Whether base environment is initialized
|
||||
debug bool // Enable debug output
|
||||
}
|
||||
|
||||
// NewSandbox creates a new sandbox
|
||||
func NewSandbox() *Sandbox {
|
||||
s := &Sandbox{
|
||||
return &Sandbox{
|
||||
modules: make(map[string]any),
|
||||
initialized: false,
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
// EnableDebug turns on debug output
|
||||
func (s *Sandbox) EnableDebug() {
|
||||
s.debug = true
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
|
@ -25,228 +31,109 @@ func (s *Sandbox) AddModule(name string, module any) {
|
|||
s.modules[name] = module
|
||||
}
|
||||
|
||||
// debugLog prints debug messages if debug is enabled
|
||||
func (s *Sandbox) debugLog(format string, args ...interface{}) {
|
||||
if s.debug {
|
||||
logger.Debug("[Sandbox Debug] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup initializes the sandbox in a Lua state
|
||||
func (s *Sandbox) Setup(state *luajit.State) error {
|
||||
// Register modules
|
||||
if err := s.registerModules(state); err != nil {
|
||||
s.debugLog("Setting up sandbox environment")
|
||||
|
||||
// Register modules in the global environment
|
||||
for name, module := range s.modules {
|
||||
s.debugLog("Registering module: %s", name)
|
||||
if err := state.PushValue(module); err != nil {
|
||||
s.debugLog("Failed to register module %s: %v", name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create high-performance persistent environment
|
||||
return state.DoString(`
|
||||
-- Global shared environment (created once)
|
||||
__env_system = __env_system or {
|
||||
base_env = nil, -- Template environment
|
||||
initialized = false, -- Initialization flag
|
||||
env_pool = {}, -- Pre-allocated environment pool
|
||||
pool_size = 0, -- Current pool size
|
||||
max_pool_size = 8 -- Maximum pool size
|
||||
state.SetGlobal(name)
|
||||
}
|
||||
|
||||
-- Initialize base environment once
|
||||
if not __env_system.initialized then
|
||||
-- Create base environment with all standard libraries
|
||||
local base = {}
|
||||
// Initialize simple environment setup
|
||||
err := state.DoString(`
|
||||
-- Global tables for response handling
|
||||
__http_responses = __http_responses or {}
|
||||
|
||||
-- Safe standard libraries
|
||||
base.string = string
|
||||
base.table = table
|
||||
base.math = math
|
||||
base.os = {
|
||||
time = os.time,
|
||||
date = os.date,
|
||||
difftime = os.difftime,
|
||||
clock = os.clock
|
||||
}
|
||||
-- Simple environment creation
|
||||
function __create_env(ctx)
|
||||
-- Create environment inheriting from _G
|
||||
local env = setmetatable({}, {__index = _G})
|
||||
|
||||
-- Basic functions
|
||||
base.print = print
|
||||
base.tonumber = tonumber
|
||||
base.tostring = tostring
|
||||
base.type = type
|
||||
base.pairs = pairs
|
||||
base.ipairs = ipairs
|
||||
base.next = next
|
||||
base.select = select
|
||||
base.unpack = unpack
|
||||
base.pcall = pcall
|
||||
base.xpcall = xpcall
|
||||
base.error = error
|
||||
base.assert = assert
|
||||
|
||||
-- Package system is shared for performance
|
||||
base.package = {
|
||||
loaded = package.loaded,
|
||||
path = package.path,
|
||||
preload = package.preload
|
||||
}
|
||||
|
||||
base.http = http
|
||||
base.cookie = cookie
|
||||
-- http_client module is now part of http.client
|
||||
|
||||
-- Add registered custom modules
|
||||
if __sandbox_modules then
|
||||
for name, mod in pairs(__sandbox_modules) do
|
||||
base[name] = mod
|
||||
end
|
||||
end
|
||||
|
||||
-- Store base environment
|
||||
__env_system.base_env = base
|
||||
__env_system.initialized = true
|
||||
end
|
||||
|
||||
-- Global variable for tracking current environment
|
||||
__last_env = nil
|
||||
|
||||
-- Fast environment creation with pre-allocation
|
||||
function __get_sandbox_env(ctx)
|
||||
local env
|
||||
|
||||
-- Try to reuse from pool
|
||||
if __env_system.pool_size > 0 then
|
||||
env = table.remove(__env_system.env_pool)
|
||||
__env_system.pool_size = __env_system.pool_size - 1
|
||||
|
||||
-- Clear any previous context
|
||||
env.ctx = ctx or nil
|
||||
-- Clear any previous response
|
||||
env._response = nil
|
||||
else
|
||||
-- Create new environment with metatable inheritance
|
||||
env = setmetatable({}, {
|
||||
__index = __env_system.base_env
|
||||
})
|
||||
|
||||
-- Set context if provided
|
||||
-- Add context if provided
|
||||
if ctx then
|
||||
env.ctx = ctx
|
||||
end
|
||||
|
||||
-- Install the fast require implementation
|
||||
env.require = function(modname)
|
||||
return __fast_require(env, modname)
|
||||
end
|
||||
|
||||
-- Install cookie module methods directly into environment
|
||||
env.cookie = {
|
||||
get = function(name)
|
||||
if type(name) ~= "string" then
|
||||
error("cookie.get: name must be a string", 2)
|
||||
end
|
||||
|
||||
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
|
||||
return tostring(env.ctx.cookies[name])
|
||||
end
|
||||
|
||||
return nil
|
||||
end,
|
||||
|
||||
set = cookie.set,
|
||||
remove = cookie.remove
|
||||
}
|
||||
end
|
||||
|
||||
-- Store reference to current environment
|
||||
__last_env = env
|
||||
|
||||
return env
|
||||
end
|
||||
|
||||
-- Return environment to pool for reuse
|
||||
function __recycle_env(env)
|
||||
-- Only recycle if pool isn't full
|
||||
if __env_system.pool_size < __env_system.max_pool_size then
|
||||
-- Clear context reference to avoid memory leaks
|
||||
env.ctx = nil
|
||||
-- Don't clear response data - we need it for extraction
|
||||
-- Execute script with clean environment
|
||||
function __execute_script(fn, ctx)
|
||||
-- Clear previous responses
|
||||
__http_responses[1] = nil
|
||||
|
||||
-- Add to pool
|
||||
table.insert(__env_system.env_pool, env)
|
||||
__env_system.pool_size = __env_system.pool_size + 1
|
||||
end
|
||||
end
|
||||
-- Create environment
|
||||
local env = __create_env(ctx)
|
||||
|
||||
-- Hyper-optimized sandbox executor
|
||||
function __execute_sandbox(bytecode, ctx)
|
||||
-- Get environment (from pool if available)
|
||||
local env = __get_sandbox_env(ctx)
|
||||
|
||||
-- Set environment for bytecode
|
||||
setfenv(bytecode, env)
|
||||
-- Set environment for function
|
||||
setfenv(fn, env)
|
||||
|
||||
-- Execute with protected call
|
||||
local success, result = pcall(bytecode)
|
||||
|
||||
-- Recycle environment for future use
|
||||
__recycle_env(env)
|
||||
|
||||
-- Process result
|
||||
if not success then
|
||||
local ok, result = pcall(fn)
|
||||
if not ok then
|
||||
error(result, 0)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- Run minimal GC for overall health
|
||||
collectgarbage("step", 10)
|
||||
`)
|
||||
}
|
||||
|
||||
// registerModules registers custom modules in the Lua state
|
||||
func (s *Sandbox) registerModules(state *luajit.State) error {
|
||||
// Create or get module registry table
|
||||
state.GetGlobal("__sandbox_modules")
|
||||
if state.IsNil(-1) {
|
||||
// Table doesn't exist, create it
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__sandbox_modules")
|
||||
state.GetGlobal("__sandbox_modules")
|
||||
}
|
||||
|
||||
// Add modules to registry
|
||||
for name, module := range s.modules {
|
||||
state.PushString(name)
|
||||
if err := state.PushValue(module); err != nil {
|
||||
state.Pop(2)
|
||||
if err != nil {
|
||||
s.debugLog("Failed to set up sandbox: %v", err)
|
||||
return err
|
||||
}
|
||||
state.SetTable(-3)
|
||||
}
|
||||
|
||||
// Pop module table
|
||||
state.Pop(1)
|
||||
s.debugLog("Sandbox setup complete")
|
||||
|
||||
// Verify HTTP module is accessible
|
||||
httpResult, _ := state.ExecuteWithResult(`
|
||||
if type(http) == "table" and
|
||||
type(http.client) == "table" and
|
||||
type(http.client.get) == "function" then
|
||||
return "HTTP module verified OK"
|
||||
else
|
||||
local status = {
|
||||
http = type(http),
|
||||
client = type(http) == "table" and type(http.client) or "N/A",
|
||||
get = type(http) == "table" and type(http.client) == "table" and type(http.client.get) or "N/A"
|
||||
}
|
||||
return status
|
||||
end
|
||||
`)
|
||||
|
||||
s.debugLog("HTTP verification result: %v", httpResult)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute runs bytecode in the sandbox
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
|
||||
// Update custom modules if needed
|
||||
if !s.initialized {
|
||||
if err := s.registerModules(state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.initialized = true
|
||||
}
|
||||
|
||||
// Load bytecode
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
s.debugLog("Failed to load bytecode: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create context table if provided
|
||||
if len(ctx) > 0 {
|
||||
// Preallocate table with appropriate size
|
||||
// Prepare context
|
||||
if ctx != nil {
|
||||
state.CreateTable(0, len(ctx))
|
||||
|
||||
// Add context entries
|
||||
for k, v := range ctx {
|
||||
state.PushString(k)
|
||||
if err := state.PushValue(v); err != nil {
|
||||
state.Pop(2)
|
||||
s.debugLog("Failed to push context value %s: %v", k, err)
|
||||
return nil, err
|
||||
}
|
||||
state.SetTable(-3)
|
||||
|
@ -255,31 +142,37 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
|
|||
state.PushNil() // No context
|
||||
}
|
||||
|
||||
// Get optimized sandbox executor
|
||||
state.GetGlobal("__execute_sandbox")
|
||||
// Get execution function
|
||||
state.GetGlobal("__execute_script")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(2) // Pop nil and non-function
|
||||
s.debugLog("__execute_script is not a function")
|
||||
return nil, fmt.Errorf("sandbox execution function not found")
|
||||
}
|
||||
|
||||
// Setup call with correct argument order
|
||||
state.PushCopy(-3) // Copy bytecode function
|
||||
state.PushCopy(-3) // Copy context
|
||||
// Push arguments
|
||||
state.PushCopy(-3) // bytecode function
|
||||
state.PushCopy(-3) // context
|
||||
|
||||
// Clean up stack
|
||||
state.Remove(-5) // Remove original bytecode
|
||||
state.Remove(-4) // Remove original context
|
||||
state.Remove(-5) // original bytecode
|
||||
state.Remove(-4) // original context
|
||||
|
||||
// Call optimized sandbox executor
|
||||
// Call with 2 args, 1 result
|
||||
if err := state.Call(2, 1); err != nil {
|
||||
s.debugLog("Execution failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := state.ToValue(-1)
|
||||
state.Pop(1) // Pop result
|
||||
state.Pop(1)
|
||||
|
||||
// Check if HTTP response was set
|
||||
httpResponse, hasHTTPResponse := GetHTTPResponse(state)
|
||||
if hasHTTPResponse {
|
||||
// Check for HTTP response
|
||||
httpResponse, hasResponse := GetHTTPResponse(state)
|
||||
if hasResponse {
|
||||
httpResponse.Body = result
|
||||
return httpResponse, err
|
||||
return httpResponse, nil
|
||||
}
|
||||
|
||||
return result, err
|
||||
|
|
|
@ -56,13 +56,23 @@ func (w *Watcher) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WatchLuaRouter sets up a watcher for a LuaRouter's routes directory
|
||||
func WatchLuaRouter(router *routers.LuaRouter, routesDir string, log *logger.Logger) (*Watcher, error) {
|
||||
// WatchLuaRouter sets up a watcher for a LuaRouter's routes directory; also updates
|
||||
// the LuaRunner so that the state can be rebuilt
|
||||
func WatchLuaRouter(router *routers.LuaRouter, runner *runner.LuaRunner, routesDir string, log *logger.Logger) (*Watcher, error) {
|
||||
manager := GetWatcherManager(log, true) // Use adaptive polling
|
||||
|
||||
// Create LuaRunner refresh callback
|
||||
runnerRefresh := func() error {
|
||||
log.Debug("Refreshing LuaRunner state due to file change")
|
||||
runner.NotifyFileChanged("")
|
||||
return nil
|
||||
}
|
||||
|
||||
combinedCallback := combineCallbacks(router.Refresh, runnerRefresh)
|
||||
|
||||
config := DirectoryWatcherConfig{
|
||||
Dir: routesDir,
|
||||
Callback: router.Refresh,
|
||||
Callback: combinedCallback,
|
||||
Log: log,
|
||||
Recursive: true,
|
||||
}
|
||||
|
@ -146,3 +156,15 @@ func ShutdownWatcherManager() {
|
|||
globalManager = nil
|
||||
}
|
||||
}
|
||||
|
||||
// combineCallbacks creates a single callback function from multiple callbacks
|
||||
func combineCallbacks(callbacks ...func() error) func() error {
|
||||
return func() error {
|
||||
for _, callback := range callbacks {
|
||||
if err := callback(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ func setupWatchers(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRou
|
|||
|
||||
// Set up watcher for Lua routes
|
||||
if config.Routes {
|
||||
luaRouterWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log)
|
||||
luaRouterWatcher, err := watchers.WatchLuaRouter(luaRouter, luaRunner, routesDir, log)
|
||||
if err != nil {
|
||||
log.Warning("Failed to watch routes directory: %v", err)
|
||||
} else {
|
||||
|
@ -180,6 +180,7 @@ func main() {
|
|||
luaRunner, err := runner.NewRunner(
|
||||
runner.WithBufferSize(bufferSize),
|
||||
runner.WithLibDirs(libDirs...),
|
||||
runner.WithDebugEnabled(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize Lua runner: %v", err)
|
||||
|
|
Loading…
Reference in New Issue
Block a user