package luajit import "testing" func TestGoFunctions(t *testing.T) { for _, f := range factories { t.Run(f.name, func(t *testing.T) { L := f.new() if L == nil { t.Fatal("Failed to create Lua state") } defer L.Close() defer L.Cleanup() addFunc := func(s *State) int { s.PushNumber(s.ToNumber(1) + s.ToNumber(2)) return 1 } if err := L.RegisterGoFunction("add", addFunc); err != nil { t.Fatalf("Failed to register function: %v", err) } // Test basic function call if err := L.DoString("result = add(40, 2)"); err != nil { t.Fatalf("Failed to call function: %v", err) } L.GetGlobal("result") if result := L.ToNumber(-1); result != 42 { t.Errorf("got %v, want 42", result) } L.Pop(1) // Test multiple return values multiFunc := func(s *State) int { s.PushString("hello") s.PushNumber(42) s.PushBoolean(true) return 3 } if err := L.RegisterGoFunction("multi", multiFunc); err != nil { t.Fatalf("Failed to register multi function: %v", err) } code := ` a, b, c = multi() result = (a == "hello" and b == 42 and c == true) ` if err := L.DoString(code); err != nil { t.Fatalf("Failed to call multi function: %v", err) } L.GetGlobal("result") if !L.ToBoolean(-1) { t.Error("Multiple return values test failed") } L.Pop(1) // Test error handling errFunc := func(s *State) int { s.PushString("test error") return -1 } if err := L.RegisterGoFunction("err", errFunc); err != nil { t.Fatalf("Failed to register error function: %v", err) } if err := L.DoString("err()"); err == nil { t.Error("Expected error from error function") } // Test unregistering L.UnregisterGoFunction("add") if err := L.DoString("add(1, 2)"); err == nil { t.Error("Expected error calling unregistered function") } }) } } func TestStackSafety(t *testing.T) { L := NewSafe() if L == nil { t.Fatal("Failed to create Lua state") } defer L.Close() defer L.Cleanup() // Test stack overflow protection overflowFunc := func(s *State) int { for i := 0; i < 100; i++ { s.PushNumber(float64(i)) } s.PushString("done") return 101 } if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil { t.Fatal(err) } if err := L.DoString("overflow()"); err != nil { t.Logf("Got expected error: %v", err) } }