316 lines
8.4 KiB
Go
316 lines
8.4 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/goccy/go-json"
|
|
"github.com/valyala/bytebufferpool"
|
|
"github.com/valyala/fasthttp"
|
|
|
|
"Moonshark/utils/logger"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// Default HTTP client with sensible timeout
|
|
var defaultFastClient = fasthttp.Client{
|
|
MaxConnsPerHost: 1024,
|
|
MaxIdleConnDuration: time.Minute,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
// HTTPClientConfig contains client settings
|
|
type HTTPClientConfig struct {
|
|
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
|
|
DefaultTimeout time.Duration // Default request timeout
|
|
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
|
|
AllowRemote bool // Whether to allow remote connections
|
|
}
|
|
|
|
// DefaultHTTPClientConfig provides sensible defaults
|
|
var DefaultHTTPClientConfig = HTTPClientConfig{
|
|
MaxTimeout: 60 * time.Second,
|
|
DefaultTimeout: 30 * time.Second,
|
|
MaxResponseSize: 10 * 1024 * 1024, // 10MB
|
|
AllowRemote: true,
|
|
}
|
|
|
|
// ApplyResponse applies a Response to a fasthttp.RequestCtx
|
|
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
|
|
// Set status code
|
|
ctx.SetStatusCode(resp.Status)
|
|
|
|
// Set headers
|
|
for name, value := range resp.Headers {
|
|
ctx.Response.Header.Set(name, value)
|
|
}
|
|
|
|
// Set cookies
|
|
for _, cookie := range resp.Cookies {
|
|
ctx.Response.Header.SetCookie(cookie)
|
|
}
|
|
|
|
// Process the body based on its type
|
|
if resp.Body == nil {
|
|
return
|
|
}
|
|
|
|
// Get a buffer from the pool
|
|
buf := bytebufferpool.Get()
|
|
defer bytebufferpool.Put(buf)
|
|
|
|
// Set body based on type
|
|
switch body := resp.Body.(type) {
|
|
case string:
|
|
ctx.SetBodyString(body)
|
|
case []byte:
|
|
ctx.SetBody(body)
|
|
case map[string]any, []any, []float64, []string, []int:
|
|
// Marshal JSON
|
|
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
|
// Set content type if not already set
|
|
if len(ctx.Response.Header.ContentType()) == 0 {
|
|
ctx.Response.Header.SetContentType("application/json")
|
|
}
|
|
ctx.SetBody(buf.Bytes())
|
|
} else {
|
|
// Fallback
|
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
}
|
|
default:
|
|
// Default to string representation
|
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
}
|
|
}
|
|
|
|
// httpRequest makes an HTTP request and returns the result to Lua
|
|
func httpRequest(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(2); err != nil {
|
|
return state.PushError("http.client.request: %v", err)
|
|
}
|
|
|
|
// Get method and URL
|
|
method, err := state.SafeToString(1)
|
|
if err != nil {
|
|
return state.PushError("http.client.request: method must be string")
|
|
}
|
|
method = strings.ToUpper(method)
|
|
|
|
urlStr, err := state.SafeToString(2)
|
|
if err != nil {
|
|
return state.PushError("http.client.request: url must be string")
|
|
}
|
|
|
|
// Parse URL to check if it's valid
|
|
parsedURL, err := url.Parse(urlStr)
|
|
if err != nil {
|
|
return state.PushError("Invalid URL: %v", err)
|
|
}
|
|
|
|
// Get client configuration
|
|
config := DefaultHTTPClientConfig
|
|
|
|
// Check if remote connections are allowed
|
|
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
|
return state.PushError("Remote connections are not allowed")
|
|
}
|
|
|
|
// Use bytebufferpool for request and response
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set up request
|
|
req.Header.SetMethod(method)
|
|
req.SetRequestURI(urlStr)
|
|
req.Header.Set("User-Agent", "Moonshark/1.0")
|
|
|
|
// Get body (optional)
|
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
|
if state.IsString(3) {
|
|
// String body
|
|
bodyStr, _ := state.SafeToString(3)
|
|
req.SetBodyString(bodyStr)
|
|
} else if state.IsTable(3) {
|
|
// Table body - convert to JSON
|
|
luaTable, err := state.SafeToTable(3)
|
|
if err != nil {
|
|
return state.PushError("Failed to parse body table: %v", err)
|
|
}
|
|
|
|
// Use bytebufferpool for JSON serialization
|
|
buf := bytebufferpool.Get()
|
|
defer bytebufferpool.Put(buf)
|
|
|
|
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
|
return state.PushError("Failed to convert body to JSON: %v", err)
|
|
}
|
|
|
|
req.SetBody(buf.Bytes())
|
|
req.Header.SetContentType("application/json")
|
|
} else {
|
|
return state.PushError("Body must be a string or table")
|
|
}
|
|
}
|
|
|
|
// Process options (headers, timeout, etc.)
|
|
timeout := config.DefaultTimeout
|
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
|
|
// Process headers using ForEachTableKV
|
|
if headers, ok := state.GetFieldTable(4, "headers"); ok {
|
|
if headerMap, ok := headers.(map[string]string); ok {
|
|
for name, value := range headerMap {
|
|
req.Header.Set(name, value)
|
|
}
|
|
} else if headerMapAny, ok := headers.(map[string]any); ok {
|
|
for name, value := range headerMapAny {
|
|
if valueStr, ok := value.(string); ok {
|
|
req.Header.Set(name, valueStr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get timeout
|
|
if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 {
|
|
requestTimeout := time.Duration(timeoutVal) * time.Second
|
|
|
|
// Apply max timeout if configured
|
|
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
|
|
timeout = config.MaxTimeout
|
|
} else {
|
|
timeout = requestTimeout
|
|
}
|
|
}
|
|
|
|
// Process query parameters
|
|
if query, ok := state.GetFieldTable(4, "query"); ok {
|
|
args := req.URI().QueryArgs()
|
|
|
|
if queryMap, ok := query.(map[string]string); ok {
|
|
for name, value := range queryMap {
|
|
args.Add(name, value)
|
|
}
|
|
} else if queryMapAny, ok := query.(map[string]any); ok {
|
|
for name, value := range queryMapAny {
|
|
switch v := value.(type) {
|
|
case string:
|
|
args.Add(name, v)
|
|
case int:
|
|
args.Add(name, fmt.Sprintf("%d", v))
|
|
case float64:
|
|
args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), "."))
|
|
case bool:
|
|
if v {
|
|
args.Add(name, "true")
|
|
} else {
|
|
args.Add(name, "false")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create context with timeout
|
|
_, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
// Execute request
|
|
err = defaultFastClient.DoTimeout(req, resp, timeout)
|
|
if err != nil {
|
|
errStr := "Request failed: " + err.Error()
|
|
if errors.Is(err, fasthttp.ErrTimeout) {
|
|
errStr = "Request timed out after " + timeout.String()
|
|
}
|
|
return state.PushError("%s", errStr)
|
|
}
|
|
|
|
// Create response using TableBuilder
|
|
builder := state.NewTableBuilder()
|
|
|
|
// Set status code and text
|
|
builder.SetNumber("status", float64(resp.StatusCode()))
|
|
builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode()))
|
|
|
|
// Set body
|
|
var respBody []byte
|
|
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
|
// Make a limited copy
|
|
respBody = make([]byte, config.MaxResponseSize)
|
|
copy(respBody, resp.Body())
|
|
} else {
|
|
respBody = resp.Body()
|
|
}
|
|
|
|
builder.SetString("body", string(respBody))
|
|
|
|
// Parse body as JSON if content type is application/json
|
|
contentType := string(resp.Header.ContentType())
|
|
if strings.Contains(contentType, "application/json") {
|
|
var jsonData any
|
|
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
|
builder.SetTable("json", jsonData)
|
|
}
|
|
}
|
|
|
|
// Set headers
|
|
headers := make(map[string]string)
|
|
resp.Header.VisitAll(func(key, value []byte) {
|
|
headers[string(key)] = string(value)
|
|
})
|
|
builder.SetTable("headers", headers)
|
|
|
|
// Create ok field (true if status code is 2xx)
|
|
builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
|
|
|
builder.Build()
|
|
return 1
|
|
}
|
|
|
|
// generateToken creates a cryptographically secure random token
|
|
func generateToken(state *luajit.State) int {
|
|
// Get the length from the Lua arguments (default to 32)
|
|
length := 32
|
|
if state.GetTop() >= 1 {
|
|
if lengthVal, err := state.SafeToNumber(1); err == nil {
|
|
length = int(lengthVal)
|
|
}
|
|
}
|
|
|
|
// Enforce minimum length for security
|
|
if length < 16 {
|
|
length = 16
|
|
}
|
|
|
|
// Generate secure random bytes
|
|
tokenBytes := make([]byte, length)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
logger.Errorf("Failed to generate secure token: %v", err)
|
|
state.PushString("")
|
|
return 1 // Return empty string on error
|
|
}
|
|
|
|
// Encode as base64
|
|
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
|
|
|
// Trim to requested length (base64 might be longer)
|
|
if len(token) > length {
|
|
token = token[:length]
|
|
}
|
|
|
|
// Push the token to the Lua stack
|
|
state.PushString(token)
|
|
return 1 // One return value
|
|
}
|