parent
52e34d4fdc
commit
f965ddc7c7
|
@ -1,372 +0,0 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrLoopClosed = errors.New("event loop is closed")
|
||||
ErrExecutionTimeout = errors.New("script execution timed out")
|
||||
)
|
||||
|
||||
// StateInitFunc is a function that initializes a Lua state
|
||||
type StateInitFunc func(*luajit.State) error
|
||||
|
||||
// EventLoop represents a single-threaded Lua execution environment
|
||||
type EventLoop struct {
|
||||
state *luajit.State // Single Lua state for all executions
|
||||
jobQueue chan job // Channel for receiving jobs
|
||||
quit chan struct{} // Channel for shutdown signaling
|
||||
wg sync.WaitGroup // WaitGroup for clean shutdown
|
||||
isRunning atomic.Bool // Flag to track if loop is running
|
||||
timeout time.Duration // Default timeout for script execution
|
||||
stateInit StateInitFunc // Optional function to initialize Lua state
|
||||
bufferSize int // Size of job queue buffer
|
||||
}
|
||||
|
||||
// EventLoopConfig contains configuration options for creating an EventLoop
|
||||
type EventLoopConfig struct {
|
||||
// StateInit is a function to initialize the Lua state with custom modules and functions
|
||||
StateInit StateInitFunc
|
||||
|
||||
// BufferSize is the size of the job queue buffer (default: 100)
|
||||
BufferSize int
|
||||
|
||||
// Timeout is the default execution timeout (default: 30s, 0 means no timeout)
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// NewEventLoop creates a new event loop with default configuration
|
||||
func NewEventLoop() (*EventLoop, error) {
|
||||
return NewEventLoopWithConfig(EventLoopConfig{})
|
||||
}
|
||||
|
||||
// NewEventLoopWithInit creates a new event loop with a state initialization function
|
||||
func NewEventLoopWithInit(init StateInitFunc) (*EventLoop, error) {
|
||||
return NewEventLoopWithConfig(EventLoopConfig{
|
||||
StateInit: init,
|
||||
})
|
||||
}
|
||||
|
||||
// NewEventLoopWithConfig creates a new event loop with custom configuration
|
||||
func NewEventLoopWithConfig(config EventLoopConfig) (*EventLoop, error) {
|
||||
// Set default values
|
||||
bufferSize := config.BufferSize
|
||||
if bufferSize <= 0 {
|
||||
bufferSize = 100 // Default buffer size
|
||||
}
|
||||
|
||||
timeout := config.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second // Default timeout
|
||||
}
|
||||
|
||||
// Initialize the Lua state
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
return nil, errors.New("failed to create Lua state")
|
||||
}
|
||||
|
||||
// Create the event loop instance
|
||||
el := &EventLoop{
|
||||
state: state,
|
||||
jobQueue: make(chan job, bufferSize),
|
||||
quit: make(chan struct{}),
|
||||
timeout: timeout,
|
||||
stateInit: config.StateInit,
|
||||
bufferSize: bufferSize,
|
||||
}
|
||||
el.isRunning.Store(true)
|
||||
|
||||
// Set up the sandbox environment
|
||||
if err := setupSandbox(el.state); err != nil {
|
||||
state.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize the state if needed
|
||||
if el.stateInit != nil {
|
||||
if err := el.stateInit(el.state); err != nil {
|
||||
state.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Start the event loop
|
||||
el.wg.Add(1)
|
||||
go el.run()
|
||||
|
||||
return el, nil
|
||||
}
|
||||
|
||||
// run is the main event loop goroutine
|
||||
func (el *EventLoop) run() {
|
||||
defer el.wg.Done()
|
||||
defer el.state.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case job, ok := <-el.jobQueue:
|
||||
if !ok {
|
||||
// Job queue closed, exit
|
||||
return
|
||||
}
|
||||
|
||||
// Execute job with timeout if configured
|
||||
if el.timeout > 0 {
|
||||
el.executeJobWithTimeout(job)
|
||||
} else {
|
||||
// Execute without timeout
|
||||
result := executeJobSandboxed(el.state, job)
|
||||
job.Result <- result
|
||||
}
|
||||
|
||||
case <-el.quit:
|
||||
// Quit signal received, exit
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeJobWithTimeout executes a job with a timeout
|
||||
func (el *EventLoop) executeJobWithTimeout(j job) {
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), el.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Create a channel for the result
|
||||
resultCh := make(chan JobResult, 1)
|
||||
|
||||
// Execute the job in a separate goroutine
|
||||
go func() {
|
||||
result := executeJobSandboxed(el.state, j)
|
||||
select {
|
||||
case resultCh <- result:
|
||||
// Result sent successfully
|
||||
case <-ctx.Done():
|
||||
// Context canceled, result no longer needed
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for result or timeout
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
// Send result to the original channel
|
||||
j.Result <- result
|
||||
case <-ctx.Done():
|
||||
// Timeout occurred
|
||||
j.Result <- JobResult{nil, ErrExecutionTimeout}
|
||||
// NOTE: The Lua execution continues in the background until it completes,
|
||||
// but the result is discarded. This is a compromise to avoid forcibly
|
||||
// terminating Lua code which could corrupt the state.
|
||||
}
|
||||
}
|
||||
|
||||
// Submit sends a job to the event loop
|
||||
func (el *EventLoop) Submit(bytecode []byte, execCtx *Context) (any, error) {
|
||||
return el.SubmitWithContext(context.Background(), bytecode, execCtx)
|
||||
}
|
||||
|
||||
// SubmitWithTimeout sends a job to the event loop with a specific timeout
|
||||
func (el *EventLoop) SubmitWithTimeout(bytecode []byte, execCtx *Context, timeout time.Duration) (any, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return el.SubmitWithContext(ctx, bytecode, execCtx)
|
||||
}
|
||||
|
||||
// SubmitWithContext sends a job to the event loop with context for cancellation
|
||||
func (el *EventLoop) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
|
||||
if !el.isRunning.Load() {
|
||||
return nil, ErrLoopClosed
|
||||
}
|
||||
|
||||
resultChan := make(chan JobResult, 1)
|
||||
j := job{
|
||||
Bytecode: bytecode,
|
||||
Context: execCtx,
|
||||
Result: resultChan,
|
||||
}
|
||||
|
||||
// Submit job with context
|
||||
select {
|
||||
case el.jobQueue <- j:
|
||||
// Job submitted
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Wait for result with context
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
return result.Value, result.Error
|
||||
case <-ctx.Done():
|
||||
// Context canceled, but the job might still be processed
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout updates the default timeout for script execution
|
||||
func (el *EventLoop) SetTimeout(timeout time.Duration) {
|
||||
el.timeout = timeout
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the event loop
|
||||
func (el *EventLoop) Shutdown() error {
|
||||
if !el.isRunning.Load() {
|
||||
return ErrLoopClosed
|
||||
}
|
||||
el.isRunning.Store(false)
|
||||
|
||||
// Signal event loop to quit
|
||||
close(el.quit)
|
||||
|
||||
// Wait for event loop to finish
|
||||
el.wg.Wait()
|
||||
|
||||
// Close job queue
|
||||
close(el.jobQueue)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupSandbox initializes the sandbox environment in the Lua state
|
||||
func setupSandbox(state *luajit.State) error {
|
||||
// This is the Lua script that creates our sandbox function
|
||||
setupScript := `
|
||||
-- Create a function to run code in a sandbox environment
|
||||
function __create_sandbox()
|
||||
-- Create new environment table
|
||||
local env = {}
|
||||
|
||||
-- Add standard library modules (can be restricted as needed)
|
||||
env.string = string
|
||||
env.table = table
|
||||
env.math = math
|
||||
env.os = {
|
||||
time = os.time,
|
||||
date = os.date,
|
||||
difftime = os.difftime,
|
||||
clock = os.clock
|
||||
}
|
||||
env.tonumber = tonumber
|
||||
env.tostring = tostring
|
||||
env.type = type
|
||||
env.pairs = pairs
|
||||
env.ipairs = ipairs
|
||||
env.next = next
|
||||
env.select = select
|
||||
env.unpack = unpack
|
||||
env.pcall = pcall
|
||||
env.xpcall = xpcall
|
||||
env.error = error
|
||||
env.assert = assert
|
||||
|
||||
-- Allow access to package.loaded for modules
|
||||
env.require = function(name)
|
||||
return package.loaded[name]
|
||||
end
|
||||
|
||||
-- Create metatable to restrict access to _G
|
||||
local mt = {
|
||||
__index = function(t, k)
|
||||
-- First check in env table
|
||||
local v = rawget(env, k)
|
||||
if v ~= nil then return v end
|
||||
|
||||
-- If not found, check for registered modules/functions
|
||||
local moduleValue = _G[k]
|
||||
if type(moduleValue) == "table" or
|
||||
type(moduleValue) == "function" then
|
||||
return moduleValue
|
||||
end
|
||||
|
||||
return nil
|
||||
end,
|
||||
__newindex = function(t, k, v)
|
||||
rawset(env, k, v)
|
||||
end
|
||||
}
|
||||
|
||||
setmetatable(env, mt)
|
||||
return env
|
||||
end
|
||||
|
||||
-- Create function to execute code with a sandbox
|
||||
function __run_sandboxed(f, ctx)
|
||||
local env = __create_sandbox()
|
||||
|
||||
-- Add context to the environment if provided
|
||||
if ctx then
|
||||
env.ctx = ctx
|
||||
end
|
||||
|
||||
-- Set the environment and run the function
|
||||
setfenv(f, env)
|
||||
return f()
|
||||
end
|
||||
`
|
||||
|
||||
return state.DoString(setupScript)
|
||||
}
|
||||
|
||||
// executeJobSandboxed runs a script in a sandbox environment
|
||||
func executeJobSandboxed(state *luajit.State, j job) JobResult {
|
||||
// Set up context if provided
|
||||
if j.Context != nil {
|
||||
// Push context table
|
||||
state.NewTable()
|
||||
|
||||
// Add values to context table
|
||||
for key, value := range j.Context.Values {
|
||||
// Push key
|
||||
state.PushString(key)
|
||||
|
||||
// Push value
|
||||
if err := state.PushValue(value); err != nil {
|
||||
state.Pop(1) // Pop table
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Set table[key] = value
|
||||
state.SetTable(-3)
|
||||
}
|
||||
} else {
|
||||
// Push nil if no context
|
||||
state.PushNil()
|
||||
}
|
||||
|
||||
// Load bytecode
|
||||
if err := state.LoadBytecode(j.Bytecode, "script"); err != nil {
|
||||
state.Pop(1) // Pop context
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Get the sandbox runner function
|
||||
state.GetGlobal("__run_sandboxed")
|
||||
|
||||
// Push loaded function and context as arguments
|
||||
state.PushCopy(-2) // Copy the loaded function
|
||||
state.PushCopy(-4) // Copy the context table or nil
|
||||
|
||||
// Remove the original function and context
|
||||
state.Remove(-5) // Remove original context
|
||||
state.Remove(-4) // Remove original function
|
||||
|
||||
// Call the sandbox runner with 2 args (function and context), expecting 1 result
|
||||
if err := state.Call(2, 1); err != nil {
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Get result
|
||||
value, err := state.ToValue(-1)
|
||||
state.Pop(1) // Pop result
|
||||
|
||||
return JobResult{value, err}
|
||||
}
|
|
@ -1,741 +0,0 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Helper function to create bytecode for testing
|
||||
func createTestBytecode(t *testing.T, code string) []byte {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile test bytecode: %v", err)
|
||||
}
|
||||
|
||||
return bytecode
|
||||
}
|
||||
|
||||
// Test creating a new event loop with default and custom configs
|
||||
func TestNewEventLoop(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config EventLoopConfig
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Default config",
|
||||
config: EventLoopConfig{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Custom buffer size",
|
||||
config: EventLoopConfig{
|
||||
BufferSize: 200,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Custom timeout",
|
||||
config: EventLoopConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "With init function",
|
||||
config: EventLoopConfig{
|
||||
StateInit: func(state *luajit.State) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
el, err := NewEventLoopWithConfig(tc.config)
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if el == nil {
|
||||
t.Errorf("Expected non-nil event loop")
|
||||
} else {
|
||||
el.Shutdown()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test basic job submission and execution
|
||||
func TestEventLoopBasicSubmission(t *testing.T) {
|
||||
el, err := NewEventLoop()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Simple return a value
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
|
||||
result, err := el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected float64 result, got %T", result)
|
||||
}
|
||||
|
||||
if num != 42 {
|
||||
t.Errorf("Expected 42, got %f", num)
|
||||
}
|
||||
|
||||
// Test more complex Lua code
|
||||
bytecode = createTestBytecode(t, `
|
||||
local result = 0
|
||||
for i = 1, 10 do
|
||||
result = result + i
|
||||
end
|
||||
return result
|
||||
`)
|
||||
|
||||
result, err = el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
num, ok = result.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected float64 result, got %T", result)
|
||||
}
|
||||
|
||||
if num != 55 {
|
||||
t.Errorf("Expected 55, got %f", num)
|
||||
}
|
||||
}
|
||||
|
||||
// Test context passing between Go and Lua
|
||||
func TestEventLoopContext(t *testing.T) {
|
||||
el, err := NewEventLoop()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
bytecode := createTestBytecode(t, `
|
||||
return {
|
||||
num = ctx.number,
|
||||
str = ctx.text,
|
||||
flag = ctx.enabled,
|
||||
list = {ctx.items[1], ctx.items[2], ctx.items[3]},
|
||||
}
|
||||
`)
|
||||
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("number", 42.5)
|
||||
execCtx.Set("text", "hello")
|
||||
execCtx.Set("enabled", true)
|
||||
execCtx.Set("items", []float64{10, 20, 30})
|
||||
|
||||
result, err := el.Submit(bytecode, execCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Result should be a map
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Check values
|
||||
if resultMap["num"] != 42.5 {
|
||||
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
|
||||
}
|
||||
if resultMap["str"] != "hello" {
|
||||
t.Errorf("Expected str=hello, got %v", resultMap["str"])
|
||||
}
|
||||
if resultMap["flag"] != true {
|
||||
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
|
||||
}
|
||||
|
||||
arr, ok := resultMap["list"].([]float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []float64, got %T", resultMap["list"])
|
||||
}
|
||||
|
||||
expected := []float64{10, 20, 30}
|
||||
for i, v := range expected {
|
||||
if arr[i] != v {
|
||||
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Test complex nested context
|
||||
nestedCtx := NewContext()
|
||||
nestedCtx.Set("user", map[string]any{
|
||||
"id": 123,
|
||||
"name": "test user",
|
||||
"roles": []any{
|
||||
"admin",
|
||||
"editor",
|
||||
},
|
||||
})
|
||||
|
||||
bytecode = createTestBytecode(t, `
|
||||
return {
|
||||
id = ctx.user.id,
|
||||
name = ctx.user.name,
|
||||
role1 = ctx.user.roles[1],
|
||||
role2 = ctx.user.roles[2],
|
||||
}
|
||||
`)
|
||||
|
||||
result, err = el.Submit(bytecode, nestedCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job with nested context: %v", err)
|
||||
}
|
||||
|
||||
resultMap, ok = result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["id"] != float64(123) {
|
||||
t.Errorf("Expected id=123, got %v", resultMap["id"])
|
||||
}
|
||||
if resultMap["name"] != "test user" {
|
||||
t.Errorf("Expected name='test user', got %v", resultMap["name"])
|
||||
}
|
||||
if resultMap["role1"] != "admin" {
|
||||
t.Errorf("Expected role1='admin', got %v", resultMap["role1"])
|
||||
}
|
||||
if resultMap["role2"] != "editor" {
|
||||
t.Errorf("Expected role2='editor', got %v", resultMap["role2"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test execution timeout
|
||||
func TestEventLoopTimeout(t *testing.T) {
|
||||
// Create event loop with short timeout
|
||||
el, err := NewEventLoopWithConfig(EventLoopConfig{
|
||||
Timeout: 100 * time.Millisecond,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Create bytecode that runs for longer than the timeout
|
||||
bytecode := createTestBytecode(t, `
|
||||
-- Loop for 500ms
|
||||
local start = os.time()
|
||||
while os.difftime(os.time(), start) < 0.5 do end
|
||||
return "done"
|
||||
`)
|
||||
|
||||
// This should time out
|
||||
_, err = el.Submit(bytecode, nil)
|
||||
if err != ErrExecutionTimeout {
|
||||
t.Errorf("Expected timeout error, got: %v", err)
|
||||
}
|
||||
|
||||
// Now set a longer timeout and try again
|
||||
el.SetTimeout(1 * time.Second)
|
||||
|
||||
// This should succeed
|
||||
result, err := el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected success with longer timeout, got: %v", err)
|
||||
}
|
||||
if result != "done" {
|
||||
t.Errorf("Expected 'done', got %v", result)
|
||||
}
|
||||
|
||||
// Test per-call timeout with SubmitWithTimeout
|
||||
bytecode = createTestBytecode(t, `
|
||||
-- Loop for 300ms
|
||||
local start = os.time()
|
||||
while os.difftime(os.time(), start) < 0.3 do end
|
||||
return "done again"
|
||||
`)
|
||||
|
||||
// This should time out with a custom timeout
|
||||
_, err = el.SubmitWithTimeout(bytecode, nil, 50*time.Millisecond)
|
||||
if err == nil {
|
||||
t.Errorf("Expected timeout error, got success")
|
||||
}
|
||||
|
||||
// This should succeed with a longer custom timeout
|
||||
result, err = el.SubmitWithTimeout(bytecode, nil, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected success with custom timeout, got: %v", err)
|
||||
}
|
||||
if result != "done again" {
|
||||
t.Errorf("Expected 'done again', got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Test module registration and execution
|
||||
func TestEventLoopModules(t *testing.T) {
|
||||
// Define an init function that registers a simple "math" module
|
||||
mathInit := func(state *luajit.State) error {
|
||||
// Register the "add" function directly
|
||||
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1 // Return one result
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register a math module with multiple functions
|
||||
mathFuncs := map[string]luajit.GoFunction{
|
||||
"multiply": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a * b)
|
||||
return 1
|
||||
},
|
||||
"subtract": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a - b)
|
||||
return 1
|
||||
},
|
||||
}
|
||||
|
||||
return RegisterModule(state, "math2", mathFuncs)
|
||||
}
|
||||
|
||||
// Create an event loop with our init function
|
||||
el, err := NewEventLoopWithInit(mathInit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Test the add function
|
||||
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
||||
result1, err := el.Submit(bytecode1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call add function: %v", err)
|
||||
}
|
||||
|
||||
num1, ok := result1.(float64)
|
||||
if !ok || num1 != 12 {
|
||||
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
|
||||
}
|
||||
|
||||
// Test the math2 module
|
||||
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
||||
result2, err := el.Submit(bytecode2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call math2.multiply: %v", err)
|
||||
}
|
||||
|
||||
num2, ok := result2.(float64)
|
||||
if !ok || num2 != 48 {
|
||||
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
|
||||
}
|
||||
|
||||
// Test multiple operations
|
||||
bytecode3 := createTestBytecode(t, `
|
||||
local a = add(10, 20)
|
||||
local b = math2.subtract(a, 5)
|
||||
return math2.multiply(b, 2)
|
||||
`)
|
||||
|
||||
result3, err := el.Submit(bytecode3, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute combined operations: %v", err)
|
||||
}
|
||||
|
||||
num3, ok := result3.(float64)
|
||||
if !ok || num3 != 50 {
|
||||
t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3)
|
||||
}
|
||||
}
|
||||
|
||||
// Test combined module init functions
|
||||
func TestEventLoopCombinedModules(t *testing.T) {
|
||||
// First init function adds a function to get a constant value
|
||||
init1 := func(state *luajit.State) error {
|
||||
return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int {
|
||||
s.PushNumber(42)
|
||||
return 1
|
||||
})
|
||||
}
|
||||
|
||||
// Second init function registers a function that multiplies a number by 2
|
||||
init2 := func(state *luajit.State) error {
|
||||
return state.RegisterGoFunction("double", func(s *luajit.State) int {
|
||||
n := s.ToNumber(1)
|
||||
s.PushNumber(n * 2)
|
||||
return 1
|
||||
})
|
||||
}
|
||||
|
||||
// Combine the init functions
|
||||
combinedInit := CombineInitFuncs(init1, init2)
|
||||
|
||||
// Create an event loop with the combined init function
|
||||
el, err := NewEventLoopWithInit(combinedInit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Test using both functions together in a single script
|
||||
bytecode := createTestBytecode(t, "return double(getAnswer())")
|
||||
result, err := el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute: %v", err)
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok || num != 84 {
|
||||
t.Errorf("Expected double(getAnswer()) = 84, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Test sandbox isolation between executions
|
||||
func TestEventLoopSandboxIsolation(t *testing.T) {
|
||||
el, err := NewEventLoop()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Create a script that tries to modify a global variable
|
||||
bytecode1 := createTestBytecode(t, `
|
||||
-- Set a "global" variable
|
||||
my_global = "test value"
|
||||
return true
|
||||
`)
|
||||
|
||||
_, err = el.Submit(bytecode1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute first script: %v", err)
|
||||
}
|
||||
|
||||
// Now try to access that variable from another script
|
||||
bytecode2 := createTestBytecode(t, `
|
||||
-- Try to access the previously set global
|
||||
return my_global ~= nil
|
||||
`)
|
||||
|
||||
result, err := el.Submit(bytecode2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute second script: %v", err)
|
||||
}
|
||||
|
||||
// The variable should not be accessible (sandbox isolation)
|
||||
if result.(bool) {
|
||||
t.Errorf("Expected sandbox isolation, but global variable was accessible")
|
||||
}
|
||||
}
|
||||
|
||||
// Test error handling
|
||||
func TestEventLoopErrorHandling(t *testing.T) {
|
||||
el, err := NewEventLoop()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Test invalid bytecode
|
||||
_, err = el.Submit([]byte("not valid bytecode"), nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid bytecode, got nil")
|
||||
}
|
||||
|
||||
// Test Lua runtime error
|
||||
bytecode := createTestBytecode(t, `
|
||||
error("intentional error")
|
||||
return true
|
||||
`)
|
||||
|
||||
_, err = el.Submit(bytecode, nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from Lua error() call, got nil")
|
||||
}
|
||||
|
||||
// Test with nil context (should work fine)
|
||||
bytecode = createTestBytecode(t, "return ctx == nil")
|
||||
result, err := el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error with nil context: %v", err)
|
||||
}
|
||||
if result.(bool) != true {
|
||||
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
|
||||
}
|
||||
|
||||
// Test access to restricted library
|
||||
bytecode = createTestBytecode(t, `
|
||||
-- Try to access io library directly
|
||||
return io ~= nil
|
||||
`)
|
||||
|
||||
result, err = el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute sandbox test: %v", err)
|
||||
}
|
||||
|
||||
// io should not be directly accessible
|
||||
if result.(bool) {
|
||||
t.Errorf("Expected io library to be restricted, but it was accessible")
|
||||
}
|
||||
}
|
||||
|
||||
// Test concurrent job submission
|
||||
func TestEventLoopConcurrency(t *testing.T) {
|
||||
el, err := NewEventLoopWithConfig(EventLoopConfig{
|
||||
BufferSize: 100, // Buffer for concurrent submissions
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Create bytecode that returns its input value
|
||||
bytecode := createTestBytecode(t, "return ctx.n")
|
||||
|
||||
// Submit multiple jobs concurrently
|
||||
const jobCount = 50
|
||||
var wg sync.WaitGroup
|
||||
results := make([]int, jobCount)
|
||||
|
||||
wg.Add(jobCount)
|
||||
for i := 0; i < jobCount; i++ {
|
||||
i := i // Capture loop variable
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Create context with job number
|
||||
ctx := NewContext()
|
||||
ctx.Set("n", float64(i))
|
||||
|
||||
// Submit job
|
||||
result, err := el.Submit(bytecode, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Job %d failed: %v", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify result matches job number
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Errorf("Job %d: expected float64, got %T", i, result)
|
||||
return
|
||||
}
|
||||
|
||||
results[i] = int(num)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all results
|
||||
for i, res := range results {
|
||||
if res != i && res != 0 { // 0 means error already logged
|
||||
t.Errorf("Expected result[%d] = %d, got %d", i, i, res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test state consistency across multiple calls
|
||||
func TestEventLoopStateConsistency(t *testing.T) {
|
||||
// Create an event loop with a module that maintains count between calls
|
||||
initFunc := func(state *luajit.State) error {
|
||||
// Create a closure that increments a counter in upvalue
|
||||
code := `
|
||||
-- Create a counter with initial value 0
|
||||
local counter = 0
|
||||
|
||||
-- Create a function that returns and increments the counter
|
||||
function get_next_count()
|
||||
local current = counter
|
||||
counter = counter + 1
|
||||
return current
|
||||
end
|
||||
`
|
||||
return state.DoString(code)
|
||||
}
|
||||
|
||||
el, err := NewEventLoopWithInit(initFunc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Now run multiple scripts that call the counter function
|
||||
bytecode := createTestBytecode(t, "return get_next_count()")
|
||||
|
||||
// Each call should return an incremented value
|
||||
for i := 0; i < 5; i++ {
|
||||
result, err := el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Call %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected float64 result, got %T", result)
|
||||
}
|
||||
|
||||
if int(num) != i {
|
||||
t.Errorf("Expected count %d, got %d", i, int(num))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test shutdown and cleanup
|
||||
func TestEventLoopShutdown(t *testing.T) {
|
||||
el, err := NewEventLoop()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
|
||||
// Submit a job to verify it works
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
_, err = el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown
|
||||
if err := el.Shutdown(); err != nil {
|
||||
t.Errorf("Shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
// Submit after shutdown should fail
|
||||
_, err = el.Submit(bytecode, nil)
|
||||
if err != ErrLoopClosed {
|
||||
t.Errorf("Expected ErrLoopClosed, got %v", err)
|
||||
}
|
||||
|
||||
// Second shutdown should return error
|
||||
if err := el.Shutdown(); err != ErrLoopClosed {
|
||||
t.Errorf("Expected ErrLoopClosed on second shutdown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test high load with multiple sequential and concurrent jobs
|
||||
func TestEventLoopHighLoad(t *testing.T) {
|
||||
el, err := NewEventLoopWithConfig(EventLoopConfig{
|
||||
BufferSize: 1000, // Large buffer for high load
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Sequential load test
|
||||
bytecode := createTestBytecode(t, `
|
||||
-- Do some work
|
||||
local result = 0
|
||||
for i = 1, 1000 do
|
||||
result = result + i
|
||||
end
|
||||
return result
|
||||
`)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < 100; i++ {
|
||||
_, err := el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Sequential job %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
seqDuration := time.Since(start)
|
||||
t.Logf("Sequential load test: 100 jobs in %v", seqDuration)
|
||||
|
||||
// Concurrent load test
|
||||
start = time.Now()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(100)
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := el.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent job failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
concDuration := time.Since(start)
|
||||
t.Logf("Concurrent load test: 100 jobs in %v", concDuration)
|
||||
}
|
||||
|
||||
// Test context cancellation
|
||||
func TestEventLoopCancel(t *testing.T) {
|
||||
el, err := NewEventLoop()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event loop: %v", err)
|
||||
}
|
||||
defer el.Shutdown()
|
||||
|
||||
// Create a long-running script
|
||||
bytecode := createTestBytecode(t, `
|
||||
-- Sleep for 500ms
|
||||
local start = os.time()
|
||||
while os.difftime(os.time(), start) < 0.5 do end
|
||||
return "done"
|
||||
`)
|
||||
|
||||
// Create a context that we can cancel
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Start execution in a goroutine
|
||||
resultCh := make(chan any, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
res, err := el.SubmitWithContext(ctx, bytecode, nil)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
} else {
|
||||
resultCh <- res
|
||||
}
|
||||
}()
|
||||
|
||||
// Cancel quickly
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
// Should get cancellation error
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if ctx.Err() == nil || err == nil {
|
||||
t.Errorf("Expected context cancellation error")
|
||||
}
|
||||
case res := <-resultCh:
|
||||
t.Errorf("Expected cancellation, got result: %v", res)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Errorf("Timed out waiting for cancellation")
|
||||
}
|
||||
}
|
346
core/workers/init_test.go
Normal file
346
core/workers/init_test.go
Normal file
|
@ -0,0 +1,346 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func TestModuleRegistration(t *testing.T) {
|
||||
// Define an init function that registers a simple "math" module
|
||||
mathInit := func(state *luajit.State) error {
|
||||
// Register the "add" function
|
||||
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1 // Return one result
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register a whole module
|
||||
mathFuncs := map[string]luajit.GoFunction{
|
||||
"multiply": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a * b)
|
||||
return 1
|
||||
},
|
||||
"subtract": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a - b)
|
||||
return 1
|
||||
},
|
||||
}
|
||||
|
||||
return RegisterModule(state, "math2", mathFuncs)
|
||||
}
|
||||
|
||||
// Create a pool with our init function
|
||||
pool, err := NewPoolWithInit(2, mathInit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Test the add function
|
||||
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
||||
result1, err := pool.Submit(bytecode1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call add function: %v", err)
|
||||
}
|
||||
|
||||
num1, ok := result1.(float64)
|
||||
if !ok || num1 != 12 {
|
||||
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
|
||||
}
|
||||
|
||||
// Test the math2 module
|
||||
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
||||
result2, err := pool.Submit(bytecode2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call math2.multiply: %v", err)
|
||||
}
|
||||
|
||||
num2, ok := result2.(float64)
|
||||
if !ok || num2 != 48 {
|
||||
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
|
||||
}
|
||||
|
||||
// Test multiple operations
|
||||
bytecode3 := createTestBytecode(t, `
|
||||
local a = add(10, 20)
|
||||
local b = math2.subtract(a, 5)
|
||||
return math2.multiply(b, 2)
|
||||
`)
|
||||
|
||||
result3, err := pool.Submit(bytecode3, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute combined operations: %v", err)
|
||||
}
|
||||
|
||||
num3, ok := result3.(float64)
|
||||
if !ok || num3 != 50 {
|
||||
t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModuleInitFunc(t *testing.T) {
|
||||
// Define math module functions
|
||||
mathModule := func() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"add": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1
|
||||
},
|
||||
"multiply": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a * b)
|
||||
return 1
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Define string module functions
|
||||
strModule := func() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"concat": func(s *luajit.State) int {
|
||||
a := s.ToString(1)
|
||||
b := s.ToString(2)
|
||||
s.PushString(a + b)
|
||||
return 1
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Create module map
|
||||
modules := map[string]ModuleFunc{
|
||||
"math2": mathModule,
|
||||
"str": strModule,
|
||||
}
|
||||
|
||||
// Create pool with module init
|
||||
pool, err := NewPoolWithInit(2, ModuleInitFunc(modules))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Test math module
|
||||
bytecode1 := createTestBytecode(t, "return math2.add(5, 7)")
|
||||
result1, err := pool.Submit(bytecode1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call math2.add: %v", err)
|
||||
}
|
||||
|
||||
num1, ok := result1.(float64)
|
||||
if !ok || num1 != 12 {
|
||||
t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1)
|
||||
}
|
||||
|
||||
// Test string module
|
||||
bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')")
|
||||
result2, err := pool.Submit(bytecode2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call str.concat: %v", err)
|
||||
}
|
||||
|
||||
str2, ok := result2.(string)
|
||||
if !ok || str2 != "helloworld" {
|
||||
t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineInitFuncs(t *testing.T) {
|
||||
// First init function adds a function to get a constant value
|
||||
init1 := func(state *luajit.State) error {
|
||||
return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int {
|
||||
s.PushNumber(42)
|
||||
return 1
|
||||
})
|
||||
}
|
||||
|
||||
// Second init function registers a function that multiplies a number by 2
|
||||
init2 := func(state *luajit.State) error {
|
||||
return state.RegisterGoFunction("double", func(s *luajit.State) int {
|
||||
n := s.ToNumber(1)
|
||||
s.PushNumber(n * 2)
|
||||
return 1
|
||||
})
|
||||
}
|
||||
|
||||
// Combine the init functions
|
||||
combinedInit := CombineInitFuncs(init1, init2)
|
||||
|
||||
// Create a pool with the combined init function
|
||||
pool, err := NewPoolWithInit(1, combinedInit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Test using both functions together in a single script
|
||||
bytecode := createTestBytecode(t, "return double(getAnswer())")
|
||||
result, err := pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute: %v", err)
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok || num != 84 {
|
||||
t.Errorf("Expected double(getAnswer()) = 84, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxIsolation(t *testing.T) {
|
||||
// Create a pool
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Create a script that tries to modify a global variable
|
||||
bytecode1 := createTestBytecode(t, `
|
||||
-- Set a "global" variable
|
||||
my_global = "test value"
|
||||
return true
|
||||
`)
|
||||
|
||||
_, err = pool.Submit(bytecode1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute first script: %v", err)
|
||||
}
|
||||
|
||||
// Now try to access that variable from another script
|
||||
bytecode2 := createTestBytecode(t, `
|
||||
-- Try to access the previously set global
|
||||
return my_global ~= nil
|
||||
`)
|
||||
|
||||
result, err := pool.Submit(bytecode2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute second script: %v", err)
|
||||
}
|
||||
|
||||
// The variable should not be accessible (sandbox isolation)
|
||||
if result.(bool) {
|
||||
t.Errorf("Expected sandbox isolation, but global variable was accessible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextInSandbox(t *testing.T) {
|
||||
// Create a pool
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Create a context with test data
|
||||
ctx := NewContext()
|
||||
ctx.Set("name", "test")
|
||||
ctx.Set("value", 42.5)
|
||||
ctx.Set("items", []float64{1, 2, 3})
|
||||
|
||||
bytecode := createTestBytecode(t, `
|
||||
-- Access and manipulate context values
|
||||
local sum = 0
|
||||
for i, v in ipairs(ctx.items) do
|
||||
sum = sum + v
|
||||
end
|
||||
|
||||
return {
|
||||
name_length = string.len(ctx.name),
|
||||
value_doubled = ctx.value * 2,
|
||||
items_sum = sum
|
||||
}
|
||||
`)
|
||||
|
||||
result, err := pool.Submit(bytecode, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute script with context: %v", err)
|
||||
}
|
||||
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Check context values were correctly accessible
|
||||
if resultMap["name_length"].(float64) != 4 {
|
||||
t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"])
|
||||
}
|
||||
|
||||
if resultMap["value_doubled"].(float64) != 85 {
|
||||
t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"])
|
||||
}
|
||||
|
||||
if resultMap["items_sum"].(float64) != 6 {
|
||||
t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandardLibsInSandbox(t *testing.T) {
|
||||
// Create a pool
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Test access to standard libraries
|
||||
bytecode := createTestBytecode(t, `
|
||||
local results = {}
|
||||
|
||||
-- Test string library
|
||||
results.string_upper = string.upper("test")
|
||||
|
||||
-- Test math library
|
||||
results.math_sqrt = math.sqrt(16)
|
||||
|
||||
-- Test table library
|
||||
local tbl = {10, 20, 30}
|
||||
table.insert(tbl, 40)
|
||||
results.table_length = #tbl
|
||||
|
||||
-- Test os library (limited functions)
|
||||
results.has_os_time = type(os.time) == "function"
|
||||
|
||||
return results
|
||||
`)
|
||||
|
||||
result, err := pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute script: %v", err)
|
||||
}
|
||||
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Check standard library functions worked
|
||||
if resultMap["string_upper"] != "TEST" {
|
||||
t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"])
|
||||
}
|
||||
|
||||
if resultMap["math_sqrt"].(float64) != 4 {
|
||||
t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"])
|
||||
}
|
||||
|
||||
if resultMap["table_length"].(float64) != 4 {
|
||||
t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"])
|
||||
}
|
||||
|
||||
if resultMap["has_os_time"] != true {
|
||||
t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"])
|
||||
}
|
||||
}
|
123
core/workers/pool.go
Normal file
123
core/workers/pool.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// StateInitFunc is a function that initializes a Lua state
|
||||
// It can be used to register custom functions and modules
|
||||
type StateInitFunc func(*luajit.State) error
|
||||
|
||||
// Pool manages a pool of Lua worker goroutines
|
||||
type Pool struct {
|
||||
workers uint32 // Number of workers
|
||||
jobs chan job // Channel to send jobs to workers
|
||||
wg sync.WaitGroup // WaitGroup to track active workers
|
||||
quit chan struct{} // Channel to signal shutdown
|
||||
isRunning atomic.Bool // Flag to track if pool is running
|
||||
stateInit StateInitFunc // Optional function to initialize Lua state
|
||||
}
|
||||
|
||||
// NewPool creates a new worker pool with the specified number of workers
|
||||
func NewPool(numWorkers int) (*Pool, error) {
|
||||
return NewPoolWithInit(numWorkers, nil)
|
||||
}
|
||||
|
||||
// NewPoolWithInit creates a new worker pool with the specified number of workers
|
||||
// and a function to initialize each worker's Lua state
|
||||
func NewPoolWithInit(numWorkers int, initFunc StateInitFunc) (*Pool, error) {
|
||||
if numWorkers <= 0 {
|
||||
return nil, ErrNoWorkers
|
||||
}
|
||||
|
||||
p := &Pool{
|
||||
workers: uint32(numWorkers),
|
||||
jobs: make(chan job, numWorkers), // Buffer equal to worker count
|
||||
quit: make(chan struct{}),
|
||||
stateInit: initFunc,
|
||||
}
|
||||
p.isRunning.Store(true)
|
||||
|
||||
// Start workers
|
||||
p.wg.Add(numWorkers)
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
w := &worker{
|
||||
pool: p,
|
||||
id: uint32(i),
|
||||
}
|
||||
go w.run()
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// RegisterGlobal is no longer needed with the sandbox approach
|
||||
// but kept as a no-op for backward compatibility
|
||||
func (p *Pool) RegisterGlobal(name string) {
|
||||
// No-op in sandbox mode
|
||||
}
|
||||
|
||||
// SubmitWithContext sends a job to the worker pool with context
|
||||
func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
|
||||
if !p.isRunning.Load() {
|
||||
return nil, ErrPoolClosed
|
||||
}
|
||||
|
||||
resultChan := make(chan JobResult, 1)
|
||||
j := job{
|
||||
Bytecode: bytecode,
|
||||
Context: execCtx,
|
||||
Result: resultChan,
|
||||
}
|
||||
|
||||
// Submit job with context
|
||||
select {
|
||||
case p.jobs <- j:
|
||||
// Job submitted
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Wait for result with context
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
return result.Value, result.Error
|
||||
case <-ctx.Done():
|
||||
// Note: The job will still be processed by a worker,
|
||||
// but the result will be discarded
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Submit sends a job to the worker pool
|
||||
func (p *Pool) Submit(bytecode []byte, execCtx *Context) (any, error) {
|
||||
return p.SubmitWithContext(context.Background(), bytecode, execCtx)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the worker pool
|
||||
func (p *Pool) Shutdown() error {
|
||||
if !p.isRunning.Load() {
|
||||
return ErrPoolClosed
|
||||
}
|
||||
p.isRunning.Store(false)
|
||||
|
||||
// Signal workers to quit
|
||||
close(p.quit)
|
||||
|
||||
// Wait for workers to finish
|
||||
p.wg.Wait()
|
||||
|
||||
// Close jobs channel
|
||||
close(p.jobs)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActiveWorkers returns the number of active workers
|
||||
func (p *Pool) ActiveWorkers() uint32 {
|
||||
return atomic.LoadUint32(&p.workers)
|
||||
}
|
144
core/workers/sandbox.go
Normal file
144
core/workers/sandbox.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package workers
|
||||
|
||||
// setupSandbox initializes the sandbox environment creation function
|
||||
func (w *worker) setupSandbox() error {
|
||||
// This is the Lua script that creates our sandbox function
|
||||
setupScript := `
|
||||
-- Create a function to run code in a sandbox environment
|
||||
function __create_sandbox()
|
||||
-- Create new environment table
|
||||
local env = {}
|
||||
|
||||
-- Add standard library modules (can be restricted as needed)
|
||||
env.string = string
|
||||
env.table = table
|
||||
env.math = math
|
||||
env.os = {
|
||||
time = os.time,
|
||||
date = os.date,
|
||||
difftime = os.difftime,
|
||||
clock = os.clock
|
||||
}
|
||||
env.tonumber = tonumber
|
||||
env.tostring = tostring
|
||||
env.type = type
|
||||
env.pairs = pairs
|
||||
env.ipairs = ipairs
|
||||
env.next = next
|
||||
env.select = select
|
||||
env.unpack = unpack
|
||||
env.pcall = pcall
|
||||
env.xpcall = xpcall
|
||||
env.error = error
|
||||
env.assert = assert
|
||||
|
||||
-- Allow access to package.loaded for modules
|
||||
env.require = function(name)
|
||||
return package.loaded[name]
|
||||
end
|
||||
|
||||
-- Create metatable to restrict access to _G
|
||||
local mt = {
|
||||
__index = function(t, k)
|
||||
-- First check in env table
|
||||
local v = rawget(env, k)
|
||||
if v ~= nil then return v end
|
||||
|
||||
-- If not found, check for registered modules/functions
|
||||
local moduleValue = _G[k]
|
||||
if type(moduleValue) == "table" or
|
||||
type(moduleValue) == "function" then
|
||||
return moduleValue
|
||||
end
|
||||
|
||||
return nil
|
||||
end,
|
||||
__newindex = function(t, k, v)
|
||||
rawset(env, k, v)
|
||||
end
|
||||
}
|
||||
|
||||
setmetatable(env, mt)
|
||||
return env
|
||||
end
|
||||
|
||||
-- Create function to execute code with a sandbox
|
||||
function __run_sandboxed(f, ctx)
|
||||
local env = __create_sandbox()
|
||||
|
||||
-- Add context to the environment if provided
|
||||
if ctx then
|
||||
env.ctx = ctx
|
||||
end
|
||||
|
||||
-- Set the environment and run the function
|
||||
setfenv(f, env)
|
||||
return f()
|
||||
end
|
||||
`
|
||||
|
||||
return w.state.DoString(setupScript)
|
||||
}
|
||||
|
||||
// executeJobSandboxed runs a script in a sandbox environment
|
||||
func (w *worker) executeJobSandboxed(j job) JobResult {
|
||||
// No need to reset the state for each execution, since we're using a sandbox
|
||||
|
||||
// Re-run init function to register functions and modules if needed
|
||||
if w.pool.stateInit != nil {
|
||||
if err := w.pool.stateInit(w.state); err != nil {
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
}
|
||||
|
||||
// Set up context if provided
|
||||
if j.Context != nil {
|
||||
// Push context table
|
||||
w.state.NewTable()
|
||||
|
||||
// Add values to context table
|
||||
for key, value := range j.Context.Values {
|
||||
// Push key
|
||||
w.state.PushString(key)
|
||||
|
||||
// Push value
|
||||
if err := w.state.PushValue(value); err != nil {
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Set table[key] = value
|
||||
w.state.SetTable(-3)
|
||||
}
|
||||
} else {
|
||||
// Push nil if no context
|
||||
w.state.PushNil()
|
||||
}
|
||||
|
||||
// Load bytecode
|
||||
if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil {
|
||||
w.state.Pop(1) // Pop context
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Get the sandbox runner function
|
||||
w.state.GetGlobal("__run_sandboxed")
|
||||
|
||||
// Push loaded function and context as arguments
|
||||
w.state.PushCopy(-2) // Copy the loaded function
|
||||
w.state.PushCopy(-4) // Copy the context table or nil
|
||||
|
||||
// Remove the original function and context
|
||||
w.state.Remove(-5) // Remove original context
|
||||
w.state.Remove(-4) // Remove original function
|
||||
|
||||
// Call the sandbox runner with 2 args (function and context), expecting 1 result
|
||||
if err := w.state.Call(2, 1); err != nil {
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Get result
|
||||
value, err := w.state.ToValue(-1)
|
||||
w.state.Pop(1) // Pop result
|
||||
|
||||
return JobResult{value, err}
|
||||
}
|
71
core/workers/worker.go
Normal file
71
core/workers/worker.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrPoolClosed = errors.New("worker pool is closed")
|
||||
ErrNoWorkers = errors.New("no workers available")
|
||||
ErrInitFailed = errors.New("worker initialization failed")
|
||||
)
|
||||
|
||||
// worker represents a single Lua execution worker
|
||||
type worker struct {
|
||||
pool *Pool // Reference to the pool
|
||||
state *luajit.State // Lua state
|
||||
id uint32 // Worker ID
|
||||
}
|
||||
|
||||
// run is the main worker function that processes jobs
|
||||
func (w *worker) run() {
|
||||
defer w.pool.wg.Done()
|
||||
|
||||
// Initialize Lua state
|
||||
w.state = luajit.New()
|
||||
if w.state == nil {
|
||||
// Worker failed to initialize, decrement counter
|
||||
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
||||
return
|
||||
}
|
||||
defer w.state.Close()
|
||||
|
||||
// Set up sandbox environment
|
||||
if err := w.setupSandbox(); err != nil {
|
||||
// Worker failed to initialize sandbox, decrement counter
|
||||
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
||||
return
|
||||
}
|
||||
|
||||
// Run init function if provided
|
||||
if w.pool.stateInit != nil {
|
||||
if err := w.pool.stateInit(w.state); err != nil {
|
||||
// Worker failed to initialize with custom init function
|
||||
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Main worker loop
|
||||
for {
|
||||
select {
|
||||
case job, ok := <-w.pool.jobs:
|
||||
if !ok {
|
||||
// Jobs channel closed, exit
|
||||
return
|
||||
}
|
||||
|
||||
// Execute job
|
||||
result := w.executeJobSandboxed(job)
|
||||
job.Result <- result
|
||||
|
||||
case <-w.pool.quit:
|
||||
// Quit signal received, exit
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
445
core/workers/workers_test.go
Normal file
445
core/workers/workers_test.go
Normal file
|
@ -0,0 +1,445 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// This helper function creates real LuaJIT bytecode for our tests. Instead of using
|
||||
// mocks, we compile actual Lua code into bytecode just like we would in production.
|
||||
func createTestBytecode(t *testing.T, code string) []byte {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile test bytecode: %v", err)
|
||||
}
|
||||
|
||||
return bytecode
|
||||
}
|
||||
|
||||
// This test makes sure we can create a worker pool with a valid number of workers,
|
||||
// and that we properly reject attempts to create a pool with zero or negative workers.
|
||||
func TestNewPool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workers int
|
||||
expectErr bool
|
||||
}{
|
||||
{"valid workers", 4, false},
|
||||
{"zero workers", 0, true},
|
||||
{"negative workers", -1, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pool, err := NewPool(tt.workers)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for %d workers, got nil", tt.workers)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Errorf("Expected non-nil pool")
|
||||
} else {
|
||||
pool.Shutdown()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Here we're testing the basic job submission flow. We run a simple Lua script
|
||||
// that returns the number 42 and make sure we get that same value back from the worker pool.
|
||||
func TestPoolSubmit(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
|
||||
result, err := pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected float64 result, got %T", result)
|
||||
}
|
||||
|
||||
if num != 42 {
|
||||
t.Errorf("Expected 42, got %f", num)
|
||||
}
|
||||
}
|
||||
|
||||
// This test checks how our worker pool handles timeouts. We run a script that takes
|
||||
// some time to complete and verify two scenarios: one where the timeout is long enough
|
||||
// for successful completion, and another where we expect the operation to be canceled
|
||||
// due to a short timeout.
|
||||
func TestPoolSubmitWithContext(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Create bytecode that sleeps
|
||||
bytecode := createTestBytecode(t, `
|
||||
-- Sleep for 500ms
|
||||
local start = os.time()
|
||||
while os.difftime(os.time(), start) < 0.5 do end
|
||||
return "done"
|
||||
`)
|
||||
|
||||
// Test with timeout that should succeed
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := pool.SubmitWithContext(ctx, bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
|
||||
}
|
||||
if result != "done" {
|
||||
t.Errorf("Expected 'done', got %v", result)
|
||||
}
|
||||
|
||||
// Test with timeout that should fail
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = pool.SubmitWithContext(ctx, bytecode, nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected timeout error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// We need to make sure we can pass different types of context values from Go to Lua and
|
||||
// get them back properly. This test sends numbers, strings, booleans, and arrays to
|
||||
// a Lua script and verifies they're all handled correctly in both directions.
|
||||
func TestContextValues(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
bytecode := createTestBytecode(t, `
|
||||
return {
|
||||
num = ctx.number,
|
||||
str = ctx.text,
|
||||
flag = ctx.enabled,
|
||||
list = {ctx.table[1], ctx.table[2], ctx.table[3]},
|
||||
}
|
||||
`)
|
||||
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("number", 42.5)
|
||||
execCtx.Set("text", "hello")
|
||||
execCtx.Set("enabled", true)
|
||||
execCtx.Set("table", []float64{10, 20, 30})
|
||||
|
||||
result, err := pool.Submit(bytecode, execCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Result should be a map
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Check values
|
||||
if resultMap["num"] != 42.5 {
|
||||
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
|
||||
}
|
||||
if resultMap["str"] != "hello" {
|
||||
t.Errorf("Expected str=hello, got %v", resultMap["str"])
|
||||
}
|
||||
if resultMap["flag"] != true {
|
||||
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
|
||||
}
|
||||
|
||||
arr, ok := resultMap["list"].([]float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []float64, got %T", resultMap["list"])
|
||||
}
|
||||
|
||||
expected := []float64{10, 20, 30}
|
||||
for i, v := range expected {
|
||||
if arr[i] != v {
|
||||
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test context with nested data structures
|
||||
func TestNestedContext(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
bytecode := createTestBytecode(t, `
|
||||
return {
|
||||
id = ctx.params.id,
|
||||
name = ctx.params.name,
|
||||
method = ctx.request.method,
|
||||
path = ctx.request.path
|
||||
}
|
||||
`)
|
||||
|
||||
execCtx := NewContext()
|
||||
|
||||
// Set nested params
|
||||
params := map[string]any{
|
||||
"id": "123",
|
||||
"name": "test",
|
||||
}
|
||||
execCtx.Set("params", params)
|
||||
|
||||
// Set nested request info
|
||||
request := map[string]any{
|
||||
"method": "GET",
|
||||
"path": "/api/test",
|
||||
}
|
||||
execCtx.Set("request", request)
|
||||
|
||||
result, err := pool.Submit(bytecode, execCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Result should be a map
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["id"] != "123" {
|
||||
t.Errorf("Expected id=123, got %v", resultMap["id"])
|
||||
}
|
||||
if resultMap["name"] != "test" {
|
||||
t.Errorf("Expected name=test, got %v", resultMap["name"])
|
||||
}
|
||||
if resultMap["method"] != "GET" {
|
||||
t.Errorf("Expected method=GET, got %v", resultMap["method"])
|
||||
}
|
||||
if resultMap["path"] != "/api/test" {
|
||||
t.Errorf("Expected path=/api/test, got %v", resultMap["path"])
|
||||
}
|
||||
}
|
||||
|
||||
// A key requirement for our worker pool is that we don't leak state between executions.
|
||||
// This test confirms that by setting a global variable in one job and then checking
|
||||
// that it's been cleared before the next job runs on the same worker.
|
||||
func TestStateReset(t *testing.T) {
|
||||
pool, err := NewPool(1) // Use 1 worker to ensure same state is reused
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// First job sets a global
|
||||
bytecode1 := createTestBytecode(t, `
|
||||
global_var = "should be cleared"
|
||||
return true
|
||||
`)
|
||||
|
||||
// Second job checks if global exists
|
||||
bytecode2 := createTestBytecode(t, `
|
||||
return global_var ~= nil
|
||||
`)
|
||||
|
||||
// Run first job
|
||||
_, err = pool.Submit(bytecode1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit first job: %v", err)
|
||||
}
|
||||
|
||||
// Run second job
|
||||
result, err := pool.Submit(bytecode2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit second job: %v", err)
|
||||
}
|
||||
|
||||
// Global should be cleared
|
||||
if result.(bool) {
|
||||
t.Errorf("Expected global_var to be cleared, but it still exists")
|
||||
}
|
||||
}
|
||||
|
||||
// Let's make sure our pool shuts down cleanly. This test confirms that jobs work
|
||||
// before shutdown, that we get the right error when trying to submit after shutdown,
|
||||
// and that we properly handle attempts to shut down an already closed pool.
|
||||
func TestPoolShutdown(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
|
||||
// Submit a job to verify pool works
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
_, err = pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown
|
||||
if err := pool.Shutdown(); err != nil {
|
||||
t.Errorf("Shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
// Submit after shutdown should fail
|
||||
_, err = pool.Submit(bytecode, nil)
|
||||
if err != ErrPoolClosed {
|
||||
t.Errorf("Expected ErrPoolClosed, got %v", err)
|
||||
}
|
||||
|
||||
// Second shutdown should return error
|
||||
if err := pool.Shutdown(); err != ErrPoolClosed {
|
||||
t.Errorf("Expected ErrPoolClosed on second shutdown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// A robust worker pool needs to handle errors gracefully. This test checks various
|
||||
// error scenarios: invalid bytecode, Lua runtime errors, nil context (which
|
||||
// should work fine), and unsupported parameter types (which should properly error out).
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Test invalid bytecode
|
||||
_, err = pool.Submit([]byte("not valid bytecode"), nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid bytecode, got nil")
|
||||
}
|
||||
|
||||
// Test Lua runtime error
|
||||
bytecode := createTestBytecode(t, `
|
||||
error("intentional error")
|
||||
return true
|
||||
`)
|
||||
|
||||
_, err = pool.Submit(bytecode, nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from Lua error() call, got nil")
|
||||
}
|
||||
|
||||
// Test with nil context
|
||||
bytecode = createTestBytecode(t, "return ctx == nil")
|
||||
result, err := pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error with nil context: %v", err)
|
||||
}
|
||||
if result.(bool) != true {
|
||||
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
|
||||
}
|
||||
|
||||
// Test invalid context value
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("param", complex128(1+2i)) // Unsupported type
|
||||
|
||||
bytecode = createTestBytecode(t, "return ctx.param")
|
||||
_, err = pool.Submit(bytecode, execCtx)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for unsupported context value type, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// The whole point of a worker pool is concurrent processing, so we need to verify
|
||||
// it works under load. This test submits multiple jobs simultaneously and makes sure
|
||||
// they all complete correctly with their own unique results.
|
||||
func TestConcurrentExecution(t *testing.T) {
|
||||
const workers = 4
|
||||
const jobs = 20
|
||||
|
||||
pool, err := NewPool(workers)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Create bytecode that returns its input
|
||||
bytecode := createTestBytecode(t, "return ctx.n")
|
||||
|
||||
// Run multiple jobs concurrently
|
||||
results := make(chan int, jobs)
|
||||
for i := 0; i < jobs; i++ {
|
||||
i := i // Capture loop variable
|
||||
go func() {
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("n", float64(i))
|
||||
|
||||
result, err := pool.Submit(bytecode, execCtx)
|
||||
if err != nil {
|
||||
t.Errorf("Job %d failed: %v", i, err)
|
||||
results <- -1
|
||||
return
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Errorf("Job %d: expected float64, got %T", i, result)
|
||||
results <- -1
|
||||
return
|
||||
}
|
||||
|
||||
results <- int(num)
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
counts := make(map[int]bool)
|
||||
for i := 0; i < jobs; i++ {
|
||||
result := <-results
|
||||
if result != -1 {
|
||||
counts[result] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all jobs were processed
|
||||
if len(counts) != jobs {
|
||||
t.Errorf("Expected %d unique results, got %d", jobs, len(counts))
|
||||
}
|
||||
}
|
||||
|
||||
// Test context operations
|
||||
func TestContext(t *testing.T) {
|
||||
ctx := NewContext()
|
||||
|
||||
// Test Set and Get
|
||||
ctx.Set("key", "value")
|
||||
if ctx.Get("key") != "value" {
|
||||
t.Errorf("Expected value, got %v", ctx.Get("key"))
|
||||
}
|
||||
|
||||
// Test overwriting
|
||||
ctx.Set("key", 123)
|
||||
if ctx.Get("key") != 123 {
|
||||
t.Errorf("Expected 123, got %v", ctx.Get("key"))
|
||||
}
|
||||
|
||||
// Test missing key
|
||||
if ctx.Get("missing") != nil {
|
||||
t.Errorf("Expected nil for missing key, got %v", ctx.Get("missing"))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user