228 lines
6.2 KiB
Go

package lualibs
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Default HTTP client with sensible timeout
var defaultFastClient = fasthttp.Client{
MaxConnsPerHost: 1024,
MaxIdleConnDuration: time.Minute,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
DisableHeaderNamesNormalizing: true,
}
// HTTPClientConfig contains client settings
type HTTPClientConfig struct {
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
DefaultTimeout time.Duration // Default request timeout
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
AllowRemote bool // Whether to allow remote connections
}
// DefaultHTTPClientConfig provides sensible defaults
var DefaultHTTPClientConfig = HTTPClientConfig{
MaxTimeout: 60 * time.Second,
DefaultTimeout: 30 * time.Second,
MaxResponseSize: 10 * 1024 * 1024, // 10MB
AllowRemote: true,
}
// RegisterHttpFunctions registers HTTP functions with the Lua state
func RegisterHttpFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
return err
}
return nil
}
// httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("http.client.request: %v", err)
}
// Get method and URL
method, err := state.SafeToString(1)
if err != nil {
return state.PushError("http.client.request: method must be string")
}
method = strings.ToUpper(method)
urlStr, err := state.SafeToString(2)
if err != nil {
return state.PushError("http.client.request: url must be string")
}
// Parse URL to check if it's valid
parsedURL, err := url.Parse(urlStr)
if err != nil {
return state.PushError("Invalid URL: %v", err)
}
// Get client configuration
config := DefaultHTTPClientConfig
// Check if remote connections are allowed
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
return state.PushError("Remote connections are not allowed")
}
// Use bytebufferpool for request and response
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set up request
req.Header.SetMethod(method)
req.SetRequestURI(urlStr)
req.Header.Set("User-Agent", "Moonshark/1.0")
// Get body (optional)
if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsString(3) {
// String body
bodyStr, _ := state.SafeToString(3)
req.SetBodyString(bodyStr)
} else if state.IsTable(3) {
// Table body - convert to JSON
luaTable, err := state.ToTable(3)
if err != nil {
return state.PushError("Failed to parse body table: %v", err)
}
// Use bytebufferpool for JSON serialization
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
return state.PushError("Failed to convert body to JSON: %v", err)
}
req.SetBody(buf.Bytes())
req.Header.SetContentType("application/json")
} else {
return state.PushError("Body must be a string or table")
}
}
// Process options (headers, timeout, etc.)
timeout := config.DefaultTimeout
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
if headers, ok := state.GetFieldTable(4, "headers"); ok {
if headerMap, ok := headers.(map[string]any); ok {
for name, value := range headerMap {
if valueStr, ok := value.(string); ok {
req.Header.Set(name, valueStr)
}
}
}
}
// Get timeout
if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 {
requestTimeout := time.Duration(timeoutVal) * time.Second
// Apply max timeout if configured
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
timeout = config.MaxTimeout
} else {
timeout = requestTimeout
}
}
// Process query parameters
if query, ok := state.GetFieldTable(4, "query"); ok {
args := req.URI().QueryArgs()
if queryMap, ok := query.(map[string]any); ok {
for name, value := range queryMap {
switch v := value.(type) {
case string:
args.Add(name, v)
case int:
args.Add(name, fmt.Sprintf("%d", v))
case float64:
args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), "."))
case bool:
if v {
args.Add(name, "true")
} else {
args.Add(name, "false")
}
}
}
}
}
}
// Create context with timeout
_, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Execute request
err = defaultFastClient.DoTimeout(req, resp, timeout)
if err != nil {
errStr := "Request failed: " + err.Error()
if errors.Is(err, fasthttp.ErrTimeout) {
errStr = "Request timed out after " + timeout.String()
}
return state.PushError("%s", errStr)
}
// Create response using TableBuilder
builder := state.NewTableBuilder()
// Set status code and text
builder.SetNumber("status", float64(resp.StatusCode()))
builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode()))
// Set body
var respBody []byte
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
// Make a limited copy
respBody = make([]byte, config.MaxResponseSize)
copy(respBody, resp.Body())
} else {
respBody = resp.Body()
}
builder.SetString("body", string(respBody))
// Parse body as JSON if content type is application/json
contentType := string(resp.Header.ContentType())
if strings.Contains(contentType, "application/json") {
var jsonData any
if err := json.Unmarshal(respBody, &jsonData); err == nil {
builder.SetTable("json", jsonData)
}
}
// Set headers
headers := make(map[string]string)
resp.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
builder.SetTable("headers", headers)
// Create ok field (true if status code is 2xx)
builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300)
builder.Build()
return 1
}