Moonshark/core/workers/init_test.go

347 lines
8.3 KiB
Go

package workers
import (
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func TestModuleRegistration(t *testing.T) {
// Define an init function that registers a simple "math" module
mathInit := func(state *luajit.State) error {
// Register the "add" function
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1 // Return one result
})
if err != nil {
return err
}
// Register a whole module
mathFuncs := map[string]luajit.GoFunction{
"multiply": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a * b)
return 1
},
"subtract": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a - b)
return 1
},
}
return RegisterModule(state, "math2", mathFuncs)
}
// Create a pool with our init function
pool, err := NewPoolWithInit(2, mathInit)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test the add function
bytecode1 := createTestBytecode(t, "return add(5, 7)")
result1, err := pool.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to call add function: %v", err)
}
num1, ok := result1.(float64)
if !ok || num1 != 12 {
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
}
// Test the math2 module
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
result2, err := pool.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to call math2.multiply: %v", err)
}
num2, ok := result2.(float64)
if !ok || num2 != 48 {
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
}
// Test multiple operations
bytecode3 := createTestBytecode(t, `
local a = add(10, 20)
local b = math2.subtract(a, 5)
return math2.multiply(b, 2)
`)
result3, err := pool.Submit(bytecode3, nil)
if err != nil {
t.Fatalf("Failed to execute combined operations: %v", err)
}
num3, ok := result3.(float64)
if !ok || num3 != 50 {
t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3)
}
}
func TestModuleInitFunc(t *testing.T) {
// Define math module functions
mathModule := func() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"add": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1
},
"multiply": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a * b)
return 1
},
}
}
// Define string module functions
strModule := func() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"concat": func(s *luajit.State) int {
a := s.ToString(1)
b := s.ToString(2)
s.PushString(a + b)
return 1
},
}
}
// Create module map
modules := map[string]ModuleFunc{
"math2": mathModule,
"str": strModule,
}
// Create pool with module init
pool, err := NewPoolWithInit(2, ModuleInitFunc(modules))
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test math module
bytecode1 := createTestBytecode(t, "return math2.add(5, 7)")
result1, err := pool.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to call math2.add: %v", err)
}
num1, ok := result1.(float64)
if !ok || num1 != 12 {
t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1)
}
// Test string module
bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')")
result2, err := pool.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to call str.concat: %v", err)
}
str2, ok := result2.(string)
if !ok || str2 != "helloworld" {
t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2)
}
}
func TestCombineInitFuncs(t *testing.T) {
// First init function adds a function to get a constant value
init1 := func(state *luajit.State) error {
return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int {
s.PushNumber(42)
return 1
})
}
// Second init function registers a function that multiplies a number by 2
init2 := func(state *luajit.State) error {
return state.RegisterGoFunction("double", func(s *luajit.State) int {
n := s.ToNumber(1)
s.PushNumber(n * 2)
return 1
})
}
// Combine the init functions
combinedInit := CombineInitFuncs(init1, init2)
// Create a pool with the combined init function
pool, err := NewPoolWithInit(1, combinedInit)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test using both functions together in a single script
bytecode := createTestBytecode(t, "return double(getAnswer())")
result, err := pool.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to execute: %v", err)
}
num, ok := result.(float64)
if !ok || num != 84 {
t.Errorf("Expected double(getAnswer()) = 84, got %v", result)
}
}
func TestSandboxIsolation(t *testing.T) {
// Create a pool
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Create a script that tries to modify a global variable
bytecode1 := createTestBytecode(t, `
-- Set a "global" variable
my_global = "test value"
return true
`)
_, err = pool.Submit(bytecode1, nil)
if err != nil {
t.Fatalf("Failed to execute first script: %v", err)
}
// Now try to access that variable from another script
bytecode2 := createTestBytecode(t, `
-- Try to access the previously set global
return my_global ~= nil
`)
result, err := pool.Submit(bytecode2, nil)
if err != nil {
t.Fatalf("Failed to execute second script: %v", err)
}
// The variable should not be accessible (sandbox isolation)
if result.(bool) {
t.Errorf("Expected sandbox isolation, but global variable was accessible")
}
}
func TestContextInSandbox(t *testing.T) {
// Create a pool
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Create a context with test data
ctx := NewContext()
ctx.Set("name", "test")
ctx.Set("value", 42.5)
ctx.Set("items", []float64{1, 2, 3})
bytecode := createTestBytecode(t, `
-- Access and manipulate context values
local sum = 0
for i, v in ipairs(ctx.items) do
sum = sum + v
end
return {
name_length = string.len(ctx.name),
value_doubled = ctx.value * 2,
items_sum = sum
}
`)
result, err := pool.Submit(bytecode, ctx)
if err != nil {
t.Fatalf("Failed to execute script with context: %v", err)
}
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Check context values were correctly accessible
if resultMap["name_length"].(float64) != 4 {
t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"])
}
if resultMap["value_doubled"].(float64) != 85 {
t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"])
}
if resultMap["items_sum"].(float64) != 6 {
t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"])
}
}
func TestStandardLibsInSandbox(t *testing.T) {
// Create a pool
pool, err := NewPool(2)
if err != nil {
t.Fatalf("Failed to create pool: %v", err)
}
defer pool.Shutdown()
// Test access to standard libraries
bytecode := createTestBytecode(t, `
local results = {}
-- Test string library
results.string_upper = string.upper("test")
-- Test math library
results.math_sqrt = math.sqrt(16)
-- Test table library
local tbl = {10, 20, 30}
table.insert(tbl, 40)
results.table_length = #tbl
-- Test os library (limited functions)
results.has_os_time = type(os.time) == "function"
return results
`)
result, err := pool.Submit(bytecode, nil)
if err != nil {
t.Fatalf("Failed to execute script: %v", err)
}
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Check standard library functions worked
if resultMap["string_upper"] != "TEST" {
t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"])
}
if resultMap["math_sqrt"].(float64) != 4 {
t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"])
}
if resultMap["table_length"].(float64) != 4 {
t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"])
}
if resultMap["has_os_time"] != true {
t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"])
}
}