179 lines
3.7 KiB
Go
179 lines
3.7 KiB
Go
package luajit
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
)
|
|
|
|
func TestBytecodeCompilation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
code string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "simple assignment",
|
|
code: "x = 42",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "function definition",
|
|
code: "function add(a,b) return a+b end",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "syntax error",
|
|
code: "function bad syntax",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, f := range factories {
|
|
for _, tt := range tests {
|
|
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
|
L := f.new()
|
|
if L == nil {
|
|
t.Fatal("Failed to create Lua state")
|
|
}
|
|
defer L.Close()
|
|
|
|
bytecode, err := L.CompileBytecode(tt.code, "test")
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("CompileBytecode() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
|
|
if !tt.wantErr {
|
|
if len(bytecode) == 0 {
|
|
t.Error("CompileBytecode() returned empty bytecode")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestBytecodeExecution(t *testing.T) {
|
|
for _, f := range factories {
|
|
t.Run(f.name, func(t *testing.T) {
|
|
L := f.new()
|
|
if L == nil {
|
|
t.Fatal("Failed to create Lua state")
|
|
}
|
|
defer L.Close()
|
|
|
|
// Compile some test code
|
|
code := `
|
|
function add(a, b)
|
|
return a + b
|
|
end
|
|
result = add(40, 2)
|
|
`
|
|
|
|
bytecode, err := L.CompileBytecode(code, "test")
|
|
if err != nil {
|
|
t.Fatalf("CompileBytecode() error = %v", err)
|
|
}
|
|
|
|
// Load and execute the bytecode
|
|
if err := L.LoadBytecode(bytecode, "test"); err != nil {
|
|
t.Fatalf("LoadBytecode() error = %v", err)
|
|
}
|
|
|
|
// Verify the result
|
|
L.GetGlobal("result")
|
|
if result := L.ToNumber(-1); result != 42 {
|
|
t.Errorf("got result = %v, want 42", result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInvalidBytecode(t *testing.T) {
|
|
for _, f := range factories {
|
|
t.Run(f.name, func(t *testing.T) {
|
|
L := f.new()
|
|
if L == nil {
|
|
t.Fatal("Failed to create Lua state")
|
|
}
|
|
defer L.Close()
|
|
|
|
// Test with invalid bytecode
|
|
invalidBytecode := []byte("this is not valid bytecode")
|
|
if err := L.LoadBytecode(invalidBytecode, "test"); err == nil {
|
|
t.Error("LoadBytecode() expected error with invalid bytecode")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBytecodeRoundTrip(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
code string
|
|
check func(*State) error
|
|
}{
|
|
{
|
|
name: "global variable",
|
|
code: "x = 42",
|
|
check: func(L *State) error {
|
|
L.GetGlobal("x")
|
|
if x := L.ToNumber(-1); x != 42 {
|
|
return fmt.Errorf("got x = %v, want 42", x)
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
name: "function definition",
|
|
code: "function test() return 'hello' end",
|
|
check: func(L *State) error {
|
|
if err := L.DoString("result = test()"); err != nil {
|
|
return err
|
|
}
|
|
L.GetGlobal("result")
|
|
if s := L.ToString(-1); s != "hello" {
|
|
return fmt.Errorf("got result = %q, want 'hello'", s)
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, f := range factories {
|
|
for _, tt := range tests {
|
|
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
|
// First state for compilation
|
|
L1 := f.new()
|
|
if L1 == nil {
|
|
t.Fatal("Failed to create first Lua state")
|
|
}
|
|
defer L1.Close()
|
|
|
|
// Compile the code
|
|
bytecode, err := L1.CompileBytecode(tt.code, "test")
|
|
if err != nil {
|
|
t.Fatalf("CompileBytecode() error = %v", err)
|
|
}
|
|
|
|
// Second state for execution
|
|
L2 := f.new()
|
|
if L2 == nil {
|
|
t.Fatal("Failed to create second Lua state")
|
|
}
|
|
defer L2.Close()
|
|
|
|
// Load and execute the bytecode
|
|
if err := L2.LoadBytecode(bytecode, "test"); err != nil {
|
|
t.Fatalf("LoadBytecode() error = %v", err)
|
|
}
|
|
|
|
// Run the check function
|
|
if err := tt.check(L2); err != nil {
|
|
t.Errorf("check failed: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|