LuaJIT-to-Go/bytecode_test.go

179 lines
3.7 KiB
Go

package luajit
import (
"fmt"
"testing"
)
func TestBytecodeCompilation(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{
name: "simple assignment",
code: "x = 42",
wantErr: false,
},
{
name: "function definition",
code: "function add(a,b) return a+b end",
wantErr: false,
},
{
name: "syntax error",
code: "function bad syntax",
wantErr: true,
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
bytecode, err := L.CompileBytecode(tt.code, "test")
if (err != nil) != tt.wantErr {
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if len(bytecode) == 0 {
t.Error("CompileBytecode() returned empty bytecode")
}
}
})
}
}
}
func TestBytecodeExecution(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Compile some test code
code := `
function add(a, b)
return a + b
end
result = add(40, 2)
`
bytecode, err := L.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Load and execute the bytecode
if err := L.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Verify the result
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got result = %v, want 42", result)
}
})
}
}
func TestInvalidBytecode(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test with invalid bytecode
invalidBytecode := []byte("this is not valid bytecode")
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
t.Error("LoadBytecode() expected error with invalid bytecode")
}
})
}
}
func TestBytecodeRoundTrip(t *testing.T) {
tests := []struct {
name string
code string
check func(*State) error
}{
{
name: "global variable",
code: "x = 42",
check: func(L *State) error {
L.GetGlobal("x")
if x := L.ToNumber(-1); x != 42 {
return fmt.Errorf("got x = %v, want 42", x)
}
return nil
},
},
{
name: "function definition",
code: "function test() return 'hello' end",
check: func(L *State) error {
if err := L.DoString("result = test()"); err != nil {
return err
}
L.GetGlobal("result")
if s := L.ToString(-1); s != "hello" {
return fmt.Errorf("got result = %q, want 'hello'", s)
}
return nil
},
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
// First state for compilation
L1 := f.new()
if L1 == nil {
t.Fatal("Failed to create first Lua state")
}
defer L1.Close()
// Compile the code
bytecode, err := L1.CompileBytecode(tt.code, "test")
if err != nil {
t.Fatalf("CompileBytecode() error = %v", err)
}
// Second state for execution
L2 := f.new()
if L2 == nil {
t.Fatal("Failed to create second Lua state")
}
defer L2.Close()
// Load and execute the bytecode
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
t.Fatalf("LoadBytecode() error = %v", err)
}
// Run the check function
if err := tt.check(L2); err != nil {
t.Errorf("check failed: %v", err)
}
})
}
}
}