hyper op 1
This commit is contained in:
parent
95eae40357
commit
f6c260a525
|
@ -1,6 +1,11 @@
|
||||||
package runner
|
package runner
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/valyala/bytebufferpool"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
// Context represents execution context for a Lua script
|
// Context represents execution context for a Lua script
|
||||||
type Context struct {
|
type Context struct {
|
||||||
|
@ -9,6 +14,12 @@ type Context struct {
|
||||||
|
|
||||||
// internal mutex for concurrent access
|
// internal mutex for concurrent access
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// FastHTTP context if this was created from an HTTP request
|
||||||
|
RequestCtx *fasthttp.RequestCtx
|
||||||
|
|
||||||
|
// Buffer for efficient string operations
|
||||||
|
buffer *bytebufferpool.ByteBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context pool to reduce allocations
|
// Context pool to reduce allocations
|
||||||
|
@ -25,6 +36,13 @@ func NewContext() *Context {
|
||||||
return contextPool.Get().(*Context)
|
return contextPool.Get().(*Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewHTTPContext creates a new context from a fasthttp RequestCtx
|
||||||
|
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
|
||||||
|
ctx := NewContext()
|
||||||
|
ctx.RequestCtx = requestCtx
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
// Release returns the context to the pool after clearing its values
|
// Release returns the context to the pool after clearing its values
|
||||||
func (c *Context) Release() {
|
func (c *Context) Release() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
@ -35,9 +53,29 @@ func (c *Context) Release() {
|
||||||
delete(c.Values, k)
|
delete(c.Values, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset request context
|
||||||
|
c.RequestCtx = nil
|
||||||
|
|
||||||
|
// Return buffer to pool if we have one
|
||||||
|
if c.buffer != nil {
|
||||||
|
bytebufferpool.Put(c.buffer)
|
||||||
|
c.buffer = nil
|
||||||
|
}
|
||||||
|
|
||||||
contextPool.Put(c)
|
contextPool.Put(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBuffer returns a byte buffer for efficient string operations
|
||||||
|
func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.buffer == nil {
|
||||||
|
c.buffer = bytebufferpool.Get()
|
||||||
|
}
|
||||||
|
return c.buffer
|
||||||
|
}
|
||||||
|
|
||||||
// Set adds a value to the context
|
// Set adds a value to the context
|
||||||
func (c *Context) Set(key string, value any) {
|
func (c *Context) Set(key string, value any) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
@ -83,3 +121,8 @@ func (c *Context) All() map[string]any {
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsHTTPRequest returns true if this context contains a fasthttp RequestCtx
|
||||||
|
func (c *Context) IsHTTPRequest() bool {
|
||||||
|
return c.RequestCtx != nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"errors"
|
||||||
"io"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"github.com/valyala/bytebufferpool"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||||
)
|
)
|
||||||
|
@ -65,8 +68,12 @@ func ReleaseResponse(resp *HTTPResponse) {
|
||||||
// ---------- HTTP CLIENT FUNCTIONALITY ----------
|
// ---------- HTTP CLIENT FUNCTIONALITY ----------
|
||||||
|
|
||||||
// Default HTTP client with sensible timeout
|
// Default HTTP client with sensible timeout
|
||||||
var defaultClient = &http.Client{
|
var defaultFastClient fasthttp.Client = fasthttp.Client{
|
||||||
Timeout: 30 * time.Second,
|
MaxConnsPerHost: 1024,
|
||||||
|
MaxIdleConnDuration: time.Minute,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
DisableHeaderNamesNormalizing: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPClientConfig contains client settings
|
// HTTPClientConfig contains client settings
|
||||||
|
@ -157,14 +164,22 @@ func httpRequest(state *luajit.State) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get body (optional)
|
// Use bytebufferpool for request and response
|
||||||
var bodyReader io.Reader
|
req := fasthttp.AcquireRequest()
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
resp := fasthttp.AcquireResponse()
|
||||||
var body []byte
|
defer fasthttp.ReleaseRequest(req)
|
||||||
|
defer fasthttp.ReleaseResponse(resp)
|
||||||
|
|
||||||
|
// Set up request
|
||||||
|
req.Header.SetMethod(method)
|
||||||
|
req.SetRequestURI(urlStr)
|
||||||
|
req.Header.Set("User-Agent", "Moonshark/1.0")
|
||||||
|
|
||||||
|
// Get body (optional)
|
||||||
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
if state.IsString(3) {
|
if state.IsString(3) {
|
||||||
// String body
|
// String body
|
||||||
body = []byte(state.ToString(3))
|
req.SetBodyString(state.ToString(3))
|
||||||
} else if state.IsTable(3) {
|
} else if state.IsTable(3) {
|
||||||
// Table body - convert to JSON
|
// Table body - convert to JSON
|
||||||
luaTable, err := state.ToTable(3)
|
luaTable, err := state.ToTable(3)
|
||||||
|
@ -173,29 +188,22 @@ func httpRequest(state *luajit.State) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err = json.Marshal(luaTable)
|
// Use bytebufferpool for JSON serialization
|
||||||
if err != nil {
|
buf := bytebufferpool.Get()
|
||||||
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
|
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
||||||
state.PushString("Failed to convert body to JSON: " + err.Error())
|
state.PushString("Failed to convert body to JSON: " + err.Error())
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.SetBody(buf.Bytes())
|
||||||
} else {
|
} else {
|
||||||
state.PushString("Body must be a string or table")
|
state.PushString("Body must be a string or table")
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyReader = bytes.NewReader(body)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create request
|
|
||||||
req, err := http.NewRequest(method, urlStr, bodyReader)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString("Failed to create request: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set default headers
|
|
||||||
req.Header.Set("User-Agent", "Moonshark/1.0")
|
|
||||||
|
|
||||||
// Process options (headers, timeout, etc.)
|
// Process options (headers, timeout, etc.)
|
||||||
timeout := config.DefaultTimeout
|
timeout := config.DefaultTimeout
|
||||||
if state.GetTop() >= 4 && !state.IsNil(4) {
|
if state.GetTop() >= 4 && !state.IsNil(4) {
|
||||||
|
@ -236,7 +244,7 @@ func httpRequest(state *luajit.State) int {
|
||||||
state.Pop(1) // Pop timeout
|
state.Pop(1) // Pop timeout
|
||||||
|
|
||||||
// Set content type for POST/PUT if body is present and content-type not manually set
|
// Set content type for POST/PUT if body is present and content-type not manually set
|
||||||
if (method == "POST" || method == "PUT") && bodyReader != nil && req.Header.Get("Content-Type") == "" {
|
if (method == "POST" || method == "PUT") && req.Body() != nil && req.Header.Peek("Content-Type") == nil {
|
||||||
// Check if options specify content type
|
// Check if options specify content type
|
||||||
state.GetField(4, "content_type")
|
state.GetField(4, "content_type")
|
||||||
if state.IsString(-1) {
|
if state.IsString(-1) {
|
||||||
|
@ -255,7 +263,8 @@ func httpRequest(state *luajit.State) int {
|
||||||
// Process query parameters
|
// Process query parameters
|
||||||
state.GetField(4, "query")
|
state.GetField(4, "query")
|
||||||
if state.IsTable(-1) {
|
if state.IsTable(-1) {
|
||||||
q := req.URL.Query()
|
// Create URL args
|
||||||
|
args := req.URI().QueryArgs()
|
||||||
|
|
||||||
// Iterate through query params
|
// Iterate through query params
|
||||||
state.PushNil() // Start iteration
|
state.PushNil() // Start iteration
|
||||||
|
@ -266,52 +275,36 @@ func httpRequest(state *luajit.State) int {
|
||||||
|
|
||||||
// Handle different value types
|
// Handle different value types
|
||||||
if state.IsString(-1) {
|
if state.IsString(-1) {
|
||||||
q.Add(paramName, state.ToString(-1))
|
args.Add(paramName, state.ToString(-1))
|
||||||
} else if state.IsNumber(-1) {
|
} else if state.IsNumber(-1) {
|
||||||
q.Add(paramName, strings.TrimRight(strings.TrimRight(
|
args.Add(paramName, strings.TrimRight(strings.TrimRight(
|
||||||
state.ToString(-1), "0"), "."))
|
state.ToString(-1), "0"), "."))
|
||||||
} else if state.IsBoolean(-1) {
|
} else if state.IsBoolean(-1) {
|
||||||
if state.ToBoolean(-1) {
|
if state.ToBoolean(-1) {
|
||||||
q.Add(paramName, "true")
|
args.Add(paramName, "true")
|
||||||
} else {
|
} else {
|
||||||
q.Add(paramName, "false")
|
args.Add(paramName, "false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
state.Pop(1) // Pop value, leave key for next iteration
|
||||||
}
|
}
|
||||||
|
|
||||||
req.URL.RawQuery = q.Encode()
|
|
||||||
}
|
}
|
||||||
state.Pop(1) // Pop query table
|
state.Pop(1) // Pop query table
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create context with timeout
|
// Create context with timeout
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
_, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Use context with request
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
// Execute request
|
// Execute request
|
||||||
resp, err := defaultClient.Do(req)
|
err = defaultFastClient.DoTimeout(req, resp, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("Request failed: " + err.Error())
|
errStr := "Request failed: " + err.Error()
|
||||||
return -1
|
if errors.Is(err, fasthttp.ErrTimeout) {
|
||||||
|
errStr = "Request timed out after " + timeout.String()
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
state.PushString(errStr)
|
||||||
|
|
||||||
// Apply size limits to response
|
|
||||||
var respBody []byte
|
|
||||||
if config.MaxResponseSize > 0 {
|
|
||||||
// Limit the response body size
|
|
||||||
respBody, err = io.ReadAll(io.LimitReader(resp.Body, config.MaxResponseSize))
|
|
||||||
} else {
|
|
||||||
respBody, err = io.ReadAll(resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
state.PushString("Failed to read response: " + err.Error())
|
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,19 +312,32 @@ func httpRequest(state *luajit.State) int {
|
||||||
state.NewTable()
|
state.NewTable()
|
||||||
|
|
||||||
// Set status code
|
// Set status code
|
||||||
state.PushNumber(float64(resp.StatusCode))
|
state.PushNumber(float64(resp.StatusCode()))
|
||||||
state.SetField(-2, "status")
|
state.SetField(-2, "status")
|
||||||
|
|
||||||
// Set status text
|
// Set status text
|
||||||
state.PushString(resp.Status)
|
statusText := fasthttp.StatusMessage(resp.StatusCode())
|
||||||
|
state.PushString(statusText)
|
||||||
state.SetField(-2, "status_text")
|
state.SetField(-2, "status_text")
|
||||||
|
|
||||||
// Set body
|
// Set body
|
||||||
|
var respBody []byte
|
||||||
|
|
||||||
|
// Apply size limits to response
|
||||||
|
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
||||||
|
// Make a limited copy
|
||||||
|
respBody = make([]byte, config.MaxResponseSize)
|
||||||
|
copy(respBody, resp.Body())
|
||||||
|
} else {
|
||||||
|
respBody = resp.Body()
|
||||||
|
}
|
||||||
|
|
||||||
state.PushString(string(respBody))
|
state.PushString(string(respBody))
|
||||||
state.SetField(-2, "body")
|
state.SetField(-2, "body")
|
||||||
|
|
||||||
// Parse body as JSON if content type is application/json
|
// Parse body as JSON if content type is application/json
|
||||||
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
|
contentType := string(resp.Header.ContentType())
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
var jsonData any
|
var jsonData any
|
||||||
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
||||||
if err := state.PushValue(jsonData); err == nil {
|
if err := state.PushValue(jsonData); err == nil {
|
||||||
|
@ -342,24 +348,14 @@ func httpRequest(state *luajit.State) int {
|
||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
state.NewTable()
|
state.NewTable()
|
||||||
for name, values := range resp.Header {
|
resp.Header.VisitAll(func(key, value []byte) {
|
||||||
if len(values) == 1 {
|
state.PushString(string(value))
|
||||||
state.PushString(values[0])
|
state.SetField(-2, string(key))
|
||||||
} else {
|
})
|
||||||
// Create array of values
|
|
||||||
state.NewTable()
|
|
||||||
for i, v := range values {
|
|
||||||
state.PushNumber(float64(i + 1))
|
|
||||||
state.PushString(v)
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.SetField(-2, name)
|
|
||||||
}
|
|
||||||
state.SetField(-2, "headers")
|
state.SetField(-2, "headers")
|
||||||
|
|
||||||
// Create ok field (true if status code is 2xx)
|
// Create ok field (true if status code is 2xx)
|
||||||
state.PushBoolean(resp.StatusCode >= 200 && resp.StatusCode < 300)
|
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
||||||
state.SetField(-2, "ok")
|
state.SetField(-2, "ok")
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
@ -483,6 +479,71 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
||||||
return response, true
|
return response, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
|
||||||
|
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
|
||||||
|
// Set status code
|
||||||
|
ctx.SetStatusCode(httpResp.Status)
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
for name, value := range httpResp.Headers {
|
||||||
|
ctx.Response.Header.Set(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set cookies
|
||||||
|
for _, cookie := range httpResp.Cookies {
|
||||||
|
// Convert net/http cookie to fasthttp cookie
|
||||||
|
var c fasthttp.Cookie
|
||||||
|
c.SetKey(cookie.Name)
|
||||||
|
c.SetValue(cookie.Value)
|
||||||
|
if cookie.Path != "" {
|
||||||
|
c.SetPath(cookie.Path)
|
||||||
|
}
|
||||||
|
if cookie.Domain != "" {
|
||||||
|
c.SetDomain(cookie.Domain)
|
||||||
|
}
|
||||||
|
if cookie.MaxAge > 0 {
|
||||||
|
c.SetMaxAge(cookie.MaxAge)
|
||||||
|
}
|
||||||
|
if cookie.Expires.After(time.Time{}) {
|
||||||
|
c.SetExpire(cookie.Expires)
|
||||||
|
}
|
||||||
|
c.SetSecure(cookie.Secure)
|
||||||
|
c.SetHTTPOnly(cookie.HttpOnly)
|
||||||
|
ctx.Response.Header.SetCookie(&c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the body based on its type
|
||||||
|
if httpResp.Body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set body based on type
|
||||||
|
switch body := httpResp.Body.(type) {
|
||||||
|
case string:
|
||||||
|
ctx.SetBodyString(body)
|
||||||
|
case []byte:
|
||||||
|
ctx.SetBody(body)
|
||||||
|
case map[string]any, []any, []float64, []string, []int:
|
||||||
|
// Marshal JSON using a buffer from the pool
|
||||||
|
buf := bytebufferpool.Get()
|
||||||
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
|
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||||
|
// Set content type if not already set
|
||||||
|
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||||
|
ctx.Response.Header.SetContentType("application/json")
|
||||||
|
}
|
||||||
|
ctx.SetBody(buf.Bytes())
|
||||||
|
} else {
|
||||||
|
// Fallback
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Default to string representation
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithHTTPClientConfig creates a runner option to configure the HTTP client
|
// WithHTTPClientConfig creates a runner option to configure the HTTP client
|
||||||
func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption {
|
func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption {
|
||||||
return func(r *Runner) {
|
return func(r *Runner) {
|
||||||
|
|
|
@ -9,6 +9,9 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/panjf2000/ants/v2"
|
||||||
|
"github.com/valyala/bytebufferpool"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||||
)
|
)
|
||||||
|
@ -39,6 +42,20 @@ type InitHook func(*luajit.State, *Context) error
|
||||||
// FinalizeHook runs after executing a script
|
// FinalizeHook runs after executing a script
|
||||||
type FinalizeHook func(*luajit.State, *Context, any) error
|
type FinalizeHook func(*luajit.State, *Context, any) error
|
||||||
|
|
||||||
|
// ExecuteTask represents a task in the execution goroutine pool
|
||||||
|
type ExecuteTask struct {
|
||||||
|
bytecode []byte
|
||||||
|
context *Context
|
||||||
|
scriptPath string
|
||||||
|
result chan<- taskResult
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskResult holds the result of an execution task
|
||||||
|
type taskResult struct {
|
||||||
|
value any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
// Runner runs Lua scripts using a pool of Lua states
|
// Runner runs Lua scripts using a pool of Lua states
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
states []*State // All states managed by this runner
|
states []*State // All states managed by this runner
|
||||||
|
@ -51,6 +68,7 @@ type Runner struct {
|
||||||
initHooks []InitHook // Hooks run before script execution
|
initHooks []InitHook // Hooks run before script execution
|
||||||
finalizeHooks []FinalizeHook // Hooks run after script execution
|
finalizeHooks []FinalizeHook // Hooks run after script execution
|
||||||
scriptDir string // Current script directory
|
scriptDir string // Current script directory
|
||||||
|
pool *ants.Pool // Goroutine pool for task execution
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPoolSize sets the state pool size
|
// WithPoolSize sets the state pool size
|
||||||
|
@ -124,6 +142,13 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
|
||||||
runner.states = make([]*State, runner.poolSize)
|
runner.states = make([]*State, runner.poolSize)
|
||||||
runner.statePool = make(chan int, runner.poolSize)
|
runner.statePool = make(chan int, runner.poolSize)
|
||||||
|
|
||||||
|
// Create ants goroutine pool
|
||||||
|
var err error
|
||||||
|
runner.pool, err = ants.NewPool(runner.poolSize * 2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Create and initialize all states
|
// Create and initialize all states
|
||||||
if err := runner.initializeStates(); err != nil {
|
if err := runner.initializeStates(); err != nil {
|
||||||
runner.Close() // Clean up already created states
|
runner.Close() // Clean up already created states
|
||||||
|
@ -250,27 +275,30 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute runs a script with context
|
// executeTask is the worker function for the ants pool
|
||||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
func (r *Runner) executeTask(i interface{}) {
|
||||||
if !r.isRunning.Load() {
|
task, ok := i.(*ExecuteTask)
|
||||||
return nil, ErrRunnerClosed
|
if !ok {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set script directory if provided
|
// Set script directory if provided
|
||||||
if scriptPath != "" {
|
if task.scriptPath != "" {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
r.scriptDir = filepath.Dir(scriptPath)
|
r.scriptDir = filepath.Dir(task.scriptPath)
|
||||||
r.moduleLoader.SetScriptDir(r.scriptDir)
|
r.moduleLoader.SetScriptDir(r.scriptDir)
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a state index from the pool with timeout
|
// Get a state index from the pool
|
||||||
var stateIndex int
|
var stateIndex int
|
||||||
select {
|
select {
|
||||||
case stateIndex = <-r.statePool:
|
case stateIndex = <-r.statePool:
|
||||||
// Got a state
|
// Got a state
|
||||||
case <-ctx.Done():
|
default:
|
||||||
return nil, ctx.Err()
|
// No state available - this shouldn't happen since we limit tasks
|
||||||
|
task.result <- taskResult{nil, errors.New("no states available")}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual state
|
// Get the actual state
|
||||||
|
@ -279,9 +307,9 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
|
||||||
r.mu.RUnlock()
|
r.mu.RUnlock()
|
||||||
|
|
||||||
if state == nil {
|
if state == nil {
|
||||||
// This shouldn't happen, but recover gracefully
|
|
||||||
r.statePool <- stateIndex
|
r.statePool <- stateIndex
|
||||||
return nil, ErrStateNotReady
|
task.result <- taskResult{nil, ErrStateNotReady}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark state as in use
|
// Mark state as in use
|
||||||
|
@ -310,28 +338,60 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
|
||||||
|
|
||||||
// Run init hooks
|
// Run init hooks
|
||||||
for _, hook := range initHooks {
|
for _, hook := range initHooks {
|
||||||
if err := hook(state.L, execCtx); err != nil {
|
if err := hook(state.L, task.context); err != nil {
|
||||||
return nil, err
|
task.result <- taskResult{nil, err}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare context values
|
// Prepare context values
|
||||||
var ctxValues map[string]any
|
var ctxValues map[string]any
|
||||||
if execCtx != nil {
|
if task.context != nil {
|
||||||
ctxValues = execCtx.Values
|
ctxValues = task.context.Values
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
result, err := state.sandbox.Execute(state.L, bytecode, ctxValues)
|
result, err := state.sandbox.Execute(state.L, task.bytecode, ctxValues)
|
||||||
|
|
||||||
// Run finalize hooks
|
// Run finalize hooks
|
||||||
for _, hook := range finalizeHooks {
|
for _, hook := range finalizeHooks {
|
||||||
if hookErr := hook(state.L, execCtx, result); hookErr != nil && err == nil {
|
if hookErr := hook(state.L, task.context, result); hookErr != nil && err == nil {
|
||||||
err = hookErr
|
err = hookErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, err
|
task.result <- taskResult{result, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs a script with context
|
||||||
|
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||||
|
if !r.isRunning.Load() {
|
||||||
|
return nil, ErrRunnerClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create result channel
|
||||||
|
resultChan := make(chan taskResult, 1)
|
||||||
|
|
||||||
|
// Create task
|
||||||
|
task := &ExecuteTask{
|
||||||
|
bytecode: bytecode,
|
||||||
|
context: execCtx,
|
||||||
|
scriptPath: scriptPath,
|
||||||
|
result: resultChan,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit task to pool
|
||||||
|
if err := r.pool.Submit(func() { r.executeTask(task) }); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for result with context timeout
|
||||||
|
select {
|
||||||
|
case result := <-resultChan:
|
||||||
|
return result.value, result.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run executes a Lua script (convenience wrapper)
|
// Run executes a Lua script (convenience wrapper)
|
||||||
|
@ -351,6 +411,9 @@ func (r *Runner) Close() error {
|
||||||
r.isRunning.Store(false)
|
r.isRunning.Store(false)
|
||||||
r.debugLog("Closing Runner and destroying all states")
|
r.debugLog("Closing Runner and destroying all states")
|
||||||
|
|
||||||
|
// Shut down goroutine pool
|
||||||
|
r.pool.Release()
|
||||||
|
|
||||||
// Drain the state pool
|
// Drain the state pool
|
||||||
r.drainStatePool()
|
r.drainStatePool()
|
||||||
|
|
||||||
|
@ -563,6 +626,11 @@ func (r *Runner) GetActiveStateCount() int {
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWorkerPoolStats returns statistics about the worker pool
|
||||||
|
func (r *Runner) GetWorkerPoolStats() (running, capacity int) {
|
||||||
|
return r.pool.Running(), r.pool.Cap()
|
||||||
|
}
|
||||||
|
|
||||||
// GetModuleCount returns the number of loaded modules in the first available state
|
// GetModuleCount returns the number of loaded modules in the first available state
|
||||||
func (r *Runner) GetModuleCount() int {
|
func (r *Runner) GetModuleCount() int {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
|
@ -593,3 +661,13 @@ func (r *Runner) GetModuleCount() int {
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBufferPool returns a buffer from the bytebufferpool
|
||||||
|
func GetBufferPool() *bytebufferpool.ByteBuffer {
|
||||||
|
return bytebufferpool.Get()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReleaseBufferPool returns a buffer to the bytebufferpool
|
||||||
|
func ReleaseBufferPool(buf *bytebufferpool.ByteBuffer) {
|
||||||
|
bytebufferpool.Put(buf)
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"github.com/valyala/bytebufferpool"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||||
)
|
)
|
||||||
|
@ -136,16 +139,42 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
|
||||||
|
|
||||||
// Execute runs bytecode in the sandbox
|
// Execute runs bytecode in the sandbox
|
||||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
|
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
|
||||||
|
// Create a temporary context if we only have a map
|
||||||
|
if ctx != nil {
|
||||||
|
tempCtx := &Context{
|
||||||
|
Values: ctx,
|
||||||
|
}
|
||||||
|
return s.OptimizedExecute(state, bytecode, tempCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just pass nil through if we have no context
|
||||||
|
return s.OptimizedExecute(state, bytecode, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptimizedExecute runs bytecode with a fasthttp context if available
|
||||||
|
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) {
|
||||||
|
// Use a buffer from the pool for any string operations
|
||||||
|
buf := bytebufferpool.Get()
|
||||||
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
// Load bytecode
|
// Load bytecode
|
||||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||||
s.debugLog("Failed to load bytecode: %v", err)
|
s.debugLog("Failed to load bytecode: %v", err)
|
||||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare context
|
// Prepare context values
|
||||||
|
var ctxValues map[string]any
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
state.CreateTable(0, len(ctx))
|
ctxValues = ctx.Values
|
||||||
for k, v := range ctx {
|
} else {
|
||||||
|
ctxValues = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare context table
|
||||||
|
if ctxValues != nil {
|
||||||
|
state.CreateTable(0, len(ctxValues))
|
||||||
|
for k, v := range ctxValues {
|
||||||
state.PushString(k)
|
state.PushString(k)
|
||||||
if err := state.PushValue(v); err != nil {
|
if err := state.PushValue(v); err != nil {
|
||||||
state.Pop(2) // Pop key and table
|
state.Pop(2) // Pop key and table
|
||||||
|
@ -189,8 +218,39 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
|
||||||
if hasResponse {
|
if hasResponse {
|
||||||
// Add the script result as the response body
|
// Add the script result as the response body
|
||||||
httpResponse.Body = result
|
httpResponse.Body = result
|
||||||
|
|
||||||
|
// If we have a fasthttp context, apply the response directly
|
||||||
|
if ctx != nil && ctx.RequestCtx != nil {
|
||||||
|
ApplyHTTPResponse(httpResponse, ctx.RequestCtx)
|
||||||
|
ReleaseResponse(httpResponse)
|
||||||
|
return nil, nil // No need to return response object
|
||||||
|
}
|
||||||
|
|
||||||
return httpResponse, nil
|
return httpResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we have a fasthttp context and the result needs to be written directly
|
||||||
|
if ctx != nil && ctx.RequestCtx != nil && (result != nil) {
|
||||||
|
// For direct HTTP responses
|
||||||
|
switch r := result.(type) {
|
||||||
|
case string:
|
||||||
|
ctx.RequestCtx.SetBodyString(r)
|
||||||
|
case []byte:
|
||||||
|
ctx.RequestCtx.SetBody(r)
|
||||||
|
case map[string]any, []any:
|
||||||
|
// JSON response
|
||||||
|
ctx.RequestCtx.Response.Header.SetContentType("application/json")
|
||||||
|
if err := json.NewEncoder(buf).Encode(r); err == nil {
|
||||||
|
ctx.RequestCtx.SetBody(buf.Bytes())
|
||||||
|
} else {
|
||||||
|
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Default string conversion
|
||||||
|
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -6,9 +6,12 @@ require git.sharkk.net/Sky/LuaJIT-to-Go v0.0.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||||
|
github.com/goccy/go-json v0.10.5 // indirect
|
||||||
github.com/klauspost/compress v1.18.0 // indirect
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
|
github.com/panjf2000/ants/v2 v2.11.2 // indirect
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
github.com/valyala/fasthttp v1.60.0 // indirect
|
github.com/valyala/fasthttp v1.60.0 // indirect
|
||||||
|
golang.org/x/sync v0.12.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace git.sharkk.net/Sky/LuaJIT-to-Go => ./luajit
|
replace git.sharkk.net/Sky/LuaJIT-to-Go => ./luajit
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -1,9 +1,15 @@
|
||||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||||
|
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||||
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
|
github.com/panjf2000/ants/v2 v2.11.2 h1:AVGpMSePxUNpcLaBO34xuIgM1ZdKOiGnpxLXixLi5Jo=
|
||||||
|
github.com/panjf2000/ants/v2 v2.11.2/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw=
|
github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw=
|
||||||
github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc=
|
github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc=
|
||||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||||
|
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||||
|
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
|
|
Loading…
Reference in New Issue
Block a user