From ce8132677e1de2409d734f206e1ca1042ecbf7d0 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 27 Mar 2025 13:47:10 -0500 Subject: [PATCH] modify file names, add http client --- core/runner/{context.go => Context.go} | 0 core/runner/Cookies.go | 187 ++++++++ core/runner/CoreModules.go | 111 +++++ core/runner/{http.go => Http.go} | 0 core/runner/HttpClient.go | 499 +++++++++++++++++++++ core/runner/{job.go => Job.go} | 0 core/runner/{luarunner.go => LuaRunner.go} | 13 +- core/runner/{modules.go => Modules.go} | 0 core/runner/{require.go => Require.go} | 53 ++- core/runner/{sandbox.go => Sandbox.go} | 1 + core/runner/cookies.go | 172 ------- 11 files changed, 826 insertions(+), 210 deletions(-) rename core/runner/{context.go => Context.go} (100%) create mode 100644 core/runner/Cookies.go create mode 100644 core/runner/CoreModules.go rename core/runner/{http.go => Http.go} (100%) create mode 100644 core/runner/HttpClient.go rename core/runner/{job.go => Job.go} (100%) rename core/runner/{luarunner.go => LuaRunner.go} (95%) rename core/runner/{modules.go => Modules.go} (100%) rename core/runner/{require.go => Require.go} (92%) rename core/runner/{sandbox.go => Sandbox.go} (99%) delete mode 100644 core/runner/cookies.go diff --git a/core/runner/context.go b/core/runner/Context.go similarity index 100% rename from core/runner/context.go rename to core/runner/Context.go diff --git a/core/runner/Cookies.go b/core/runner/Cookies.go new file mode 100644 index 0000000..d4ab0b0 --- /dev/null +++ b/core/runner/Cookies.go @@ -0,0 +1,187 @@ +package runner + +import ( + "net/http" + "time" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// extractCookie grabs cookies from the Lua state +func extractCookie(state *luajit.State) *http.Cookie { + cookie := &http.Cookie{ + Path: "/", // Default path + } + + // Get name + state.GetField(-1, "name") + if !state.IsString(-1) { + state.Pop(1) + return nil // Name is required + } + cookie.Name = state.ToString(-1) + state.Pop(1) + + // Get value + state.GetField(-1, "value") + if state.IsString(-1) { + cookie.Value = state.ToString(-1) + } + state.Pop(1) + + // Get path + state.GetField(-1, "path") + if state.IsString(-1) { + cookie.Path = state.ToString(-1) + } + state.Pop(1) + + // Get domain + state.GetField(-1, "domain") + if state.IsString(-1) { + cookie.Domain = state.ToString(-1) + } + state.Pop(1) + + // Get expires + state.GetField(-1, "expires") + if state.IsNumber(-1) { + expiry := int64(state.ToNumber(-1)) + cookie.Expires = time.Unix(expiry, 0) + } + state.Pop(1) + + // Get max age + state.GetField(-1, "max_age") + if state.IsNumber(-1) { + cookie.MaxAge = int(state.ToNumber(-1)) + } + state.Pop(1) + + // Get secure + state.GetField(-1, "secure") + if state.IsBoolean(-1) { + cookie.Secure = state.ToBoolean(-1) + } + state.Pop(1) + + // Get http only + state.GetField(-1, "http_only") + if state.IsBoolean(-1) { + cookie.HttpOnly = state.ToBoolean(-1) + } + state.Pop(1) + + return cookie +} + +// LuaCookieModule provides cookie functionality to Lua scripts +const LuaCookieModule = ` +-- Cookie module implementation +local cookie = { + -- Set a cookie + set = function(name, value, options, ...) + if type(name) ~= "string" then + error("cookie.set: name must be a string", 2) + end + + -- Get or create response + local resp = __http_responses[1] or {} + resp.cookies = resp.cookies or {} + __http_responses[1] = resp + + -- Handle options as table or legacy params + local opts = {} + if type(options) == "table" then + opts = options + elseif options ~= nil then + -- Legacy support: options is actually 'expires' + opts.expires = options + -- Check for other legacy params (4th-7th args) + local args = {...} + if args[1] then opts.path = args[1] end + if args[2] then opts.domain = args[2] end + if args[3] then opts.secure = args[3] end + if args[4] ~= nil then opts.http_only = args[4] end + end + + -- Create cookie table + local cookie = { + name = name, + value = value or "", + path = opts.path or "/", + domain = opts.domain + } + + -- Handle expiry + if opts.expires then + if type(opts.expires) == "number" then + if opts.expires > 0 then + -- Add seconds to current time + cookie.max_age = opts.expires + local now = os.time() + cookie.expires = now + opts.expires + elseif opts.expires < 0 then + -- Session cookie (default) + else + -- Expire immediately + cookie.expires = 0 + cookie.max_age = 0 + end + end + end + + -- Set flags (http_only defaults to true) + cookie.secure = opts.secure or false + cookie.http_only = (opts.http_only ~= false) -- Default to true unless explicitly set to false + + -- Store in cookies table + local n = #resp.cookies + 1 + resp.cookies[n] = cookie + + return true + end, + + -- Get a cookie value + get = function(name) + if type(name) ~= "string" then + error("cookie.get: name must be a string", 2) + end + + -- Access values directly from current environment + local env = getfenv(1) + + -- Check if context exists and has cookies + if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then + return tostring(env.ctx.cookies[name]) + end + + return nil + end, + + -- Remove a cookie + remove = function(name, path, domain) + if type(name) ~= "string" then + error("cookie.remove: name must be a string", 2) + end + + -- Create an expired cookie + return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) + end +} + +-- Install cookie module +_G.cookie = cookie + +-- Make sure the cookie module is accessible in sandbox +if __env_system and __env_system.base_env then + __env_system.base_env.cookie = cookie +end +` + +// CookieModuleInitFunc returns an initializer for the cookie module +func CookieModuleInitFunc() StateInitFunc { + return func(state *luajit.State) error { + return state.DoString(LuaCookieModule) + } +} diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go new file mode 100644 index 0000000..f04d01e --- /dev/null +++ b/core/runner/CoreModules.go @@ -0,0 +1,111 @@ +package runner + +import ( + "strings" + "sync" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// CoreModuleRegistry manages the initialization and reloading of core modules +type CoreModuleRegistry struct { + modules map[string]StateInitFunc + mu sync.RWMutex +} + +// NewCoreModuleRegistry creates a new core module registry +func NewCoreModuleRegistry() *CoreModuleRegistry { + return &CoreModuleRegistry{ + modules: make(map[string]StateInitFunc), + } +} + +// Register adds a module to the registry +func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) { + r.mu.Lock() + defer r.mu.Unlock() + r.modules[name] = initFunc +} + +// Initialize initializes all registered modules +func (r *CoreModuleRegistry) Initialize(state *luajit.State) error { + r.mu.RLock() + defer r.mu.RUnlock() + + // Convert to StateInitFunc + initFunc := CombineInitFuncs(r.getInitFuncs()...) + return initFunc(state) +} + +// getInitFuncs returns all module init functions +func (r *CoreModuleRegistry) getInitFuncs() []StateInitFunc { + funcs := make([]StateInitFunc, 0, len(r.modules)) + for _, initFunc := range r.modules { + funcs = append(funcs, initFunc) + } + return funcs +} + +// InitializeModule initializes a specific module +func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string) error { + r.mu.RLock() + defer r.mu.RUnlock() + + initFunc, ok := r.modules[name] + if !ok { + return nil // Module not found, no error + } + + return initFunc(state) +} + +// ModuleNames returns a list of all registered module names +func (r *CoreModuleRegistry) ModuleNames() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.modules)) + for name := range r.modules { + names = append(names, name) + } + return names +} + +// MatchModuleName checks if a file path corresponds to a registered module +func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + // Exact match + if _, ok := r.modules[modName]; ok { + return modName, true + } + + // Check if the module name ends with a registered module + for name := range r.modules { + if strings.HasSuffix(modName, "."+name) { + return name, true + } + } + + return "", false +} + +// Global registry instance +var GlobalRegistry = NewCoreModuleRegistry() + +// Initialize global registry with core modules +func init() { + GlobalRegistry.Register("http", HTTPModuleInitFunc()) + GlobalRegistry.Register("cookie", CookieModuleInitFunc()) + GlobalRegistry.Register("http_client", HTTPClientModuleInitFunc()) +} + +// RegisterCoreModule is a helper to register a core module +// with the global registry +func RegisterCoreModule(name string, initFunc StateInitFunc) { + GlobalRegistry.Register(name, initFunc) +} + +// To add a new module, simply call: +// RegisterCoreModule("new_module_name", NewModuleInitFunc()) diff --git a/core/runner/http.go b/core/runner/Http.go similarity index 100% rename from core/runner/http.go rename to core/runner/Http.go diff --git a/core/runner/HttpClient.go b/core/runner/HttpClient.go new file mode 100644 index 0000000..66d919a --- /dev/null +++ b/core/runner/HttpClient.go @@ -0,0 +1,499 @@ +package runner + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "time" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Default HTTP client with sensible timeout +var defaultClient = &http.Client{ + Timeout: 30 * time.Second, +} + +// HTTPClientConfig contains client settings +type HTTPClientConfig struct { + // Maximum timeout for requests (0 = no limit) + MaxTimeout time.Duration + // Default request timeout + DefaultTimeout time.Duration + // Maximum response size in bytes (0 = no limit) + MaxResponseSize int64 + // Whether to allow remote connections + AllowRemote bool +} + +// DefaultHTTPClientConfig provides sensible defaults +var DefaultHTTPClientConfig = HTTPClientConfig{ + MaxTimeout: 60 * time.Second, + DefaultTimeout: 30 * time.Second, + MaxResponseSize: 10 * 1024 * 1024, // 10MB + AllowRemote: true, +} + +// Function name constant to ensure consistency +const httpRequestFuncName = "__http_request" + +// httpRequest makes an HTTP request and returns the result to Lua +func httpRequest(state *luajit.State) int { + // Get method (required) + if !state.IsString(1) { + state.PushString("http_client.request: method must be a string") + return -1 + } + method := strings.ToUpper(state.ToString(1)) + + // Get URL (required) + if !state.IsString(2) { + state.PushString("http_client.request: url must be a string") + return -1 + } + urlStr := state.ToString(2) + + // Parse URL to check if it's valid and if it's allowed + parsedURL, err := url.Parse(urlStr) + if err != nil { + state.PushString("Invalid URL: " + err.Error()) + return -1 + } + + // Get client configuration from registry (if available) + var config HTTPClientConfig = DefaultHTTPClientConfig + state.GetGlobal("__http_client_config") + if !state.IsNil(-1) { + if state.IsTable(-1) { + // Extract max timeout + state.GetField(-1, "max_timeout") + if state.IsNumber(-1) { + config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second + } + state.Pop(1) + + // Extract default timeout + state.GetField(-1, "default_timeout") + if state.IsNumber(-1) { + config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second + } + state.Pop(1) + + // Extract max response size + state.GetField(-1, "max_response_size") + if state.IsNumber(-1) { + config.MaxResponseSize = int64(state.ToNumber(-1)) + } + state.Pop(1) + + // Extract allow remote + state.GetField(-1, "allow_remote") + if state.IsBoolean(-1) { + config.AllowRemote = state.ToBoolean(-1) + } + state.Pop(1) + } + } + state.Pop(1) + + // Check if remote connections are allowed + if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { + state.PushString("Remote connections are not allowed") + return -1 + } + + // Get body (optional) + var bodyReader io.Reader + if state.GetTop() >= 3 && !state.IsNil(3) { + var body []byte + + if state.IsString(3) { + // String body + body = []byte(state.ToString(3)) + } else if state.IsTable(3) { + // Table body - convert to JSON + luaTable, err := state.ToTable(3) + if err != nil { + state.PushString("Failed to parse body table: " + err.Error()) + return -1 + } + + body, err = json.Marshal(luaTable) + if err != nil { + state.PushString("Failed to convert body to JSON: " + err.Error()) + return -1 + } + } else { + state.PushString("Body must be a string or table") + return -1 + } + + bodyReader = bytes.NewReader(body) + } + + // Create request + req, err := http.NewRequest(method, urlStr, bodyReader) + if err != nil { + state.PushString("Failed to create request: " + err.Error()) + return -1 + } + + // Set default headers + req.Header.Set("User-Agent", "Moonshark/1.0") + + // Process options (headers, timeout, etc.) + timeout := config.DefaultTimeout + if state.GetTop() >= 4 && !state.IsNil(4) { + if !state.IsTable(4) { + state.PushString("Options must be a table") + return -1 + } + + // Process headers + state.GetField(4, "headers") + if state.IsTable(-1) { + // Iterate through headers + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack now has key at -2 and value at -1 + if state.IsString(-2) && state.IsString(-1) { + headerName := state.ToString(-2) + headerValue := state.ToString(-1) + req.Header.Set(headerName, headerValue) + } + state.Pop(1) // Pop value, leave key for next iteration + } + } + state.Pop(1) // Pop headers table + + // Get timeout + state.GetField(4, "timeout") + if state.IsNumber(-1) { + requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second + + // Apply max timeout if configured + if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { + timeout = config.MaxTimeout + } else { + timeout = requestTimeout + } + } + state.Pop(1) // Pop timeout + + // Set content type for POST/PUT if body is present and content-type not manually set + if (method == "POST" || method == "PUT") && bodyReader != nil && req.Header.Get("Content-Type") == "" { + // Check if options specify content type + state.GetField(4, "content_type") + if state.IsString(-1) { + req.Header.Set("Content-Type", state.ToString(-1)) + } else { + // Default to JSON if body is a table, otherwise plain text + if state.IsTable(3) { + req.Header.Set("Content-Type", "application/json") + } else { + req.Header.Set("Content-Type", "text/plain") + } + } + state.Pop(1) // Pop content_type + } + + // Process query parameters + state.GetField(4, "query") + if state.IsTable(-1) { + q := req.URL.Query() + + // Iterate through query params + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack now has key at -2 and value at -1 + if state.IsString(-2) { + paramName := state.ToString(-2) + + // Handle different value types + if state.IsString(-1) { + q.Add(paramName, state.ToString(-1)) + } else if state.IsNumber(-1) { + q.Add(paramName, strings.TrimRight(strings.TrimRight( + state.ToString(-1), "0"), ".")) + } else if state.IsBoolean(-1) { + if state.ToBoolean(-1) { + q.Add(paramName, "true") + } else { + q.Add(paramName, "false") + } + } + } + state.Pop(1) // Pop value, leave key for next iteration + } + + req.URL.RawQuery = q.Encode() + } + state.Pop(1) // Pop query table + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Use context with request + req = req.WithContext(ctx) + + // Execute request + resp, err := defaultClient.Do(req) + if err != nil { + state.PushString("Request failed: " + err.Error()) + return -1 + } + defer resp.Body.Close() + + // Apply size limits to response + var respBody []byte + if config.MaxResponseSize > 0 { + // Limit the response body size + respBody, err = io.ReadAll(io.LimitReader(resp.Body, config.MaxResponseSize)) + } else { + respBody, err = io.ReadAll(resp.Body) + } + + if err != nil { + state.PushString("Failed to read response: " + err.Error()) + return -1 + } + + // Create response table + state.NewTable() + + // Set status code + state.PushNumber(float64(resp.StatusCode)) + state.SetField(-2, "status") + + // Set status text + state.PushString(resp.Status) + state.SetField(-2, "status_text") + + // Set body + state.PushString(string(respBody)) + state.SetField(-2, "body") + + // Parse body as JSON if content type is application/json + if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { + var jsonData any + if err := json.Unmarshal(respBody, &jsonData); err == nil { + if err := state.PushValue(jsonData); err == nil { + state.SetField(-2, "json") + } + } + } + + // Set headers + state.NewTable() + for name, values := range resp.Header { + if len(values) == 1 { + state.PushString(values[0]) + } else { + // Create array of values + state.NewTable() + for i, v := range values { + state.PushNumber(float64(i + 1)) + state.PushString(v) + state.SetTable(-3) + } + } + state.SetField(-2, name) + } + state.SetField(-2, "headers") + + // Create ok field (true if status code is 2xx) + state.PushBoolean(resp.StatusCode >= 200 && resp.StatusCode < 300) + state.SetField(-2, "ok") + + return 1 +} + +// LuaHTTPClientModule defines the HTTP client module in Lua +const LuaHTTPClientModule = ` +-- HTTP client module implementation +local http_client = { + -- Generic request function + request = function(method, url, body, options) + if type(method) ~= "string" then + error("http_client.request: method must be a string", 2) + end + if type(url) ~= "string" then + error("http_client.request: url must be a string", 2) + end + + -- Call native implementation + return __http_request(method, url, body, options) + end, + + -- Simple GET request + get = function(url, options) + return http_client.request("GET", url, nil, options) + end, + + -- Simple POST request with automatic content-type + post = function(url, body, options) + options = options or {} + return http_client.request("POST", url, body, options) + end, + + -- Simple PUT request with automatic content-type + put = function(url, body, options) + options = options or {} + return http_client.request("PUT", url, body, options) + end, + + -- Simple DELETE request + delete = function(url, options) + return http_client.request("DELETE", url, nil, options) + end, + + -- Simple PATCH request + patch = function(url, body, options) + options = options or {} + return http_client.request("PATCH", url, body, options) + end, + + -- Simple HEAD request + head = function(url, options) + options = options or {} + local old_options = options + options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query} + local response = http_client.request("HEAD", url, nil, options) + return response + end, + + -- Simple OPTIONS request + options = function(url, options) + return http_client.request("OPTIONS", url, nil, options) + end, + + -- Shorthand function to directly get JSON + get_json = function(url, options) + options = options or {} + local response = http_client.get(url, options) + if response.ok and response.json then + return response.json + end + return nil, response + end, + + -- Utility to build a URL with query parameters + build_url = function(base_url, params) + if not params or type(params) ~= "table" then + return base_url + end + + local query = {} + for k, v in pairs(params) do + if type(v) == "table" then + for _, item in ipairs(v) do + table.insert(query, k .. "=" .. tostring(item)) + end + else + table.insert(query, k .. "=" .. tostring(v)) + end + end + + if #query > 0 then + if base_url:find("?") then + return base_url .. "&" .. table.concat(query, "&") + else + return base_url .. "?" .. table.concat(query, "&") + end + end + + return base_url + end +} + +-- Install HTTP client module +_G.http_client = http_client + +-- Add to sandbox environment +if __env_system and __env_system.base_env then + __env_system.base_env.http_client = http_client +end +` + +// HTTPClientModuleInitFunc returns an initializer for the HTTP client module +func HTTPClientModuleInitFunc() StateInitFunc { + return func(state *luajit.State) error { + // First, unregister any existing function to prevent registry leaks + state.UnregisterGoFunction(httpRequestFuncName) + + // Register the native __http_request function + if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil { + return err + } + + // Install the pure Lua module + if err := state.DoString(LuaHTTPClientModule); err != nil { + return err + } + + // Check for existing config (in sandbox modules) + state.GetGlobal("__sandbox_modules") + if !state.IsNil(-1) && state.IsTable(-1) { + state.PushString("__http_client_config") + state.GetTable(-2) + + if !state.IsNil(-1) && state.IsTable(-1) { + // Use the config from sandbox modules + state.SetGlobal("__http_client_config") + state.Pop(1) // Pop the sandbox modules table + return nil + } + state.Pop(1) // Pop the nil or non-table value + } + state.Pop(1) // Pop the nil or sandbox modules table + + // Setup default configuration if no custom config exists + state.NewTable() + + state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second)) + state.SetField(-2, "max_timeout") + + state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second)) + state.SetField(-2, "default_timeout") + + state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize)) + state.SetField(-2, "max_response_size") + + state.PushBoolean(DefaultHTTPClientConfig.AllowRemote) + state.SetField(-2, "allow_remote") + + state.SetGlobal("__http_client_config") + + return nil + } +} + +// WithHTTPClientConfig creates a runner option to configure the HTTP client +func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption { + return func(r *LuaRunner) { + // Store the config to be applied during initialization + r.AddModule("__http_client_config", map[string]any{ + "max_timeout": float64(config.MaxTimeout / time.Second), + "default_timeout": float64(config.DefaultTimeout / time.Second), + "max_response_size": float64(config.MaxResponseSize), + "allow_remote": config.AllowRemote, + }) + } +} + +// RestrictHTTPToLocalhost is a convenience function to restrict HTTP client +// to localhost connections only +func RestrictHTTPToLocalhost() RunnerOption { + return WithHTTPClientConfig(HTTPClientConfig{ + MaxTimeout: DefaultHTTPClientConfig.MaxTimeout, + DefaultTimeout: DefaultHTTPClientConfig.DefaultTimeout, + MaxResponseSize: DefaultHTTPClientConfig.MaxResponseSize, + AllowRemote: false, + }) +} diff --git a/core/runner/job.go b/core/runner/Job.go similarity index 100% rename from core/runner/job.go rename to core/runner/Job.go diff --git a/core/runner/luarunner.go b/core/runner/LuaRunner.go similarity index 95% rename from core/runner/luarunner.go rename to core/runner/LuaRunner.go index 00633cf..b6f5a75 100644 --- a/core/runner/luarunner.go +++ b/core/runner/LuaRunner.go @@ -110,9 +110,8 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { return nil, ErrInitFailed } - // Preload core modules - moduleInits := CombineInitFuncs(HTTPModuleInitFunc(), CookieModuleInitFunc()) - if err := moduleInits(state); err != nil { + // Initialize all core modules from the registry + if err := GlobalRegistry.Initialize(state); err != nil { state.Close() return nil, ErrInitFailed } @@ -144,14 +143,6 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { return runner, nil } -// libDirs returns the current library directories -func (r *LuaRunner) libDirs() []string { - if r.moduleLoader != nil && r.moduleLoader.config != nil { - return r.moduleLoader.config.LibDirs - } - return nil -} - // processJobs handles the job queue func (r *LuaRunner) processJobs() { defer r.wg.Done() diff --git a/core/runner/modules.go b/core/runner/Modules.go similarity index 100% rename from core/runner/modules.go rename to core/runner/Modules.go diff --git a/core/runner/require.go b/core/runner/Require.go similarity index 92% rename from core/runner/require.go rename to core/runner/Require.go index 1067add..89df7d2 100644 --- a/core/runner/require.go +++ b/core/runner/Require.go @@ -351,18 +351,34 @@ func (l *NativeModuleLoader) NotifyFileChanged(state *luajit.State, path string) return false } - // Update bytecode and invalidate caches + // Check if this is a core module + coreName, isCoreModule := GlobalRegistry.MatchModuleName(modName) + + // Invalidate module in Lua + escapedName := escapeLuaString(modName) + invalidateCode := ` + package.loaded["` + escapedName + `"] = nil + __ready_modules["` + escapedName + `"] = nil + if package.preload then + package.preload["` + escapedName + `"] = nil + end + ` + if err := state.DoString(invalidateCode); err != nil { + return false + } + + // For core modules, reinitialize them + if isCoreModule { + if err := GlobalRegistry.InitializeModule(state, coreName); err != nil { + return false + } + return true + } + + // For regular modules, update bytecode if the file still exists content, err := os.ReadFile(path) if err != nil { // File might have been deleted - just invalidate - escapedName := escapeLuaString(modName) - state.DoString(` - package.loaded["` + escapedName + `"] = nil - __ready_modules["` + escapedName + `"] = nil - if package.preload then - package.preload["` + escapedName + `"] = nil - end - `) return true } @@ -370,33 +386,16 @@ func (l *NativeModuleLoader) NotifyFileChanged(state *luajit.State, path string) bytecode, err := state.CompileBytecode(string(content), path) if err != nil { // Invalid Lua - just invalidate - escapedName := escapeLuaString(modName) - state.DoString(` - package.loaded["` + escapedName + `"] = nil - __ready_modules["` + escapedName + `"] = nil - if package.preload then - package.preload["` + escapedName + `"] = nil - end - `) return true } // Load bytecode if err := state.LoadBytecode(bytecode, path); err != nil { // Unable to load - just invalidate - escapedName := escapeLuaString(modName) - state.DoString(` - package.loaded["` + escapedName + `"] = nil - __ready_modules["` + escapedName + `"] = nil - if package.preload then - package.preload["` + escapedName + `"] = nil - end - `) return true } - // Update preload with new chunk - escapedName := escapeLuaString(modName) + // Update preload with new chunk for regular modules luaCode := ` -- Update module in package.preload and clear loaded package.loaded["` + escapedName + `"] = nil diff --git a/core/runner/sandbox.go b/core/runner/Sandbox.go similarity index 99% rename from core/runner/sandbox.go rename to core/runner/Sandbox.go index d1b9800..263ffbc 100644 --- a/core/runner/sandbox.go +++ b/core/runner/Sandbox.go @@ -83,6 +83,7 @@ func (s *Sandbox) Setup(state *luajit.State) error { base.http = http base.cookie = cookie + base.http_client = http_client -- Add registered custom modules if __sandbox_modules then diff --git a/core/runner/cookies.go b/core/runner/cookies.go deleted file mode 100644 index dbca114..0000000 --- a/core/runner/cookies.go +++ /dev/null @@ -1,172 +0,0 @@ -package runner - -import ( - "net/http" - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// extractCookie grabs cookies from the Lua state -func extractCookie(state *luajit.State) *http.Cookie { - cookie := &http.Cookie{ - Path: "/", // Default path - } - - // Get name - state.GetField(-1, "name") - if !state.IsString(-1) { - state.Pop(1) - return nil // Name is required - } - cookie.Name = state.ToString(-1) - state.Pop(1) - - // Get value - state.GetField(-1, "value") - if state.IsString(-1) { - cookie.Value = state.ToString(-1) - } - state.Pop(1) - - // Get path - state.GetField(-1, "path") - if state.IsString(-1) { - cookie.Path = state.ToString(-1) - } - state.Pop(1) - - // Get domain - state.GetField(-1, "domain") - if state.IsString(-1) { - cookie.Domain = state.ToString(-1) - } - state.Pop(1) - - // Get expires - state.GetField(-1, "expires") - if state.IsNumber(-1) { - expiry := int64(state.ToNumber(-1)) - cookie.Expires = time.Unix(expiry, 0) - } - state.Pop(1) - - // Get max age - state.GetField(-1, "max_age") - if state.IsNumber(-1) { - cookie.MaxAge = int(state.ToNumber(-1)) - } - state.Pop(1) - - // Get secure - state.GetField(-1, "secure") - if state.IsBoolean(-1) { - cookie.Secure = state.ToBoolean(-1) - } - state.Pop(1) - - // Get http only - state.GetField(-1, "http_only") - if state.IsBoolean(-1) { - cookie.HttpOnly = state.ToBoolean(-1) - } - state.Pop(1) - - return cookie -} - -// LuaCookieModule provides cookie functionality to Lua scripts -const LuaCookieModule = ` --- Cookie module implementation -local cookie = { - -- Set a cookie - set = function(name, value, expires, path, domain, secure, http_only) - if type(name) ~= "string" then - error("cookie.set: name must be a string", 2) - end - - -- Get or create response - local resp = __http_responses[1] or {} - resp.cookies = resp.cookies or {} - __http_responses[1] = resp - - -- Create cookie table - local cookie = { - name = name, - value = value or "", - path = path or "/", - domain = domain - } - - -- Handle expiry - if expires then - if type(expires) == "number" then - if expires > 0 then - -- Add seconds to current time - cookie.max_age = expires - local now = os.time() - cookie.expires = now + expires - elseif expires < 0 then - -- Session cookie (default) - else - -- Expire immediately - cookie.expires = 0 - cookie.max_age = 0 - end - end - end - - -- Set flags - cookie.secure = secure or false - cookie.http_only = http_only or false - - -- Store in cookies table - local n = #resp.cookies + 1 - resp.cookies[n] = cookie - - return true - end, - - -- Get a cookie value - get = function(name) - if type(name) ~= "string" then - error("cookie.get: name must be a string", 2) - end - - -- Access values directly from current environment - local env = getfenv(1) - - -- Check if context exists and has cookies - if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then - return tostring(env.ctx.cookies[name]) - end - - return nil - end, - - -- Remove a cookie - remove = function(name, path, domain) - if type(name) ~= "string" then - error("cookie.remove: name must be a string", 2) - end - - -- Create an expired cookie - return cookie.set(name, "", 0, path or "/", domain, false, false) - end -} - --- Install cookie module -_G.cookie = cookie - --- Make sure the cookie module is accessible in sandbox -if __env_system and __env_system.base_env then - __env_system.base_env.cookie = cookie -end -` - -// CookieModuleInitFunc returns an initializer for the cookie module -func CookieModuleInitFunc() StateInitFunc { - return func(state *luajit.State) error { - return state.DoString(LuaCookieModule) - } -}