package workers import ( "testing" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) func TestModuleRegistration(t *testing.T) { // Define an init function that registers a simple "math" module mathInit := func(state *luajit.State) error { // Register the "add" function err := state.RegisterGoFunction("add", func(s *luajit.State) int { a := s.ToNumber(1) b := s.ToNumber(2) s.PushNumber(a + b) return 1 // Return one result }) if err != nil { return err } // Register a whole module mathFuncs := map[string]luajit.GoFunction{ "multiply": func(s *luajit.State) int { a := s.ToNumber(1) b := s.ToNumber(2) s.PushNumber(a * b) return 1 }, "subtract": func(s *luajit.State) int { a := s.ToNumber(1) b := s.ToNumber(2) s.PushNumber(a - b) return 1 }, } return RegisterModule(state, "math2", mathFuncs) } // Create a pool with our init function pool, err := NewPoolWithInit(2, mathInit) if err != nil { t.Fatalf("Failed to create pool: %v", err) } defer pool.Shutdown() // Test the add function bytecode1 := createTestBytecode(t, "return add(5, 7)") result1, err := pool.Submit(bytecode1, nil) if err != nil { t.Fatalf("Failed to call add function: %v", err) } num1, ok := result1.(float64) if !ok || num1 != 12 { t.Errorf("Expected add(5, 7) = 12, got %v", result1) } // Test the math2 module bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)") result2, err := pool.Submit(bytecode2, nil) if err != nil { t.Fatalf("Failed to call math2.multiply: %v", err) } num2, ok := result2.(float64) if !ok || num2 != 48 { t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2) } // Test multiple operations bytecode3 := createTestBytecode(t, ` local a = add(10, 20) local b = math2.subtract(a, 5) return math2.multiply(b, 2) `) result3, err := pool.Submit(bytecode3, nil) if err != nil { t.Fatalf("Failed to execute combined operations: %v", err) } num3, ok := result3.(float64) if !ok || num3 != 50 { t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3) } } func TestModuleInitFunc(t *testing.T) { // Define math module functions mathModule := func() map[string]luajit.GoFunction { return map[string]luajit.GoFunction{ "add": func(s *luajit.State) int { a := s.ToNumber(1) b := s.ToNumber(2) s.PushNumber(a + b) return 1 }, "multiply": func(s *luajit.State) int { a := s.ToNumber(1) b := s.ToNumber(2) s.PushNumber(a * b) return 1 }, } } // Define string module functions strModule := func() map[string]luajit.GoFunction { return map[string]luajit.GoFunction{ "concat": func(s *luajit.State) int { a := s.ToString(1) b := s.ToString(2) s.PushString(a + b) return 1 }, } } // Create module map modules := map[string]ModuleFunc{ "math2": mathModule, "str": strModule, } // Create pool with module init pool, err := NewPoolWithInit(2, ModuleInitFunc(modules)) if err != nil { t.Fatalf("Failed to create pool: %v", err) } defer pool.Shutdown() // Test math module bytecode1 := createTestBytecode(t, "return math2.add(5, 7)") result1, err := pool.Submit(bytecode1, nil) if err != nil { t.Fatalf("Failed to call math2.add: %v", err) } num1, ok := result1.(float64) if !ok || num1 != 12 { t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1) } // Test string module bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')") result2, err := pool.Submit(bytecode2, nil) if err != nil { t.Fatalf("Failed to call str.concat: %v", err) } str2, ok := result2.(string) if !ok || str2 != "helloworld" { t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2) } } func TestCombineInitFuncs(t *testing.T) { // First init function adds a function to get a constant value init1 := func(state *luajit.State) error { return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int { s.PushNumber(42) return 1 }) } // Second init function registers a function that multiplies a number by 2 init2 := func(state *luajit.State) error { return state.RegisterGoFunction("double", func(s *luajit.State) int { n := s.ToNumber(1) s.PushNumber(n * 2) return 1 }) } // Combine the init functions combinedInit := CombineInitFuncs(init1, init2) // Create a pool with the combined init function pool, err := NewPoolWithInit(1, combinedInit) if err != nil { t.Fatalf("Failed to create pool: %v", err) } defer pool.Shutdown() // Test using both functions together in a single script bytecode := createTestBytecode(t, "return double(getAnswer())") result, err := pool.Submit(bytecode, nil) if err != nil { t.Fatalf("Failed to execute: %v", err) } num, ok := result.(float64) if !ok || num != 84 { t.Errorf("Expected double(getAnswer()) = 84, got %v", result) } } func TestSandboxIsolation(t *testing.T) { // Create a pool pool, err := NewPool(2) if err != nil { t.Fatalf("Failed to create pool: %v", err) } defer pool.Shutdown() // Create a script that tries to modify a global variable bytecode1 := createTestBytecode(t, ` -- Set a "global" variable my_global = "test value" return true `) _, err = pool.Submit(bytecode1, nil) if err != nil { t.Fatalf("Failed to execute first script: %v", err) } // Now try to access that variable from another script bytecode2 := createTestBytecode(t, ` -- Try to access the previously set global return my_global ~= nil `) result, err := pool.Submit(bytecode2, nil) if err != nil { t.Fatalf("Failed to execute second script: %v", err) } // The variable should not be accessible (sandbox isolation) if result.(bool) { t.Errorf("Expected sandbox isolation, but global variable was accessible") } } func TestContextInSandbox(t *testing.T) { // Create a pool pool, err := NewPool(2) if err != nil { t.Fatalf("Failed to create pool: %v", err) } defer pool.Shutdown() // Create a context with test data ctx := NewContext() ctx.Set("name", "test") ctx.Set("value", 42.5) ctx.Set("items", []float64{1, 2, 3}) bytecode := createTestBytecode(t, ` -- Access and manipulate context values local sum = 0 for i, v in ipairs(ctx.items) do sum = sum + v end return { name_length = string.len(ctx.name), value_doubled = ctx.value * 2, items_sum = sum } `) result, err := pool.Submit(bytecode, ctx) if err != nil { t.Fatalf("Failed to execute script with context: %v", err) } resultMap, ok := result.(map[string]any) if !ok { t.Fatalf("Expected map result, got %T", result) } // Check context values were correctly accessible if resultMap["name_length"].(float64) != 4 { t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"]) } if resultMap["value_doubled"].(float64) != 85 { t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"]) } if resultMap["items_sum"].(float64) != 6 { t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"]) } } func TestStandardLibsInSandbox(t *testing.T) { // Create a pool pool, err := NewPool(2) if err != nil { t.Fatalf("Failed to create pool: %v", err) } defer pool.Shutdown() // Test access to standard libraries bytecode := createTestBytecode(t, ` local results = {} -- Test string library results.string_upper = string.upper("test") -- Test math library results.math_sqrt = math.sqrt(16) -- Test table library local tbl = {10, 20, 30} table.insert(tbl, 40) results.table_length = #tbl -- Test os library (limited functions) results.has_os_time = type(os.time) == "function" return results `) result, err := pool.Submit(bytecode, nil) if err != nil { t.Fatalf("Failed to execute script: %v", err) } resultMap, ok := result.(map[string]any) if !ok { t.Fatalf("Expected map result, got %T", result) } // Check standard library functions worked if resultMap["string_upper"] != "TEST" { t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"]) } if resultMap["math_sqrt"].(float64) != 4 { t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"]) } if resultMap["table_length"].(float64) != 4 { t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"]) } if resultMap["has_os_time"] != true { t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"]) } }