diff --git a/bench/bench_test.go b/bench/bench_test.go index b3f7483..356de95 100644 --- a/bench/bench_test.go +++ b/bench/bench_test.go @@ -321,37 +321,37 @@ func BenchmarkComplexScript(b *testing.B) { -- Define a simple class local Class = {} Class.__index = Class - + function Class.new(x, y) local self = setmetatable({}, Class) self.x = x or 0 self.y = y or 0 return self end - + function Class:move(dx, dy) self.x = self.x + dx self.y = self.y + dy return self end - + function Class:getPosition() return self.x, self.y end - + -- Create instances and operate on them local instances = {} for i = 1, 50 do instances[i] = Class.new(i, i*2) end - + local result = 0 for i, obj in ipairs(instances) do obj:move(i, -i) local x, y = obj:getPosition() result = result + x + y end - + return result ` b.ResetTimer() @@ -377,37 +377,37 @@ func BenchmarkComplexScriptPrecompiled(b *testing.B) { -- Define a simple class local Class = {} Class.__index = Class - + function Class.new(x, y) local self = setmetatable({}, Class) self.x = x or 0 self.y = y or 0 return self end - + function Class:move(dx, dy) self.x = self.x + dx self.y = self.y + dy return self end - + function Class:getPosition() return self.x, self.y end - + -- Create instances and operate on them local instances = {} for i = 1, 50 do instances[i] = Class.new(i, i*2) end - + local result = 0 for i, obj in ipairs(instances) do obj:move(i, -i) local x, y = obj:getPosition() result = result + x + y end - + return result ` bytecode, err := state.CompileBytecode(code, "complex") diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..ed8fa01 --- /dev/null +++ b/builder.go @@ -0,0 +1,72 @@ +package luajit + +// TableBuilder provides a fluent interface for building Lua tables +type TableBuilder struct { + state *State + index int +} + +// NewTableBuilder creates a new table and returns a builder +func (s *State) NewTableBuilder() *TableBuilder { + s.NewTable() + return &TableBuilder{ + state: s, + index: s.GetTop(), + } +} + +// SetString sets a string field +func (tb *TableBuilder) SetString(key, value string) *TableBuilder { + tb.state.PushString(value) + tb.state.SetField(tb.index, key) + return tb +} + +// SetNumber sets a number field +func (tb *TableBuilder) SetNumber(key string, value float64) *TableBuilder { + tb.state.PushNumber(value) + tb.state.SetField(tb.index, key) + return tb +} + +// SetBool sets a boolean field +func (tb *TableBuilder) SetBool(key string, value bool) *TableBuilder { + tb.state.PushBoolean(value) + tb.state.SetField(tb.index, key) + return tb +} + +// SetNil sets a nil field +func (tb *TableBuilder) SetNil(key string) *TableBuilder { + tb.state.PushNil() + tb.state.SetField(tb.index, key) + return tb +} + +// SetTable sets a table field +func (tb *TableBuilder) SetTable(key string, value any) *TableBuilder { + if err := tb.state.PushValue(value); err == nil { + tb.state.SetField(tb.index, key) + } + return tb +} + +// SetArray sets an array field +func (tb *TableBuilder) SetArray(key string, values []any) *TableBuilder { + tb.state.CreateTable(len(values), 0) + for i, v := range values { + tb.state.PushNumber(float64(i + 1)) + if err := tb.state.PushValue(v); err == nil { + tb.state.SetTable(-3) + } else { + tb.state.Pop(1) + } + } + tb.state.SetField(tb.index, key) + return tb +} + +// Build finalizes the table (no-op, table is already on stack) +func (tb *TableBuilder) Build() { + // Table is already on the stack at tb.index +} diff --git a/bytecode.go b/bytecode.go index b28a30d..bd3e6c7 100644 --- a/bytecode.go +++ b/bytecode.go @@ -12,6 +12,12 @@ typedef struct { const char *name; } BytecodeReader; +typedef struct { + unsigned char *buf; + size_t size; + size_t capacity; +} BytecodeBuffer; + const char *bytecode_reader(lua_State *L, void *ud, size_t *size) { BytecodeReader *r = (BytecodeReader *)ud; (void)L; // unused @@ -26,16 +32,24 @@ int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char return lua_load(L, bytecode_reader, &reader, name); } -// Direct bytecode dumping without intermediate buffer - more efficient -int direct_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) { - void **data = (void **)ud; - size_t current_size = (size_t)data[1]; - void *newbuf = realloc(data[0], current_size + sz); - if (newbuf == NULL) return 1; +// Optimized bytecode writer with pre-allocated buffer +int buffered_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) { + BytecodeBuffer *buf = (BytecodeBuffer *)ud; - memcpy((unsigned char*)newbuf + current_size, p, sz); - data[0] = newbuf; - data[1] = (void*)(current_size + sz); + // Grow buffer if needed (double size to avoid frequent reallocs) + if (buf->size + sz > buf->capacity) { + size_t new_capacity = buf->capacity; + while (new_capacity < buf->size + sz) { + new_capacity *= 2; + } + unsigned char *newbuf = realloc(buf->buf, new_capacity); + if (newbuf == NULL) return 1; + buf->buf = newbuf; + buf->capacity = new_capacity; + } + + memcpy(buf->buf + buf->size, p, sz); + buf->size += sz; return 0; } @@ -52,36 +66,56 @@ int load_and_run_bytecode(lua_State *L, const unsigned char *buf, size_t len, import "C" import ( "fmt" + "sync" "unsafe" ) +// bytecodeBuffer wraps []byte to avoid boxing allocations in sync.Pool +type bytecodeBuffer struct { + data []byte +} + +// Buffer pool for bytecode generation +var bytecodeBufferPool = sync.Pool{ + New: func() any { + return &bytecodeBuffer{data: make([]byte, 0, 1024)} + }, +} + // CompileBytecode compiles a Lua chunk to bytecode without executing it func (s *State) CompileBytecode(code string, name string) ([]byte, error) { if err := s.LoadString(code); err != nil { return nil, fmt.Errorf("failed to load string: %w", err) } - // Use a simpler direct writer with just two pointers - data := [2]unsafe.Pointer{nil, nil} + // Always use C memory for dump operation to avoid cgo pointer issues + cbuf := C.BytecodeBuffer{ + buf: (*C.uchar)(C.malloc(1024)), + size: 0, + capacity: 1024, + } + if cbuf.buf == nil { + return nil, fmt.Errorf("failed to allocate initial buffer") + } // Dump the function to bytecode - status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.direct_bytecode_writer)), unsafe.Pointer(&data)) - if status != 0 { - return nil, fmt.Errorf("failed to dump bytecode: status %d", status) - } - - // Get result - var bytecode []byte - if data[0] != nil { - // Create Go slice that references the C memory - length := uintptr(data[1]) - bytecode = C.GoBytes(data[0], C.int(length)) - C.free(data[0]) - } + status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.buffered_bytecode_writer)), unsafe.Pointer(&cbuf)) s.Pop(1) // Remove the function from stack - return bytecode, nil + if status != 0 { + C.free(unsafe.Pointer(cbuf.buf)) + return nil, fmt.Errorf("failed to dump bytecode: status %d", status) + } + + // Copy to Go memory and free C buffer + var result []byte + if cbuf.size > 0 { + result = C.GoBytes(unsafe.Pointer(cbuf.buf), C.int(cbuf.size)) + } + C.free(unsafe.Pointer(cbuf.buf)) + + return result, nil } // LoadBytecode loads precompiled bytecode without executing it @@ -116,7 +150,6 @@ func (s *State) RunBytecode() error { } // RunBytecodeWithResults executes bytecode and keeps nresults on the stack -// Use LUA_MULTRET (-1) to keep all results func (s *State) RunBytecodeWithResults(nresults int) error { status := C.lua_pcall(s.L, 0, C.int(nresults), 0) if status != 0 { @@ -136,13 +169,12 @@ func (s *State) LoadAndRunBytecode(bytecode []byte, name string) error { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) - // Use combined load and run function status := C.load_and_run_bytecode( s.L, (*C.uchar)(unsafe.Pointer(&bytecode[0])), C.size_t(len(bytecode)), cname, - 0, // No results + 0, ) if status != 0 { @@ -163,7 +195,6 @@ func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nres cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) - // Use combined load and run function status := C.load_and_run_bytecode( s.L, (*C.uchar)(unsafe.Pointer(&bytecode[0])), diff --git a/functions.go b/functions.go index 0d396a8..bf550f3 100644 --- a/functions.go +++ b/functions.go @@ -9,7 +9,7 @@ extern int goFunctionWrapper(lua_State* L); // Helper function to access upvalues static int get_upvalue_index(int i) { - return lua_upvalueindex(i); + return lua_upvalueindex(i); } */ import "C" @@ -34,11 +34,20 @@ var ( }{ funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize), } + + // statePool reuses State structs to avoid allocations + statePool = sync.Pool{ + New: func() any { + return &State{} + }, + } ) //export goFunctionWrapper func goFunctionWrapper(L *C.lua_State) C.int { - state := &State{L: L} + state := statePool.Get().(*State) + state.L = L + defer statePool.Put(state) ptr := C.lua_touserdata(L, C.get_upvalue_index(1)) if ptr == nil { @@ -51,8 +60,6 @@ func goFunctionWrapper(L *C.lua_State) C.int { functionRegistry.RUnlock() if !ok { - // Debug logging - fmt.Printf("Function not found for pointer %p\n", ptr) state.PushString("error: function not found in registry") return -1 } diff --git a/stack.go b/stack.go index eba38e1..b5625ed 100644 --- a/stack.go +++ b/stack.go @@ -46,14 +46,6 @@ func (e *LuaError) Error() string { return result } -// Stack management constants from lua.h -const ( - LUA_MINSTACK = 20 // Minimum Lua stack size - LUA_MAXSTACK = 1000000 // Maximum Lua stack size - LUA_REGISTRYINDEX = -10000 // Pseudo-index for the Lua registry - LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table -) - // GetStackTrace returns the current Lua stack trace func (s *State) GetStackTrace() string { s.GetGlobal("debug") @@ -64,13 +56,13 @@ func (s *State) GetStackTrace() string { s.GetField(-1, "traceback") if !s.IsFunction(-1) { - s.Pop(2) // Remove debug table and non-function + s.Pop(2) return "debug.traceback not available" } s.Call(0, 1) trace := s.ToString(-1) - s.Pop(1) // Remove the trace + s.Pop(1) return trace } @@ -97,13 +89,11 @@ func (s *State) GetErrorInfo(context string) *LuaError { if secondColonPos := strings.Index(afterColon, ":"); secondColonPos > 0 { file = beforeColon if n, err := fmt.Sscanf(afterColon[:secondColonPos], "%d", &line); n == 1 && err == nil { - // Strip the file:line part from message for cleaner display message = strings.TrimSpace(afterColon[secondColonPos+1:]) } } } - // Get stack trace stackTrace := s.GetStackTrace() return &LuaError{ @@ -121,3 +111,9 @@ func (s *State) CreateLuaError(code int, context string) *LuaError { err.Code = code return err } + +// PushError pushes an error string and returns -1 +func (s *State) PushError(format string, args ...any) int { + s.PushString(fmt.Sprintf(format, args...)) + return -1 +} diff --git a/table.go b/table.go deleted file mode 100644 index 9486595..0000000 --- a/table.go +++ /dev/null @@ -1,164 +0,0 @@ -package luajit - -/* -#include -#include -#include -#include - -// Simple direct length check -size_t get_table_length(lua_State *L, int index) { - return lua_objlen(L, index); -} -*/ -import "C" -import ( - "fmt" - "strconv" -) - -// GetTableLength returns the length of a table at the given index -func (s *State) GetTableLength(index int) int { - return int(C.get_table_length(s.L, C.int(index))) -} - -// PushTable pushes a Go map onto the Lua stack as a table -func (s *State) PushTable(table map[string]any) error { - // Fast path for array tables - if arr, ok := table[""]; ok { - if floatArr, ok := arr.([]float64); ok { - s.CreateTable(len(floatArr), 0) - for i, v := range floatArr { - s.PushNumber(float64(i + 1)) - s.PushNumber(v) - s.SetTable(-3) - } - return nil - } else if anyArr, ok := arr.([]any); ok { - s.CreateTable(len(anyArr), 0) - for i, v := range anyArr { - s.PushNumber(float64(i + 1)) - if err := s.PushValue(v); err != nil { - return err - } - s.SetTable(-3) - } - return nil - } - } - - // Regular table case - optimize capacity hint - s.CreateTable(0, len(table)) - - // Add each key-value pair directly - for k, v := range table { - s.PushString(k) - if err := s.PushValue(v); err != nil { - return err - } - s.SetTable(-3) - } - - return nil -} - -// ToTable converts a Lua table at the given index to a Go map -func (s *State) ToTable(index int) (map[string]any, error) { - absIdx := s.absIndex(index) - if !s.IsTable(absIdx) { - return nil, fmt.Errorf("value at index %d is not a table", index) - } - - // Try to detect array-like tables first - length := s.GetTableLength(absIdx) - if length > 0 { - // Fast path for common array case - allNumbers := true - - // Sample first few values to check if it's likely an array of numbers - for i := 1; i <= min(length, 5); i++ { - s.PushNumber(float64(i)) - s.GetTable(absIdx) - - if !s.IsNumber(-1) { - allNumbers = false - s.Pop(1) - break - } - s.Pop(1) - } - - if allNumbers { - // Efficiently extract array values - array := make([]float64, length) - for i := 1; i <= length; i++ { - s.PushNumber(float64(i)) - s.GetTable(absIdx) - array[i-1] = s.ToNumber(-1) - s.Pop(1) - } - - // Return array as a special table with empty key - result := make(map[string]any, 1) - result[""] = array - return result, nil - } - } - - // Handle regular table with pre-allocated capacity - table := make(map[string]any, max(length, 8)) - - // Iterate through all key-value pairs - s.PushNil() // Start iteration with nil key - for s.Next(absIdx) { - // Stack now has key at -2 and value at -1 - - // Convert key to string - var key string - keyType := s.GetType(-2) - switch keyType { - case TypeString: - key = s.ToString(-2) - case TypeNumber: - key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64) - default: - // Skip non-string/non-number keys - s.Pop(1) // Pop value, leave key for next iteration - continue - } - - // Convert and store the value - value, err := s.ToValue(-1) - if err != nil { - s.Pop(2) // Pop both key and value - return nil, err - } - - // Unwrap nested array tables - if m, ok := value.(map[string]any); ok { - if arr, ok := m[""]; ok { - value = arr - } - } - - table[key] = value - s.Pop(1) // Pop value, leave key for next iteration - } - - return table, nil -} - -// Helper functions for min/max operations -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func max(a, b int) int { - if a > b { - return a - } - return b -} diff --git a/tests/table_test.go b/tests/table_test.go index ca0f9d7..c3703a1 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -19,7 +19,6 @@ func TestGetTableLength(t *testing.T) { t.Fatalf("Failed to create test table: %v", err) } - // Get the table state.GetGlobal("t") length := state.GetTableLength(-1) if length != 5 { @@ -32,7 +31,6 @@ func TestGetTableLength(t *testing.T) { t.Fatalf("Failed to create test table: %v", err) } - // Get the table state.GetGlobal("t2") length = state.GetTableLength(-1) if length != 0 { @@ -41,206 +39,234 @@ func TestGetTableLength(t *testing.T) { state.Pop(1) } -func TestPushTable(t *testing.T) { +func TestPushTypedArrays(t *testing.T) { state := luajit.New() if state == nil { t.Fatal("Failed to create Lua state") } defer state.Close() - // Create a test table - testTable := map[string]any{ - "int": 42, - "float": 3.14, - "string": "hello", - "boolean": true, - "nil": nil, + // Test []int + intArr := []int{1, 2, 3, 4, 5} + if err := state.PushValue(intArr); err != nil { + t.Fatalf("Failed to push int array: %v", err) } + state.SetGlobal("int_arr") - // Push the table onto the stack - if err := state.PushTable(testTable); err != nil { - t.Fatalf("Failed to push table: %v", err) + // Test []string + stringArr := []string{"hello", "world", "test"} + if err := state.PushValue(stringArr); err != nil { + t.Fatalf("Failed to push string array: %v", err) } + state.SetGlobal("string_arr") - // Execute Lua code to test the table contents + // Test []bool + boolArr := []bool{true, false, true} + if err := state.PushValue(boolArr); err != nil { + t.Fatalf("Failed to push bool array: %v", err) + } + state.SetGlobal("bool_arr") + + // Test []float64 + floatArr := []float64{1.1, 2.2, 3.3} + if err := state.PushValue(floatArr); err != nil { + t.Fatalf("Failed to push float array: %v", err) + } + state.SetGlobal("float_arr") + + // Verify arrays in Lua if err := state.DoString(` - function validate_table(t) - return t.int == 42 and - math.abs(t.float - 3.14) < 0.0001 and - t.string == "hello" and - t.boolean == true and - t["nil"] == nil - end + assert(int_arr[1] == 1 and int_arr[5] == 5) + assert(string_arr[1] == "hello" and string_arr[3] == "test") + assert(bool_arr[1] == true and bool_arr[2] == false) + assert(math.abs(float_arr[1] - 1.1) < 0.0001) `); err != nil { - t.Fatalf("Failed to create validation function: %v", err) + t.Fatalf("Array verification failed: %v", err) } - - // Call the validation function - state.GetGlobal("validate_table") - state.PushCopy(-2) // Copy the table to the top - if err := state.Call(1, 1); err != nil { - t.Fatalf("Failed to call validation function: %v", err) - } - - if !state.ToBoolean(-1) { - t.Fatalf("Table validation failed") - } - state.Pop(2) // Pop the result and the table } -func TestToTable(t *testing.T) { +func TestPushTypedMaps(t *testing.T) { state := luajit.New() if state == nil { t.Fatal("Failed to create Lua state") } defer state.Close() - // Test regular table conversion - if err := state.DoString(`t = {a=1, b=2.5, c="test", d=true, e=nil}`); err != nil { - t.Fatalf("Failed to create test table: %v", err) + // Test map[string]string + stringMap := map[string]string{"name": "John", "city": "NYC"} + if err := state.PushValue(stringMap); err != nil { + t.Fatalf("Failed to push string map: %v", err) } + state.SetGlobal("string_map") - state.GetGlobal("t") - table, err := state.ToTable(-1) - if err != nil { - t.Fatalf("Failed to convert table: %v", err) + // Test map[string]int + intMap := map[string]int{"age": 25, "score": 100} + if err := state.PushValue(intMap); err != nil { + t.Fatalf("Failed to push int map: %v", err) } - state.Pop(1) + state.SetGlobal("int_map") - expected := map[string]any{ - "a": float64(1), - "b": 2.5, - "c": "test", - "d": true, + // Test map[int]any + intKeyMap := map[int]any{1: "first", 2: 42, 3: true} + if err := state.PushValue(intKeyMap); err != nil { + t.Fatalf("Failed to push int key map: %v", err) } + state.SetGlobal("int_key_map") - for k, v := range expected { - if table[k] != v { - t.Fatalf("Expected table[%s] = %v, got %v", k, v, table[k]) - } - } - - // Test array-like table conversion - if err := state.DoString(`arr = {10, 20, 30, 40, 50}`); err != nil { - t.Fatalf("Failed to create test array: %v", err) - } - - state.GetGlobal("arr") - table, err = state.ToTable(-1) - if err != nil { - t.Fatalf("Failed to convert array table: %v", err) - } - state.Pop(1) - - // For array tables, we should get a special format with an empty key - // and the array as the value - expectedArray := []float64{10, 20, 30, 40, 50} - if arr, ok := table[""].([]float64); !ok { - t.Fatalf("Expected array table to be converted with empty key, got: %v", table) - } else if !reflect.DeepEqual(arr, expectedArray) { - t.Fatalf("Expected %v, got %v", expectedArray, arr) - } - - // Test invalid table index - _, err = state.ToTable(100) - if err == nil { - t.Fatalf("Expected error for invalid table index, got nil") - } - - // Test non-table value - state.PushNumber(123) - _, err = state.ToTable(-1) - if err == nil { - t.Fatalf("Expected error for non-table value, got nil") - } - state.Pop(1) - - // Test mixed array with non-numeric values - if err := state.DoString(`mixed = {10, 20, key="value", 30}`); err != nil { - t.Fatalf("Failed to create mixed table: %v", err) - } - - state.GetGlobal("mixed") - table, err = state.ToTable(-1) - if err != nil { - t.Fatalf("Failed to convert mixed table: %v", err) - } - - // Let's print the table for debugging - t.Logf("Table contents: %v", table) - - state.Pop(1) - - // Check if the array part is detected and stored with empty key - if arr, ok := table[""]; !ok { - t.Fatalf("Expected array-like part to be detected, got: %v", table) - } else { - // Verify the array contains the expected values - expectedArr := []float64{10, 20, 30} - actualArr := arr.([]float64) - if len(actualArr) != len(expectedArr) { - t.Fatalf("Expected array length %d, got %d", len(expectedArr), len(actualArr)) - } - - for i, v := range expectedArr { - if actualArr[i] != v { - t.Fatalf("Expected array[%d] = %v, got %v", i, v, actualArr[i]) - } - } - } - - // Based on the implementation, we need to create a separate test for string keys - if err := state.DoString(`dict = {foo="bar", baz="qux"}`); err != nil { - t.Fatalf("Failed to create dict table: %v", err) - } - - state.GetGlobal("dict") - dictTable, err := state.ToTable(-1) - if err != nil { - t.Fatalf("Failed to convert dict table: %v", err) - } - state.Pop(1) - - // Check the string keys - if val, ok := dictTable["foo"]; !ok || val != "bar" { - t.Fatalf("Expected dictTable[\"foo\"] = \"bar\", got: %v", val) - } - if val, ok := dictTable["baz"]; !ok || val != "qux" { - t.Fatalf("Expected dictTable[\"baz\"] = \"qux\", got: %v", val) + // Verify maps in Lua + if err := state.DoString(` + assert(string_map.name == "John" and string_map.city == "NYC") + assert(int_map.age == 25 and int_map.score == 100) + assert(int_key_map[1] == "first" and int_key_map[2] == 42 and int_key_map[3] == true) + `); err != nil { + t.Fatalf("Map verification failed: %v", err) } } -func TestTablePooling(t *testing.T) { +func TestToTableTypedArrays(t *testing.T) { state := luajit.New() if state == nil { t.Fatal("Failed to create Lua state") } defer state.Close() - // Create a Lua table and push it onto the stack - if err := state.DoString(`t = {a=1, b=2}`); err != nil { - t.Fatalf("Failed to create test table: %v", err) + // Test integer array detection + if err := state.DoString("int_arr = {10, 20, 30}"); err != nil { + t.Fatalf("Failed to create int array: %v", err) } - - state.GetGlobal("t") - - // First conversion - should get a table from the pool - table1, err := state.ToTable(-1) + state.GetGlobal("int_arr") + result, err := state.ToValue(-1) if err != nil { - t.Fatalf("Failed to convert table (1): %v", err) + t.Fatalf("Failed to convert int array: %v", err) } + intArr, ok := result.([]int) + if !ok { + t.Fatalf("Expected []int, got %T", result) + } + expected := []int{10, 20, 30} + if !reflect.DeepEqual(intArr, expected) { + t.Fatalf("Expected %v, got %v", expected, intArr) + } + state.Pop(1) - // Second conversion - should get another table from the pool - table2, err := state.ToTable(-1) + // Test float array detection + if err := state.DoString("float_arr = {1.5, 2.7, 3.9}"); err != nil { + t.Fatalf("Failed to create float array: %v", err) + } + state.GetGlobal("float_arr") + result, err = state.ToValue(-1) if err != nil { - t.Fatalf("Failed to convert table (2): %v", err) + t.Fatalf("Failed to convert float array: %v", err) } - - // Both tables should have the same content - if !reflect.DeepEqual(table1, table2) { - t.Fatalf("Tables should have the same content: %v vs %v", table1, table2) + floatArr, ok := result.([]float64) + if !ok { + t.Fatalf("Expected []float64, got %T", result) } + expectedFloat := []float64{1.5, 2.7, 3.9} + if !reflect.DeepEqual(floatArr, expectedFloat) { + t.Fatalf("Expected %v, got %v", expectedFloat, floatArr) + } + state.Pop(1) - // Clean up + // Test string array detection + if err := state.DoString(`string_arr = {"hello", "world"}`); err != nil { + t.Fatalf("Failed to create string array: %v", err) + } + state.GetGlobal("string_arr") + result, err = state.ToValue(-1) + if err != nil { + t.Fatalf("Failed to convert string array: %v", err) + } + stringArr, ok := result.([]string) + if !ok { + t.Fatalf("Expected []string, got %T", result) + } + expectedString := []string{"hello", "world"} + if !reflect.DeepEqual(stringArr, expectedString) { + t.Fatalf("Expected %v, got %v", expectedString, stringArr) + } + state.Pop(1) + + // Test bool array detection + if err := state.DoString("bool_arr = {true, false, true}"); err != nil { + t.Fatalf("Failed to create bool array: %v", err) + } + state.GetGlobal("bool_arr") + result, err = state.ToValue(-1) + if err != nil { + t.Fatalf("Failed to convert bool array: %v", err) + } + boolArr, ok := result.([]bool) + if !ok { + t.Fatalf("Expected []bool, got %T", result) + } + expectedBool := []bool{true, false, true} + if !reflect.DeepEqual(boolArr, expectedBool) { + t.Fatalf("Expected %v, got %v", expectedBool, boolArr) + } + state.Pop(1) +} + +func TestToTableTypedMaps(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Test string map detection + if err := state.DoString(`string_map = {name="John", city="NYC"}`); err != nil { + t.Fatalf("Failed to create string map: %v", err) + } + state.GetGlobal("string_map") + result, err := state.ToValue(-1) + if err != nil { + t.Fatalf("Failed to convert string map: %v", err) + } + stringMap, ok := result.(map[string]string) + if !ok { + t.Fatalf("Expected map[string]string, got %T", result) + } + expectedStringMap := map[string]string{"name": "John", "city": "NYC"} + if !reflect.DeepEqual(stringMap, expectedStringMap) { + t.Fatalf("Expected %v, got %v", expectedStringMap, stringMap) + } + state.Pop(1) + + // Test int map detection + if err := state.DoString("int_map = {age=25, score=100}"); err != nil { + t.Fatalf("Failed to create int map: %v", err) + } + state.GetGlobal("int_map") + result, err = state.ToValue(-1) + if err != nil { + t.Fatalf("Failed to convert int map: %v", err) + } + intMap, ok := result.(map[string]int) + if !ok { + t.Fatalf("Expected map[string]int, got %T", result) + } + expectedIntMap := map[string]int{"age": 25, "score": 100} + if !reflect.DeepEqual(intMap, expectedIntMap) { + t.Fatalf("Expected %v, got %v", expectedIntMap, intMap) + } + state.Pop(1) + + // Test mixed map (should fallback to map[string]any) + if err := state.DoString(`mixed_map = {name="John", age=25, active=true}`); err != nil { + t.Fatalf("Failed to create mixed map: %v", err) + } + state.GetGlobal("mixed_map") + result, err = state.ToValue(-1) + if err != nil { + t.Fatalf("Failed to convert mixed map: %v", err) + } + mixedMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map[string]any, got %T", result) + } + if mixedMap["name"] != "John" || mixedMap["age"] != 25 || mixedMap["active"] != true { + t.Fatalf("Mixed map conversion failed: %v", mixedMap) + } state.Pop(1) } diff --git a/tests/wrapper_test.go b/tests/wrapper_test.go index 8f5fc6e..0430698 100644 --- a/tests/wrapper_test.go +++ b/tests/wrapper_test.go @@ -9,70 +9,50 @@ import ( ) func TestStateLifecycle(t *testing.T) { - // Test creation state := luajit.New() if state == nil { t.Fatal("Failed to create Lua state") } - - // Test close - state.Close() - - // Test close is idempotent (doesn't crash) state.Close() + state.Close() // Test idempotent close } -func TestStackManipulation(t *testing.T) { +func TestStackOperations(t *testing.T) { state := luajit.New() if state == nil { t.Fatal("Failed to create Lua state") } defer state.Close() - // Test initial stack size + // Test stack manipulation if state.GetTop() != 0 { - t.Fatalf("Expected empty stack, got %d elements", state.GetTop()) + t.Fatalf("Expected empty stack, got %d", state.GetTop()) } - // Push values state.PushNil() state.PushBoolean(true) state.PushNumber(42) state.PushString("hello") - // Check stack size if state.GetTop() != 4 { t.Fatalf("Expected 4 elements, got %d", state.GetTop()) } - // Test SetTop state.SetTop(2) if state.GetTop() != 2 { t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop()) } - // Test PushCopy - state.PushCopy(2) // Copy the boolean + state.PushCopy(2) if !state.IsBoolean(-1) { - t.Fatalf("Expected boolean at top of stack") + t.Fatal("Expected boolean at top") } - // Test Pop state.Pop(1) - if state.GetTop() != 2 { - t.Fatalf("Expected 2 elements after Pop, got %d", state.GetTop()) - } - - // Test Remove state.PushNumber(99) - state.Remove(1) // Remove the first element (nil) - if state.GetTop() != 2 { - t.Fatalf("Expected 2 elements after Remove, got %d", state.GetTop()) - } - - // Verify first element is now boolean + state.Remove(1) if !state.IsBoolean(1) { - t.Fatalf("Expected boolean at index 1 after Remove") + t.Fatal("Expected boolean at index 1 after Remove") } } @@ -83,52 +63,33 @@ func TestTypeChecking(t *testing.T) { } defer state.Close() - // Push values of different types - state.PushNil() - state.PushBoolean(true) - state.PushNumber(42) - state.PushString("hello") - state.NewTable() - - // Check types with GetType - if state.GetType(1) != luajit.TypeNil { - t.Fatalf("Expected nil type at index 1, got %s", state.GetType(1)) - } - if state.GetType(2) != luajit.TypeBoolean { - t.Fatalf("Expected boolean type at index 2, got %s", state.GetType(2)) - } - if state.GetType(3) != luajit.TypeNumber { - t.Fatalf("Expected number type at index 3, got %s", state.GetType(3)) - } - if state.GetType(4) != luajit.TypeString { - t.Fatalf("Expected string type at index 4, got %s", state.GetType(4)) - } - if state.GetType(5) != luajit.TypeTable { - t.Fatalf("Expected table type at index 5, got %s", state.GetType(5)) + values := []struct { + push func() + luaType luajit.LuaType + checkFn func(int) bool + }{ + {state.PushNil, luajit.TypeNil, state.IsNil}, + {func() { state.PushBoolean(true) }, luajit.TypeBoolean, state.IsBoolean}, + {func() { state.PushNumber(42) }, luajit.TypeNumber, state.IsNumber}, + {func() { state.PushString("test") }, luajit.TypeString, state.IsString}, + {state.NewTable, luajit.TypeTable, state.IsTable}, } - // Test individual type checking functions - if !state.IsNil(1) { - t.Fatalf("IsNil failed for nil value") - } - if !state.IsBoolean(2) { - t.Fatalf("IsBoolean failed for boolean value") - } - if !state.IsNumber(3) { - t.Fatalf("IsNumber failed for number value") - } - if !state.IsString(4) { - t.Fatalf("IsString failed for string value") - } - if !state.IsTable(5) { - t.Fatalf("IsTable failed for table value") + for i, v := range values { + v.push() + idx := i + 1 + if state.GetType(idx) != v.luaType { + t.Fatalf("Type mismatch at %d: expected %s, got %s", idx, v.luaType, state.GetType(idx)) + } + if !v.checkFn(idx) { + t.Fatalf("Type check failed at %d", idx) + } } - // Function test state.DoString("function test() return true end") state.GetGlobal("test") if !state.IsFunction(-1) { - t.Fatalf("IsFunction failed for function value") + t.Fatal("IsFunction failed") } } @@ -139,20 +100,18 @@ func TestValueConversion(t *testing.T) { } defer state.Close() - // Push values state.PushBoolean(true) state.PushNumber(42.5) state.PushString("hello") - // Test conversion if !state.ToBoolean(1) { - t.Fatalf("ToBoolean failed") + t.Fatal("ToBoolean failed") } if state.ToNumber(2) != 42.5 { - t.Fatalf("ToNumber failed, expected 42.5, got %f", state.ToNumber(2)) + t.Fatalf("ToNumber failed: expected 42.5, got %f", state.ToNumber(2)) } if state.ToString(3) != "hello" { - t.Fatalf("ToString failed, expected 'hello', got '%s'", state.ToString(3)) + t.Fatalf("ToString failed: expected 'hello', got '%s'", state.ToString(3)) } } @@ -163,46 +122,34 @@ func TestTableOperations(t *testing.T) { } defer state.Close() - // Test CreateTable state.CreateTable(0, 3) - // Add fields using SetField + // Set fields state.PushNumber(42) state.SetField(-2, "answer") - state.PushString("hello") state.SetField(-2, "greeting") - state.PushBoolean(true) state.SetField(-2, "flag") - // Test GetField + // Get fields state.GetField(-1, "answer") if state.ToNumber(-1) != 42 { - t.Fatalf("GetField for 'answer' failed") + t.Fatal("GetField failed for 'answer'") } state.Pop(1) - state.GetField(-1, "greeting") - if state.ToString(-1) != "hello" { - t.Fatalf("GetField for 'greeting' failed") - } - state.Pop(1) - - // Test Next for iteration - state.PushNil() // Start iteration + // Test iteration + state.PushNil() count := 0 for state.Next(-2) { count++ - state.Pop(1) // Pop value, leave key for next iteration + state.Pop(1) } - if count != 3 { - t.Fatalf("Expected 3 table entries, found %d", count) + t.Fatalf("Expected 3 entries, found %d", count) } - - // Clean up - state.Pop(1) // Pop the table + state.Pop(1) } func TestGlobalOperations(t *testing.T) { @@ -212,21 +159,18 @@ func TestGlobalOperations(t *testing.T) { } defer state.Close() - // Set a global value state.PushNumber(42) state.SetGlobal("answer") - // Get the global value state.GetGlobal("answer") if state.ToNumber(-1) != 42 { - t.Fatalf("GetGlobal failed, expected 42, got %f", state.ToNumber(-1)) + t.Fatalf("GetGlobal failed: expected 42, got %f", state.ToNumber(-1)) } state.Pop(1) - // Test non-existent global (should be nil) state.GetGlobal("nonexistent") if !state.IsNil(-1) { - t.Fatalf("Expected nil for non-existent global") + t.Fatal("Expected nil for non-existent global") } state.Pop(1) } @@ -238,18 +182,15 @@ func TestCodeExecution(t *testing.T) { } defer state.Close() - // Test LoadString + // Test LoadString and Call if err := state.LoadString("return 42"); err != nil { t.Fatalf("LoadString failed: %v", err) } - - // Test Call if err := state.Call(0, 1); err != nil { t.Fatalf("Call failed: %v", err) } - if state.ToNumber(-1) != 42 { - t.Fatalf("Call result incorrect, expected 42, got %f", state.ToNumber(-1)) + t.Fatalf("Call result incorrect: expected 42, got %f", state.ToNumber(-1)) } state.Pop(1) @@ -257,10 +198,9 @@ func TestCodeExecution(t *testing.T) { if err := state.DoString("answer = 42 + 1"); err != nil { t.Fatalf("DoString failed: %v", err) } - state.GetGlobal("answer") if state.ToNumber(-1) != 43 { - t.Fatalf("DoString execution incorrect, expected 43, got %f", state.ToNumber(-1)) + t.Fatalf("DoString result incorrect: expected 43, got %f", state.ToNumber(-1)) } state.Pop(1) @@ -269,13 +209,11 @@ func TestCodeExecution(t *testing.T) { if err != nil { t.Fatalf("Execute failed: %v", err) } - if nresults != 3 { t.Fatalf("Execute returned %d results, expected 3", nresults) } - if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 { - t.Fatalf("Execute results incorrect") + t.Fatal("Execute results incorrect") } state.Pop(3) @@ -284,26 +222,24 @@ func TestCodeExecution(t *testing.T) { if err != nil { t.Fatalf("ExecuteWithResult failed: %v", err) } - if result != "hello" { t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result) } // Test error handling - err = state.DoString("this is not valid lua code") - if err == nil { - t.Fatalf("Expected error for invalid code, got nil") + if err := state.DoString("invalid lua code"); err == nil { + t.Fatal("Expected error for invalid code") } } -func TestDoFile(t *testing.T) { +func TestFileOperations(t *testing.T) { state := luajit.New() if state == nil { t.Fatal("Failed to create Lua state") } defer state.Close() - // Create a temporary Lua file + // Create temp file content := []byte("answer = 42") tmpfile, err := os.CreateTemp("", "test-*.lua") if err != nil { @@ -312,40 +248,17 @@ func TestDoFile(t *testing.T) { defer os.Remove(tmpfile.Name()) if _, err := tmpfile.Write(content); err != nil { - t.Fatalf("Failed to write to temp file: %v", err) - } - if err := tmpfile.Close(); err != nil { - t.Fatalf("Failed to close temp file: %v", err) - } - - // Test LoadFile and DoFile - if err := state.LoadFile(tmpfile.Name()); err != nil { - t.Fatalf("LoadFile failed: %v", err) - } - - if err := state.Call(0, 0); err != nil { - t.Fatalf("Call failed after LoadFile: %v", err) - } - - state.GetGlobal("answer") - if state.ToNumber(-1) != 42 { - t.Fatalf("Incorrect result after LoadFile, expected 42, got %f", state.ToNumber(-1)) - } - state.Pop(1) - - // Reset global - if err := state.DoString("answer = nil"); err != nil { - t.Fatalf("Failed to reset answer: %v", err) + t.Fatalf("Failed to write temp file: %v", err) } + tmpfile.Close() // Test DoFile if err := state.DoFile(tmpfile.Name()); err != nil { t.Fatalf("DoFile failed: %v", err) } - state.GetGlobal("answer") if state.ToNumber(-1) != 42 { - t.Fatalf("Incorrect result after DoFile, expected 42, got %f", state.ToNumber(-1)) + t.Fatalf("DoFile result incorrect: expected 42, got %f", state.ToNumber(-1)) } state.Pop(1) } @@ -357,7 +270,6 @@ func TestPackagePath(t *testing.T) { } defer state.Close() - // Test SetPackagePath testPath := "/test/path/?.lua" if err := state.SetPackagePath(testPath); err != nil { t.Fatalf("SetPackagePath failed: %v", err) @@ -367,12 +279,10 @@ func TestPackagePath(t *testing.T) { if err != nil { t.Fatalf("Failed to get package.path: %v", err) } - if result != testPath { - t.Fatalf("Expected package.path to be '%s', got '%s'", testPath, result) + t.Fatalf("SetPackagePath failed: expected '%s', got '%s'", testPath, result) } - // Test AddPackagePath addPath := "/another/path/?.lua" if err := state.AddPackagePath(addPath); err != nil { t.Fatalf("AddPackagePath failed: %v", err) @@ -382,92 +292,134 @@ func TestPackagePath(t *testing.T) { if err != nil { t.Fatalf("Failed to get package.path: %v", err) } - expected := testPath + ";" + addPath if result != expected { - t.Fatalf("Expected package.path to be '%s', got '%s'", expected, result) + t.Fatalf("AddPackagePath failed: expected '%s', got '%s'", expected, result) } } -func TestPushValueAndToValue(t *testing.T) { +func TestEnhancedTypes(t *testing.T) { state := luajit.New() if state == nil { t.Fatal("Failed to create Lua state") } defer state.Close() + // Test typed arrays testCases := []struct { - value any + input any + expected any }{ - {nil}, - {true}, - {false}, - {42}, - {42.5}, - {"hello"}, - {[]float64{1, 2, 3, 4, 5}}, - {[]any{1, "test", true}}, - {map[string]any{"a": 1, "b": "test", "c": true}}, + // Primitive types + {nil, nil}, + {true, true}, + {42, 42}, // Should preserve as int + {42.5, 42.5}, // Should be float64 + {"hello", "hello"}, + + // Typed arrays + {[]int{1, 2, 3}, []int{1, 2, 3}}, + {[]string{"a", "b"}, []string{"a", "b"}}, + {[]bool{true, false}, []bool{true, false}}, + {[]float64{1.1, 2.2}, []float64{1.1, 2.2}}, + + // Typed maps + {map[string]string{"name": "John"}, map[string]string{"name": "John"}}, + {map[string]int{"age": 25}, map[string]int{"age": 25}}, + {map[int]any{10: "first", 20: 42}, map[string]any{"10": "first", "20": 42}}, } for i, tc := range testCases { - // Push value - err := state.PushValue(tc.value) - if err != nil { - t.Fatalf("PushValue failed for testCase %d: %v", i, err) + // Push and retrieve value + if err := state.PushValue(tc.input); err != nil { + t.Fatalf("Case %d: PushValue failed: %v", i, err) } - // Check stack - if state.GetTop() != i+1 { - t.Fatalf("Stack size incorrect after push, expected %d, got %d", i+1, state.GetTop()) + result, err := state.ToValue(-1) + if err != nil { + t.Fatalf("Case %d: ToValue failed: %v", i, err) } + + if !reflect.DeepEqual(result, tc.expected) { + t.Fatalf("Case %d: expected %v (%T), got %v (%T)", + i, tc.expected, tc.expected, result, result) + } + state.Pop(1) } - // Test conversion back to Go - for i := range testCases { - index := len(testCases) - i - value, err := state.ToValue(index) - if err != nil { - t.Fatalf("ToValue failed for index %d: %v", index, err) - } - - // For tables, we need special handling due to how Go types are stored - switch expected := testCases[index-1].value.(type) { - case []float64: - // Arrays come back as map[string]any with empty key - if m, ok := value.(map[string]any); ok { - if arr, ok := m[""].([]float64); ok { - if !reflect.DeepEqual(arr, expected) { - t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", index-1, expected, arr) - } - } else { - t.Fatalf("Invalid array conversion for testCase %d", index-1) - } - } else { - t.Fatalf("Expected map for array value in testCase %d, got %T", index-1, value) - } - case int: - if num, ok := value.(float64); ok { - if float64(expected) == num { - continue // Values match after type conversion - } - } - case []any: - // Skip detailed comparison for mixed arrays - case map[string]any: - // Skip detailed comparison for maps - default: - if !reflect.DeepEqual(value, testCases[index-1].value) { - t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", - index-1, testCases[index-1].value, value) - } - } + // Test mixed array (should become []any) + state.DoString("mixed = {1, 'hello', true}") + state.GetGlobal("mixed") + result, err := state.ToValue(-1) + if err != nil { + t.Fatalf("Mixed array conversion failed: %v", err) } + if _, ok := result.([]any); !ok { + t.Fatalf("Expected []any for mixed array, got %T", result) + } + state.Pop(1) + + // Test mixed map (should become map[string]any) + state.DoString("mixedMap = {name='John', age=25, active=true}") + state.GetGlobal("mixedMap") + result, err = state.ToValue(-1) + if err != nil { + t.Fatalf("Mixed map conversion failed: %v", err) + } + if _, ok := result.(map[string]any); !ok { + t.Fatalf("Expected map[string]any for mixed map, got %T", result) + } + state.Pop(1) +} + +func TestIntegerPreservation(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + // Test that integers are preserved + state.DoString("num = 42") + state.GetGlobal("num") + result, err := state.ToValue(-1) + if err != nil { + t.Fatalf("Integer conversion failed: %v", err) + } + if val, ok := result.(int); !ok || val != 42 { + t.Fatalf("Expected int 42, got %T %v", result, result) + } + state.Pop(1) + + // Test that floats remain floats + state.DoString("fnum = 42.5") + state.GetGlobal("fnum") + result, err = state.ToValue(-1) + if err != nil { + t.Fatalf("Float conversion failed: %v", err) + } + if val, ok := result.(float64); !ok || val != 42.5 { + t.Fatalf("Expected float64 42.5, got %T %v", result, result) + } + state.Pop(1) +} + +func TestErrorHandling(t *testing.T) { + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() // Test unsupported type - complex := complex(1, 2) - err := state.PushValue(complex) + type customStruct struct{ Field int } + if err := state.PushValue(customStruct{Field: 42}); err == nil { + t.Fatal("Expected error for unsupported type") + } + + // Test invalid stack index + _, err := state.ToValue(100) if err == nil { - t.Fatalf("Expected error for unsupported type") + t.Fatal("Expected error for invalid index") } } diff --git a/types.go b/types.go index 1498059..78ea256 100644 --- a/types.go +++ b/types.go @@ -13,7 +13,6 @@ import ( type LuaType int const ( - // These constants match lua.h's LUA_T* values TypeNone LuaType = -1 TypeNil LuaType = 0 TypeBoolean LuaType = 1 @@ -26,7 +25,6 @@ const ( TypeThread LuaType = 8 ) -// String returns the string representation of the Lua type func (t LuaType) String() string { switch t { case TypeNone: @@ -54,92 +52,309 @@ func (t LuaType) String() string { } } -// ConvertValue converts a value to the requested type with proper type conversion +// ConvertValue converts a value to the requested type with comprehensive type conversion func ConvertValue[T any](value any) (T, bool) { var zero T - // Handle nil case if value == nil { return zero, false } - // Try direct type assertion first if result, ok := value.(T); ok { return result, true } - // Type-specific conversions switch any(zero).(type) { case string: - switch v := value.(type) { - case float64: - return any(fmt.Sprintf("%g", v)).(T), true - case int: - return any(strconv.Itoa(v)).(T), true - case bool: - if v { - return any("true").(T), true - } - return any("false").(T), true - } + return convertToString[T](value) case int: - switch v := value.(type) { - case float64: - return any(int(v)).(T), true - case string: - if i, err := strconv.Atoi(v); err == nil { - return any(i).(T), true - } - case bool: - if v { - return any(1).(T), true - } - return any(0).(T), true - } + return convertToInt[T](value) case float64: - switch v := value.(type) { - case int: - return any(float64(v)).(T), true - case string: - if f, err := strconv.ParseFloat(v, 64); err == nil { - return any(f).(T), true - } - case bool: - if v { - return any(1.0).(T), true - } - return any(0.0).(T), true - } + return convertToFloat[T](value) case bool: - switch v := value.(type) { - case string: - switch v { - case "true", "yes", "1": - return any(true).(T), true - case "false", "no", "0": - return any(false).(T), true - } - case int: - return any(v != 0).(T), true - case float64: - return any(v != 0).(T), true - } + return convertToBool[T](value) + case []int: + return convertToIntSlice[T](value) + case []string: + return convertToStringSlice[T](value) + case []bool: + return convertToBoolSlice[T](value) + case []float64: + return convertToFloatSlice[T](value) + case []any: + return convertToAnySlice[T](value) + case map[string]string: + return convertToStringMap[T](value) + case map[string]int: + return convertToIntMap[T](value) + case map[int]any: + return convertToIntKeyMap[T](value) + case map[string]any: + return convertToAnyMap[T](value) } return zero, false } +func convertToString[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case float64: + if v == float64(int(v)) { + return any(strconv.Itoa(int(v))).(T), true + } + return any(fmt.Sprintf("%g", v)).(T), true + case int: + return any(strconv.Itoa(v)).(T), true + case bool: + return any(strconv.FormatBool(v)).(T), true + } + return zero, false +} + +func convertToInt[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case float64: + return any(int(v)).(T), true + case string: + if i, err := strconv.Atoi(v); err == nil { + return any(i).(T), true + } + case bool: + if v { + return any(1).(T), true + } + return any(0).(T), true + } + return zero, false +} + +func convertToFloat[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case int: + return any(float64(v)).(T), true + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return any(f).(T), true + } + case bool: + if v { + return any(1.0).(T), true + } + return any(0.0).(T), true + } + return zero, false +} + +func convertToBool[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case string: + switch v { + case "true", "yes", "1": + return any(true).(T), true + case "false", "no", "0": + return any(false).(T), true + } + case int: + return any(v != 0).(T), true + case float64: + return any(v != 0).(T), true + } + return zero, false +} + +func convertToIntSlice[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case []float64: + result := make([]int, len(v)) + for i, f := range v { + result[i] = int(f) + } + return any(result).(T), true + case []any: + result := make([]int, 0, len(v)) + for _, item := range v { + if i, ok := ConvertValue[int](item); ok { + result = append(result, i) + } else { + return zero, false + } + } + return any(result).(T), true + } + return zero, false +} + +func convertToStringSlice[T any](value any) (T, bool) { + var zero T + if v, ok := value.([]any); ok { + result := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := ConvertValue[string](item); ok { + result = append(result, s) + } else { + return zero, false + } + } + return any(result).(T), true + } + return zero, false +} + +func convertToBoolSlice[T any](value any) (T, bool) { + var zero T + if v, ok := value.([]any); ok { + result := make([]bool, 0, len(v)) + for _, item := range v { + if b, ok := ConvertValue[bool](item); ok { + result = append(result, b) + } else { + return zero, false + } + } + return any(result).(T), true + } + return zero, false +} + +func convertToFloatSlice[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case []int: + result := make([]float64, len(v)) + for i, n := range v { + result[i] = float64(n) + } + return any(result).(T), true + case []any: + result := make([]float64, 0, len(v)) + for _, item := range v { + if f, ok := ConvertValue[float64](item); ok { + result = append(result, f) + } else { + return zero, false + } + } + return any(result).(T), true + } + return zero, false +} + +func convertToAnySlice[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case []int: + result := make([]any, len(v)) + for i, n := range v { + result[i] = n + } + return any(result).(T), true + case []string: + result := make([]any, len(v)) + for i, s := range v { + result[i] = s + } + return any(result).(T), true + case []bool: + result := make([]any, len(v)) + for i, b := range v { + result[i] = b + } + return any(result).(T), true + case []float64: + result := make([]any, len(v)) + for i, f := range v { + result[i] = f + } + return any(result).(T), true + } + return zero, false +} + +func convertToStringMap[T any](value any) (T, bool) { + var zero T + if v, ok := value.(map[string]any); ok { + result := make(map[string]string, len(v)) + for k, val := range v { + if s, ok := ConvertValue[string](val); ok { + result[k] = s + } else { + return zero, false + } + } + return any(result).(T), true + } + return zero, false +} + +func convertToIntMap[T any](value any) (T, bool) { + var zero T + if v, ok := value.(map[string]any); ok { + result := make(map[string]int, len(v)) + for k, val := range v { + if i, ok := ConvertValue[int](val); ok { + result[k] = i + } else { + return zero, false + } + } + return any(result).(T), true + } + return zero, false +} + +func convertToIntKeyMap[T any](value any) (T, bool) { + var zero T + if v, ok := value.(map[string]any); ok { + result := make(map[int]any, len(v)) + for k, val := range v { + if i, err := strconv.Atoi(k); err == nil { + result[i] = val + } else { + return zero, false + } + } + return any(result).(T), true + } + return zero, false +} + +func convertToAnyMap[T any](value any) (T, bool) { + var zero T + switch v := value.(type) { + case map[string]string: + result := make(map[string]any, len(v)) + for k, s := range v { + result[k] = s + } + return any(result).(T), true + case map[string]int: + result := make(map[string]any, len(v)) + for k, i := range v { + result[k] = i + } + return any(result).(T), true + case map[int]any: + result := make(map[string]any, len(v)) + for k, val := range v { + result[strconv.Itoa(k)] = val + } + return any(result).(T), true + } + return zero, false +} + // GetTypedValue gets a value from the state with type conversion func GetTypedValue[T any](s *State, index int) (T, bool) { - var zero T - - // Get the value as any type value, err := s.ToValue(index) if err != nil { + var zero T return zero, false } - - // Convert it to the requested type return ConvertValue[T](value) } @@ -147,6 +362,5 @@ func GetTypedValue[T any](s *State, index int) (T, bool) { func GetGlobalTyped[T any](s *State, name string) (T, bool) { s.GetGlobal(name) defer s.Pop(1) - return GetTypedValue[T](s, -1) } diff --git a/validation.go b/validation.go new file mode 100644 index 0000000..901c7c4 --- /dev/null +++ b/validation.go @@ -0,0 +1,59 @@ +package luajit + +import "fmt" + +// ArgSpec defines an argument specification for validation +type ArgSpec struct { + Name string + Type string + Required bool + Check func(*State, int) bool +} + +// Common argument checkers +var ( + CheckString = func(s *State, i int) bool { return s.IsString(i) } + CheckNumber = func(s *State, i int) bool { return s.IsNumber(i) } + CheckBool = func(s *State, i int) bool { return s.IsBoolean(i) } + CheckTable = func(s *State, i int) bool { return s.IsTable(i) } + CheckFunc = func(s *State, i int) bool { return s.IsFunction(i) } + CheckAny = func(s *State, i int) bool { return true } +) + +// CheckArgs validates function arguments against specifications +func (s *State) CheckArgs(specs ...ArgSpec) error { + for i, spec := range specs { + argIdx := i + 1 + if argIdx > s.GetTop() { + if spec.Required { + return fmt.Errorf("missing argument %d: %s", argIdx, spec.Name) + } + break + } + + if s.IsNil(argIdx) && !spec.Required { + continue + } + + if !spec.Check(s, argIdx) { + return fmt.Errorf("argument %d (%s) must be %s", argIdx, spec.Name, spec.Type) + } + } + return nil +} + +// CheckMinArgs checks for minimum number of arguments +func (s *State) CheckMinArgs(min int) error { + if s.GetTop() < min { + return fmt.Errorf("expected at least %d arguments, got %d", min, s.GetTop()) + } + return nil +} + +// CheckExactArgs checks for exact number of arguments +func (s *State) CheckExactArgs(count int) error { + if s.GetTop() != count { + return fmt.Errorf("expected exactly %d arguments, got %d", count, s.GetTop()) + } + return nil +} diff --git a/wrapper.go b/wrapper.go index 8aa2d90..a5afdf6 100644 --- a/wrapper.go +++ b/wrapper.go @@ -11,7 +11,6 @@ package luajit #include #include -// Direct execution helpers to minimize CGO transitions static int do_string(lua_State *L, const char *s) { int status = luaL_loadstring(L, s); if (status == 0) { @@ -34,31 +33,98 @@ static int execute_with_results(lua_State *L, const char *code, int store_result return lua_pcall(L, 0, store_results ? LUA_MULTRET : 0, 0); } -static int has_metatable(lua_State *L, int index) { - return lua_getmetatable(L, index); +static size_t get_table_length(lua_State *L, int index) { + return lua_objlen(L, index); +} + +static int is_integer(lua_State *L, int index) { + if (!lua_isnumber(L, index)) return 0; + lua_Number n = lua_tonumber(L, index); + return n == (lua_Number)(lua_Integer)n; +} + +static int sample_array_type(lua_State *L, int index, int count) { + int all_numbers = 1; + int all_integers = 1; + int all_strings = 1; + int all_bools = 1; + + for (int i = 1; i <= count && i <= 5; i++) { + lua_pushnumber(L, i); + lua_gettable(L, index); + + int type = lua_type(L, -1); + if (type != LUA_TNUMBER) all_numbers = all_integers = 0; + if (type != LUA_TSTRING) all_strings = 0; + if (type != LUA_TBOOLEAN) all_bools = 0; + + if (all_numbers && !is_integer(L, -1)) all_integers = 0; + + lua_pop(L, 1); + + if (!all_numbers && !all_strings && !all_bools) break; + } + + if (all_integers) return 1; + if (all_numbers) return 2; + if (all_strings) return 3; + if (all_bools) return 4; + return 0; +} + +static int sample_map_type(lua_State *L, int index) { + int all_string_vals = 1; + int all_int_vals = 1; + int all_int_keys = 1; + int count = 0; + + lua_pushnil(L); + while (lua_next(L, index) && count < 5) { + if (lua_type(L, -2) != LUA_TSTRING) { + all_int_keys = 0; + } else { + const char *key = lua_tostring(L, -2); + char *endptr; + strtol(key, &endptr, 10); + if (*endptr != '\0') all_int_keys = 0; + } + + int val_type = lua_type(L, -1); + if (val_type != LUA_TSTRING) all_string_vals = 0; + if (val_type != LUA_TNUMBER || !is_integer(L, -1)) all_int_vals = 0; + + lua_pop(L, 1); + count++; + + if (!all_string_vals && !all_int_vals && !all_int_keys) break; + } + + if (all_int_keys) return 4; + if (all_string_vals) return 1; + if (all_int_vals) return 2; + return 3; } */ import "C" import ( "fmt" + "strconv" "strings" - "sync" "unsafe" ) -// Type pool for common objects to reduce GC pressure -var stringBufferPool = sync.Pool{ - New: func() any { - return new(strings.Builder) - }, -} +// Stack management constants +const ( + LUA_MINSTACK = 20 + LUA_MAXSTACK = 1000000 + LUA_REGISTRYINDEX = -10000 + LUA_GLOBALSINDEX = -10002 +) -// State represents a Lua state type State struct { L *C.lua_State } -// New creates a new Lua state with optional standard libraries; true if not specified func New(openLibs ...bool) *State { L := C.luaL_newstate() if L == nil { @@ -72,7 +138,6 @@ func New(openLibs ...bool) *State { return &State{L: L} } -// Close closes the Lua state and frees resources func (s *State) Close() { if s.L != nil { C.lua_close(s.L) @@ -80,34 +145,13 @@ func (s *State) Close() { } } -// Stack manipulation methods +// Stack operations +func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) } +func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) } +func (s *State) PushCopy(index int) { C.lua_pushvalue(s.L, C.int(index)) } +func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) } +func (s *State) Remove(index int) { C.lua_remove(s.L, C.int(index)) } -// GetTop returns the index of the top element in the stack -func (s *State) GetTop() int { - return int(C.lua_gettop(s.L)) -} - -// SetTop sets the stack top to a specific index -func (s *State) SetTop(index int) { - C.lua_settop(s.L, C.int(index)) -} - -// PushCopy pushes a copy of the value at the given index onto the stack -func (s *State) PushCopy(index int) { - C.lua_pushvalue(s.L, C.int(index)) -} - -// Pop pops n elements from the stack -func (s *State) Pop(n int) { - C.lua_settop(s.L, C.int(-n-1)) -} - -// Remove removes the element at the given valid index -func (s *State) Remove(index int) { - C.lua_remove(s.L, C.int(index)) -} - -// absIndex converts a possibly negative index to its absolute position func (s *State) absIndex(index int) int { if index > 0 || index <= LUA_REGISTRYINDEX { return index @@ -115,56 +159,19 @@ func (s *State) absIndex(index int) int { return s.GetTop() + index + 1 } -// Type checking methods +// Type checking +func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) } +func (s *State) IsNil(index int) bool { return s.GetType(index) == TypeNil } +func (s *State) IsBoolean(index int) bool { return s.GetType(index) == TypeBoolean } +func (s *State) IsNumber(index int) bool { return C.lua_isnumber(s.L, C.int(index)) != 0 } +func (s *State) IsString(index int) bool { return C.lua_isstring(s.L, C.int(index)) != 0 } +func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable } +func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction } -// GetType returns the type of the value at the given index -func (s *State) GetType(index int) LuaType { - return LuaType(C.lua_type(s.L, C.int(index))) -} +// Value conversion +func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 } +func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) } -// IsNil checks if the value at the given index is nil -func (s *State) IsNil(index int) bool { - return s.GetType(index) == TypeNil -} - -// IsBoolean checks if the value at the given index is a boolean -func (s *State) IsBoolean(index int) bool { - return s.GetType(index) == TypeBoolean -} - -// IsNumber checks if the value at the given index is a number -func (s *State) IsNumber(index int) bool { - return C.lua_isnumber(s.L, C.int(index)) != 0 -} - -// IsString checks if the value at the given index is a string -func (s *State) IsString(index int) bool { - return C.lua_isstring(s.L, C.int(index)) != 0 -} - -// IsTable checks if the value at the given index is a table -func (s *State) IsTable(index int) bool { - return s.GetType(index) == TypeTable -} - -// IsFunction checks if the value at the given index is a function -func (s *State) IsFunction(index int) bool { - return s.GetType(index) == TypeFunction -} - -// Value conversion methods - -// ToBoolean returns the value at the given index as a boolean -func (s *State) ToBoolean(index int) bool { - return C.lua_toboolean(s.L, C.int(index)) != 0 -} - -// ToNumber returns the value at the given index as a number -func (s *State) ToNumber(index int) float64 { - return float64(C.lua_tonumber(s.L, C.int(index))) -} - -// ToString returns the value at the given index as a string func (s *State) ToString(index int) string { var length C.size_t cstr := C.lua_tolstring(s.L, C.int(index), &length) @@ -175,170 +182,384 @@ func (s *State) ToString(index int) string { } // Push methods +func (s *State) PushNil() { C.lua_pushnil(s.L) } +func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, boolToInt(b)) } +func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.lua_Number(n)) } -// PushNil pushes a nil value onto the stack -func (s *State) PushNil() { - C.lua_pushnil(s.L) -} - -// PushBoolean pushes a boolean value onto the stack -func (s *State) PushBoolean(b bool) { - var value C.int - if b { - value = 1 - } - C.lua_pushboolean(s.L, value) -} - -// PushNumber pushes a number value onto the stack -func (s *State) PushNumber(n float64) { - C.lua_pushnumber(s.L, C.lua_Number(n)) -} - -// PushString pushes a string value onto the stack func (s *State) PushString(str string) { - // Use direct C string for short strings (avoid allocations) if len(str) < 128 { cstr := C.CString(str) defer C.free(unsafe.Pointer(cstr)) C.lua_pushlstring(s.L, cstr, C.size_t(len(str))) - return + } else { + header := (*struct { + p unsafe.Pointer + len int + cap int + })(unsafe.Pointer(&str)) + C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str))) } - - // For longer strings, avoid double copy by using unsafe pointer - header := (*struct { - p unsafe.Pointer - len int - cap int - })(unsafe.Pointer(&str)) - - C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str))) } // Table operations +func (s *State) CreateTable(narr, nrec int) { C.lua_createtable(s.L, C.int(narr), C.int(nrec)) } +func (s *State) NewTable() { C.lua_createtable(s.L, 0, 0) } +func (s *State) GetTable(index int) { C.lua_gettable(s.L, C.int(index)) } +func (s *State) SetTable(index int) { C.lua_settable(s.L, C.int(index)) } +func (s *State) Next(index int) bool { return C.lua_next(s.L, C.int(index)) != 0 } -// CreateTable creates a new table and pushes it onto the stack -func (s *State) CreateTable(narr, nrec int) { - C.lua_createtable(s.L, C.int(narr), C.int(nrec)) -} - -// NewTable creates a new empty table and pushes it onto the stack -func (s *State) NewTable() { - C.lua_createtable(s.L, 0, 0) -} - -// GetTable gets a table field (t[k]) where t is at the given index and k is at the top of the stack -func (s *State) GetTable(index int) { - C.lua_gettable(s.L, C.int(index)) -} - -// SetTable sets a table field (t[k] = v) where t is at the given index, k is at -2, and v is at -1 -func (s *State) SetTable(index int) { - C.lua_settable(s.L, C.int(index)) -} - -// GetField gets a table field t[k] and pushes it onto the stack func (s *State) GetField(index int, key string) { ckey := C.CString(key) defer C.free(unsafe.Pointer(ckey)) C.lua_getfield(s.L, C.int(index), ckey) } -// SetField sets a table field t[k] = v, where v is the value at the top of the stack func (s *State) SetField(index int, key string) { ckey := C.CString(key) defer C.free(unsafe.Pointer(ckey)) C.lua_setfield(s.L, C.int(index), ckey) } -// Next pops a key from the stack and pushes the next key-value pair from the table at the given index -func (s *State) Next(index int) bool { - return C.lua_next(s.L, C.int(index)) != 0 +func (s *State) GetTableLength(index int) int { + return int(C.get_table_length(s.L, C.int(index))) } -// PushValue pushes a Go value onto the stack with proper type conversion +// Enhanced PushValue with comprehensive type support func (s *State) PushValue(v any) error { - switch v := v.(type) { + switch val := v.(type) { case nil: s.PushNil() case bool: - s.PushBoolean(v) + s.PushBoolean(val) case int: - s.PushNumber(float64(v)) + s.PushNumber(float64(val)) case int64: - s.PushNumber(float64(v)) + s.PushNumber(float64(val)) case float64: - s.PushNumber(v) + s.PushNumber(val) case string: - s.PushString(v) - case map[string]any: - // Special case: handle array stored in map - if arr, ok := v[""].([]float64); ok { - s.CreateTable(len(arr), 0) - for i, elem := range arr { - s.PushNumber(float64(i + 1)) - s.PushNumber(elem) - s.SetTable(-3) - } - return nil - } - return s.PushTable(v) + s.PushString(val) + case []int: + return s.pushIntSlice(val) + case []string: + return s.pushStringSlice(val) + case []bool: + return s.pushBoolSlice(val) case []float64: - s.CreateTable(len(v), 0) - for i, elem := range v { - s.PushNumber(float64(i + 1)) - s.PushNumber(elem) - s.SetTable(-3) - } + return s.pushFloatSlice(val) case []any: - s.CreateTable(len(v), 0) - for i, elem := range v { - s.PushNumber(float64(i + 1)) - if err := s.PushValue(elem); err != nil { - return err - } - s.SetTable(-3) - } + return s.pushAnySlice(val) + case map[string]string: + return s.pushStringMap(val) + case map[string]int: + return s.pushIntMap(val) + case map[int]any: + return s.pushIntKeyMap(val) + case map[string]any: + return s.pushAnyMap(val) default: return fmt.Errorf("unsupported type: %T", v) } return nil } -// ToValue converts a Lua value at the given index to a Go value +func (s *State) pushIntSlice(arr []int) error { + s.CreateTable(len(arr), 0) + for i, v := range arr { + s.PushNumber(float64(i + 1)) + s.PushNumber(float64(v)) + s.SetTable(-3) + } + return nil +} + +func (s *State) pushStringSlice(arr []string) error { + s.CreateTable(len(arr), 0) + for i, v := range arr { + s.PushNumber(float64(i + 1)) + s.PushString(v) + s.SetTable(-3) + } + return nil +} + +func (s *State) pushBoolSlice(arr []bool) error { + s.CreateTable(len(arr), 0) + for i, v := range arr { + s.PushNumber(float64(i + 1)) + s.PushBoolean(v) + s.SetTable(-3) + } + return nil +} + +func (s *State) pushFloatSlice(arr []float64) error { + s.CreateTable(len(arr), 0) + for i, v := range arr { + s.PushNumber(float64(i + 1)) + s.PushNumber(v) + s.SetTable(-3) + } + return nil +} + +func (s *State) pushAnySlice(arr []any) error { + s.CreateTable(len(arr), 0) + for i, v := range arr { + s.PushNumber(float64(i + 1)) + if err := s.PushValue(v); err != nil { + return err + } + s.SetTable(-3) + } + return nil +} + +func (s *State) pushStringMap(m map[string]string) error { + s.CreateTable(0, len(m)) + for k, v := range m { + s.PushString(k) + s.PushString(v) + s.SetTable(-3) + } + return nil +} + +func (s *State) pushIntMap(m map[string]int) error { + s.CreateTable(0, len(m)) + for k, v := range m { + s.PushString(k) + s.PushNumber(float64(v)) + s.SetTable(-3) + } + return nil +} + +func (s *State) pushIntKeyMap(m map[int]any) error { + s.CreateTable(0, len(m)) + for k, v := range m { + s.PushNumber(float64(k)) + if err := s.PushValue(v); err != nil { + return err + } + s.SetTable(-3) + } + return nil +} + +func (s *State) pushAnyMap(m map[string]any) error { + s.CreateTable(0, len(m)) + for k, v := range m { + s.PushString(k) + if err := s.PushValue(v); err != nil { + return err + } + s.SetTable(-3) + } + return nil +} + +// Enhanced ToValue with automatic type detection func (s *State) ToValue(index int) (any, error) { - luaType := s.GetType(index) - switch luaType { + switch s.GetType(index) { case TypeNil: return nil, nil case TypeBoolean: return s.ToBoolean(index), nil case TypeNumber: - return s.ToNumber(index), nil + num := s.ToNumber(index) + if num == float64(int(num)) && num >= -2147483648 && num <= 2147483647 { + return int(num), nil + } + return num, nil case TypeString: return s.ToString(index), nil case TypeTable: return s.ToTable(index) default: - return nil, fmt.Errorf("unsupported type: %s", luaType) + return nil, fmt.Errorf("unsupported type: %s", s.GetType(index)) } } +// ToTable converts a Lua table to optimal Go type +func (s *State) ToTable(index int) (any, error) { + absIdx := s.absIndex(index) + if !s.IsTable(absIdx) { + return nil, fmt.Errorf("value at index %d is not a table", index) + } + + length := s.GetTableLength(absIdx) + + if length > 0 { + arrayType := int(C.sample_array_type(s.L, C.int(absIdx), C.int(length))) + switch arrayType { + case 1: // int array + return s.extractIntArray(absIdx, length), nil + case 2: // float array + return s.extractFloatArray(absIdx, length), nil + case 3: // string array + return s.extractStringArray(absIdx, length), nil + case 4: // bool array + return s.extractBoolArray(absIdx, length), nil + default: // mixed array + return s.extractAnyArray(absIdx, length), nil + } + } + + mapType := int(C.sample_map_type(s.L, C.int(absIdx))) + switch mapType { + case 1: // map[string]string + return s.extractStringMap(absIdx) + case 2: // map[string]int + return s.extractIntMap(absIdx) + case 4: // map[int]any + return s.extractIntKeyMap(absIdx) + default: // map[string]any + return s.extractAnyMap(absIdx) + } +} + +func (s *State) extractIntArray(index, length int) []int { + result := make([]int, length) + for i := 1; i <= length; i++ { + s.PushNumber(float64(i)) + s.GetTable(index) + result[i-1] = int(s.ToNumber(-1)) + s.Pop(1) + } + return result +} + +func (s *State) extractFloatArray(index, length int) []float64 { + result := make([]float64, length) + for i := 1; i <= length; i++ { + s.PushNumber(float64(i)) + s.GetTable(index) + result[i-1] = s.ToNumber(-1) + s.Pop(1) + } + return result +} + +func (s *State) extractStringArray(index, length int) []string { + result := make([]string, length) + for i := 1; i <= length; i++ { + s.PushNumber(float64(i)) + s.GetTable(index) + result[i-1] = s.ToString(-1) + s.Pop(1) + } + return result +} + +func (s *State) extractBoolArray(index, length int) []bool { + result := make([]bool, length) + for i := 1; i <= length; i++ { + s.PushNumber(float64(i)) + s.GetTable(index) + result[i-1] = s.ToBoolean(-1) + s.Pop(1) + } + return result +} + +func (s *State) extractAnyArray(index, length int) []any { + result := make([]any, length) + for i := 1; i <= length; i++ { + s.PushNumber(float64(i)) + s.GetTable(index) + if val, err := s.ToValue(-1); err == nil { + result[i-1] = val + } + s.Pop(1) + } + return result +} + +func (s *State) extractStringMap(index int) (map[string]string, error) { + result := make(map[string]string) + s.PushNil() + for s.Next(index) { + if s.GetType(-2) == TypeString { + key := s.ToString(-2) + value := s.ToString(-1) + result[key] = value + } + s.Pop(1) + } + return result, nil +} + +func (s *State) extractIntMap(index int) (map[string]int, error) { + result := make(map[string]int) + s.PushNil() + for s.Next(index) { + if s.GetType(-2) == TypeString { + key := s.ToString(-2) + value := int(s.ToNumber(-1)) + result[key] = value + } + s.Pop(1) + } + return result, nil +} + +func (s *State) extractIntKeyMap(index int) (map[int]any, error) { + result := make(map[int]any) + s.PushNil() + for s.Next(index) { + var key int + switch s.GetType(-2) { + case TypeString: + if k, err := strconv.Atoi(s.ToString(-2)); err == nil { + key = k + } else { + s.Pop(1) + continue + } + case TypeNumber: + key = int(s.ToNumber(-2)) + default: + s.Pop(1) + continue + } + + if value, err := s.ToValue(-1); err == nil { + result[key] = value + } + s.Pop(1) + } + return result, nil +} + +func (s *State) extractAnyMap(index int) (map[string]any, error) { + result := make(map[string]any) + s.PushNil() + for s.Next(index) { + var key string + switch s.GetType(-2) { + case TypeString: + key = s.ToString(-2) + case TypeNumber: + key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64) + default: + s.Pop(1) + continue + } + + if value, err := s.ToValue(-1); err == nil { + result[key] = value + } + s.Pop(1) + } + return result, nil +} + // Global operations +func (s *State) GetGlobal(name string) { s.GetField(LUA_GLOBALSINDEX, name) } +func (s *State) SetGlobal(name string) { s.SetField(LUA_GLOBALSINDEX, name) } -// GetGlobal pushes the global variable with the given name onto the stack -func (s *State) GetGlobal(name string) { - s.GetField(LUA_GLOBALSINDEX, name) -} - -// SetGlobal sets the global variable with the given name to the value at the top of the stack -func (s *State) SetGlobal(name string) { - s.SetField(LUA_GLOBALSINDEX, name) -} - -// Code execution methods - -// LoadString loads a Lua chunk from a string without executing it +// Code execution func (s *State) LoadString(code string) error { ccode := C.CString(code) defer C.free(unsafe.Pointer(ccode)) @@ -346,13 +567,12 @@ func (s *State) LoadString(code string) error { status := C.luaL_loadstring(s.L, ccode) if status != 0 { err := s.CreateLuaError(int(status), "LoadString") - s.Pop(1) // Remove error message + s.Pop(1) return err } return nil } -// LoadFile loads a Lua chunk from a file without executing it func (s *State) LoadFile(filename string) error { cfilename := C.CString(filename) defer C.free(unsafe.Pointer(cfilename)) @@ -360,24 +580,22 @@ func (s *State) LoadFile(filename string) error { status := C.luaL_loadfile(s.L, cfilename) if status != 0 { err := s.CreateLuaError(int(status), fmt.Sprintf("LoadFile(%s)", filename)) - s.Pop(1) // Remove error message + s.Pop(1) return err } return nil } -// Call calls a function with the given number of arguments and results func (s *State) Call(nargs, nresults int) error { status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0) if status != 0 { err := s.CreateLuaError(int(status), fmt.Sprintf("Call(%d,%d)", nargs, nresults)) - s.Pop(1) // Remove error message + s.Pop(1) return err } return nil } -// DoString executes a Lua string and cleans up the stack func (s *State) DoString(code string) error { ccode := C.CString(code) defer C.free(unsafe.Pointer(ccode)) @@ -385,13 +603,12 @@ func (s *State) DoString(code string) error { status := C.do_string(s.L, ccode) if status != 0 { err := s.CreateLuaError(int(status), "DoString") - s.Pop(1) // Remove error message + s.Pop(1) return err } return nil } -// DoFile executes a Lua file and cleans up the stack func (s *State) DoFile(filename string) error { cfilename := C.CString(filename) defer C.free(unsafe.Pointer(cfilename)) @@ -399,39 +616,35 @@ func (s *State) DoFile(filename string) error { status := C.do_file(s.L, cfilename) if status != 0 { err := s.CreateLuaError(int(status), fmt.Sprintf("DoFile(%s)", filename)) - s.Pop(1) // Remove error message + s.Pop(1) return err } return nil } -// Execute executes a Lua string and returns the number of results left on the stack func (s *State) Execute(code string) (int, error) { baseTop := s.GetTop() - ccode := C.CString(code) defer C.free(unsafe.Pointer(ccode)) - status := C.execute_with_results(s.L, ccode, 1) // store_results=true + status := C.execute_with_results(s.L, ccode, 1) if status != 0 { err := s.CreateLuaError(int(status), "Execute") - s.Pop(1) // Remove error message + s.Pop(1) return 0, err } return s.GetTop() - baseTop, nil } -// ExecuteWithResult executes a Lua string and returns the first result func (s *State) ExecuteWithResult(code string) (any, error) { top := s.GetTop() - defer s.SetTop(top) // Restore stack when done + defer s.SetTop(top) nresults, err := s.Execute(code) if err != nil { return nil, err } - if nresults == 0 { return nil, nil } @@ -439,42 +652,161 @@ func (s *State) ExecuteWithResult(code string) (any, error) { return s.ToValue(-nresults) } -// BatchExecute executes multiple statements with a single CGO transition func (s *State) BatchExecute(statements []string) error { - // Join statements with semicolons - combinedCode := "" - for i, stmt := range statements { - combinedCode += stmt - if i < len(statements)-1 { - combinedCode += "; " - } - } - - return s.DoString(combinedCode) + return s.DoString(strings.Join(statements, "; ")) } // Package path operations - -// SetPackagePath sets the Lua package.path func (s *State) SetPackagePath(path string) error { - path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths - code := fmt.Sprintf(`package.path = %q`, path) - return s.DoString(code) + path = strings.ReplaceAll(path, "\\", "/") + return s.DoString(fmt.Sprintf(`package.path = %q`, path)) } -// AddPackagePath adds a path to package.path func (s *State) AddPackagePath(path string) error { - path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths - code := fmt.Sprintf(`package.path = package.path .. ";%s"`, path) - return s.DoString(code) + path = strings.ReplaceAll(path, "\\", "/") + return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)) } -// SetMetatable sets the metatable for the value at the given index -func (s *State) SetMetatable(index int) { - C.lua_setmetatable(s.L, C.int(index)) +// Metatable operations +func (s *State) SetMetatable(index int) { C.lua_setmetatable(s.L, C.int(index)) } +func (s *State) GetMetatable(index int) bool { return C.lua_getmetatable(s.L, C.int(index)) != 0 } + +// Helper functions +func boolToInt(b bool) C.int { + if b { + return 1 + } + return 0 } -// GetMetatable gets the metatable of the value at the given index -func (s *State) GetMetatable(index int) bool { - return C.lua_getmetatable(s.L, C.int(index)) != 0 +// GetFieldString gets a string field from a table with default +func (s *State) GetFieldString(index int, key string, defaultVal string) string { + s.GetField(index, key) + defer s.Pop(1) + if s.IsString(-1) { + return s.ToString(-1) + } + return defaultVal +} + +// GetFieldNumber gets a number field from a table with default +func (s *State) GetFieldNumber(index int, key string, defaultVal float64) float64 { + s.GetField(index, key) + defer s.Pop(1) + if s.IsNumber(-1) { + return s.ToNumber(-1) + } + return defaultVal +} + +// GetFieldBool gets a boolean field from a table with default +func (s *State) GetFieldBool(index int, key string, defaultVal bool) bool { + s.GetField(index, key) + defer s.Pop(1) + if s.IsBoolean(-1) { + return s.ToBoolean(-1) + } + return defaultVal +} + +// GetFieldTable gets a table field from a table +func (s *State) GetFieldTable(index int, key string) (any, bool) { + s.GetField(index, key) + defer s.Pop(1) + if s.IsTable(-1) { + val, err := s.ToTable(-1) + return val, err == nil + } + return nil, false +} + +// ForEachTableKV iterates over string key-value pairs in a table +func (s *State) ForEachTableKV(index int, fn func(key, value string) bool) { + absIdx := s.absIndex(index) + s.PushNil() + for s.Next(absIdx) { + if s.IsString(-2) && s.IsString(-1) { + if !fn(s.ToString(-2), s.ToString(-1)) { + s.Pop(2) + return + } + } + s.Pop(1) + } +} + +// ForEachArray iterates over array elements +func (s *State) ForEachArray(index int, fn func(i int, state *State) bool) { + absIdx := s.absIndex(index) + length := s.GetTableLength(absIdx) + for i := 1; i <= length; i++ { + s.PushNumber(float64(i)) + s.GetTable(absIdx) + if !fn(i, s) { + s.Pop(1) + return + } + s.Pop(1) + } +} + +// SafeToString safely converts value to string with error +func (s *State) SafeToString(index int) (string, error) { + if !s.IsString(index) && !s.IsNumber(index) { + return "", fmt.Errorf("value at index %d is not a string", index) + } + return s.ToString(index), nil +} + +// SafeToNumber safely converts value to number with error +func (s *State) SafeToNumber(index int) (float64, error) { + if !s.IsNumber(index) { + return 0, fmt.Errorf("value at index %d is not a number", index) + } + return s.ToNumber(index), nil +} + +// SafeToTable safely converts value to table with error +func (s *State) SafeToTable(index int) (any, error) { + if !s.IsTable(index) { + return nil, fmt.Errorf("value at index %d is not a table", index) + } + return s.ToTable(index) +} + +// CallGlobal calls a global function with arguments +func (s *State) CallGlobal(name string, args ...any) ([]any, error) { + s.GetGlobal(name) + if !s.IsFunction(-1) { + s.Pop(1) + return nil, fmt.Errorf("global '%s' is not a function", name) + } + + for i, arg := range args { + if err := s.PushValue(arg); err != nil { + s.Pop(i + 1) + return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err) + } + } + + baseTop := s.GetTop() - len(args) - 1 + if err := s.Call(len(args), C.LUA_MULTRET); err != nil { + return nil, err + } + + newTop := s.GetTop() + nresults := newTop - baseTop + + results := make([]any, nresults) + for i := 0; i < nresults; i++ { + val, err := s.ToValue(baseTop + i + 1) + if err != nil { + results[i] = nil + } else { + results[i] = val + } + } + + s.SetTop(baseTop) + return results, nil }