hyper op 1
This commit is contained in:
parent
95eae40357
commit
f6c260a525
|
@ -1,6 +1,11 @@
|
|||
package runner
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Context represents execution context for a Lua script
|
||||
type Context struct {
|
||||
|
@ -9,6 +14,12 @@ type Context struct {
|
|||
|
||||
// internal mutex for concurrent access
|
||||
mu sync.RWMutex
|
||||
|
||||
// FastHTTP context if this was created from an HTTP request
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
|
||||
// Buffer for efficient string operations
|
||||
buffer *bytebufferpool.ByteBuffer
|
||||
}
|
||||
|
||||
// Context pool to reduce allocations
|
||||
|
@ -25,6 +36,13 @@ func NewContext() *Context {
|
|||
return contextPool.Get().(*Context)
|
||||
}
|
||||
|
||||
// NewHTTPContext creates a new context from a fasthttp RequestCtx
|
||||
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
|
||||
ctx := NewContext()
|
||||
ctx.RequestCtx = requestCtx
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Release returns the context to the pool after clearing its values
|
||||
func (c *Context) Release() {
|
||||
c.mu.Lock()
|
||||
|
@ -35,9 +53,29 @@ func (c *Context) Release() {
|
|||
delete(c.Values, k)
|
||||
}
|
||||
|
||||
// Reset request context
|
||||
c.RequestCtx = nil
|
||||
|
||||
// Return buffer to pool if we have one
|
||||
if c.buffer != nil {
|
||||
bytebufferpool.Put(c.buffer)
|
||||
c.buffer = nil
|
||||
}
|
||||
|
||||
contextPool.Put(c)
|
||||
}
|
||||
|
||||
// GetBuffer returns a byte buffer for efficient string operations
|
||||
func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.buffer == nil {
|
||||
c.buffer = bytebufferpool.Get()
|
||||
}
|
||||
return c.buffer
|
||||
}
|
||||
|
||||
// Set adds a value to the context
|
||||
func (c *Context) Set(key string, value any) {
|
||||
c.mu.Lock()
|
||||
|
@ -83,3 +121,8 @@ func (c *Context) All() map[string]any {
|
|||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsHTTPRequest returns true if this context contains a fasthttp RequestCtx
|
||||
func (c *Context) IsHTTPRequest() bool {
|
||||
return c.RequestCtx != nil
|
||||
}
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
@ -65,8 +68,12 @@ func ReleaseResponse(resp *HTTPResponse) {
|
|||
// ---------- HTTP CLIENT FUNCTIONALITY ----------
|
||||
|
||||
// Default HTTP client with sensible timeout
|
||||
var defaultClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
var defaultFastClient fasthttp.Client = fasthttp.Client{
|
||||
MaxConnsPerHost: 1024,
|
||||
MaxIdleConnDuration: time.Minute,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
DisableHeaderNamesNormalizing: true,
|
||||
}
|
||||
|
||||
// HTTPClientConfig contains client settings
|
||||
|
@ -157,14 +164,22 @@ func httpRequest(state *luajit.State) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
// Get body (optional)
|
||||
var bodyReader io.Reader
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
var body []byte
|
||||
// Use bytebufferpool for request and response
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set up request
|
||||
req.Header.SetMethod(method)
|
||||
req.SetRequestURI(urlStr)
|
||||
req.Header.Set("User-Agent", "Moonshark/1.0")
|
||||
|
||||
// Get body (optional)
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if state.IsString(3) {
|
||||
// String body
|
||||
body = []byte(state.ToString(3))
|
||||
req.SetBodyString(state.ToString(3))
|
||||
} else if state.IsTable(3) {
|
||||
// Table body - convert to JSON
|
||||
luaTable, err := state.ToTable(3)
|
||||
|
@ -173,29 +188,22 @@ func httpRequest(state *luajit.State) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
body, err = json.Marshal(luaTable)
|
||||
if err != nil {
|
||||
// Use bytebufferpool for JSON serialization
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
||||
state.PushString("Failed to convert body to JSON: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
req.SetBody(buf.Bytes())
|
||||
} else {
|
||||
state.PushString("Body must be a string or table")
|
||||
return -1
|
||||
}
|
||||
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(method, urlStr, bodyReader)
|
||||
if err != nil {
|
||||
state.PushString("Failed to create request: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Set default headers
|
||||
req.Header.Set("User-Agent", "Moonshark/1.0")
|
||||
|
||||
// Process options (headers, timeout, etc.)
|
||||
timeout := config.DefaultTimeout
|
||||
if state.GetTop() >= 4 && !state.IsNil(4) {
|
||||
|
@ -236,7 +244,7 @@ func httpRequest(state *luajit.State) int {
|
|||
state.Pop(1) // Pop timeout
|
||||
|
||||
// Set content type for POST/PUT if body is present and content-type not manually set
|
||||
if (method == "POST" || method == "PUT") && bodyReader != nil && req.Header.Get("Content-Type") == "" {
|
||||
if (method == "POST" || method == "PUT") && req.Body() != nil && req.Header.Peek("Content-Type") == nil {
|
||||
// Check if options specify content type
|
||||
state.GetField(4, "content_type")
|
||||
if state.IsString(-1) {
|
||||
|
@ -255,7 +263,8 @@ func httpRequest(state *luajit.State) int {
|
|||
// Process query parameters
|
||||
state.GetField(4, "query")
|
||||
if state.IsTable(-1) {
|
||||
q := req.URL.Query()
|
||||
// Create URL args
|
||||
args := req.URI().QueryArgs()
|
||||
|
||||
// Iterate through query params
|
||||
state.PushNil() // Start iteration
|
||||
|
@ -266,52 +275,36 @@ func httpRequest(state *luajit.State) int {
|
|||
|
||||
// Handle different value types
|
||||
if state.IsString(-1) {
|
||||
q.Add(paramName, state.ToString(-1))
|
||||
args.Add(paramName, state.ToString(-1))
|
||||
} else if state.IsNumber(-1) {
|
||||
q.Add(paramName, strings.TrimRight(strings.TrimRight(
|
||||
args.Add(paramName, strings.TrimRight(strings.TrimRight(
|
||||
state.ToString(-1), "0"), "."))
|
||||
} else if state.IsBoolean(-1) {
|
||||
if state.ToBoolean(-1) {
|
||||
q.Add(paramName, "true")
|
||||
args.Add(paramName, "true")
|
||||
} else {
|
||||
q.Add(paramName, "false")
|
||||
args.Add(paramName, "false")
|
||||
}
|
||||
}
|
||||
}
|
||||
state.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
state.Pop(1) // Pop query table
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
_, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Use context with request
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// Execute request
|
||||
resp, err := defaultClient.Do(req)
|
||||
err = defaultFastClient.DoTimeout(req, resp, timeout)
|
||||
if err != nil {
|
||||
state.PushString("Request failed: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Apply size limits to response
|
||||
var respBody []byte
|
||||
if config.MaxResponseSize > 0 {
|
||||
// Limit the response body size
|
||||
respBody, err = io.ReadAll(io.LimitReader(resp.Body, config.MaxResponseSize))
|
||||
} else {
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
state.PushString("Failed to read response: " + err.Error())
|
||||
errStr := "Request failed: " + err.Error()
|
||||
if errors.Is(err, fasthttp.ErrTimeout) {
|
||||
errStr = "Request timed out after " + timeout.String()
|
||||
}
|
||||
state.PushString(errStr)
|
||||
return -1
|
||||
}
|
||||
|
||||
|
@ -319,19 +312,32 @@ func httpRequest(state *luajit.State) int {
|
|||
state.NewTable()
|
||||
|
||||
// Set status code
|
||||
state.PushNumber(float64(resp.StatusCode))
|
||||
state.PushNumber(float64(resp.StatusCode()))
|
||||
state.SetField(-2, "status")
|
||||
|
||||
// Set status text
|
||||
state.PushString(resp.Status)
|
||||
statusText := fasthttp.StatusMessage(resp.StatusCode())
|
||||
state.PushString(statusText)
|
||||
state.SetField(-2, "status_text")
|
||||
|
||||
// Set body
|
||||
var respBody []byte
|
||||
|
||||
// Apply size limits to response
|
||||
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
||||
// Make a limited copy
|
||||
respBody = make([]byte, config.MaxResponseSize)
|
||||
copy(respBody, resp.Body())
|
||||
} else {
|
||||
respBody = resp.Body()
|
||||
}
|
||||
|
||||
state.PushString(string(respBody))
|
||||
state.SetField(-2, "body")
|
||||
|
||||
// Parse body as JSON if content type is application/json
|
||||
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
|
||||
contentType := string(resp.Header.ContentType())
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
var jsonData any
|
||||
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
||||
if err := state.PushValue(jsonData); err == nil {
|
||||
|
@ -342,24 +348,14 @@ func httpRequest(state *luajit.State) int {
|
|||
|
||||
// Set headers
|
||||
state.NewTable()
|
||||
for name, values := range resp.Header {
|
||||
if len(values) == 1 {
|
||||
state.PushString(values[0])
|
||||
} else {
|
||||
// Create array of values
|
||||
state.NewTable()
|
||||
for i, v := range values {
|
||||
state.PushNumber(float64(i + 1))
|
||||
state.PushString(v)
|
||||
state.SetTable(-3)
|
||||
}
|
||||
}
|
||||
state.SetField(-2, name)
|
||||
}
|
||||
resp.Header.VisitAll(func(key, value []byte) {
|
||||
state.PushString(string(value))
|
||||
state.SetField(-2, string(key))
|
||||
})
|
||||
state.SetField(-2, "headers")
|
||||
|
||||
// Create ok field (true if status code is 2xx)
|
||||
state.PushBoolean(resp.StatusCode >= 200 && resp.StatusCode < 300)
|
||||
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
||||
state.SetField(-2, "ok")
|
||||
|
||||
return 1
|
||||
|
@ -483,6 +479,71 @@ func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
|||
return response, true
|
||||
}
|
||||
|
||||
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
|
||||
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
|
||||
// Set status code
|
||||
ctx.SetStatusCode(httpResp.Status)
|
||||
|
||||
// Set headers
|
||||
for name, value := range httpResp.Headers {
|
||||
ctx.Response.Header.Set(name, value)
|
||||
}
|
||||
|
||||
// Set cookies
|
||||
for _, cookie := range httpResp.Cookies {
|
||||
// Convert net/http cookie to fasthttp cookie
|
||||
var c fasthttp.Cookie
|
||||
c.SetKey(cookie.Name)
|
||||
c.SetValue(cookie.Value)
|
||||
if cookie.Path != "" {
|
||||
c.SetPath(cookie.Path)
|
||||
}
|
||||
if cookie.Domain != "" {
|
||||
c.SetDomain(cookie.Domain)
|
||||
}
|
||||
if cookie.MaxAge > 0 {
|
||||
c.SetMaxAge(cookie.MaxAge)
|
||||
}
|
||||
if cookie.Expires.After(time.Time{}) {
|
||||
c.SetExpire(cookie.Expires)
|
||||
}
|
||||
c.SetSecure(cookie.Secure)
|
||||
c.SetHTTPOnly(cookie.HttpOnly)
|
||||
ctx.Response.Header.SetCookie(&c)
|
||||
}
|
||||
|
||||
// Process the body based on its type
|
||||
if httpResp.Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Set body based on type
|
||||
switch body := httpResp.Body.(type) {
|
||||
case string:
|
||||
ctx.SetBodyString(body)
|
||||
case []byte:
|
||||
ctx.SetBody(body)
|
||||
case map[string]any, []any, []float64, []string, []int:
|
||||
// Marshal JSON using a buffer from the pool
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||
// Set content type if not already set
|
||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||
ctx.Response.Header.SetContentType("application/json")
|
||||
}
|
||||
ctx.SetBody(buf.Bytes())
|
||||
} else {
|
||||
// Fallback
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||
}
|
||||
default:
|
||||
// Default to string representation
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClientConfig creates a runner option to configure the HTTP client
|
||||
func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
|
|
|
@ -9,6 +9,9 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
@ -39,6 +42,20 @@ type InitHook func(*luajit.State, *Context) error
|
|||
// FinalizeHook runs after executing a script
|
||||
type FinalizeHook func(*luajit.State, *Context, any) error
|
||||
|
||||
// ExecuteTask represents a task in the execution goroutine pool
|
||||
type ExecuteTask struct {
|
||||
bytecode []byte
|
||||
context *Context
|
||||
scriptPath string
|
||||
result chan<- taskResult
|
||||
}
|
||||
|
||||
// taskResult holds the result of an execution task
|
||||
type taskResult struct {
|
||||
value any
|
||||
err error
|
||||
}
|
||||
|
||||
// Runner runs Lua scripts using a pool of Lua states
|
||||
type Runner struct {
|
||||
states []*State // All states managed by this runner
|
||||
|
@ -51,6 +68,7 @@ type Runner struct {
|
|||
initHooks []InitHook // Hooks run before script execution
|
||||
finalizeHooks []FinalizeHook // Hooks run after script execution
|
||||
scriptDir string // Current script directory
|
||||
pool *ants.Pool // Goroutine pool for task execution
|
||||
}
|
||||
|
||||
// WithPoolSize sets the state pool size
|
||||
|
@ -124,6 +142,13 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
|
|||
runner.states = make([]*State, runner.poolSize)
|
||||
runner.statePool = make(chan int, runner.poolSize)
|
||||
|
||||
// Create ants goroutine pool
|
||||
var err error
|
||||
runner.pool, err = ants.NewPool(runner.poolSize * 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create and initialize all states
|
||||
if err := runner.initializeStates(); err != nil {
|
||||
runner.Close() // Clean up already created states
|
||||
|
@ -250,27 +275,30 @@ func (r *Runner) createState(index int) (*State, error) {
|
|||
return state, nil
|
||||
}
|
||||
|
||||
// Execute runs a script with context
|
||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||
if !r.isRunning.Load() {
|
||||
return nil, ErrRunnerClosed
|
||||
// executeTask is the worker function for the ants pool
|
||||
func (r *Runner) executeTask(i interface{}) {
|
||||
task, ok := i.(*ExecuteTask)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Set script directory if provided
|
||||
if scriptPath != "" {
|
||||
if task.scriptPath != "" {
|
||||
r.mu.Lock()
|
||||
r.scriptDir = filepath.Dir(scriptPath)
|
||||
r.scriptDir = filepath.Dir(task.scriptPath)
|
||||
r.moduleLoader.SetScriptDir(r.scriptDir)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// Get a state index from the pool with timeout
|
||||
// Get a state index from the pool
|
||||
var stateIndex int
|
||||
select {
|
||||
case stateIndex = <-r.statePool:
|
||||
// Got a state
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
// No state available - this shouldn't happen since we limit tasks
|
||||
task.result <- taskResult{nil, errors.New("no states available")}
|
||||
return
|
||||
}
|
||||
|
||||
// Get the actual state
|
||||
|
@ -279,9 +307,9 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
|
|||
r.mu.RUnlock()
|
||||
|
||||
if state == nil {
|
||||
// This shouldn't happen, but recover gracefully
|
||||
r.statePool <- stateIndex
|
||||
return nil, ErrStateNotReady
|
||||
task.result <- taskResult{nil, ErrStateNotReady}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark state as in use
|
||||
|
@ -310,28 +338,60 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
|
|||
|
||||
// Run init hooks
|
||||
for _, hook := range initHooks {
|
||||
if err := hook(state.L, execCtx); err != nil {
|
||||
return nil, err
|
||||
if err := hook(state.L, task.context); err != nil {
|
||||
task.result <- taskResult{nil, err}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare context values
|
||||
var ctxValues map[string]any
|
||||
if execCtx != nil {
|
||||
ctxValues = execCtx.Values
|
||||
if task.context != nil {
|
||||
ctxValues = task.context.Values
|
||||
}
|
||||
|
||||
// Execute in sandbox
|
||||
result, err := state.sandbox.Execute(state.L, bytecode, ctxValues)
|
||||
result, err := state.sandbox.Execute(state.L, task.bytecode, ctxValues)
|
||||
|
||||
// Run finalize hooks
|
||||
for _, hook := range finalizeHooks {
|
||||
if hookErr := hook(state.L, execCtx, result); hookErr != nil && err == nil {
|
||||
if hookErr := hook(state.L, task.context, result); hookErr != nil && err == nil {
|
||||
err = hookErr
|
||||
}
|
||||
}
|
||||
|
||||
return result, err
|
||||
task.result <- taskResult{result, err}
|
||||
}
|
||||
|
||||
// Execute runs a script with context
|
||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||
if !r.isRunning.Load() {
|
||||
return nil, ErrRunnerClosed
|
||||
}
|
||||
|
||||
// Create result channel
|
||||
resultChan := make(chan taskResult, 1)
|
||||
|
||||
// Create task
|
||||
task := &ExecuteTask{
|
||||
bytecode: bytecode,
|
||||
context: execCtx,
|
||||
scriptPath: scriptPath,
|
||||
result: resultChan,
|
||||
}
|
||||
|
||||
// Submit task to pool
|
||||
if err := r.pool.Submit(func() { r.executeTask(task) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for result with context timeout
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
return result.value, result.err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes a Lua script (convenience wrapper)
|
||||
|
@ -351,6 +411,9 @@ func (r *Runner) Close() error {
|
|||
r.isRunning.Store(false)
|
||||
r.debugLog("Closing Runner and destroying all states")
|
||||
|
||||
// Shut down goroutine pool
|
||||
r.pool.Release()
|
||||
|
||||
// Drain the state pool
|
||||
r.drainStatePool()
|
||||
|
||||
|
@ -563,6 +626,11 @@ func (r *Runner) GetActiveStateCount() int {
|
|||
return count
|
||||
}
|
||||
|
||||
// GetWorkerPoolStats returns statistics about the worker pool
|
||||
func (r *Runner) GetWorkerPoolStats() (running, capacity int) {
|
||||
return r.pool.Running(), r.pool.Cap()
|
||||
}
|
||||
|
||||
// GetModuleCount returns the number of loaded modules in the first available state
|
||||
func (r *Runner) GetModuleCount() int {
|
||||
r.mu.RLock()
|
||||
|
@ -593,3 +661,13 @@ func (r *Runner) GetModuleCount() int {
|
|||
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetBufferPool returns a buffer from the bytebufferpool
|
||||
func GetBufferPool() *bytebufferpool.ByteBuffer {
|
||||
return bytebufferpool.Get()
|
||||
}
|
||||
|
||||
// ReleaseBufferPool returns a buffer to the bytebufferpool
|
||||
func ReleaseBufferPool(buf *bytebufferpool.ByteBuffer) {
|
||||
bytebufferpool.Put(buf)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,9 @@ import (
|
|||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
@ -136,16 +139,42 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
|
|||
|
||||
// Execute runs bytecode in the sandbox
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
|
||||
// Create a temporary context if we only have a map
|
||||
if ctx != nil {
|
||||
tempCtx := &Context{
|
||||
Values: ctx,
|
||||
}
|
||||
return s.OptimizedExecute(state, bytecode, tempCtx)
|
||||
}
|
||||
|
||||
// Just pass nil through if we have no context
|
||||
return s.OptimizedExecute(state, bytecode, nil)
|
||||
}
|
||||
|
||||
// OptimizedExecute runs bytecode with a fasthttp context if available
|
||||
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) {
|
||||
// Use a buffer from the pool for any string operations
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
// Load bytecode
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
s.debugLog("Failed to load bytecode: %v", err)
|
||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||
}
|
||||
|
||||
// Prepare context
|
||||
// Prepare context values
|
||||
var ctxValues map[string]any
|
||||
if ctx != nil {
|
||||
state.CreateTable(0, len(ctx))
|
||||
for k, v := range ctx {
|
||||
ctxValues = ctx.Values
|
||||
} else {
|
||||
ctxValues = nil
|
||||
}
|
||||
|
||||
// Prepare context table
|
||||
if ctxValues != nil {
|
||||
state.CreateTable(0, len(ctxValues))
|
||||
for k, v := range ctxValues {
|
||||
state.PushString(k)
|
||||
if err := state.PushValue(v); err != nil {
|
||||
state.Pop(2) // Pop key and table
|
||||
|
@ -189,8 +218,39 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
|
|||
if hasResponse {
|
||||
// Add the script result as the response body
|
||||
httpResponse.Body = result
|
||||
|
||||
// If we have a fasthttp context, apply the response directly
|
||||
if ctx != nil && ctx.RequestCtx != nil {
|
||||
ApplyHTTPResponse(httpResponse, ctx.RequestCtx)
|
||||
ReleaseResponse(httpResponse)
|
||||
return nil, nil // No need to return response object
|
||||
}
|
||||
|
||||
return httpResponse, nil
|
||||
}
|
||||
|
||||
// If we have a fasthttp context and the result needs to be written directly
|
||||
if ctx != nil && ctx.RequestCtx != nil && (result != nil) {
|
||||
// For direct HTTP responses
|
||||
switch r := result.(type) {
|
||||
case string:
|
||||
ctx.RequestCtx.SetBodyString(r)
|
||||
case []byte:
|
||||
ctx.RequestCtx.SetBody(r)
|
||||
case map[string]any, []any:
|
||||
// JSON response
|
||||
ctx.RequestCtx.Response.Header.SetContentType("application/json")
|
||||
if err := json.NewEncoder(buf).Encode(r); err == nil {
|
||||
ctx.RequestCtx.SetBody(buf.Bytes())
|
||||
} else {
|
||||
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
||||
}
|
||||
default:
|
||||
// Default string conversion
|
||||
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
|
3
go.mod
3
go.mod
|
@ -6,9 +6,12 @@ require git.sharkk.net/Sky/LuaJIT-to-Go v0.0.0
|
|||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/panjf2000/ants/v2 v2.11.2 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.60.0 // indirect
|
||||
golang.org/x/sync v0.12.0 // indirect
|
||||
)
|
||||
|
||||
replace git.sharkk.net/Sky/LuaJIT-to-Go => ./luajit
|
||||
|
|
6
go.sum
6
go.sum
|
@ -1,9 +1,15 @@
|
|||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/panjf2000/ants/v2 v2.11.2 h1:AVGpMSePxUNpcLaBO34xuIgM1ZdKOiGnpxLXixLi5Jo=
|
||||
github.com/panjf2000/ants/v2 v2.11.2/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw=
|
||||
github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
|
|
Loading…
Reference in New Issue
Block a user