From b336ce5efa3ebda55bdd749c6df76ff82c3e9e2d Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 27 Mar 2025 15:49:07 -0500 Subject: [PATCH] consolidate http modules --- core/runner/CoreModules.go | 1 - core/runner/Http.go | 489 +++++++++++++++++++++++++++++++++++- core/runner/HttpClient.go | 499 ------------------------------------- core/runner/Sandbox.go | 2 +- 4 files changed, 483 insertions(+), 508 deletions(-) delete mode 100644 core/runner/HttpClient.go diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go index f04d01e..9a1f709 100644 --- a/core/runner/CoreModules.go +++ b/core/runner/CoreModules.go @@ -98,7 +98,6 @@ var GlobalRegistry = NewCoreModuleRegistry() func init() { GlobalRegistry.Register("http", HTTPModuleInitFunc()) GlobalRegistry.Register("cookie", CookieModuleInitFunc()) - GlobalRegistry.Register("http_client", HTTPClientModuleInitFunc()) } // RegisterCoreModule is a helper to register a core module diff --git a/core/runner/Http.go b/core/runner/Http.go index 00890cc..6b7e7ee 100644 --- a/core/runner/Http.go +++ b/core/runner/Http.go @@ -1,8 +1,15 @@ package runner import ( + "bytes" + "context" + "encoding/json" + "io" "net/http" + "net/url" + "strings" "sync" + "time" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) @@ -54,6 +61,309 @@ func ReleaseResponse(resp *HTTPResponse) { responsePool.Put(resp) } +// ---------- HTTP CLIENT FUNCTIONALITY ---------- + +// Default HTTP client with sensible timeout +var defaultClient = &http.Client{ + Timeout: 30 * time.Second, +} + +// HTTPClientConfig contains client settings +type HTTPClientConfig struct { + // Maximum timeout for requests (0 = no limit) + MaxTimeout time.Duration + // Default request timeout + DefaultTimeout time.Duration + // Maximum response size in bytes (0 = no limit) + MaxResponseSize int64 + // Whether to allow remote connections + AllowRemote bool +} + +// DefaultHTTPClientConfig provides sensible defaults +var DefaultHTTPClientConfig = HTTPClientConfig{ + MaxTimeout: 60 * time.Second, + DefaultTimeout: 30 * time.Second, + MaxResponseSize: 10 * 1024 * 1024, // 10MB + AllowRemote: true, +} + +// Function name constant to ensure consistency +const httpRequestFuncName = "__http_request" + +// httpRequest makes an HTTP request and returns the result to Lua +func httpRequest(state *luajit.State) int { + // Get method (required) + if !state.IsString(1) { + state.PushString("http.client.request: method must be a string") + return -1 + } + method := strings.ToUpper(state.ToString(1)) + + // Get URL (required) + if !state.IsString(2) { + state.PushString("http.client.request: url must be a string") + return -1 + } + urlStr := state.ToString(2) + + // Parse URL to check if it's valid and if it's allowed + parsedURL, err := url.Parse(urlStr) + if err != nil { + state.PushString("Invalid URL: " + err.Error()) + return -1 + } + + // Get client configuration from registry (if available) + var config HTTPClientConfig = DefaultHTTPClientConfig + state.GetGlobal("__http_client_config") + if !state.IsNil(-1) { + if state.IsTable(-1) { + // Extract max timeout + state.GetField(-1, "max_timeout") + if state.IsNumber(-1) { + config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second + } + state.Pop(1) + + // Extract default timeout + state.GetField(-1, "default_timeout") + if state.IsNumber(-1) { + config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second + } + state.Pop(1) + + // Extract max response size + state.GetField(-1, "max_response_size") + if state.IsNumber(-1) { + config.MaxResponseSize = int64(state.ToNumber(-1)) + } + state.Pop(1) + + // Extract allow remote + state.GetField(-1, "allow_remote") + if state.IsBoolean(-1) { + config.AllowRemote = state.ToBoolean(-1) + } + state.Pop(1) + } + } + state.Pop(1) + + // Check if remote connections are allowed + if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { + state.PushString("Remote connections are not allowed") + return -1 + } + + // Get body (optional) + var bodyReader io.Reader + if state.GetTop() >= 3 && !state.IsNil(3) { + var body []byte + + if state.IsString(3) { + // String body + body = []byte(state.ToString(3)) + } else if state.IsTable(3) { + // Table body - convert to JSON + luaTable, err := state.ToTable(3) + if err != nil { + state.PushString("Failed to parse body table: " + err.Error()) + return -1 + } + + body, err = json.Marshal(luaTable) + if err != nil { + state.PushString("Failed to convert body to JSON: " + err.Error()) + return -1 + } + } else { + state.PushString("Body must be a string or table") + return -1 + } + + bodyReader = bytes.NewReader(body) + } + + // Create request + req, err := http.NewRequest(method, urlStr, bodyReader) + if err != nil { + state.PushString("Failed to create request: " + err.Error()) + return -1 + } + + // Set default headers + req.Header.Set("User-Agent", "Moonshark/1.0") + + // Process options (headers, timeout, etc.) + timeout := config.DefaultTimeout + if state.GetTop() >= 4 && !state.IsNil(4) { + if !state.IsTable(4) { + state.PushString("Options must be a table") + return -1 + } + + // Process headers + state.GetField(4, "headers") + if state.IsTable(-1) { + // Iterate through headers + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack now has key at -2 and value at -1 + if state.IsString(-2) && state.IsString(-1) { + headerName := state.ToString(-2) + headerValue := state.ToString(-1) + req.Header.Set(headerName, headerValue) + } + state.Pop(1) // Pop value, leave key for next iteration + } + } + state.Pop(1) // Pop headers table + + // Get timeout + state.GetField(4, "timeout") + if state.IsNumber(-1) { + requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second + + // Apply max timeout if configured + if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { + timeout = config.MaxTimeout + } else { + timeout = requestTimeout + } + } + state.Pop(1) // Pop timeout + + // Set content type for POST/PUT if body is present and content-type not manually set + if (method == "POST" || method == "PUT") && bodyReader != nil && req.Header.Get("Content-Type") == "" { + // Check if options specify content type + state.GetField(4, "content_type") + if state.IsString(-1) { + req.Header.Set("Content-Type", state.ToString(-1)) + } else { + // Default to JSON if body is a table, otherwise plain text + if state.IsTable(3) { + req.Header.Set("Content-Type", "application/json") + } else { + req.Header.Set("Content-Type", "text/plain") + } + } + state.Pop(1) // Pop content_type + } + + // Process query parameters + state.GetField(4, "query") + if state.IsTable(-1) { + q := req.URL.Query() + + // Iterate through query params + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack now has key at -2 and value at -1 + if state.IsString(-2) { + paramName := state.ToString(-2) + + // Handle different value types + if state.IsString(-1) { + q.Add(paramName, state.ToString(-1)) + } else if state.IsNumber(-1) { + q.Add(paramName, strings.TrimRight(strings.TrimRight( + state.ToString(-1), "0"), ".")) + } else if state.IsBoolean(-1) { + if state.ToBoolean(-1) { + q.Add(paramName, "true") + } else { + q.Add(paramName, "false") + } + } + } + state.Pop(1) // Pop value, leave key for next iteration + } + + req.URL.RawQuery = q.Encode() + } + state.Pop(1) // Pop query table + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Use context with request + req = req.WithContext(ctx) + + // Execute request + resp, err := defaultClient.Do(req) + if err != nil { + state.PushString("Request failed: " + err.Error()) + return -1 + } + defer resp.Body.Close() + + // Apply size limits to response + var respBody []byte + if config.MaxResponseSize > 0 { + // Limit the response body size + respBody, err = io.ReadAll(io.LimitReader(resp.Body, config.MaxResponseSize)) + } else { + respBody, err = io.ReadAll(resp.Body) + } + + if err != nil { + state.PushString("Failed to read response: " + err.Error()) + return -1 + } + + // Create response table + state.NewTable() + + // Set status code + state.PushNumber(float64(resp.StatusCode)) + state.SetField(-2, "status") + + // Set status text + state.PushString(resp.Status) + state.SetField(-2, "status_text") + + // Set body + state.PushString(string(respBody)) + state.SetField(-2, "body") + + // Parse body as JSON if content type is application/json + if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { + var jsonData any + if err := json.Unmarshal(respBody, &jsonData); err == nil { + if err := state.PushValue(jsonData); err == nil { + state.SetField(-2, "json") + } + } + } + + // Set headers + state.NewTable() + for name, values := range resp.Header { + if len(values) == 1 { + state.PushString(values[0]) + } else { + // Create array of values + state.NewTable() + for i, v := range values { + state.PushNumber(float64(i + 1)) + state.PushString(v) + state.SetTable(-3) + } + } + state.SetField(-2, name) + } + state.SetField(-2, "headers") + + // Create ok field (true if status code is 2xx) + state.PushBoolean(resp.StatusCode >= 200 && resp.StatusCode < 300) + state.SetField(-2, "ok") + + return 1 +} + // LuaHTTPModule is the pure Lua implementation of the HTTP module const LuaHTTPModule = ` -- Table to store response data @@ -82,13 +392,109 @@ local http = { resp.headers = resp.headers or {} resp.headers[name] = value __http_responses[1] = resp - end -} + end, --- Set content type; set_header helper -http.set_content_type = function(content_type) - http.set_header("Content-Type", content_type) -end + -- Set content type; set_header helper + set_content_type = function(content_type) + http.set_header("Content-Type", content_type) + end, + + -- HTTP client submodule + client = { + -- Generic request function + request = function(method, url, body, options) + if type(method) ~= "string" then + error("http.client.request: method must be a string", 2) + end + if type(url) ~= "string" then + error("http.client.request: url must be a string", 2) + end + + -- Call native implementation + return __http_request(method, url, body, options) + end, + + -- Simple GET request + get = function(url, options) + return http.client.request("GET", url, nil, options) + end, + + -- Simple POST request with automatic content-type + post = function(url, body, options) + options = options or {} + return http.client.request("POST", url, body, options) + end, + + -- Simple PUT request with automatic content-type + put = function(url, body, options) + options = options or {} + return http.client.request("PUT", url, body, options) + end, + + -- Simple DELETE request + delete = function(url, options) + return http.client.request("DELETE", url, nil, options) + end, + + -- Simple PATCH request + patch = function(url, body, options) + options = options or {} + return http.client.request("PATCH", url, body, options) + end, + + -- Simple HEAD request + head = function(url, options) + options = options or {} + local old_options = options + options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query} + local response = http.client.request("HEAD", url, nil, options) + return response + end, + + -- Simple OPTIONS request + options = function(url, options) + return http.client.request("OPTIONS", url, nil, options) + end, + + -- Shorthand function to directly get JSON + get_json = function(url, options) + options = options or {} + local response = http.client.get(url, options) + if response.ok and response.json then + return response.json + end + return nil, response + end, + + -- Utility to build a URL with query parameters + build_url = function(base_url, params) + if not params or type(params) ~= "table" then + return base_url + end + + local query = {} + for k, v in pairs(params) do + if type(v) == "table" then + for _, item in ipairs(v) do + table.insert(query, k .. "=" .. tostring(item)) + end + else + table.insert(query, k .. "=" .. tostring(v)) + end + end + + if #query > 0 then + if base_url:find("?") then + return base_url .. "&" .. table.concat(query, "&") + else + return base_url .. "?" .. table.concat(query, "&") + end + end + + return base_url + end + } +} -- Install HTTP module _G.http = http @@ -115,8 +521,53 @@ end // HTTPModuleInitFunc returns an initializer function for the HTTP module func HTTPModuleInitFunc() StateInitFunc { return func(state *luajit.State) error { + // First, unregister any existing function to prevent registry leaks + state.UnregisterGoFunction(httpRequestFuncName) + + // Register the native __http_request function + if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil { + return err + } + // Initialize pure Lua HTTP module - return state.DoString(LuaHTTPModule) + if err := state.DoString(LuaHTTPModule); err != nil { + return err + } + + // Check for existing config (in sandbox modules) + state.GetGlobal("__sandbox_modules") + if !state.IsNil(-1) && state.IsTable(-1) { + state.PushString("__http_client_config") + state.GetTable(-2) + + if !state.IsNil(-1) && state.IsTable(-1) { + // Use the config from sandbox modules + state.SetGlobal("__http_client_config") + state.Pop(1) // Pop the sandbox modules table + return nil + } + state.Pop(1) // Pop the nil or non-table value + } + state.Pop(1) // Pop the nil or sandbox modules table + + // Setup default configuration if no custom config exists + state.NewTable() + + state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second)) + state.SetField(-2, "max_timeout") + + state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second)) + state.SetField(-2, "default_timeout") + + state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize)) + state.SetField(-2, "max_response_size") + + state.PushBoolean(DefaultHTTPClientConfig.AllowRemote) + state.SetField(-2, "allow_remote") + + state.SetGlobal("__http_client_config") + + return nil } } @@ -190,3 +641,27 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) { return response, true } + +// WithHTTPClientConfig creates a runner option to configure the HTTP client +func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption { + return func(r *LuaRunner) { + // Store the config to be applied during initialization + r.AddModule("__http_client_config", map[string]any{ + "max_timeout": float64(config.MaxTimeout / time.Second), + "default_timeout": float64(config.DefaultTimeout / time.Second), + "max_response_size": float64(config.MaxResponseSize), + "allow_remote": config.AllowRemote, + }) + } +} + +// RestrictHTTPToLocalhost is a convenience function to restrict HTTP client +// to localhost connections only +func RestrictHTTPToLocalhost() RunnerOption { + return WithHTTPClientConfig(HTTPClientConfig{ + MaxTimeout: DefaultHTTPClientConfig.MaxTimeout, + DefaultTimeout: DefaultHTTPClientConfig.DefaultTimeout, + MaxResponseSize: DefaultHTTPClientConfig.MaxResponseSize, + AllowRemote: false, + }) +} diff --git a/core/runner/HttpClient.go b/core/runner/HttpClient.go deleted file mode 100644 index 66d919a..0000000 --- a/core/runner/HttpClient.go +++ /dev/null @@ -1,499 +0,0 @@ -package runner - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/url" - "strings" - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// Default HTTP client with sensible timeout -var defaultClient = &http.Client{ - Timeout: 30 * time.Second, -} - -// HTTPClientConfig contains client settings -type HTTPClientConfig struct { - // Maximum timeout for requests (0 = no limit) - MaxTimeout time.Duration - // Default request timeout - DefaultTimeout time.Duration - // Maximum response size in bytes (0 = no limit) - MaxResponseSize int64 - // Whether to allow remote connections - AllowRemote bool -} - -// DefaultHTTPClientConfig provides sensible defaults -var DefaultHTTPClientConfig = HTTPClientConfig{ - MaxTimeout: 60 * time.Second, - DefaultTimeout: 30 * time.Second, - MaxResponseSize: 10 * 1024 * 1024, // 10MB - AllowRemote: true, -} - -// Function name constant to ensure consistency -const httpRequestFuncName = "__http_request" - -// httpRequest makes an HTTP request and returns the result to Lua -func httpRequest(state *luajit.State) int { - // Get method (required) - if !state.IsString(1) { - state.PushString("http_client.request: method must be a string") - return -1 - } - method := strings.ToUpper(state.ToString(1)) - - // Get URL (required) - if !state.IsString(2) { - state.PushString("http_client.request: url must be a string") - return -1 - } - urlStr := state.ToString(2) - - // Parse URL to check if it's valid and if it's allowed - parsedURL, err := url.Parse(urlStr) - if err != nil { - state.PushString("Invalid URL: " + err.Error()) - return -1 - } - - // Get client configuration from registry (if available) - var config HTTPClientConfig = DefaultHTTPClientConfig - state.GetGlobal("__http_client_config") - if !state.IsNil(-1) { - if state.IsTable(-1) { - // Extract max timeout - state.GetField(-1, "max_timeout") - if state.IsNumber(-1) { - config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second - } - state.Pop(1) - - // Extract default timeout - state.GetField(-1, "default_timeout") - if state.IsNumber(-1) { - config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second - } - state.Pop(1) - - // Extract max response size - state.GetField(-1, "max_response_size") - if state.IsNumber(-1) { - config.MaxResponseSize = int64(state.ToNumber(-1)) - } - state.Pop(1) - - // Extract allow remote - state.GetField(-1, "allow_remote") - if state.IsBoolean(-1) { - config.AllowRemote = state.ToBoolean(-1) - } - state.Pop(1) - } - } - state.Pop(1) - - // Check if remote connections are allowed - if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { - state.PushString("Remote connections are not allowed") - return -1 - } - - // Get body (optional) - var bodyReader io.Reader - if state.GetTop() >= 3 && !state.IsNil(3) { - var body []byte - - if state.IsString(3) { - // String body - body = []byte(state.ToString(3)) - } else if state.IsTable(3) { - // Table body - convert to JSON - luaTable, err := state.ToTable(3) - if err != nil { - state.PushString("Failed to parse body table: " + err.Error()) - return -1 - } - - body, err = json.Marshal(luaTable) - if err != nil { - state.PushString("Failed to convert body to JSON: " + err.Error()) - return -1 - } - } else { - state.PushString("Body must be a string or table") - return -1 - } - - bodyReader = bytes.NewReader(body) - } - - // Create request - req, err := http.NewRequest(method, urlStr, bodyReader) - if err != nil { - state.PushString("Failed to create request: " + err.Error()) - return -1 - } - - // Set default headers - req.Header.Set("User-Agent", "Moonshark/1.0") - - // Process options (headers, timeout, etc.) - timeout := config.DefaultTimeout - if state.GetTop() >= 4 && !state.IsNil(4) { - if !state.IsTable(4) { - state.PushString("Options must be a table") - return -1 - } - - // Process headers - state.GetField(4, "headers") - if state.IsTable(-1) { - // Iterate through headers - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack now has key at -2 and value at -1 - if state.IsString(-2) && state.IsString(-1) { - headerName := state.ToString(-2) - headerValue := state.ToString(-1) - req.Header.Set(headerName, headerValue) - } - state.Pop(1) // Pop value, leave key for next iteration - } - } - state.Pop(1) // Pop headers table - - // Get timeout - state.GetField(4, "timeout") - if state.IsNumber(-1) { - requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second - - // Apply max timeout if configured - if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { - timeout = config.MaxTimeout - } else { - timeout = requestTimeout - } - } - state.Pop(1) // Pop timeout - - // Set content type for POST/PUT if body is present and content-type not manually set - if (method == "POST" || method == "PUT") && bodyReader != nil && req.Header.Get("Content-Type") == "" { - // Check if options specify content type - state.GetField(4, "content_type") - if state.IsString(-1) { - req.Header.Set("Content-Type", state.ToString(-1)) - } else { - // Default to JSON if body is a table, otherwise plain text - if state.IsTable(3) { - req.Header.Set("Content-Type", "application/json") - } else { - req.Header.Set("Content-Type", "text/plain") - } - } - state.Pop(1) // Pop content_type - } - - // Process query parameters - state.GetField(4, "query") - if state.IsTable(-1) { - q := req.URL.Query() - - // Iterate through query params - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack now has key at -2 and value at -1 - if state.IsString(-2) { - paramName := state.ToString(-2) - - // Handle different value types - if state.IsString(-1) { - q.Add(paramName, state.ToString(-1)) - } else if state.IsNumber(-1) { - q.Add(paramName, strings.TrimRight(strings.TrimRight( - state.ToString(-1), "0"), ".")) - } else if state.IsBoolean(-1) { - if state.ToBoolean(-1) { - q.Add(paramName, "true") - } else { - q.Add(paramName, "false") - } - } - } - state.Pop(1) // Pop value, leave key for next iteration - } - - req.URL.RawQuery = q.Encode() - } - state.Pop(1) // Pop query table - } - - // Create context with timeout - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - // Use context with request - req = req.WithContext(ctx) - - // Execute request - resp, err := defaultClient.Do(req) - if err != nil { - state.PushString("Request failed: " + err.Error()) - return -1 - } - defer resp.Body.Close() - - // Apply size limits to response - var respBody []byte - if config.MaxResponseSize > 0 { - // Limit the response body size - respBody, err = io.ReadAll(io.LimitReader(resp.Body, config.MaxResponseSize)) - } else { - respBody, err = io.ReadAll(resp.Body) - } - - if err != nil { - state.PushString("Failed to read response: " + err.Error()) - return -1 - } - - // Create response table - state.NewTable() - - // Set status code - state.PushNumber(float64(resp.StatusCode)) - state.SetField(-2, "status") - - // Set status text - state.PushString(resp.Status) - state.SetField(-2, "status_text") - - // Set body - state.PushString(string(respBody)) - state.SetField(-2, "body") - - // Parse body as JSON if content type is application/json - if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { - var jsonData any - if err := json.Unmarshal(respBody, &jsonData); err == nil { - if err := state.PushValue(jsonData); err == nil { - state.SetField(-2, "json") - } - } - } - - // Set headers - state.NewTable() - for name, values := range resp.Header { - if len(values) == 1 { - state.PushString(values[0]) - } else { - // Create array of values - state.NewTable() - for i, v := range values { - state.PushNumber(float64(i + 1)) - state.PushString(v) - state.SetTable(-3) - } - } - state.SetField(-2, name) - } - state.SetField(-2, "headers") - - // Create ok field (true if status code is 2xx) - state.PushBoolean(resp.StatusCode >= 200 && resp.StatusCode < 300) - state.SetField(-2, "ok") - - return 1 -} - -// LuaHTTPClientModule defines the HTTP client module in Lua -const LuaHTTPClientModule = ` --- HTTP client module implementation -local http_client = { - -- Generic request function - request = function(method, url, body, options) - if type(method) ~= "string" then - error("http_client.request: method must be a string", 2) - end - if type(url) ~= "string" then - error("http_client.request: url must be a string", 2) - end - - -- Call native implementation - return __http_request(method, url, body, options) - end, - - -- Simple GET request - get = function(url, options) - return http_client.request("GET", url, nil, options) - end, - - -- Simple POST request with automatic content-type - post = function(url, body, options) - options = options or {} - return http_client.request("POST", url, body, options) - end, - - -- Simple PUT request with automatic content-type - put = function(url, body, options) - options = options or {} - return http_client.request("PUT", url, body, options) - end, - - -- Simple DELETE request - delete = function(url, options) - return http_client.request("DELETE", url, nil, options) - end, - - -- Simple PATCH request - patch = function(url, body, options) - options = options or {} - return http_client.request("PATCH", url, body, options) - end, - - -- Simple HEAD request - head = function(url, options) - options = options or {} - local old_options = options - options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query} - local response = http_client.request("HEAD", url, nil, options) - return response - end, - - -- Simple OPTIONS request - options = function(url, options) - return http_client.request("OPTIONS", url, nil, options) - end, - - -- Shorthand function to directly get JSON - get_json = function(url, options) - options = options or {} - local response = http_client.get(url, options) - if response.ok and response.json then - return response.json - end - return nil, response - end, - - -- Utility to build a URL with query parameters - build_url = function(base_url, params) - if not params or type(params) ~= "table" then - return base_url - end - - local query = {} - for k, v in pairs(params) do - if type(v) == "table" then - for _, item in ipairs(v) do - table.insert(query, k .. "=" .. tostring(item)) - end - else - table.insert(query, k .. "=" .. tostring(v)) - end - end - - if #query > 0 then - if base_url:find("?") then - return base_url .. "&" .. table.concat(query, "&") - else - return base_url .. "?" .. table.concat(query, "&") - end - end - - return base_url - end -} - --- Install HTTP client module -_G.http_client = http_client - --- Add to sandbox environment -if __env_system and __env_system.base_env then - __env_system.base_env.http_client = http_client -end -` - -// HTTPClientModuleInitFunc returns an initializer for the HTTP client module -func HTTPClientModuleInitFunc() StateInitFunc { - return func(state *luajit.State) error { - // First, unregister any existing function to prevent registry leaks - state.UnregisterGoFunction(httpRequestFuncName) - - // Register the native __http_request function - if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil { - return err - } - - // Install the pure Lua module - if err := state.DoString(LuaHTTPClientModule); err != nil { - return err - } - - // Check for existing config (in sandbox modules) - state.GetGlobal("__sandbox_modules") - if !state.IsNil(-1) && state.IsTable(-1) { - state.PushString("__http_client_config") - state.GetTable(-2) - - if !state.IsNil(-1) && state.IsTable(-1) { - // Use the config from sandbox modules - state.SetGlobal("__http_client_config") - state.Pop(1) // Pop the sandbox modules table - return nil - } - state.Pop(1) // Pop the nil or non-table value - } - state.Pop(1) // Pop the nil or sandbox modules table - - // Setup default configuration if no custom config exists - state.NewTable() - - state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second)) - state.SetField(-2, "max_timeout") - - state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second)) - state.SetField(-2, "default_timeout") - - state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize)) - state.SetField(-2, "max_response_size") - - state.PushBoolean(DefaultHTTPClientConfig.AllowRemote) - state.SetField(-2, "allow_remote") - - state.SetGlobal("__http_client_config") - - return nil - } -} - -// WithHTTPClientConfig creates a runner option to configure the HTTP client -func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption { - return func(r *LuaRunner) { - // Store the config to be applied during initialization - r.AddModule("__http_client_config", map[string]any{ - "max_timeout": float64(config.MaxTimeout / time.Second), - "default_timeout": float64(config.DefaultTimeout / time.Second), - "max_response_size": float64(config.MaxResponseSize), - "allow_remote": config.AllowRemote, - }) - } -} - -// RestrictHTTPToLocalhost is a convenience function to restrict HTTP client -// to localhost connections only -func RestrictHTTPToLocalhost() RunnerOption { - return WithHTTPClientConfig(HTTPClientConfig{ - MaxTimeout: DefaultHTTPClientConfig.MaxTimeout, - DefaultTimeout: DefaultHTTPClientConfig.DefaultTimeout, - MaxResponseSize: DefaultHTTPClientConfig.MaxResponseSize, - AllowRemote: false, - }) -} diff --git a/core/runner/Sandbox.go b/core/runner/Sandbox.go index 263ffbc..106fac3 100644 --- a/core/runner/Sandbox.go +++ b/core/runner/Sandbox.go @@ -83,7 +83,7 @@ func (s *Sandbox) Setup(state *luajit.State) error { base.http = http base.cookie = cookie - base.http_client = http_client + -- http_client module is now part of http.client -- Add registered custom modules if __sandbox_modules then