Compare commits

..

No commits in common. "424b778f2804f4e2581274510e3b3ef49be79660" and "50e848da6e331dd2dfe90c77c9c33d05e28c0859" have entirely different histories.

9 changed files with 1150 additions and 1161 deletions

View File

@ -27,7 +27,6 @@ const (
LevelWarning
LevelError
LevelFatal
LevelServer
)
// Level names and colors
@ -40,7 +39,6 @@ var levelProps = map[int]struct {
LevelWarning: {"WARN", colorYellow},
LevelError: {" ERR", colorRed},
LevelFatal: {"FATL", colorPurple},
LevelServer: {"SRVR", colorGreen},
}
// Time format for log messages
@ -229,11 +227,6 @@ func (l *Logger) Fatal(format string, args ...any) {
// No need for os.Exit here as it's handled in log()
}
// Server logs a server message
func (l *Logger) Server(format string, args ...any) {
l.log(LevelServer, format, args...)
}
// Default global logger
var defaultLogger = New(LevelInfo, true)
@ -262,11 +255,6 @@ func Fatal(format string, args ...any) {
defaultLogger.Fatal(format, args...)
}
// Server logs a server message to the default logger
func Server(format string, args ...any) {
defaultLogger.Server(format, args...)
}
// LogRaw logs a raw message to the default logger
func LogRaw(format string, args ...any) {
defaultLogger.LogRaw(format, args...)

View File

@ -1,372 +0,0 @@
package workers
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrLoopClosed = errors.New("event loop is closed")
ErrExecutionTimeout = errors.New("script execution timed out")
)
// StateInitFunc is a function that initializes a Lua state
type StateInitFunc func(*luajit.State) error
// EventLoop represents a single-threaded Lua execution environment
type EventLoop struct {
state *luajit.State // Single Lua state for all executions
jobQueue chan job // Channel for receiving jobs
quit chan struct{} // Channel for shutdown signaling
wg sync.WaitGroup // WaitGroup for clean shutdown
isRunning atomic.Bool // Flag to track if loop is running
timeout time.Duration // Default timeout for script execution
stateInit StateInitFunc // Optional function to initialize Lua state
bufferSize int // Size of job queue buffer
}
// EventLoopConfig contains configuration options for creating an EventLoop
type EventLoopConfig struct {
// StateInit is a function to initialize the Lua state with custom modules and functions
StateInit StateInitFunc
// BufferSize is the size of the job queue buffer (default: 100)
BufferSize int
// Timeout is the default execution timeout (default: 30s, 0 means no timeout)
Timeout time.Duration
}
// NewEventLoop creates a new event loop with default configuration
func NewEventLoop() (*EventLoop, error) {
return NewEventLoopWithConfig(EventLoopConfig{})
}
// NewEventLoopWithInit creates a new event loop with a state initialization function
func NewEventLoopWithInit(init StateInitFunc) (*EventLoop, error) {
return NewEventLoopWithConfig(EventLoopConfig{
StateInit: init,
})
}
// NewEventLoopWithConfig creates a new event loop with custom configuration
func NewEventLoopWithConfig(config EventLoopConfig) (*EventLoop, error) {
// Set default values
bufferSize := config.BufferSize
if bufferSize <= 0 {
bufferSize = 100 // Default buffer size
}
timeout := config.Timeout
if timeout == 0 {
timeout = 30 * time.Second // Default timeout
}
// Initialize the Lua state
state := luajit.New()
if state == nil {
return nil, errors.New("failed to create Lua state")
}
// Create the event loop instance
el := &EventLoop{
state: state,
jobQueue: make(chan job, bufferSize),
quit: make(chan struct{}),
timeout: timeout,
stateInit: config.StateInit,
bufferSize: bufferSize,
}
el.isRunning.Store(true)
// Set up the sandbox environment
if err := setupSandbox(el.state); err != nil {
state.Close()
return nil, err
}
// Initialize the state if needed
if el.stateInit != nil {
if err := el.stateInit(el.state); err != nil {
state.Close()
return nil, err
}
}
// Start the event loop
el.wg.Add(1)
go el.run()
return el, nil
}
// run is the main event loop goroutine
func (el *EventLoop) run() {
defer el.wg.Done()
defer el.state.Close()
for {
select {
case job, ok := <-el.jobQueue:
if !ok {
// Job queue closed, exit
return
}
// Execute job with timeout if configured
if el.timeout > 0 {
el.executeJobWithTimeout(job)
} else {
// Execute without timeout
result := executeJobSandboxed(el.state, job)
job.Result <- result
}
case <-el.quit:
// Quit signal received, exit
return
}
}
}
// executeJobWithTimeout executes a job with a timeout
func (el *EventLoop) executeJobWithTimeout(j job) {
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), el.timeout)
defer cancel()
// Create a channel for the result
resultCh := make(chan JobResult, 1)
// Execute the job in a separate goroutine
go func() {
result := executeJobSandboxed(el.state, j)
select {
case resultCh <- result:
// Result sent successfully
case <-ctx.Done():
// Context canceled, result no longer needed
}
}()
// Wait for result or timeout
select {
case result := <-resultCh:
// Send result to the original channel
j.Result <- result
case <-ctx.Done():
// Timeout occurred
j.Result <- JobResult{nil, ErrExecutionTimeout}
// NOTE: The Lua execution continues in the background until it completes,
// but the result is discarded. This is a compromise to avoid forcibly
// terminating Lua code which could corrupt the state.
}
}
// Submit sends a job to the event loop
func (el *EventLoop) Submit(bytecode []byte, execCtx *Context) (any, error) {
return el.SubmitWithContext(context.Background(), bytecode, execCtx)
}
// SubmitWithTimeout sends a job to the event loop with a specific timeout
func (el *EventLoop) SubmitWithTimeout(bytecode []byte, execCtx *Context, timeout time.Duration) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return el.SubmitWithContext(ctx, bytecode, execCtx)
}
// SubmitWithContext sends a job to the event loop with context for cancellation
func (el *EventLoop) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
if !el.isRunning.Load() {
return nil, ErrLoopClosed
}
resultChan := make(chan JobResult, 1)
j := job{
Bytecode: bytecode,
Context: execCtx,
Result: resultChan,
}
// Submit job with context
select {
case el.jobQueue <- j:
// Job submitted
case <-ctx.Done():
return nil, ctx.Err()
}
// Wait for result with context
select {
case result := <-resultChan:
return result.Value, result.Error
case <-ctx.Done():
// Context canceled, but the job might still be processed
return nil, ctx.Err()
}
}
// SetTimeout updates the default timeout for script execution
func (el *EventLoop) SetTimeout(timeout time.Duration) {
el.timeout = timeout
}
// Shutdown gracefully shuts down the event loop
func (el *EventLoop) Shutdown() error {
if !el.isRunning.Load() {
return ErrLoopClosed
}
el.isRunning.Store(false)
// Signal event loop to quit
close(el.quit)
// Wait for event loop to finish
el.wg.Wait()
// Close job queue
close(el.jobQueue)
return nil
}
// setupSandbox initializes the sandbox environment in the Lua state
func setupSandbox(state *luajit.State) error {
// This is the Lua script that creates our sandbox function
setupScript := `
-- Create a function to run code in a sandbox environment
function __create_sandbox()
-- Create new environment table
local env = {}
-- Add standard library modules (can be restricted as needed)
env.string = string
env.table = table
env.math = math
env.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
env.tonumber = tonumber
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Allow access to package.loaded for modules
env.require = function(name)
return package.loaded[name]
end
-- Create metatable to restrict access to _G
local mt = {
__index = function(t, k)
-- First check in env table
local v = rawget(env, k)
if v ~= nil then return v end
-- If not found, check for registered modules/functions
local moduleValue = _G[k]
if type(moduleValue) == "table" or
type(moduleValue) == "function" then
return moduleValue
end
return nil
end,
__newindex = function(t, k, v)
rawset(env, k, v)
end
}
setmetatable(env, mt)
return env
end
-- Create function to execute code with a sandbox
function __run_sandboxed(f, ctx)
local env = __create_sandbox()
-- Add context to the environment if provided
if ctx then
env.ctx = ctx
end
-- Set the environment and run the function
setfenv(f, env)
return f()
end
`
return state.DoString(setupScript)
}
// executeJobSandboxed runs a script in a sandbox environment
func executeJobSandboxed(state *luajit.State, j job) JobResult {
// Set up context if provided
if j.Context != nil {
// Push context table
state.NewTable()
// Add values to context table
for key, value := range j.Context.Values {
// Push key
state.PushString(key)
// Push value
if err := state.PushValue(value); err != nil {
state.Pop(1) // Pop table
return JobResult{nil, err}
}
// Set table[key] = value
state.SetTable(-3)
}
} else {
// Push nil if no context
state.PushNil()
}
// Load bytecode
if err := state.LoadBytecode(j.Bytecode, "script"); err != nil {
state.Pop(1) // Pop context
return JobResult{nil, err}
}
// Get the sandbox runner function
state.GetGlobal("__run_sandboxed")
// Push loaded function and context as arguments
state.PushCopy(-2) // Copy the loaded function
state.PushCopy(-4) // Copy the context table or nil
// Remove the original function and context
state.Remove(-5) // Remove original context
state.Remove(-4) // Remove original function
// Call the sandbox runner with 2 args (function and context), expecting 1 result
if err := state.Call(2, 1); err != nil {
return JobResult{nil, err}
}
// Get result
value, err := state.ToValue(-1)
state.Pop(1) // Pop result
return JobResult{value, err}
}

View File

@ -1,741 +0,0 @@
package workers
import (
"context"
"sync"
"testing"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Helper function to create bytecode for testing
func createTestBytecode(t *testing.T, code string) []byte {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("Failed to compile test bytecode: %v", err)
}
return bytecode
}
// Test creating a new event loop with default and custom configs
func TestNewEventLoop(t *testing.T) {
tests := []struct {
name string
config EventLoopConfig
expectError bool
}{
{
name: "Default config",
config: EventLoopConfig{},
expectError: false,
},
{
name: "Custom buffer size",
config: EventLoopConfig{
BufferSize: 200,
},
expectError: false,
},
{
name: "Custom timeout",
config: EventLoopConfig{
Timeout: 5 * time.Second,
},
expectError: false,
},
{
name: "With init function",
config: EventLoopConfig{
StateInit: func(state *luajit.State) error {
return nil
},
},
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
el, err := NewEventLoopWithConfig(tc.config)
if tc.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if el == nil {
t.Errorf("Expected non-nil event loop")
} else {
el.Shutdown()
}
}
})
}
}
// Test basic job submission and execution
func TestEventLoopBasicSubmission(t *testing.T) {
el, err := NewEventLoop()
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Simple return a value
bytecode := createTestBytecode(t, "return 42")
result, err := el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
num, ok := result.(float64)
if !ok {
t.Fatalf("Expected float64 result, got %T", result)
}
if num != 42 {
t.Errorf("Expected 42, got %f", num)
}
// Test more complex Lua code
bytecode = createTestBytecode(t, `
local result = 0
for i = 1, 10 do
result = result + i
end
return result
`)
result, err = el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
num, ok = result.(float64)
if !ok {
t.Fatalf("Expected float64 result, got %T", result)
}
if num != 55 {
t.Errorf("Expected 55, got %f", num)
}
}
// Test context passing between Go and Lua
func TestEventLoopContext(t *testing.T) {
el, err := NewEventLoop()
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
bytecode := createTestBytecode(t, `
return {
num = ctx.number,
str = ctx.text,
flag = ctx.enabled,
list = {ctx.items[1], ctx.items[2], ctx.items[3]},
}
`)
execCtx := NewContext()
execCtx.Set("number", 42.5)
execCtx.Set("text", "hello")
execCtx.Set("enabled", true)
execCtx.Set("items", []float64{10, 20, 30})
result, err := el.Submit(bytecode, execCtx)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
// Result should be a map
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Check values
if resultMap["num"] != 42.5 {
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
}
if resultMap["str"] != "hello" {
t.Errorf("Expected str=hello, got %v", resultMap["str"])
}
if resultMap["flag"] != true {
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
}
arr, ok := resultMap["list"].([]float64)
if !ok {
t.Fatalf("Expected []float64, got %T", resultMap["list"])
}
expected := []float64{10, 20, 30}
for i, v := range expected {
if arr[i] != v {
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
}
}
// Test complex nested context
nestedCtx := NewContext()
nestedCtx.Set("user", map[string]any{
"id": 123,
"name": "test user",
"roles": []any{
"admin",
"editor",
},
})
bytecode = createTestBytecode(t, `
return {
id = ctx.user.id,
name = ctx.user.name,
role1 = ctx.user.roles[1],
role2 = ctx.user.roles[2],
}
`)
result, err = el.Submit(bytecode, nestedCtx)
if err != nil {
t.Fatalf("Failed to submit job with nested context: %v", err)
}
resultMap, ok = result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
if resultMap["id"] != float64(123) {
t.Errorf("Expected id=123, got %v", resultMap["id"])
}
if resultMap["name"] != "test user" {
t.Errorf("Expected name='test user', got %v", resultMap["name"])
}
if resultMap["role1"] != "admin" {
t.Errorf("Expected role1='admin', got %v", resultMap["role1"])
}
if resultMap["role2"] != "editor" {
t.Errorf("Expected role2='editor', got %v", resultMap["role2"])
}
}
// Test execution timeout
func TestEventLoopTimeout(t *testing.T) {
// Create event loop with short timeout
el, err := NewEventLoopWithConfig(EventLoopConfig{
Timeout: 100 * time.Millisecond,
})
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Create bytecode that runs for longer than the timeout
bytecode := createTestBytecode(t, `
-- Loop for 500ms
local start = os.time()
while os.difftime(os.time(), start) < 0.5 do end
return "done"
`)
// This should time out
_, err = el.Submit(bytecode, nil)
if err != ErrExecutionTimeout {
t.Errorf("Expected timeout error, got: %v", err)
}
// Now set a longer timeout and try again
el.SetTimeout(1 * time.Second)
// This should succeed
result, err := el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Expected success with longer timeout, got: %v", err)
}
if result != "done" {
t.Errorf("Expected 'done', got %v", result)
}
// Test per-call timeout with SubmitWithTimeout
bytecode = createTestBytecode(t, `
-- Loop for 300ms
local start = os.time()
while os.difftime(os.time(), start) < 0.3 do end
return "done again"
`)
// This should time out with a custom timeout
_, err = el.SubmitWithTimeout(bytecode, nil, 50*time.Millisecond)
if err == nil {
t.Errorf("Expected timeout error, got success")
}
// This should succeed with a longer custom timeout
result, err = el.SubmitWithTimeout(bytecode, nil, 500*time.Millisecond)
if err != nil {
t.Fatalf("Expected success with custom timeout, got: %v", err)
}
if result != "done again" {
t.Errorf("Expected 'done again', got %v", result)
}
}
// Test module registration and execution
func TestEventLoopModules(t *testing.T) {
// Define an init function that registers a simple "math" module
mathInit := func(state *luajit.State) error {
// Register the "add" function directly
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1 // Return one result
})
if err != nil {
return err
}
// Register a math module with multiple functions
mathFuncs := map[string]luajit.GoFunction{
"multiply": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a * b)
return 1
},
"subtract": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a - b)
return 1
},
}
return RegisterModule(state, "math2", mathFuncs)
}
// Create an event loop with our init function
el, err := NewEventLoopWithInit(mathInit)
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Test the add function
bytecode1 := createTestBytecode(t, "return add(5, 7)")
result1, err := el.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to call add function: %v", err)
}
num1, ok := result1.(float64)
if !ok || num1 != 12 {
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
}
// Test the math2 module
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
result2, err := el.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to call math2.multiply: %v", err)
}
num2, ok := result2.(float64)
if !ok || num2 != 48 {
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
}
// Test multiple operations
bytecode3 := createTestBytecode(t, `
local a = add(10, 20)
local b = math2.subtract(a, 5)
return math2.multiply(b, 2)
`)
result3, err := el.Submit(bytecode3, nil)
if err != nil {
t.Fatalf("Failed to execute combined operations: %v", err)
}
num3, ok := result3.(float64)
if !ok || num3 != 50 {
t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3)
}
}
// Test combined module init functions
func TestEventLoopCombinedModules(t *testing.T) {
// First init function adds a function to get a constant value
init1 := func(state *luajit.State) error {
return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int {
s.PushNumber(42)
return 1
})
}
// Second init function registers a function that multiplies a number by 2
init2 := func(state *luajit.State) error {
return state.RegisterGoFunction("double", func(s *luajit.State) int {
n := s.ToNumber(1)
s.PushNumber(n * 2)
return 1
})
}
// Combine the init functions
combinedInit := CombineInitFuncs(init1, init2)
// Create an event loop with the combined init function
el, err := NewEventLoopWithInit(combinedInit)
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Test using both functions together in a single script
bytecode := createTestBytecode(t, "return double(getAnswer())")
result, err := el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to execute: %v", err)
}
num, ok := result.(float64)
if !ok || num != 84 {
t.Errorf("Expected double(getAnswer()) = 84, got %v", result)
}
}
// Test sandbox isolation between executions
func TestEventLoopSandboxIsolation(t *testing.T) {
el, err := NewEventLoop()
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Create a script that tries to modify a global variable
bytecode1 := createTestBytecode(t, `
-- Set a "global" variable
my_global = "test value"
return true
`)
_, err = el.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to execute first script: %v", err)
}
// Now try to access that variable from another script
bytecode2 := createTestBytecode(t, `
-- Try to access the previously set global
return my_global ~= nil
`)
result, err := el.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to execute second script: %v", err)
}
// The variable should not be accessible (sandbox isolation)
if result.(bool) {
t.Errorf("Expected sandbox isolation, but global variable was accessible")
}
}
// Test error handling
func TestEventLoopErrorHandling(t *testing.T) {
el, err := NewEventLoop()
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Test invalid bytecode
_, err = el.Submit([]byte("not valid bytecode"), nil)
if err == nil {
t.Errorf("Expected error for invalid bytecode, got nil")
}
// Test Lua runtime error
bytecode := createTestBytecode(t, `
error("intentional error")
return true
`)
_, err = el.Submit(bytecode, nil)
if err == nil {
t.Errorf("Expected error from Lua error() call, got nil")
}
// Test with nil context (should work fine)
bytecode = createTestBytecode(t, "return ctx == nil")
result, err := el.Submit(bytecode, nil)
if err != nil {
t.Errorf("Unexpected error with nil context: %v", err)
}
if result.(bool) != true {
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
}
// Test access to restricted library
bytecode = createTestBytecode(t, `
-- Try to access io library directly
return io ~= nil
`)
result, err = el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to execute sandbox test: %v", err)
}
// io should not be directly accessible
if result.(bool) {
t.Errorf("Expected io library to be restricted, but it was accessible")
}
}
// Test concurrent job submission
func TestEventLoopConcurrency(t *testing.T) {
el, err := NewEventLoopWithConfig(EventLoopConfig{
BufferSize: 100, // Buffer for concurrent submissions
Timeout: 5 * time.Second,
})
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Create bytecode that returns its input value
bytecode := createTestBytecode(t, "return ctx.n")
// Submit multiple jobs concurrently
const jobCount = 50
var wg sync.WaitGroup
results := make([]int, jobCount)
wg.Add(jobCount)
for i := 0; i < jobCount; i++ {
i := i // Capture loop variable
go func() {
defer wg.Done()
// Create context with job number
ctx := NewContext()
ctx.Set("n", float64(i))
// Submit job
result, err := el.Submit(bytecode, ctx)
if err != nil {
t.Errorf("Job %d failed: %v", i, err)
return
}
// Verify result matches job number
num, ok := result.(float64)
if !ok {
t.Errorf("Job %d: expected float64, got %T", i, result)
return
}
results[i] = int(num)
}()
}
wg.Wait()
// Verify all results
for i, res := range results {
if res != i && res != 0 { // 0 means error already logged
t.Errorf("Expected result[%d] = %d, got %d", i, i, res)
}
}
}
// Test state consistency across multiple calls
func TestEventLoopStateConsistency(t *testing.T) {
// Create an event loop with a module that maintains count between calls
initFunc := func(state *luajit.State) error {
// Create a closure that increments a counter in upvalue
code := `
-- Create a counter with initial value 0
local counter = 0
-- Create a function that returns and increments the counter
function get_next_count()
local current = counter
counter = counter + 1
return current
end
`
return state.DoString(code)
}
el, err := NewEventLoopWithInit(initFunc)
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Now run multiple scripts that call the counter function
bytecode := createTestBytecode(t, "return get_next_count()")
// Each call should return an incremented value
for i := 0; i < 5; i++ {
result, err := el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Call %d failed: %v", i, err)
}
num, ok := result.(float64)
if !ok {
t.Fatalf("Expected float64 result, got %T", result)
}
if int(num) != i {
t.Errorf("Expected count %d, got %d", i, int(num))
}
}
}
// Test shutdown and cleanup
func TestEventLoopShutdown(t *testing.T) {
el, err := NewEventLoop()
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
// Submit a job to verify it works
bytecode := createTestBytecode(t, "return 42")
_, err = el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
// Shutdown
if err := el.Shutdown(); err != nil {
t.Errorf("Shutdown failed: %v", err)
}
// Submit after shutdown should fail
_, err = el.Submit(bytecode, nil)
if err != ErrLoopClosed {
t.Errorf("Expected ErrLoopClosed, got %v", err)
}
// Second shutdown should return error
if err := el.Shutdown(); err != ErrLoopClosed {
t.Errorf("Expected ErrLoopClosed on second shutdown, got %v", err)
}
}
// Test high load with multiple sequential and concurrent jobs
func TestEventLoopHighLoad(t *testing.T) {
el, err := NewEventLoopWithConfig(EventLoopConfig{
BufferSize: 1000, // Large buffer for high load
Timeout: 5 * time.Second,
})
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Sequential load test
bytecode := createTestBytecode(t, `
-- Do some work
local result = 0
for i = 1, 1000 do
result = result + i
end
return result
`)
start := time.Now()
for i := 0; i < 100; i++ {
_, err := el.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Sequential job %d failed: %v", i, err)
}
}
seqDuration := time.Since(start)
t.Logf("Sequential load test: 100 jobs in %v", seqDuration)
// Concurrent load test
start = time.Now()
var wg sync.WaitGroup
wg.Add(100)
for i := 0; i < 100; i++ {
go func() {
defer wg.Done()
_, err := el.Submit(bytecode, nil)
if err != nil {
t.Errorf("Concurrent job failed: %v", err)
}
}()
}
wg.Wait()
concDuration := time.Since(start)
t.Logf("Concurrent load test: 100 jobs in %v", concDuration)
}
// Test context cancellation
func TestEventLoopCancel(t *testing.T) {
el, err := NewEventLoop()
if err != nil {
t.Fatalf("Failed to create event loop: %v", err)
}
defer el.Shutdown()
// Create a long-running script
bytecode := createTestBytecode(t, `
-- Sleep for 500ms
local start = os.time()
while os.difftime(os.time(), start) < 0.5 do end
return "done"
`)
// Create a context that we can cancel
ctx, cancel := context.WithCancel(context.Background())
// Start execution in a goroutine
resultCh := make(chan any, 1)
errCh := make(chan error, 1)
go func() {
res, err := el.SubmitWithContext(ctx, bytecode, nil)
if err != nil {
errCh <- err
} else {
resultCh <- res
}
}()
// Cancel quickly
time.Sleep(50 * time.Millisecond)
cancel()
// Should get cancellation error
select {
case err := <-errCh:
if ctx.Err() == nil || err == nil {
t.Errorf("Expected context cancellation error")
}
case res := <-resultCh:
t.Errorf("Expected cancellation, got result: %v", res)
case <-time.After(1 * time.Second):
t.Errorf("Timed out waiting for cancellation")
}
}

346
core/workers/init_test.go Normal file
View File

@ -0,0 +1,346 @@
package workers
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestModuleRegistration(t *testing.T) {
// Define an init function that registers a simple "math" module
mathInit := func(state *luajit.State) error {
// Register the "add" function
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1 // Return one result
})
if err != nil {
return err
}
// Register a whole module
mathFuncs := map[string]luajit.GoFunction{
"multiply": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a * b)
return 1
},
"subtract": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a - b)
return 1
},
}
return RegisterModule(state, "math2", mathFuncs)
}
// Create a pool with our init function
pool, err := NewPoolWithInit(2, mathInit)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test the add function
bytecode1 := createTestBytecode(t, "return add(5, 7)")
result1, err := pool.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to call add function: %v", err)
}
num1, ok := result1.(float64)
if !ok || num1 != 12 {
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
}
// Test the math2 module
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
result2, err := pool.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to call math2.multiply: %v", err)
}
num2, ok := result2.(float64)
if !ok || num2 != 48 {
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
}
// Test multiple operations
bytecode3 := createTestBytecode(t, `
local a = add(10, 20)
local b = math2.subtract(a, 5)
return math2.multiply(b, 2)
`)
result3, err := pool.Submit(bytecode3, nil)
if err != nil {
t.Fatalf("Failed to execute combined operations: %v", err)
}
num3, ok := result3.(float64)
if !ok || num3 != 50 {
t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3)
}
}
func TestModuleInitFunc(t *testing.T) {
// Define math module functions
mathModule := func() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"add": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
},
"multiply": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a * b)
return 1
},
}
}
// Define string module functions
strModule := func() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"concat": func(s *luajit.State) int {
a := s.ToString(1)
b := s.ToString(2)
s.PushString(a + b)
return 1
},
}
}
// Create module map
modules := map[string]ModuleFunc{
"math2": mathModule,
"str": strModule,
}
// Create pool with module init
pool, err := NewPoolWithInit(2, ModuleInitFunc(modules))
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test math module
bytecode1 := createTestBytecode(t, "return math2.add(5, 7)")
result1, err := pool.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to call math2.add: %v", err)
}
num1, ok := result1.(float64)
if !ok || num1 != 12 {
t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1)
}
// Test string module
bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')")
result2, err := pool.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to call str.concat: %v", err)
}
str2, ok := result2.(string)
if !ok || str2 != "helloworld" {
t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2)
}
}
func TestCombineInitFuncs(t *testing.T) {
// First init function adds a function to get a constant value
init1 := func(state *luajit.State) error {
return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int {
s.PushNumber(42)
return 1
})
}
// Second init function registers a function that multiplies a number by 2
init2 := func(state *luajit.State) error {
return state.RegisterGoFunction("double", func(s *luajit.State) int {
n := s.ToNumber(1)
s.PushNumber(n * 2)
return 1
})
}
// Combine the init functions
combinedInit := CombineInitFuncs(init1, init2)
// Create a pool with the combined init function
pool, err := NewPoolWithInit(1, combinedInit)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test using both functions together in a single script
bytecode := createTestBytecode(t, "return double(getAnswer())")
result, err := pool.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to execute: %v", err)
}
num, ok := result.(float64)
if !ok || num != 84 {
t.Errorf("Expected double(getAnswer()) = 84, got %v", result)
}
}
func TestSandboxIsolation(t *testing.T) {
// Create a pool
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Create a script that tries to modify a global variable
bytecode1 := createTestBytecode(t, `
-- Set a "global" variable
my_global = "test value"
return true
`)
_, err = pool.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to execute first script: %v", err)
}
// Now try to access that variable from another script
bytecode2 := createTestBytecode(t, `
-- Try to access the previously set global
return my_global ~= nil
`)
result, err := pool.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to execute second script: %v", err)
}
// The variable should not be accessible (sandbox isolation)
if result.(bool) {
t.Errorf("Expected sandbox isolation, but global variable was accessible")
}
}
func TestContextInSandbox(t *testing.T) {
// Create a pool
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Create a context with test data
ctx := NewContext()
ctx.Set("name", "test")
ctx.Set("value", 42.5)
ctx.Set("items", []float64{1, 2, 3})
bytecode := createTestBytecode(t, `
-- Access and manipulate context values
local sum = 0
for i, v in ipairs(ctx.items) do
sum = sum + v
end
return {
name_length = string.len(ctx.name),
value_doubled = ctx.value * 2,
items_sum = sum
}
`)
result, err := pool.Submit(bytecode, ctx)
if err != nil {
t.Fatalf("Failed to execute script with context: %v", err)
}
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Check context values were correctly accessible
if resultMap["name_length"].(float64) != 4 {
t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"])
}
if resultMap["value_doubled"].(float64) != 85 {
t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"])
}
if resultMap["items_sum"].(float64) != 6 {
t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"])
}
}
func TestStandardLibsInSandbox(t *testing.T) {
// Create a pool
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test access to standard libraries
bytecode := createTestBytecode(t, `
local results = {}
-- Test string library
results.string_upper = string.upper("test")
-- Test math library
results.math_sqrt = math.sqrt(16)
-- Test table library
local tbl = {10, 20, 30}
table.insert(tbl, 40)
results.table_length = #tbl
-- Test os library (limited functions)
results.has_os_time = type(os.time) == "function"
return results
`)
result, err := pool.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to execute script: %v", err)
}
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Check standard library functions worked
if resultMap["string_upper"] != "TEST" {
t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"])
}
if resultMap["math_sqrt"].(float64) != 4 {
t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"])
}
if resultMap["table_length"].(float64) != 4 {
t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"])
}
if resultMap["has_os_time"] != true {
t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"])
}
}

123
core/workers/pool.go Normal file
View File

@ -0,0 +1,123 @@
package workers
import (
"context"
"sync"
"sync/atomic"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// StateInitFunc is a function that initializes a Lua state
// It can be used to register custom functions and modules
type StateInitFunc func(*luajit.State) error
// Pool manages a pool of Lua worker goroutines
type Pool struct {
workers uint32 // Number of workers
jobs chan job // Channel to send jobs to workers
wg sync.WaitGroup // WaitGroup to track active workers
quit chan struct{} // Channel to signal shutdown
isRunning atomic.Bool // Flag to track if pool is running
stateInit StateInitFunc // Optional function to initialize Lua state
}
// NewPool creates a new worker pool with the specified number of workers
func NewPool(numWorkers int) (*Pool, error) {
return NewPoolWithInit(numWorkers, nil)
}
// NewPoolWithInit creates a new worker pool with the specified number of workers
// and a function to initialize each worker's Lua state
func NewPoolWithInit(numWorkers int, initFunc StateInitFunc) (*Pool, error) {
if numWorkers <= 0 {
return nil, ErrNoWorkers
}
p := &Pool{
workers: uint32(numWorkers),
jobs: make(chan job, numWorkers), // Buffer equal to worker count
quit: make(chan struct{}),
stateInit: initFunc,
}
p.isRunning.Store(true)
// Start workers
p.wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
w := &worker{
pool: p,
id: uint32(i),
}
go w.run()
}
return p, nil
}
// RegisterGlobal is no longer needed with the sandbox approach
// but kept as a no-op for backward compatibility
func (p *Pool) RegisterGlobal(name string) {
// No-op in sandbox mode
}
// SubmitWithContext sends a job to the worker pool with context
func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
if !p.isRunning.Load() {
return nil, ErrPoolClosed
}
resultChan := make(chan JobResult, 1)
j := job{
Bytecode: bytecode,
Context: execCtx,
Result: resultChan,
}
// Submit job with context
select {
case p.jobs <- j:
// Job submitted
case <-ctx.Done():
return nil, ctx.Err()
}
// Wait for result with context
select {
case result := <-resultChan:
return result.Value, result.Error
case <-ctx.Done():
// Note: The job will still be processed by a worker,
// but the result will be discarded
return nil, ctx.Err()
}
}
// Submit sends a job to the worker pool
func (p *Pool) Submit(bytecode []byte, execCtx *Context) (any, error) {
return p.SubmitWithContext(context.Background(), bytecode, execCtx)
}
// Shutdown gracefully shuts down the worker pool
func (p *Pool) Shutdown() error {
if !p.isRunning.Load() {
return ErrPoolClosed
}
p.isRunning.Store(false)
// Signal workers to quit
close(p.quit)
// Wait for workers to finish
p.wg.Wait()
// Close jobs channel
close(p.jobs)
return nil
}
// ActiveWorkers returns the number of active workers
func (p *Pool) ActiveWorkers() uint32 {
return atomic.LoadUint32(&p.workers)
}

144
core/workers/sandbox.go Normal file
View File

@ -0,0 +1,144 @@
package workers
// setupSandbox initializes the sandbox environment creation function
func (w *worker) setupSandbox() error {
// This is the Lua script that creates our sandbox function
setupScript := `
-- Create a function to run code in a sandbox environment
function __create_sandbox()
-- Create new environment table
local env = {}
-- Add standard library modules (can be restricted as needed)
env.string = string
env.table = table
env.math = math
env.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
env.tonumber = tonumber
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Allow access to package.loaded for modules
env.require = function(name)
return package.loaded[name]
end
-- Create metatable to restrict access to _G
local mt = {
__index = function(t, k)
-- First check in env table
local v = rawget(env, k)
if v ~= nil then return v end
-- If not found, check for registered modules/functions
local moduleValue = _G[k]
if type(moduleValue) == "table" or
type(moduleValue) == "function" then
return moduleValue
end
return nil
end,
__newindex = function(t, k, v)
rawset(env, k, v)
end
}
setmetatable(env, mt)
return env
end
-- Create function to execute code with a sandbox
function __run_sandboxed(f, ctx)
local env = __create_sandbox()
-- Add context to the environment if provided
if ctx then
env.ctx = ctx
end
-- Set the environment and run the function
setfenv(f, env)
return f()
end
`
return w.state.DoString(setupScript)
}
// executeJobSandboxed runs a script in a sandbox environment
func (w *worker) executeJobSandboxed(j job) JobResult {
// No need to reset the state for each execution, since we're using a sandbox
// Re-run init function to register functions and modules if needed
if w.pool.stateInit != nil {
if err := w.pool.stateInit(w.state); err != nil {
return JobResult{nil, err}
}
}
// Set up context if provided
if j.Context != nil {
// Push context table
w.state.NewTable()
// Add values to context table
for key, value := range j.Context.Values {
// Push key
w.state.PushString(key)
// Push value
if err := w.state.PushValue(value); err != nil {
return JobResult{nil, err}
}
// Set table[key] = value
w.state.SetTable(-3)
}
} else {
// Push nil if no context
w.state.PushNil()
}
// Load bytecode
if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil {
w.state.Pop(1) // Pop context
return JobResult{nil, err}
}
// Get the sandbox runner function
w.state.GetGlobal("__run_sandboxed")
// Push loaded function and context as arguments
w.state.PushCopy(-2) // Copy the loaded function
w.state.PushCopy(-4) // Copy the context table or nil
// Remove the original function and context
w.state.Remove(-5) // Remove original context
w.state.Remove(-4) // Remove original function
// Call the sandbox runner with 2 args (function and context), expecting 1 result
if err := w.state.Call(2, 1); err != nil {
return JobResult{nil, err}
}
// Get result
value, err := w.state.ToValue(-1)
w.state.Pop(1) // Pop result
return JobResult{value, err}
}

71
core/workers/worker.go Normal file
View File

@ -0,0 +1,71 @@
package workers
import (
"errors"
"sync/atomic"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrPoolClosed = errors.New("worker pool is closed")
ErrNoWorkers = errors.New("no workers available")
ErrInitFailed = errors.New("worker initialization failed")
)
// worker represents a single Lua execution worker
type worker struct {
pool *Pool // Reference to the pool
state *luajit.State // Lua state
id uint32 // Worker ID
}
// run is the main worker function that processes jobs
func (w *worker) run() {
defer w.pool.wg.Done()
// Initialize Lua state
w.state = luajit.New()
if w.state == nil {
// Worker failed to initialize, decrement counter
atomic.AddUint32(&w.pool.workers, ^uint32(0))
return
}
defer w.state.Close()
// Set up sandbox environment
if err := w.setupSandbox(); err != nil {
// Worker failed to initialize sandbox, decrement counter
atomic.AddUint32(&w.pool.workers, ^uint32(0))
return
}
// Run init function if provided
if w.pool.stateInit != nil {
if err := w.pool.stateInit(w.state); err != nil {
// Worker failed to initialize with custom init function
atomic.AddUint32(&w.pool.workers, ^uint32(0))
return
}
}
// Main worker loop
for {
select {
case job, ok := <-w.pool.jobs:
if !ok {
// Jobs channel closed, exit
return
}
// Execute job
result := w.executeJobSandboxed(job)
job.Result <- result
case <-w.pool.quit:
// Quit signal received, exit
return
}
}
}

View File

@ -0,0 +1,445 @@
package workers
import (
"context"
"testing"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// This helper function creates real LuaJIT bytecode for our tests. Instead of using
// mocks, we compile actual Lua code into bytecode just like we would in production.
func createTestBytecode(t *testing.T, code string) []byte {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("Failed to compile test bytecode: %v", err)
}
return bytecode
}
// This test makes sure we can create a worker pool with a valid number of workers,
// and that we properly reject attempts to create a pool with zero or negative workers.
func TestNewPool(t *testing.T) {
tests := []struct {
name string
workers int
expectErr bool
}{
{"valid workers", 4, false},
{"zero workers", 0, true},
{"negative workers", -1, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pool, err := NewPool(tt.workers)
if tt.expectErr {
if err == nil {
t.Errorf("Expected error for %d workers, got nil", tt.workers)
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if pool == nil {
t.Errorf("Expected non-nil pool")
} else {
pool.Shutdown()
}
}
})
}
}
// Here we're testing the basic job submission flow. We run a simple Lua script
// that returns the number 42 and make sure we get that same value back from the worker pool.
func TestPoolSubmit(t *testing.T) {
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
bytecode := createTestBytecode(t, "return 42")
result, err := pool.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
num, ok := result.(float64)
if !ok {
t.Fatalf("Expected float64 result, got %T", result)
}
if num != 42 {
t.Errorf("Expected 42, got %f", num)
}
}
// This test checks how our worker pool handles timeouts. We run a script that takes
// some time to complete and verify two scenarios: one where the timeout is long enough
// for successful completion, and another where we expect the operation to be canceled
// due to a short timeout.
func TestPoolSubmitWithContext(t *testing.T) {
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Create bytecode that sleeps
bytecode := createTestBytecode(t, `
-- Sleep for 500ms
local start = os.time()
while os.difftime(os.time(), start) < 0.5 do end
return "done"
`)
// Test with timeout that should succeed
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
result, err := pool.SubmitWithContext(ctx, bytecode, nil)
if err != nil {
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
}
if result != "done" {
t.Errorf("Expected 'done', got %v", result)
}
// Test with timeout that should fail
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = pool.SubmitWithContext(ctx, bytecode, nil)
if err == nil {
t.Errorf("Expected timeout error, got nil")
}
}
// We need to make sure we can pass different types of context values from Go to Lua and
// get them back properly. This test sends numbers, strings, booleans, and arrays to
// a Lua script and verifies they're all handled correctly in both directions.
func TestContextValues(t *testing.T) {
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
bytecode := createTestBytecode(t, `
return {
num = ctx.number,
str = ctx.text,
flag = ctx.enabled,
list = {ctx.table[1], ctx.table[2], ctx.table[3]},
}
`)
execCtx := NewContext()
execCtx.Set("number", 42.5)
execCtx.Set("text", "hello")
execCtx.Set("enabled", true)
execCtx.Set("table", []float64{10, 20, 30})
result, err := pool.Submit(bytecode, execCtx)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
// Result should be a map
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Check values
if resultMap["num"] != 42.5 {
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
}
if resultMap["str"] != "hello" {
t.Errorf("Expected str=hello, got %v", resultMap["str"])
}
if resultMap["flag"] != true {
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
}
arr, ok := resultMap["list"].([]float64)
if !ok {
t.Fatalf("Expected []float64, got %T", resultMap["list"])
}
expected := []float64{10, 20, 30}
for i, v := range expected {
if arr[i] != v {
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
}
}
}
// Test context with nested data structures
func TestNestedContext(t *testing.T) {
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
bytecode := createTestBytecode(t, `
return {
id = ctx.params.id,
name = ctx.params.name,
method = ctx.request.method,
path = ctx.request.path
}
`)
execCtx := NewContext()
// Set nested params
params := map[string]any{
"id": "123",
"name": "test",
}
execCtx.Set("params", params)
// Set nested request info
request := map[string]any{
"method": "GET",
"path": "/api/test",
}
execCtx.Set("request", request)
result, err := pool.Submit(bytecode, execCtx)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
// Result should be a map
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
if resultMap["id"] != "123" {
t.Errorf("Expected id=123, got %v", resultMap["id"])
}
if resultMap["name"] != "test" {
t.Errorf("Expected name=test, got %v", resultMap["name"])
}
if resultMap["method"] != "GET" {
t.Errorf("Expected method=GET, got %v", resultMap["method"])
}
if resultMap["path"] != "/api/test" {
t.Errorf("Expected path=/api/test, got %v", resultMap["path"])
}
}
// A key requirement for our worker pool is that we don't leak state between executions.
// This test confirms that by setting a global variable in one job and then checking
// that it's been cleared before the next job runs on the same worker.
func TestStateReset(t *testing.T) {
pool, err := NewPool(1) // Use 1 worker to ensure same state is reused
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// First job sets a global
bytecode1 := createTestBytecode(t, `
global_var = "should be cleared"
return true
`)
// Second job checks if global exists
bytecode2 := createTestBytecode(t, `
return global_var ~= nil
`)
// Run first job
_, err = pool.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to submit first job: %v", err)
}
// Run second job
result, err := pool.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to submit second job: %v", err)
}
// Global should be cleared
if result.(bool) {
t.Errorf("Expected global_var to be cleared, but it still exists")
}
}
// Let's make sure our pool shuts down cleanly. This test confirms that jobs work
// before shutdown, that we get the right error when trying to submit after shutdown,
// and that we properly handle attempts to shut down an already closed pool.
func TestPoolShutdown(t *testing.T) {
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
// Submit a job to verify pool works
bytecode := createTestBytecode(t, "return 42")
_, err = pool.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to submit job: %v", err)
}
// Shutdown
if err := pool.Shutdown(); err != nil {
t.Errorf("Shutdown failed: %v", err)
}
// Submit after shutdown should fail
_, err = pool.Submit(bytecode, nil)
if err != ErrPoolClosed {
t.Errorf("Expected ErrPoolClosed, got %v", err)
}
// Second shutdown should return error
if err := pool.Shutdown(); err != ErrPoolClosed {
t.Errorf("Expected ErrPoolClosed on second shutdown, got %v", err)
}
}
// A robust worker pool needs to handle errors gracefully. This test checks various
// error scenarios: invalid bytecode, Lua runtime errors, nil context (which
// should work fine), and unsupported parameter types (which should properly error out).
func TestErrorHandling(t *testing.T) {
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test invalid bytecode
_, err = pool.Submit([]byte("not valid bytecode"), nil)
if err == nil {
t.Errorf("Expected error for invalid bytecode, got nil")
}
// Test Lua runtime error
bytecode := createTestBytecode(t, `
error("intentional error")
return true
`)
_, err = pool.Submit(bytecode, nil)
if err == nil {
t.Errorf("Expected error from Lua error() call, got nil")
}
// Test with nil context
bytecode = createTestBytecode(t, "return ctx == nil")
result, err := pool.Submit(bytecode, nil)
if err != nil {
t.Errorf("Unexpected error with nil context: %v", err)
}
if result.(bool) != true {
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
}
// Test invalid context value
execCtx := NewContext()
execCtx.Set("param", complex128(1+2i)) // Unsupported type
bytecode = createTestBytecode(t, "return ctx.param")
_, err = pool.Submit(bytecode, execCtx)
if err == nil {
t.Errorf("Expected error for unsupported context value type, got nil")
}
}
// The whole point of a worker pool is concurrent processing, so we need to verify
// it works under load. This test submits multiple jobs simultaneously and makes sure
// they all complete correctly with their own unique results.
func TestConcurrentExecution(t *testing.T) {
const workers = 4
const jobs = 20
pool, err := NewPool(workers)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Create bytecode that returns its input
bytecode := createTestBytecode(t, "return ctx.n")
// Run multiple jobs concurrently
results := make(chan int, jobs)
for i := 0; i < jobs; i++ {
i := i // Capture loop variable
go func() {
execCtx := NewContext()
execCtx.Set("n", float64(i))
result, err := pool.Submit(bytecode, execCtx)
if err != nil {
t.Errorf("Job %d failed: %v", i, err)
results <- -1
return
}
num, ok := result.(float64)
if !ok {
t.Errorf("Job %d: expected float64, got %T", i, result)
results <- -1
return
}
results <- int(num)
}()
}
// Collect results
counts := make(map[int]bool)
for i := 0; i < jobs; i++ {
result := <-results
if result != -1 {
counts[result] = true
}
}
// Verify all jobs were processed
if len(counts) != jobs {
t.Errorf("Expected %d unique results, got %d", jobs, len(counts))
}
}
// Test context operations
func TestContext(t *testing.T) {
ctx := NewContext()
// Test Set and Get
ctx.Set("key", "value")
if ctx.Get("key") != "value" {
t.Errorf("Expected value, got %v", ctx.Get("key"))
}
// Test overwriting
ctx.Set("key", 123)
if ctx.Get("key") != 123 {
t.Errorf("Expected 123, got %v", ctx.Get("key"))
}
// Test missing key
if ctx.Get("missing") != nil {
t.Errorf("Expected nil for missing key, got %v", ctx.Get("missing"))
}
}

View File

@ -47,9 +47,9 @@ func initRouters(routesDir, staticDir string, log *logger.Logger) (*routers.LuaR
func main() {
// Initialize logger
log := logger.New(logger.LevelDebug, true)
log := logger.New(logger.LevelInfo, true)
log.Server("Starting Moonshark server")
log.Info("Starting Moonshark server")
// Load configuration from config.lua
cfg, err := config.Load("config.lua")
@ -59,19 +59,6 @@ func main() {
cfg = config.New()
}
switch cfg.GetString("log_level", "info") {
case "debug":
log.SetLevel(logger.LevelDebug)
case "warn":
log.SetLevel(logger.LevelWarning)
case "error":
log.SetLevel(logger.LevelError)
case "fatal":
log.SetLevel(logger.LevelFatal)
default:
log.SetLevel(logger.LevelInfo)
}
// Get port from config or use default
port := cfg.GetInt("port", 3117)
@ -83,34 +70,32 @@ func main() {
log.Fatal("Router initialization failed: %v", err)
}
if cfg.GetBool("watchers", false) {
// Set up file watchers for automatic reloading
luaWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log)
if err != nil {
log.Warning("Failed to watch routes directory: %v", err)
} else {
defer luaWatcher.Close()
log.Info("File watcher active for Lua routes")
}
// Set up file watchers for automatic reloading
luaWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log)
if err != nil {
log.Warning("Failed to watch routes directory: %v", err)
} else {
defer luaWatcher.Close()
log.Info("File watcher active for Lua routes")
}
staticWatcher, err := watchers.WatchStaticRouter(staticRouter, staticDir, log)
if err != nil {
log.Warning("Failed to watch static directory: %v", err)
} else {
defer staticWatcher.Close()
log.Info("File watcher active for static files")
}
staticWatcher, err := watchers.WatchStaticRouter(staticRouter, staticDir, log)
if err != nil {
log.Warning("Failed to watch static directory: %v", err)
} else {
defer staticWatcher.Close()
log.Info("File watcher active for static files")
}
// Get worker pool size from config or use default
workerPoolSize := cfg.GetInt("pool_size", 4)
workerPoolSize := cfg.GetInt("worker_pool_size", 4)
// Initialize worker pool
pool, err := workers.NewPool(workerPoolSize)
if err != nil {
log.Fatal("Failed to initialize worker pool: %v", err)
}
log.Server("Worker pool initialized with %d workers", workerPoolSize)
log.Info("Worker pool initialized with %d workers", workerPoolSize)
defer pool.Shutdown()
// Create HTTP server
@ -129,11 +114,11 @@ func main() {
}
}()
log.Server("Server started on port %d", port)
log.Info("Server started on port %d", port)
// Wait for interrupt signal
<-stop
log.Server("Shutdown signal received")
log.Info("Shutdown signal received")
// Gracefully shut down the server
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -143,5 +128,5 @@ func main() {
log.Error("Server shutdown error: %v", err)
}
log.Server("Server stopped")
log.Info("Server stopped")
}