package lualibs import ( "context" "errors" "fmt" "net/url" "strings" "time" "github.com/goccy/go-json" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // Default HTTP client with sensible timeout var defaultFastClient = fasthttp.Client{ MaxConnsPerHost: 1024, MaxIdleConnDuration: time.Minute, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, DisableHeaderNamesNormalizing: true, } // HTTPClientConfig contains client settings type HTTPClientConfig struct { MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit) DefaultTimeout time.Duration // Default request timeout MaxResponseSize int64 // Maximum response size in bytes (0 = no limit) AllowRemote bool // Whether to allow remote connections } // DefaultHTTPClientConfig provides sensible defaults var DefaultHTTPClientConfig = HTTPClientConfig{ MaxTimeout: 60 * time.Second, DefaultTimeout: 30 * time.Second, MaxResponseSize: 10 * 1024 * 1024, // 10MB AllowRemote: true, } // RegisterHttpFunctions registers HTTP functions with the Lua state func RegisterHttpFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { return err } return nil } // httpRequest makes an HTTP request and returns the result to Lua func httpRequest(state *luajit.State) int { if err := state.CheckMinArgs(2); err != nil { return state.PushError("http.client.request: %v", err) } // Get method and URL method, err := state.SafeToString(1) if err != nil { return state.PushError("http.client.request: method must be string") } method = strings.ToUpper(method) urlStr, err := state.SafeToString(2) if err != nil { return state.PushError("http.client.request: url must be string") } // Parse URL to check if it's valid parsedURL, err := url.Parse(urlStr) if err != nil { return state.PushError("Invalid URL: %v", err) } // Get client configuration config := DefaultHTTPClientConfig // Check if remote connections are allowed if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { return state.PushError("Remote connections are not allowed") } // Use bytebufferpool for request and response req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) // Set up request req.Header.SetMethod(method) req.SetRequestURI(urlStr) req.Header.Set("User-Agent", "Moonshark/1.0") // Get body (optional) if state.GetTop() >= 3 && !state.IsNil(3) { if state.IsString(3) { // String body bodyStr, _ := state.SafeToString(3) req.SetBodyString(bodyStr) } else if state.IsTable(3) { // Table body - convert to JSON luaTable, err := state.SafeToTable(3) if err != nil { return state.PushError("Failed to parse body table: %v", err) } // Use bytebufferpool for JSON serialization buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) if err := json.NewEncoder(buf).Encode(luaTable); err != nil { return state.PushError("Failed to convert body to JSON: %v", err) } req.SetBody(buf.Bytes()) req.Header.SetContentType("application/json") } else { return state.PushError("Body must be a string or table") } } // Process options (headers, timeout, etc.) timeout := config.DefaultTimeout if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) { // Process headers using ForEachTableKV if headers, ok := state.GetFieldTable(4, "headers"); ok { if headerMap, ok := headers.(map[string]string); ok { for name, value := range headerMap { req.Header.Set(name, value) } } else if headerMapAny, ok := headers.(map[string]any); ok { for name, value := range headerMapAny { if valueStr, ok := value.(string); ok { req.Header.Set(name, valueStr) } } } } // Get timeout if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 { requestTimeout := time.Duration(timeoutVal) * time.Second // Apply max timeout if configured if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { timeout = config.MaxTimeout } else { timeout = requestTimeout } } // Process query parameters if query, ok := state.GetFieldTable(4, "query"); ok { args := req.URI().QueryArgs() if queryMap, ok := query.(map[string]string); ok { for name, value := range queryMap { args.Add(name, value) } } else if queryMapAny, ok := query.(map[string]any); ok { for name, value := range queryMapAny { switch v := value.(type) { case string: args.Add(name, v) case int: args.Add(name, fmt.Sprintf("%d", v)) case float64: args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), ".")) case bool: if v { args.Add(name, "true") } else { args.Add(name, "false") } } } } } } // Create context with timeout _, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // Execute request err = defaultFastClient.DoTimeout(req, resp, timeout) if err != nil { errStr := "Request failed: " + err.Error() if errors.Is(err, fasthttp.ErrTimeout) { errStr = "Request timed out after " + timeout.String() } return state.PushError("%s", errStr) } // Create response using TableBuilder builder := state.NewTableBuilder() // Set status code and text builder.SetNumber("status", float64(resp.StatusCode())) builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode())) // Set body var respBody []byte if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize { // Make a limited copy respBody = make([]byte, config.MaxResponseSize) copy(respBody, resp.Body()) } else { respBody = resp.Body() } builder.SetString("body", string(respBody)) // Parse body as JSON if content type is application/json contentType := string(resp.Header.ContentType()) if strings.Contains(contentType, "application/json") { var jsonData any if err := json.Unmarshal(respBody, &jsonData); err == nil { builder.SetTable("json", jsonData) } } // Set headers headers := make(map[string]string) resp.Header.VisitAll(func(key, value []byte) { headers[string(key)] = string(value) }) builder.SetTable("headers", headers) // Create ok field (true if status code is 2xx) builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300) builder.Build() return 1 }