347 lines
8.3 KiB
Go
347 lines
8.3 KiB
Go
package workers
|
|
|
|
import (
|
|
"testing"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
func TestModuleRegistration(t *testing.T) {
|
|
// Define an init function that registers a simple "math" module
|
|
mathInit := func(state *luajit.State) error {
|
|
// Register the "add" function
|
|
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
|
|
a := s.ToNumber(1)
|
|
b := s.ToNumber(2)
|
|
s.PushNumber(a + b)
|
|
return 1 // Return one result
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Register a whole module
|
|
mathFuncs := map[string]luajit.GoFunction{
|
|
"multiply": func(s *luajit.State) int {
|
|
a := s.ToNumber(1)
|
|
b := s.ToNumber(2)
|
|
s.PushNumber(a * b)
|
|
return 1
|
|
},
|
|
"subtract": func(s *luajit.State) int {
|
|
a := s.ToNumber(1)
|
|
b := s.ToNumber(2)
|
|
s.PushNumber(a - b)
|
|
return 1
|
|
},
|
|
}
|
|
|
|
return RegisterModule(state, "math2", mathFuncs)
|
|
}
|
|
|
|
// Create a pool with our init function
|
|
pool, err := NewPoolWithInit(2, mathInit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create pool: %v", err)
|
|
}
|
|
defer pool.Shutdown()
|
|
|
|
// Test the add function
|
|
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
|
result1, err := pool.Submit(bytecode1, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to call add function: %v", err)
|
|
}
|
|
|
|
num1, ok := result1.(float64)
|
|
if !ok || num1 != 12 {
|
|
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
|
|
}
|
|
|
|
// Test the math2 module
|
|
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
|
result2, err := pool.Submit(bytecode2, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to call math2.multiply: %v", err)
|
|
}
|
|
|
|
num2, ok := result2.(float64)
|
|
if !ok || num2 != 48 {
|
|
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
|
|
}
|
|
|
|
// Test multiple operations
|
|
bytecode3 := createTestBytecode(t, `
|
|
local a = add(10, 20)
|
|
local b = math2.subtract(a, 5)
|
|
return math2.multiply(b, 2)
|
|
`)
|
|
|
|
result3, err := pool.Submit(bytecode3, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute combined operations: %v", err)
|
|
}
|
|
|
|
num3, ok := result3.(float64)
|
|
if !ok || num3 != 50 {
|
|
t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3)
|
|
}
|
|
}
|
|
|
|
func TestModuleInitFunc(t *testing.T) {
|
|
// Define math module functions
|
|
mathModule := func() map[string]luajit.GoFunction {
|
|
return map[string]luajit.GoFunction{
|
|
"add": func(s *luajit.State) int {
|
|
a := s.ToNumber(1)
|
|
b := s.ToNumber(2)
|
|
s.PushNumber(a + b)
|
|
return 1
|
|
},
|
|
"multiply": func(s *luajit.State) int {
|
|
a := s.ToNumber(1)
|
|
b := s.ToNumber(2)
|
|
s.PushNumber(a * b)
|
|
return 1
|
|
},
|
|
}
|
|
}
|
|
|
|
// Define string module functions
|
|
strModule := func() map[string]luajit.GoFunction {
|
|
return map[string]luajit.GoFunction{
|
|
"concat": func(s *luajit.State) int {
|
|
a := s.ToString(1)
|
|
b := s.ToString(2)
|
|
s.PushString(a + b)
|
|
return 1
|
|
},
|
|
}
|
|
}
|
|
|
|
// Create module map
|
|
modules := map[string]ModuleFunc{
|
|
"math2": mathModule,
|
|
"str": strModule,
|
|
}
|
|
|
|
// Create pool with module init
|
|
pool, err := NewPoolWithInit(2, ModuleInitFunc(modules))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create pool: %v", err)
|
|
}
|
|
defer pool.Shutdown()
|
|
|
|
// Test math module
|
|
bytecode1 := createTestBytecode(t, "return math2.add(5, 7)")
|
|
result1, err := pool.Submit(bytecode1, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to call math2.add: %v", err)
|
|
}
|
|
|
|
num1, ok := result1.(float64)
|
|
if !ok || num1 != 12 {
|
|
t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1)
|
|
}
|
|
|
|
// Test string module
|
|
bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')")
|
|
result2, err := pool.Submit(bytecode2, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to call str.concat: %v", err)
|
|
}
|
|
|
|
str2, ok := result2.(string)
|
|
if !ok || str2 != "helloworld" {
|
|
t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2)
|
|
}
|
|
}
|
|
|
|
func TestCombineInitFuncs(t *testing.T) {
|
|
// First init function adds a function to get a constant value
|
|
init1 := func(state *luajit.State) error {
|
|
return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int {
|
|
s.PushNumber(42)
|
|
return 1
|
|
})
|
|
}
|
|
|
|
// Second init function registers a function that multiplies a number by 2
|
|
init2 := func(state *luajit.State) error {
|
|
return state.RegisterGoFunction("double", func(s *luajit.State) int {
|
|
n := s.ToNumber(1)
|
|
s.PushNumber(n * 2)
|
|
return 1
|
|
})
|
|
}
|
|
|
|
// Combine the init functions
|
|
combinedInit := CombineInitFuncs(init1, init2)
|
|
|
|
// Create a pool with the combined init function
|
|
pool, err := NewPoolWithInit(1, combinedInit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create pool: %v", err)
|
|
}
|
|
defer pool.Shutdown()
|
|
|
|
// Test using both functions together in a single script
|
|
bytecode := createTestBytecode(t, "return double(getAnswer())")
|
|
result, err := pool.Submit(bytecode, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute: %v", err)
|
|
}
|
|
|
|
num, ok := result.(float64)
|
|
if !ok || num != 84 {
|
|
t.Errorf("Expected double(getAnswer()) = 84, got %v", result)
|
|
}
|
|
}
|
|
|
|
func TestSandboxIsolation(t *testing.T) {
|
|
// Create a pool
|
|
pool, err := NewPool(2)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create pool: %v", err)
|
|
}
|
|
defer pool.Shutdown()
|
|
|
|
// Create a script that tries to modify a global variable
|
|
bytecode1 := createTestBytecode(t, `
|
|
-- Set a "global" variable
|
|
my_global = "test value"
|
|
return true
|
|
`)
|
|
|
|
_, err = pool.Submit(bytecode1, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute first script: %v", err)
|
|
}
|
|
|
|
// Now try to access that variable from another script
|
|
bytecode2 := createTestBytecode(t, `
|
|
-- Try to access the previously set global
|
|
return my_global ~= nil
|
|
`)
|
|
|
|
result, err := pool.Submit(bytecode2, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute second script: %v", err)
|
|
}
|
|
|
|
// The variable should not be accessible (sandbox isolation)
|
|
if result.(bool) {
|
|
t.Errorf("Expected sandbox isolation, but global variable was accessible")
|
|
}
|
|
}
|
|
|
|
func TestContextInSandbox(t *testing.T) {
|
|
// Create a pool
|
|
pool, err := NewPool(2)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create pool: %v", err)
|
|
}
|
|
defer pool.Shutdown()
|
|
|
|
// Create a context with test data
|
|
ctx := NewContext()
|
|
ctx.Set("name", "test")
|
|
ctx.Set("value", 42.5)
|
|
ctx.Set("items", []float64{1, 2, 3})
|
|
|
|
bytecode := createTestBytecode(t, `
|
|
-- Access and manipulate context values
|
|
local sum = 0
|
|
for i, v in ipairs(ctx.items) do
|
|
sum = sum + v
|
|
end
|
|
|
|
return {
|
|
name_length = string.len(ctx.name),
|
|
value_doubled = ctx.value * 2,
|
|
items_sum = sum
|
|
}
|
|
`)
|
|
|
|
result, err := pool.Submit(bytecode, ctx)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute script with context: %v", err)
|
|
}
|
|
|
|
resultMap, ok := result.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("Expected map result, got %T", result)
|
|
}
|
|
|
|
// Check context values were correctly accessible
|
|
if resultMap["name_length"].(float64) != 4 {
|
|
t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"])
|
|
}
|
|
|
|
if resultMap["value_doubled"].(float64) != 85 {
|
|
t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"])
|
|
}
|
|
|
|
if resultMap["items_sum"].(float64) != 6 {
|
|
t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"])
|
|
}
|
|
}
|
|
|
|
func TestStandardLibsInSandbox(t *testing.T) {
|
|
// Create a pool
|
|
pool, err := NewPool(2)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create pool: %v", err)
|
|
}
|
|
defer pool.Shutdown()
|
|
|
|
// Test access to standard libraries
|
|
bytecode := createTestBytecode(t, `
|
|
local results = {}
|
|
|
|
-- Test string library
|
|
results.string_upper = string.upper("test")
|
|
|
|
-- Test math library
|
|
results.math_sqrt = math.sqrt(16)
|
|
|
|
-- Test table library
|
|
local tbl = {10, 20, 30}
|
|
table.insert(tbl, 40)
|
|
results.table_length = #tbl
|
|
|
|
-- Test os library (limited functions)
|
|
results.has_os_time = type(os.time) == "function"
|
|
|
|
return results
|
|
`)
|
|
|
|
result, err := pool.Submit(bytecode, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute script: %v", err)
|
|
}
|
|
|
|
resultMap, ok := result.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("Expected map result, got %T", result)
|
|
}
|
|
|
|
// Check standard library functions worked
|
|
if resultMap["string_upper"] != "TEST" {
|
|
t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"])
|
|
}
|
|
|
|
if resultMap["math_sqrt"].(float64) != 4 {
|
|
t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"])
|
|
}
|
|
|
|
if resultMap["table_length"].(float64) != 4 {
|
|
t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"])
|
|
}
|
|
|
|
if resultMap["has_os_time"] != true {
|
|
t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"])
|
|
}
|
|
}
|