massive rewrite
fix go func mallocs add helper utils
This commit is contained in:
parent
fc7958312d
commit
f4bfff470f
72
builder.go
Normal file
72
builder.go
Normal file
@ -0,0 +1,72 @@
|
||||
package luajit
|
||||
|
||||
// TableBuilder provides a fluent interface for building Lua tables
|
||||
type TableBuilder struct {
|
||||
state *State
|
||||
index int
|
||||
}
|
||||
|
||||
// NewTableBuilder creates a new table and returns a builder
|
||||
func (s *State) NewTableBuilder() *TableBuilder {
|
||||
s.NewTable()
|
||||
return &TableBuilder{
|
||||
state: s,
|
||||
index: s.GetTop(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetString sets a string field
|
||||
func (tb *TableBuilder) SetString(key, value string) *TableBuilder {
|
||||
tb.state.PushString(value)
|
||||
tb.state.SetField(tb.index, key)
|
||||
return tb
|
||||
}
|
||||
|
||||
// SetNumber sets a number field
|
||||
func (tb *TableBuilder) SetNumber(key string, value float64) *TableBuilder {
|
||||
tb.state.PushNumber(value)
|
||||
tb.state.SetField(tb.index, key)
|
||||
return tb
|
||||
}
|
||||
|
||||
// SetBool sets a boolean field
|
||||
func (tb *TableBuilder) SetBool(key string, value bool) *TableBuilder {
|
||||
tb.state.PushBoolean(value)
|
||||
tb.state.SetField(tb.index, key)
|
||||
return tb
|
||||
}
|
||||
|
||||
// SetNil sets a nil field
|
||||
func (tb *TableBuilder) SetNil(key string) *TableBuilder {
|
||||
tb.state.PushNil()
|
||||
tb.state.SetField(tb.index, key)
|
||||
return tb
|
||||
}
|
||||
|
||||
// SetTable sets a table field
|
||||
func (tb *TableBuilder) SetTable(key string, value any) *TableBuilder {
|
||||
if err := tb.state.PushValue(value); err == nil {
|
||||
tb.state.SetField(tb.index, key)
|
||||
}
|
||||
return tb
|
||||
}
|
||||
|
||||
// SetArray sets an array field
|
||||
func (tb *TableBuilder) SetArray(key string, values []any) *TableBuilder {
|
||||
tb.state.CreateTable(len(values), 0)
|
||||
for i, v := range values {
|
||||
tb.state.PushNumber(float64(i + 1))
|
||||
if err := tb.state.PushValue(v); err == nil {
|
||||
tb.state.SetTable(-3)
|
||||
} else {
|
||||
tb.state.Pop(1)
|
||||
}
|
||||
}
|
||||
tb.state.SetField(tb.index, key)
|
||||
return tb
|
||||
}
|
||||
|
||||
// Build finalizes the table (no-op, table is already on stack)
|
||||
func (tb *TableBuilder) Build() {
|
||||
// Table is already on the stack at tb.index
|
||||
}
|
89
bytecode.go
89
bytecode.go
@ -12,6 +12,12 @@ typedef struct {
|
||||
const char *name;
|
||||
} BytecodeReader;
|
||||
|
||||
typedef struct {
|
||||
unsigned char *buf;
|
||||
size_t size;
|
||||
size_t capacity;
|
||||
} BytecodeBuffer;
|
||||
|
||||
const char *bytecode_reader(lua_State *L, void *ud, size_t *size) {
|
||||
BytecodeReader *r = (BytecodeReader *)ud;
|
||||
(void)L; // unused
|
||||
@ -26,16 +32,24 @@ int load_bytecode(lua_State *L, const unsigned char *buf, size_t len, const char
|
||||
return lua_load(L, bytecode_reader, &reader, name);
|
||||
}
|
||||
|
||||
// Direct bytecode dumping without intermediate buffer - more efficient
|
||||
int direct_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
|
||||
void **data = (void **)ud;
|
||||
size_t current_size = (size_t)data[1];
|
||||
void *newbuf = realloc(data[0], current_size + sz);
|
||||
if (newbuf == NULL) return 1;
|
||||
// Optimized bytecode writer with pre-allocated buffer
|
||||
int buffered_bytecode_writer(lua_State *L, const void *p, size_t sz, void *ud) {
|
||||
BytecodeBuffer *buf = (BytecodeBuffer *)ud;
|
||||
|
||||
memcpy((unsigned char*)newbuf + current_size, p, sz);
|
||||
data[0] = newbuf;
|
||||
data[1] = (void*)(current_size + sz);
|
||||
// Grow buffer if needed (double size to avoid frequent reallocs)
|
||||
if (buf->size + sz > buf->capacity) {
|
||||
size_t new_capacity = buf->capacity;
|
||||
while (new_capacity < buf->size + sz) {
|
||||
new_capacity *= 2;
|
||||
}
|
||||
unsigned char *newbuf = realloc(buf->buf, new_capacity);
|
||||
if (newbuf == NULL) return 1;
|
||||
buf->buf = newbuf;
|
||||
buf->capacity = new_capacity;
|
||||
}
|
||||
|
||||
memcpy(buf->buf + buf->size, p, sz);
|
||||
buf->size += sz;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -52,36 +66,56 @@ int load_and_run_bytecode(lua_State *L, const unsigned char *buf, size_t len,
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// bytecodeBuffer wraps []byte to avoid boxing allocations in sync.Pool
|
||||
type bytecodeBuffer struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
// Buffer pool for bytecode generation
|
||||
var bytecodeBufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &bytecodeBuffer{data: make([]byte, 0, 1024)}
|
||||
},
|
||||
}
|
||||
|
||||
// CompileBytecode compiles a Lua chunk to bytecode without executing it
|
||||
func (s *State) CompileBytecode(code string, name string) ([]byte, error) {
|
||||
if err := s.LoadString(code); err != nil {
|
||||
return nil, fmt.Errorf("failed to load string: %w", err)
|
||||
}
|
||||
|
||||
// Use a simpler direct writer with just two pointers
|
||||
data := [2]unsafe.Pointer{nil, nil}
|
||||
// Always use C memory for dump operation to avoid cgo pointer issues
|
||||
cbuf := C.BytecodeBuffer{
|
||||
buf: (*C.uchar)(C.malloc(1024)),
|
||||
size: 0,
|
||||
capacity: 1024,
|
||||
}
|
||||
if cbuf.buf == nil {
|
||||
return nil, fmt.Errorf("failed to allocate initial buffer")
|
||||
}
|
||||
|
||||
// Dump the function to bytecode
|
||||
status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.direct_bytecode_writer)), unsafe.Pointer(&data))
|
||||
if status != 0 {
|
||||
return nil, fmt.Errorf("failed to dump bytecode: status %d", status)
|
||||
}
|
||||
|
||||
// Get result
|
||||
var bytecode []byte
|
||||
if data[0] != nil {
|
||||
// Create Go slice that references the C memory
|
||||
length := uintptr(data[1])
|
||||
bytecode = C.GoBytes(data[0], C.int(length))
|
||||
C.free(data[0])
|
||||
}
|
||||
status := C.lua_dump(s.L, (*[0]byte)(unsafe.Pointer(C.buffered_bytecode_writer)), unsafe.Pointer(&cbuf))
|
||||
|
||||
s.Pop(1) // Remove the function from stack
|
||||
|
||||
return bytecode, nil
|
||||
if status != 0 {
|
||||
C.free(unsafe.Pointer(cbuf.buf))
|
||||
return nil, fmt.Errorf("failed to dump bytecode: status %d", status)
|
||||
}
|
||||
|
||||
// Copy to Go memory and free C buffer
|
||||
var result []byte
|
||||
if cbuf.size > 0 {
|
||||
result = C.GoBytes(unsafe.Pointer(cbuf.buf), C.int(cbuf.size))
|
||||
}
|
||||
C.free(unsafe.Pointer(cbuf.buf))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// LoadBytecode loads precompiled bytecode without executing it
|
||||
@ -116,7 +150,6 @@ func (s *State) RunBytecode() error {
|
||||
}
|
||||
|
||||
// RunBytecodeWithResults executes bytecode and keeps nresults on the stack
|
||||
// Use LUA_MULTRET (-1) to keep all results
|
||||
func (s *State) RunBytecodeWithResults(nresults int) error {
|
||||
status := C.lua_pcall(s.L, 0, C.int(nresults), 0)
|
||||
if status != 0 {
|
||||
@ -136,13 +169,12 @@ func (s *State) LoadAndRunBytecode(bytecode []byte, name string) error {
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
// Use combined load and run function
|
||||
status := C.load_and_run_bytecode(
|
||||
s.L,
|
||||
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
|
||||
C.size_t(len(bytecode)),
|
||||
cname,
|
||||
0, // No results
|
||||
0,
|
||||
)
|
||||
|
||||
if status != 0 {
|
||||
@ -163,7 +195,6 @@ func (s *State) LoadAndRunBytecodeWithResults(bytecode []byte, name string, nres
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
// Use combined load and run function
|
||||
status := C.load_and_run_bytecode(
|
||||
s.L,
|
||||
(*C.uchar)(unsafe.Pointer(&bytecode[0])),
|
||||
|
15
functions.go
15
functions.go
@ -9,7 +9,7 @@ extern int goFunctionWrapper(lua_State* L);
|
||||
|
||||
// Helper function to access upvalues
|
||||
static int get_upvalue_index(int i) {
|
||||
return lua_upvalueindex(i);
|
||||
return lua_upvalueindex(i);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
@ -34,11 +34,20 @@ var (
|
||||
}{
|
||||
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
|
||||
}
|
||||
|
||||
// statePool reuses State structs to avoid allocations
|
||||
statePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &State{}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
//export goFunctionWrapper
|
||||
func goFunctionWrapper(L *C.lua_State) C.int {
|
||||
state := &State{L: L}
|
||||
state := statePool.Get().(*State)
|
||||
state.L = L
|
||||
defer statePool.Put(state)
|
||||
|
||||
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
||||
if ptr == nil {
|
||||
@ -51,8 +60,6 @@ func goFunctionWrapper(L *C.lua_State) C.int {
|
||||
functionRegistry.RUnlock()
|
||||
|
||||
if !ok {
|
||||
// Debug logging
|
||||
fmt.Printf("Function not found for pointer %p\n", ptr)
|
||||
state.PushString("error: function not found in registry")
|
||||
return -1
|
||||
}
|
||||
|
20
stack.go
20
stack.go
@ -46,14 +46,6 @@ func (e *LuaError) Error() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// Stack management constants from lua.h
|
||||
const (
|
||||
LUA_MINSTACK = 20 // Minimum Lua stack size
|
||||
LUA_MAXSTACK = 1000000 // Maximum Lua stack size
|
||||
LUA_REGISTRYINDEX = -10000 // Pseudo-index for the Lua registry
|
||||
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
|
||||
)
|
||||
|
||||
// GetStackTrace returns the current Lua stack trace
|
||||
func (s *State) GetStackTrace() string {
|
||||
s.GetGlobal("debug")
|
||||
@ -64,13 +56,13 @@ func (s *State) GetStackTrace() string {
|
||||
|
||||
s.GetField(-1, "traceback")
|
||||
if !s.IsFunction(-1) {
|
||||
s.Pop(2) // Remove debug table and non-function
|
||||
s.Pop(2)
|
||||
return "debug.traceback not available"
|
||||
}
|
||||
|
||||
s.Call(0, 1)
|
||||
trace := s.ToString(-1)
|
||||
s.Pop(1) // Remove the trace
|
||||
s.Pop(1)
|
||||
|
||||
return trace
|
||||
}
|
||||
@ -97,13 +89,11 @@ func (s *State) GetErrorInfo(context string) *LuaError {
|
||||
if secondColonPos := strings.Index(afterColon, ":"); secondColonPos > 0 {
|
||||
file = beforeColon
|
||||
if n, err := fmt.Sscanf(afterColon[:secondColonPos], "%d", &line); n == 1 && err == nil {
|
||||
// Strip the file:line part from message for cleaner display
|
||||
message = strings.TrimSpace(afterColon[secondColonPos+1:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get stack trace
|
||||
stackTrace := s.GetStackTrace()
|
||||
|
||||
return &LuaError{
|
||||
@ -121,3 +111,9 @@ func (s *State) CreateLuaError(code int, context string) *LuaError {
|
||||
err.Code = code
|
||||
return err
|
||||
}
|
||||
|
||||
// PushError pushes an error string and returns -1
|
||||
func (s *State) PushError(format string, args ...any) int {
|
||||
s.PushString(fmt.Sprintf(format, args...))
|
||||
return -1
|
||||
}
|
||||
|
164
table.go
164
table.go
@ -1,164 +0,0 @@
|
||||
package luajit
|
||||
|
||||
/*
|
||||
#include <lua.h>
|
||||
#include <lualib.h>
|
||||
#include <lauxlib.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Simple direct length check
|
||||
size_t get_table_length(lua_State *L, int index) {
|
||||
return lua_objlen(L, index);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// GetTableLength returns the length of a table at the given index
|
||||
func (s *State) GetTableLength(index int) int {
|
||||
return int(C.get_table_length(s.L, C.int(index)))
|
||||
}
|
||||
|
||||
// PushTable pushes a Go map onto the Lua stack as a table
|
||||
func (s *State) PushTable(table map[string]any) error {
|
||||
// Fast path for array tables
|
||||
if arr, ok := table[""]; ok {
|
||||
if floatArr, ok := arr.([]float64); ok {
|
||||
s.CreateTable(len(floatArr), 0)
|
||||
for i, v := range floatArr {
|
||||
s.PushNumber(float64(i + 1))
|
||||
s.PushNumber(v)
|
||||
s.SetTable(-3)
|
||||
}
|
||||
return nil
|
||||
} else if anyArr, ok := arr.([]any); ok {
|
||||
s.CreateTable(len(anyArr), 0)
|
||||
for i, v := range anyArr {
|
||||
s.PushNumber(float64(i + 1))
|
||||
if err := s.PushValue(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetTable(-3)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Regular table case - optimize capacity hint
|
||||
s.CreateTable(0, len(table))
|
||||
|
||||
// Add each key-value pair directly
|
||||
for k, v := range table {
|
||||
s.PushString(k)
|
||||
if err := s.PushValue(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetTable(-3)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToTable converts a Lua table at the given index to a Go map
|
||||
func (s *State) ToTable(index int) (map[string]any, error) {
|
||||
absIdx := s.absIndex(index)
|
||||
if !s.IsTable(absIdx) {
|
||||
return nil, fmt.Errorf("value at index %d is not a table", index)
|
||||
}
|
||||
|
||||
// Try to detect array-like tables first
|
||||
length := s.GetTableLength(absIdx)
|
||||
if length > 0 {
|
||||
// Fast path for common array case
|
||||
allNumbers := true
|
||||
|
||||
// Sample first few values to check if it's likely an array of numbers
|
||||
for i := 1; i <= min(length, 5); i++ {
|
||||
s.PushNumber(float64(i))
|
||||
s.GetTable(absIdx)
|
||||
|
||||
if !s.IsNumber(-1) {
|
||||
allNumbers = false
|
||||
s.Pop(1)
|
||||
break
|
||||
}
|
||||
s.Pop(1)
|
||||
}
|
||||
|
||||
if allNumbers {
|
||||
// Efficiently extract array values
|
||||
array := make([]float64, length)
|
||||
for i := 1; i <= length; i++ {
|
||||
s.PushNumber(float64(i))
|
||||
s.GetTable(absIdx)
|
||||
array[i-1] = s.ToNumber(-1)
|
||||
s.Pop(1)
|
||||
}
|
||||
|
||||
// Return array as a special table with empty key
|
||||
result := make(map[string]any, 1)
|
||||
result[""] = array
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle regular table with pre-allocated capacity
|
||||
table := make(map[string]any, max(length, 8))
|
||||
|
||||
// Iterate through all key-value pairs
|
||||
s.PushNil() // Start iteration with nil key
|
||||
for s.Next(absIdx) {
|
||||
// Stack now has key at -2 and value at -1
|
||||
|
||||
// Convert key to string
|
||||
var key string
|
||||
keyType := s.GetType(-2)
|
||||
switch keyType {
|
||||
case TypeString:
|
||||
key = s.ToString(-2)
|
||||
case TypeNumber:
|
||||
key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64)
|
||||
default:
|
||||
// Skip non-string/non-number keys
|
||||
s.Pop(1) // Pop value, leave key for next iteration
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert and store the value
|
||||
value, err := s.ToValue(-1)
|
||||
if err != nil {
|
||||
s.Pop(2) // Pop both key and value
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unwrap nested array tables
|
||||
if m, ok := value.(map[string]any); ok {
|
||||
if arr, ok := m[""]; ok {
|
||||
value = arr
|
||||
}
|
||||
}
|
||||
|
||||
table[key] = value
|
||||
s.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// Helper functions for min/max operations
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
@ -19,7 +19,6 @@ func TestGetTableLength(t *testing.T) {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Get the table
|
||||
state.GetGlobal("t")
|
||||
length := state.GetTableLength(-1)
|
||||
if length != 5 {
|
||||
@ -32,7 +31,6 @@ func TestGetTableLength(t *testing.T) {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Get the table
|
||||
state.GetGlobal("t2")
|
||||
length = state.GetTableLength(-1)
|
||||
if length != 0 {
|
||||
@ -41,206 +39,234 @@ func TestGetTableLength(t *testing.T) {
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestPushTable(t *testing.T) {
|
||||
func TestPushTypedArrays(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a test table
|
||||
testTable := map[string]any{
|
||||
"int": 42,
|
||||
"float": 3.14,
|
||||
"string": "hello",
|
||||
"boolean": true,
|
||||
"nil": nil,
|
||||
// Test []int
|
||||
intArr := []int{1, 2, 3, 4, 5}
|
||||
if err := state.PushValue(intArr); err != nil {
|
||||
t.Fatalf("Failed to push int array: %v", err)
|
||||
}
|
||||
state.SetGlobal("int_arr")
|
||||
|
||||
// Push the table onto the stack
|
||||
if err := state.PushTable(testTable); err != nil {
|
||||
t.Fatalf("Failed to push table: %v", err)
|
||||
// Test []string
|
||||
stringArr := []string{"hello", "world", "test"}
|
||||
if err := state.PushValue(stringArr); err != nil {
|
||||
t.Fatalf("Failed to push string array: %v", err)
|
||||
}
|
||||
state.SetGlobal("string_arr")
|
||||
|
||||
// Execute Lua code to test the table contents
|
||||
// Test []bool
|
||||
boolArr := []bool{true, false, true}
|
||||
if err := state.PushValue(boolArr); err != nil {
|
||||
t.Fatalf("Failed to push bool array: %v", err)
|
||||
}
|
||||
state.SetGlobal("bool_arr")
|
||||
|
||||
// Test []float64
|
||||
floatArr := []float64{1.1, 2.2, 3.3}
|
||||
if err := state.PushValue(floatArr); err != nil {
|
||||
t.Fatalf("Failed to push float array: %v", err)
|
||||
}
|
||||
state.SetGlobal("float_arr")
|
||||
|
||||
// Verify arrays in Lua
|
||||
if err := state.DoString(`
|
||||
function validate_table(t)
|
||||
return t.int == 42 and
|
||||
math.abs(t.float - 3.14) < 0.0001 and
|
||||
t.string == "hello" and
|
||||
t.boolean == true and
|
||||
t["nil"] == nil
|
||||
end
|
||||
assert(int_arr[1] == 1 and int_arr[5] == 5)
|
||||
assert(string_arr[1] == "hello" and string_arr[3] == "test")
|
||||
assert(bool_arr[1] == true and bool_arr[2] == false)
|
||||
assert(math.abs(float_arr[1] - 1.1) < 0.0001)
|
||||
`); err != nil {
|
||||
t.Fatalf("Failed to create validation function: %v", err)
|
||||
t.Fatalf("Array verification failed: %v", err)
|
||||
}
|
||||
|
||||
// Call the validation function
|
||||
state.GetGlobal("validate_table")
|
||||
state.PushCopy(-2) // Copy the table to the top
|
||||
if err := state.Call(1, 1); err != nil {
|
||||
t.Fatalf("Failed to call validation function: %v", err)
|
||||
}
|
||||
|
||||
if !state.ToBoolean(-1) {
|
||||
t.Fatalf("Table validation failed")
|
||||
}
|
||||
state.Pop(2) // Pop the result and the table
|
||||
}
|
||||
|
||||
func TestToTable(t *testing.T) {
|
||||
func TestPushTypedMaps(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test regular table conversion
|
||||
if err := state.DoString(`t = {a=1, b=2.5, c="test", d=true, e=nil}`); err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
// Test map[string]string
|
||||
stringMap := map[string]string{"name": "John", "city": "NYC"}
|
||||
if err := state.PushValue(stringMap); err != nil {
|
||||
t.Fatalf("Failed to push string map: %v", err)
|
||||
}
|
||||
state.SetGlobal("string_map")
|
||||
|
||||
state.GetGlobal("t")
|
||||
table, err := state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert table: %v", err)
|
||||
// Test map[string]int
|
||||
intMap := map[string]int{"age": 25, "score": 100}
|
||||
if err := state.PushValue(intMap); err != nil {
|
||||
t.Fatalf("Failed to push int map: %v", err)
|
||||
}
|
||||
state.Pop(1)
|
||||
state.SetGlobal("int_map")
|
||||
|
||||
expected := map[string]any{
|
||||
"a": float64(1),
|
||||
"b": 2.5,
|
||||
"c": "test",
|
||||
"d": true,
|
||||
// Test map[int]any
|
||||
intKeyMap := map[int]any{1: "first", 2: 42, 3: true}
|
||||
if err := state.PushValue(intKeyMap); err != nil {
|
||||
t.Fatalf("Failed to push int key map: %v", err)
|
||||
}
|
||||
state.SetGlobal("int_key_map")
|
||||
|
||||
for k, v := range expected {
|
||||
if table[k] != v {
|
||||
t.Fatalf("Expected table[%s] = %v, got %v", k, v, table[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Test array-like table conversion
|
||||
if err := state.DoString(`arr = {10, 20, 30, 40, 50}`); err != nil {
|
||||
t.Fatalf("Failed to create test array: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("arr")
|
||||
table, err = state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert array table: %v", err)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// For array tables, we should get a special format with an empty key
|
||||
// and the array as the value
|
||||
expectedArray := []float64{10, 20, 30, 40, 50}
|
||||
if arr, ok := table[""].([]float64); !ok {
|
||||
t.Fatalf("Expected array table to be converted with empty key, got: %v", table)
|
||||
} else if !reflect.DeepEqual(arr, expectedArray) {
|
||||
t.Fatalf("Expected %v, got %v", expectedArray, arr)
|
||||
}
|
||||
|
||||
// Test invalid table index
|
||||
_, err = state.ToTable(100)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for invalid table index, got nil")
|
||||
}
|
||||
|
||||
// Test non-table value
|
||||
state.PushNumber(123)
|
||||
_, err = state.ToTable(-1)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for non-table value, got nil")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test mixed array with non-numeric values
|
||||
if err := state.DoString(`mixed = {10, 20, key="value", 30}`); err != nil {
|
||||
t.Fatalf("Failed to create mixed table: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("mixed")
|
||||
table, err = state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert mixed table: %v", err)
|
||||
}
|
||||
|
||||
// Let's print the table for debugging
|
||||
t.Logf("Table contents: %v", table)
|
||||
|
||||
state.Pop(1)
|
||||
|
||||
// Check if the array part is detected and stored with empty key
|
||||
if arr, ok := table[""]; !ok {
|
||||
t.Fatalf("Expected array-like part to be detected, got: %v", table)
|
||||
} else {
|
||||
// Verify the array contains the expected values
|
||||
expectedArr := []float64{10, 20, 30}
|
||||
actualArr := arr.([]float64)
|
||||
if len(actualArr) != len(expectedArr) {
|
||||
t.Fatalf("Expected array length %d, got %d", len(expectedArr), len(actualArr))
|
||||
}
|
||||
|
||||
for i, v := range expectedArr {
|
||||
if actualArr[i] != v {
|
||||
t.Fatalf("Expected array[%d] = %v, got %v", i, v, actualArr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Based on the implementation, we need to create a separate test for string keys
|
||||
if err := state.DoString(`dict = {foo="bar", baz="qux"}`); err != nil {
|
||||
t.Fatalf("Failed to create dict table: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("dict")
|
||||
dictTable, err := state.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert dict table: %v", err)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Check the string keys
|
||||
if val, ok := dictTable["foo"]; !ok || val != "bar" {
|
||||
t.Fatalf("Expected dictTable[\"foo\"] = \"bar\", got: %v", val)
|
||||
}
|
||||
if val, ok := dictTable["baz"]; !ok || val != "qux" {
|
||||
t.Fatalf("Expected dictTable[\"baz\"] = \"qux\", got: %v", val)
|
||||
// Verify maps in Lua
|
||||
if err := state.DoString(`
|
||||
assert(string_map.name == "John" and string_map.city == "NYC")
|
||||
assert(int_map.age == 25 and int_map.score == 100)
|
||||
assert(int_key_map[1] == "first" and int_key_map[2] == 42 and int_key_map[3] == true)
|
||||
`); err != nil {
|
||||
t.Fatalf("Map verification failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTablePooling(t *testing.T) {
|
||||
func TestToTableTypedArrays(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a Lua table and push it onto the stack
|
||||
if err := state.DoString(`t = {a=1, b=2}`); err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
// Test integer array detection
|
||||
if err := state.DoString("int_arr = {10, 20, 30}"); err != nil {
|
||||
t.Fatalf("Failed to create int array: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("t")
|
||||
|
||||
// First conversion - should get a table from the pool
|
||||
table1, err := state.ToTable(-1)
|
||||
state.GetGlobal("int_arr")
|
||||
result, err := state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert table (1): %v", err)
|
||||
t.Fatalf("Failed to convert int array: %v", err)
|
||||
}
|
||||
intArr, ok := result.([]int)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []int, got %T", result)
|
||||
}
|
||||
expected := []int{10, 20, 30}
|
||||
if !reflect.DeepEqual(intArr, expected) {
|
||||
t.Fatalf("Expected %v, got %v", expected, intArr)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Second conversion - should get another table from the pool
|
||||
table2, err := state.ToTable(-1)
|
||||
// Test float array detection
|
||||
if err := state.DoString("float_arr = {1.5, 2.7, 3.9}"); err != nil {
|
||||
t.Fatalf("Failed to create float array: %v", err)
|
||||
}
|
||||
state.GetGlobal("float_arr")
|
||||
result, err = state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert table (2): %v", err)
|
||||
t.Fatalf("Failed to convert float array: %v", err)
|
||||
}
|
||||
|
||||
// Both tables should have the same content
|
||||
if !reflect.DeepEqual(table1, table2) {
|
||||
t.Fatalf("Tables should have the same content: %v vs %v", table1, table2)
|
||||
floatArr, ok := result.([]float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []float64, got %T", result)
|
||||
}
|
||||
expectedFloat := []float64{1.5, 2.7, 3.9}
|
||||
if !reflect.DeepEqual(floatArr, expectedFloat) {
|
||||
t.Fatalf("Expected %v, got %v", expectedFloat, floatArr)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Clean up
|
||||
// Test string array detection
|
||||
if err := state.DoString(`string_arr = {"hello", "world"}`); err != nil {
|
||||
t.Fatalf("Failed to create string array: %v", err)
|
||||
}
|
||||
state.GetGlobal("string_arr")
|
||||
result, err = state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert string array: %v", err)
|
||||
}
|
||||
stringArr, ok := result.([]string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []string, got %T", result)
|
||||
}
|
||||
expectedString := []string{"hello", "world"}
|
||||
if !reflect.DeepEqual(stringArr, expectedString) {
|
||||
t.Fatalf("Expected %v, got %v", expectedString, stringArr)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test bool array detection
|
||||
if err := state.DoString("bool_arr = {true, false, true}"); err != nil {
|
||||
t.Fatalf("Failed to create bool array: %v", err)
|
||||
}
|
||||
state.GetGlobal("bool_arr")
|
||||
result, err = state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert bool array: %v", err)
|
||||
}
|
||||
boolArr, ok := result.([]bool)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []bool, got %T", result)
|
||||
}
|
||||
expectedBool := []bool{true, false, true}
|
||||
if !reflect.DeepEqual(boolArr, expectedBool) {
|
||||
t.Fatalf("Expected %v, got %v", expectedBool, boolArr)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestToTableTypedMaps(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test string map detection
|
||||
if err := state.DoString(`string_map = {name="John", city="NYC"}`); err != nil {
|
||||
t.Fatalf("Failed to create string map: %v", err)
|
||||
}
|
||||
state.GetGlobal("string_map")
|
||||
result, err := state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert string map: %v", err)
|
||||
}
|
||||
stringMap, ok := result.(map[string]string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map[string]string, got %T", result)
|
||||
}
|
||||
expectedStringMap := map[string]string{"name": "John", "city": "NYC"}
|
||||
if !reflect.DeepEqual(stringMap, expectedStringMap) {
|
||||
t.Fatalf("Expected %v, got %v", expectedStringMap, stringMap)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test int map detection
|
||||
if err := state.DoString("int_map = {age=25, score=100}"); err != nil {
|
||||
t.Fatalf("Failed to create int map: %v", err)
|
||||
}
|
||||
state.GetGlobal("int_map")
|
||||
result, err = state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert int map: %v", err)
|
||||
}
|
||||
intMap, ok := result.(map[string]int)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map[string]int, got %T", result)
|
||||
}
|
||||
expectedIntMap := map[string]int{"age": 25, "score": 100}
|
||||
if !reflect.DeepEqual(intMap, expectedIntMap) {
|
||||
t.Fatalf("Expected %v, got %v", expectedIntMap, intMap)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test mixed map (should fallback to map[string]any)
|
||||
if err := state.DoString(`mixed_map = {name="John", age=25, active=true}`); err != nil {
|
||||
t.Fatalf("Failed to create mixed map: %v", err)
|
||||
}
|
||||
state.GetGlobal("mixed_map")
|
||||
result, err = state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert mixed map: %v", err)
|
||||
}
|
||||
mixedMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map[string]any, got %T", result)
|
||||
}
|
||||
if mixedMap["name"] != "John" || mixedMap["age"] != 25 || mixedMap["active"] != true {
|
||||
t.Fatalf("Mixed map conversion failed: %v", mixedMap)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
@ -9,70 +9,50 @@ import (
|
||||
)
|
||||
|
||||
func TestStateLifecycle(t *testing.T) {
|
||||
// Test creation
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
|
||||
// Test close
|
||||
state.Close()
|
||||
|
||||
// Test close is idempotent (doesn't crash)
|
||||
state.Close()
|
||||
state.Close() // Test idempotent close
|
||||
}
|
||||
|
||||
func TestStackManipulation(t *testing.T) {
|
||||
func TestStackOperations(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test initial stack size
|
||||
// Test stack manipulation
|
||||
if state.GetTop() != 0 {
|
||||
t.Fatalf("Expected empty stack, got %d elements", state.GetTop())
|
||||
t.Fatalf("Expected empty stack, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Push values
|
||||
state.PushNil()
|
||||
state.PushBoolean(true)
|
||||
state.PushNumber(42)
|
||||
state.PushString("hello")
|
||||
|
||||
// Check stack size
|
||||
if state.GetTop() != 4 {
|
||||
t.Fatalf("Expected 4 elements, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Test SetTop
|
||||
state.SetTop(2)
|
||||
if state.GetTop() != 2 {
|
||||
t.Fatalf("Expected 2 elements after SetTop, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Test PushCopy
|
||||
state.PushCopy(2) // Copy the boolean
|
||||
state.PushCopy(2)
|
||||
if !state.IsBoolean(-1) {
|
||||
t.Fatalf("Expected boolean at top of stack")
|
||||
t.Fatal("Expected boolean at top")
|
||||
}
|
||||
|
||||
// Test Pop
|
||||
state.Pop(1)
|
||||
if state.GetTop() != 2 {
|
||||
t.Fatalf("Expected 2 elements after Pop, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Test Remove
|
||||
state.PushNumber(99)
|
||||
state.Remove(1) // Remove the first element (nil)
|
||||
if state.GetTop() != 2 {
|
||||
t.Fatalf("Expected 2 elements after Remove, got %d", state.GetTop())
|
||||
}
|
||||
|
||||
// Verify first element is now boolean
|
||||
state.Remove(1)
|
||||
if !state.IsBoolean(1) {
|
||||
t.Fatalf("Expected boolean at index 1 after Remove")
|
||||
t.Fatal("Expected boolean at index 1 after Remove")
|
||||
}
|
||||
}
|
||||
|
||||
@ -83,52 +63,33 @@ func TestTypeChecking(t *testing.T) {
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Push values of different types
|
||||
state.PushNil()
|
||||
state.PushBoolean(true)
|
||||
state.PushNumber(42)
|
||||
state.PushString("hello")
|
||||
state.NewTable()
|
||||
|
||||
// Check types with GetType
|
||||
if state.GetType(1) != luajit.TypeNil {
|
||||
t.Fatalf("Expected nil type at index 1, got %s", state.GetType(1))
|
||||
}
|
||||
if state.GetType(2) != luajit.TypeBoolean {
|
||||
t.Fatalf("Expected boolean type at index 2, got %s", state.GetType(2))
|
||||
}
|
||||
if state.GetType(3) != luajit.TypeNumber {
|
||||
t.Fatalf("Expected number type at index 3, got %s", state.GetType(3))
|
||||
}
|
||||
if state.GetType(4) != luajit.TypeString {
|
||||
t.Fatalf("Expected string type at index 4, got %s", state.GetType(4))
|
||||
}
|
||||
if state.GetType(5) != luajit.TypeTable {
|
||||
t.Fatalf("Expected table type at index 5, got %s", state.GetType(5))
|
||||
values := []struct {
|
||||
push func()
|
||||
luaType luajit.LuaType
|
||||
checkFn func(int) bool
|
||||
}{
|
||||
{state.PushNil, luajit.TypeNil, state.IsNil},
|
||||
{func() { state.PushBoolean(true) }, luajit.TypeBoolean, state.IsBoolean},
|
||||
{func() { state.PushNumber(42) }, luajit.TypeNumber, state.IsNumber},
|
||||
{func() { state.PushString("test") }, luajit.TypeString, state.IsString},
|
||||
{state.NewTable, luajit.TypeTable, state.IsTable},
|
||||
}
|
||||
|
||||
// Test individual type checking functions
|
||||
if !state.IsNil(1) {
|
||||
t.Fatalf("IsNil failed for nil value")
|
||||
}
|
||||
if !state.IsBoolean(2) {
|
||||
t.Fatalf("IsBoolean failed for boolean value")
|
||||
}
|
||||
if !state.IsNumber(3) {
|
||||
t.Fatalf("IsNumber failed for number value")
|
||||
}
|
||||
if !state.IsString(4) {
|
||||
t.Fatalf("IsString failed for string value")
|
||||
}
|
||||
if !state.IsTable(5) {
|
||||
t.Fatalf("IsTable failed for table value")
|
||||
for i, v := range values {
|
||||
v.push()
|
||||
idx := i + 1
|
||||
if state.GetType(idx) != v.luaType {
|
||||
t.Fatalf("Type mismatch at %d: expected %s, got %s", idx, v.luaType, state.GetType(idx))
|
||||
}
|
||||
if !v.checkFn(idx) {
|
||||
t.Fatalf("Type check failed at %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Function test
|
||||
state.DoString("function test() return true end")
|
||||
state.GetGlobal("test")
|
||||
if !state.IsFunction(-1) {
|
||||
t.Fatalf("IsFunction failed for function value")
|
||||
t.Fatal("IsFunction failed")
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,20 +100,18 @@ func TestValueConversion(t *testing.T) {
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Push values
|
||||
state.PushBoolean(true)
|
||||
state.PushNumber(42.5)
|
||||
state.PushString("hello")
|
||||
|
||||
// Test conversion
|
||||
if !state.ToBoolean(1) {
|
||||
t.Fatalf("ToBoolean failed")
|
||||
t.Fatal("ToBoolean failed")
|
||||
}
|
||||
if state.ToNumber(2) != 42.5 {
|
||||
t.Fatalf("ToNumber failed, expected 42.5, got %f", state.ToNumber(2))
|
||||
t.Fatalf("ToNumber failed: expected 42.5, got %f", state.ToNumber(2))
|
||||
}
|
||||
if state.ToString(3) != "hello" {
|
||||
t.Fatalf("ToString failed, expected 'hello', got '%s'", state.ToString(3))
|
||||
t.Fatalf("ToString failed: expected 'hello', got '%s'", state.ToString(3))
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,46 +122,34 @@ func TestTableOperations(t *testing.T) {
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test CreateTable
|
||||
state.CreateTable(0, 3)
|
||||
|
||||
// Add fields using SetField
|
||||
// Set fields
|
||||
state.PushNumber(42)
|
||||
state.SetField(-2, "answer")
|
||||
|
||||
state.PushString("hello")
|
||||
state.SetField(-2, "greeting")
|
||||
|
||||
state.PushBoolean(true)
|
||||
state.SetField(-2, "flag")
|
||||
|
||||
// Test GetField
|
||||
// Get fields
|
||||
state.GetField(-1, "answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("GetField for 'answer' failed")
|
||||
t.Fatal("GetField failed for 'answer'")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(-1, "greeting")
|
||||
if state.ToString(-1) != "hello" {
|
||||
t.Fatalf("GetField for 'greeting' failed")
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test Next for iteration
|
||||
state.PushNil() // Start iteration
|
||||
// Test iteration
|
||||
state.PushNil()
|
||||
count := 0
|
||||
for state.Next(-2) {
|
||||
count++
|
||||
state.Pop(1) // Pop value, leave key for next iteration
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
if count != 3 {
|
||||
t.Fatalf("Expected 3 table entries, found %d", count)
|
||||
t.Fatalf("Expected 3 entries, found %d", count)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
state.Pop(1) // Pop the table
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestGlobalOperations(t *testing.T) {
|
||||
@ -212,21 +159,18 @@ func TestGlobalOperations(t *testing.T) {
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Set a global value
|
||||
state.PushNumber(42)
|
||||
state.SetGlobal("answer")
|
||||
|
||||
// Get the global value
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("GetGlobal failed, expected 42, got %f", state.ToNumber(-1))
|
||||
t.Fatalf("GetGlobal failed: expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test non-existent global (should be nil)
|
||||
state.GetGlobal("nonexistent")
|
||||
if !state.IsNil(-1) {
|
||||
t.Fatalf("Expected nil for non-existent global")
|
||||
t.Fatal("Expected nil for non-existent global")
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
@ -238,18 +182,15 @@ func TestCodeExecution(t *testing.T) {
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test LoadString
|
||||
// Test LoadString and Call
|
||||
if err := state.LoadString("return 42"); err != nil {
|
||||
t.Fatalf("LoadString failed: %v", err)
|
||||
}
|
||||
|
||||
// Test Call
|
||||
if err := state.Call(0, 1); err != nil {
|
||||
t.Fatalf("Call failed: %v", err)
|
||||
}
|
||||
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Call result incorrect, expected 42, got %f", state.ToNumber(-1))
|
||||
t.Fatalf("Call result incorrect: expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
@ -257,10 +198,9 @@ func TestCodeExecution(t *testing.T) {
|
||||
if err := state.DoString("answer = 42 + 1"); err != nil {
|
||||
t.Fatalf("DoString failed: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 43 {
|
||||
t.Fatalf("DoString execution incorrect, expected 43, got %f", state.ToNumber(-1))
|
||||
t.Fatalf("DoString result incorrect: expected 43, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
@ -269,13 +209,11 @@ func TestCodeExecution(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
|
||||
if nresults != 3 {
|
||||
t.Fatalf("Execute returned %d results, expected 3", nresults)
|
||||
}
|
||||
|
||||
if state.ToNumber(-3) != 5 || state.ToNumber(-2) != 10 || state.ToNumber(-1) != 15 {
|
||||
t.Fatalf("Execute results incorrect")
|
||||
t.Fatal("Execute results incorrect")
|
||||
}
|
||||
state.Pop(3)
|
||||
|
||||
@ -284,26 +222,24 @@ func TestCodeExecution(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteWithResult failed: %v", err)
|
||||
}
|
||||
|
||||
if result != "hello" {
|
||||
t.Fatalf("ExecuteWithResult returned %v, expected 'hello'", result)
|
||||
}
|
||||
|
||||
// Test error handling
|
||||
err = state.DoString("this is not valid lua code")
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for invalid code, got nil")
|
||||
if err := state.DoString("invalid lua code"); err == nil {
|
||||
t.Fatal("Expected error for invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoFile(t *testing.T) {
|
||||
func TestFileOperations(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Create a temporary Lua file
|
||||
// Create temp file
|
||||
content := []byte("answer = 42")
|
||||
tmpfile, err := os.CreateTemp("", "test-*.lua")
|
||||
if err != nil {
|
||||
@ -312,40 +248,17 @@ func TestDoFile(t *testing.T) {
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(content); err != nil {
|
||||
t.Fatalf("Failed to write to temp file: %v", err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatalf("Failed to close temp file: %v", err)
|
||||
}
|
||||
|
||||
// Test LoadFile and DoFile
|
||||
if err := state.LoadFile(tmpfile.Name()); err != nil {
|
||||
t.Fatalf("LoadFile failed: %v", err)
|
||||
}
|
||||
|
||||
if err := state.Call(0, 0); err != nil {
|
||||
t.Fatalf("Call failed after LoadFile: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Incorrect result after LoadFile, expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Reset global
|
||||
if err := state.DoString("answer = nil"); err != nil {
|
||||
t.Fatalf("Failed to reset answer: %v", err)
|
||||
t.Fatalf("Failed to write temp file: %v", err)
|
||||
}
|
||||
tmpfile.Close()
|
||||
|
||||
// Test DoFile
|
||||
if err := state.DoFile(tmpfile.Name()); err != nil {
|
||||
t.Fatalf("DoFile failed: %v", err)
|
||||
}
|
||||
|
||||
state.GetGlobal("answer")
|
||||
if state.ToNumber(-1) != 42 {
|
||||
t.Fatalf("Incorrect result after DoFile, expected 42, got %f", state.ToNumber(-1))
|
||||
t.Fatalf("DoFile result incorrect: expected 42, got %f", state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
@ -357,7 +270,6 @@ func TestPackagePath(t *testing.T) {
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test SetPackagePath
|
||||
testPath := "/test/path/?.lua"
|
||||
if err := state.SetPackagePath(testPath); err != nil {
|
||||
t.Fatalf("SetPackagePath failed: %v", err)
|
||||
@ -367,12 +279,10 @@ func TestPackagePath(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get package.path: %v", err)
|
||||
}
|
||||
|
||||
if result != testPath {
|
||||
t.Fatalf("Expected package.path to be '%s', got '%s'", testPath, result)
|
||||
t.Fatalf("SetPackagePath failed: expected '%s', got '%s'", testPath, result)
|
||||
}
|
||||
|
||||
// Test AddPackagePath
|
||||
addPath := "/another/path/?.lua"
|
||||
if err := state.AddPackagePath(addPath); err != nil {
|
||||
t.Fatalf("AddPackagePath failed: %v", err)
|
||||
@ -382,92 +292,134 @@ func TestPackagePath(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get package.path: %v", err)
|
||||
}
|
||||
|
||||
expected := testPath + ";" + addPath
|
||||
if result != expected {
|
||||
t.Fatalf("Expected package.path to be '%s', got '%s'", expected, result)
|
||||
t.Fatalf("AddPackagePath failed: expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushValueAndToValue(t *testing.T) {
|
||||
func TestEnhancedTypes(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test typed arrays
|
||||
testCases := []struct {
|
||||
value any
|
||||
input any
|
||||
expected any
|
||||
}{
|
||||
{nil},
|
||||
{true},
|
||||
{false},
|
||||
{42},
|
||||
{42.5},
|
||||
{"hello"},
|
||||
{[]float64{1, 2, 3, 4, 5}},
|
||||
{[]any{1, "test", true}},
|
||||
{map[string]any{"a": 1, "b": "test", "c": true}},
|
||||
// Primitive types
|
||||
{nil, nil},
|
||||
{true, true},
|
||||
{42, 42}, // Should preserve as int
|
||||
{42.5, 42.5}, // Should be float64
|
||||
{"hello", "hello"},
|
||||
|
||||
// Typed arrays
|
||||
{[]int{1, 2, 3}, []int{1, 2, 3}},
|
||||
{[]string{"a", "b"}, []string{"a", "b"}},
|
||||
{[]bool{true, false}, []bool{true, false}},
|
||||
{[]float64{1.1, 2.2}, []float64{1.1, 2.2}},
|
||||
|
||||
// Typed maps
|
||||
{map[string]string{"name": "John"}, map[string]string{"name": "John"}},
|
||||
{map[string]int{"age": 25}, map[string]int{"age": 25}},
|
||||
{map[int]any{10: "first", 20: 42}, map[string]any{"10": "first", "20": 42}},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
// Push value
|
||||
err := state.PushValue(tc.value)
|
||||
if err != nil {
|
||||
t.Fatalf("PushValue failed for testCase %d: %v", i, err)
|
||||
// Push and retrieve value
|
||||
if err := state.PushValue(tc.input); err != nil {
|
||||
t.Fatalf("Case %d: PushValue failed: %v", i, err)
|
||||
}
|
||||
|
||||
// Check stack
|
||||
if state.GetTop() != i+1 {
|
||||
t.Fatalf("Stack size incorrect after push, expected %d, got %d", i+1, state.GetTop())
|
||||
result, err := state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Case %d: ToValue failed: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result, tc.expected) {
|
||||
t.Fatalf("Case %d: expected %v (%T), got %v (%T)",
|
||||
i, tc.expected, tc.expected, result, result)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// Test conversion back to Go
|
||||
for i := range testCases {
|
||||
index := len(testCases) - i
|
||||
value, err := state.ToValue(index)
|
||||
if err != nil {
|
||||
t.Fatalf("ToValue failed for index %d: %v", index, err)
|
||||
}
|
||||
|
||||
// For tables, we need special handling due to how Go types are stored
|
||||
switch expected := testCases[index-1].value.(type) {
|
||||
case []float64:
|
||||
// Arrays come back as map[string]any with empty key
|
||||
if m, ok := value.(map[string]any); ok {
|
||||
if arr, ok := m[""].([]float64); ok {
|
||||
if !reflect.DeepEqual(arr, expected) {
|
||||
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v", index-1, expected, arr)
|
||||
}
|
||||
} else {
|
||||
t.Fatalf("Invalid array conversion for testCase %d", index-1)
|
||||
}
|
||||
} else {
|
||||
t.Fatalf("Expected map for array value in testCase %d, got %T", index-1, value)
|
||||
}
|
||||
case int:
|
||||
if num, ok := value.(float64); ok {
|
||||
if float64(expected) == num {
|
||||
continue // Values match after type conversion
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
// Skip detailed comparison for mixed arrays
|
||||
case map[string]any:
|
||||
// Skip detailed comparison for maps
|
||||
default:
|
||||
if !reflect.DeepEqual(value, testCases[index-1].value) {
|
||||
t.Fatalf("Value mismatch for testCase %d: expected %v, got %v",
|
||||
index-1, testCases[index-1].value, value)
|
||||
}
|
||||
}
|
||||
// Test mixed array (should become []any)
|
||||
state.DoString("mixed = {1, 'hello', true}")
|
||||
state.GetGlobal("mixed")
|
||||
result, err := state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Mixed array conversion failed: %v", err)
|
||||
}
|
||||
if _, ok := result.([]any); !ok {
|
||||
t.Fatalf("Expected []any for mixed array, got %T", result)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test mixed map (should become map[string]any)
|
||||
state.DoString("mixedMap = {name='John', age=25, active=true}")
|
||||
state.GetGlobal("mixedMap")
|
||||
result, err = state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Mixed map conversion failed: %v", err)
|
||||
}
|
||||
if _, ok := result.(map[string]any); !ok {
|
||||
t.Fatalf("Expected map[string]any for mixed map, got %T", result)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestIntegerPreservation(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test that integers are preserved
|
||||
state.DoString("num = 42")
|
||||
state.GetGlobal("num")
|
||||
result, err := state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Integer conversion failed: %v", err)
|
||||
}
|
||||
if val, ok := result.(int); !ok || val != 42 {
|
||||
t.Fatalf("Expected int 42, got %T %v", result, result)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Test that floats remain floats
|
||||
state.DoString("fnum = 42.5")
|
||||
state.GetGlobal("fnum")
|
||||
result, err = state.ToValue(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("Float conversion failed: %v", err)
|
||||
}
|
||||
if val, ok := result.(float64); !ok || val != 42.5 {
|
||||
t.Fatalf("Expected float64 42.5, got %T %v", result, result)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Test unsupported type
|
||||
complex := complex(1, 2)
|
||||
err := state.PushValue(complex)
|
||||
type customStruct struct{ Field int }
|
||||
if err := state.PushValue(customStruct{Field: 42}); err == nil {
|
||||
t.Fatal("Expected error for unsupported type")
|
||||
}
|
||||
|
||||
// Test invalid stack index
|
||||
_, err := state.ToValue(100)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for unsupported type")
|
||||
t.Fatal("Expected error for invalid index")
|
||||
}
|
||||
}
|
||||
|
338
types.go
338
types.go
@ -13,7 +13,6 @@ import (
|
||||
type LuaType int
|
||||
|
||||
const (
|
||||
// These constants match lua.h's LUA_T* values
|
||||
TypeNone LuaType = -1
|
||||
TypeNil LuaType = 0
|
||||
TypeBoolean LuaType = 1
|
||||
@ -26,7 +25,6 @@ const (
|
||||
TypeThread LuaType = 8
|
||||
)
|
||||
|
||||
// String returns the string representation of the Lua type
|
||||
func (t LuaType) String() string {
|
||||
switch t {
|
||||
case TypeNone:
|
||||
@ -54,92 +52,309 @@ func (t LuaType) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertValue converts a value to the requested type with proper type conversion
|
||||
// ConvertValue converts a value to the requested type with comprehensive type conversion
|
||||
func ConvertValue[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
|
||||
// Handle nil case
|
||||
if value == nil {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Try direct type assertion first
|
||||
if result, ok := value.(T); ok {
|
||||
return result, true
|
||||
}
|
||||
|
||||
// Type-specific conversions
|
||||
switch any(zero).(type) {
|
||||
case string:
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return any(fmt.Sprintf("%g", v)).(T), true
|
||||
case int:
|
||||
return any(strconv.Itoa(v)).(T), true
|
||||
case bool:
|
||||
if v {
|
||||
return any("true").(T), true
|
||||
}
|
||||
return any("false").(T), true
|
||||
}
|
||||
return convertToString[T](value)
|
||||
case int:
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return any(int(v)).(T), true
|
||||
case string:
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return any(i).(T), true
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return any(1).(T), true
|
||||
}
|
||||
return any(0).(T), true
|
||||
}
|
||||
return convertToInt[T](value)
|
||||
case float64:
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return any(float64(v)).(T), true
|
||||
case string:
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return any(f).(T), true
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return any(1.0).(T), true
|
||||
}
|
||||
return any(0.0).(T), true
|
||||
}
|
||||
return convertToFloat[T](value)
|
||||
case bool:
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "true", "yes", "1":
|
||||
return any(true).(T), true
|
||||
case "false", "no", "0":
|
||||
return any(false).(T), true
|
||||
}
|
||||
case int:
|
||||
return any(v != 0).(T), true
|
||||
case float64:
|
||||
return any(v != 0).(T), true
|
||||
}
|
||||
return convertToBool[T](value)
|
||||
case []int:
|
||||
return convertToIntSlice[T](value)
|
||||
case []string:
|
||||
return convertToStringSlice[T](value)
|
||||
case []bool:
|
||||
return convertToBoolSlice[T](value)
|
||||
case []float64:
|
||||
return convertToFloatSlice[T](value)
|
||||
case []any:
|
||||
return convertToAnySlice[T](value)
|
||||
case map[string]string:
|
||||
return convertToStringMap[T](value)
|
||||
case map[string]int:
|
||||
return convertToIntMap[T](value)
|
||||
case map[int]any:
|
||||
return convertToIntKeyMap[T](value)
|
||||
case map[string]any:
|
||||
return convertToAnyMap[T](value)
|
||||
}
|
||||
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToString[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
if v == float64(int(v)) {
|
||||
return any(strconv.Itoa(int(v))).(T), true
|
||||
}
|
||||
return any(fmt.Sprintf("%g", v)).(T), true
|
||||
case int:
|
||||
return any(strconv.Itoa(v)).(T), true
|
||||
case bool:
|
||||
return any(strconv.FormatBool(v)).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToInt[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return any(int(v)).(T), true
|
||||
case string:
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return any(i).(T), true
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return any(1).(T), true
|
||||
}
|
||||
return any(0).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToFloat[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return any(float64(v)).(T), true
|
||||
case string:
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return any(f).(T), true
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return any(1.0).(T), true
|
||||
}
|
||||
return any(0.0).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToBool[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "true", "yes", "1":
|
||||
return any(true).(T), true
|
||||
case "false", "no", "0":
|
||||
return any(false).(T), true
|
||||
}
|
||||
case int:
|
||||
return any(v != 0).(T), true
|
||||
case float64:
|
||||
return any(v != 0).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToIntSlice[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case []float64:
|
||||
result := make([]int, len(v))
|
||||
for i, f := range v {
|
||||
result[i] = int(f)
|
||||
}
|
||||
return any(result).(T), true
|
||||
case []any:
|
||||
result := make([]int, 0, len(v))
|
||||
for _, item := range v {
|
||||
if i, ok := ConvertValue[int](item); ok {
|
||||
result = append(result, i)
|
||||
} else {
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToStringSlice[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
if v, ok := value.([]any); ok {
|
||||
result := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := ConvertValue[string](item); ok {
|
||||
result = append(result, s)
|
||||
} else {
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToBoolSlice[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
if v, ok := value.([]any); ok {
|
||||
result := make([]bool, 0, len(v))
|
||||
for _, item := range v {
|
||||
if b, ok := ConvertValue[bool](item); ok {
|
||||
result = append(result, b)
|
||||
} else {
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToFloatSlice[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case []int:
|
||||
result := make([]float64, len(v))
|
||||
for i, n := range v {
|
||||
result[i] = float64(n)
|
||||
}
|
||||
return any(result).(T), true
|
||||
case []any:
|
||||
result := make([]float64, 0, len(v))
|
||||
for _, item := range v {
|
||||
if f, ok := ConvertValue[float64](item); ok {
|
||||
result = append(result, f)
|
||||
} else {
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToAnySlice[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case []int:
|
||||
result := make([]any, len(v))
|
||||
for i, n := range v {
|
||||
result[i] = n
|
||||
}
|
||||
return any(result).(T), true
|
||||
case []string:
|
||||
result := make([]any, len(v))
|
||||
for i, s := range v {
|
||||
result[i] = s
|
||||
}
|
||||
return any(result).(T), true
|
||||
case []bool:
|
||||
result := make([]any, len(v))
|
||||
for i, b := range v {
|
||||
result[i] = b
|
||||
}
|
||||
return any(result).(T), true
|
||||
case []float64:
|
||||
result := make([]any, len(v))
|
||||
for i, f := range v {
|
||||
result[i] = f
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToStringMap[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
if v, ok := value.(map[string]any); ok {
|
||||
result := make(map[string]string, len(v))
|
||||
for k, val := range v {
|
||||
if s, ok := ConvertValue[string](val); ok {
|
||||
result[k] = s
|
||||
} else {
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToIntMap[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
if v, ok := value.(map[string]any); ok {
|
||||
result := make(map[string]int, len(v))
|
||||
for k, val := range v {
|
||||
if i, ok := ConvertValue[int](val); ok {
|
||||
result[k] = i
|
||||
} else {
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToIntKeyMap[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
if v, ok := value.(map[string]any); ok {
|
||||
result := make(map[int]any, len(v))
|
||||
for k, val := range v {
|
||||
if i, err := strconv.Atoi(k); err == nil {
|
||||
result[i] = val
|
||||
} else {
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
func convertToAnyMap[T any](value any) (T, bool) {
|
||||
var zero T
|
||||
switch v := value.(type) {
|
||||
case map[string]string:
|
||||
result := make(map[string]any, len(v))
|
||||
for k, s := range v {
|
||||
result[k] = s
|
||||
}
|
||||
return any(result).(T), true
|
||||
case map[string]int:
|
||||
result := make(map[string]any, len(v))
|
||||
for k, i := range v {
|
||||
result[k] = i
|
||||
}
|
||||
return any(result).(T), true
|
||||
case map[int]any:
|
||||
result := make(map[string]any, len(v))
|
||||
for k, val := range v {
|
||||
result[strconv.Itoa(k)] = val
|
||||
}
|
||||
return any(result).(T), true
|
||||
}
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// GetTypedValue gets a value from the state with type conversion
|
||||
func GetTypedValue[T any](s *State, index int) (T, bool) {
|
||||
var zero T
|
||||
|
||||
// Get the value as any type
|
||||
value, err := s.ToValue(index)
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Convert it to the requested type
|
||||
return ConvertValue[T](value)
|
||||
}
|
||||
|
||||
@ -147,6 +362,5 @@ func GetTypedValue[T any](s *State, index int) (T, bool) {
|
||||
func GetGlobalTyped[T any](s *State, name string) (T, bool) {
|
||||
s.GetGlobal(name)
|
||||
defer s.Pop(1)
|
||||
|
||||
return GetTypedValue[T](s, -1)
|
||||
}
|
||||
|
59
validation.go
Normal file
59
validation.go
Normal file
@ -0,0 +1,59 @@
|
||||
package luajit
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ArgSpec defines an argument specification for validation
|
||||
type ArgSpec struct {
|
||||
Name string
|
||||
Type string
|
||||
Required bool
|
||||
Check func(*State, int) bool
|
||||
}
|
||||
|
||||
// Common argument checkers
|
||||
var (
|
||||
CheckString = func(s *State, i int) bool { return s.IsString(i) }
|
||||
CheckNumber = func(s *State, i int) bool { return s.IsNumber(i) }
|
||||
CheckBool = func(s *State, i int) bool { return s.IsBoolean(i) }
|
||||
CheckTable = func(s *State, i int) bool { return s.IsTable(i) }
|
||||
CheckFunc = func(s *State, i int) bool { return s.IsFunction(i) }
|
||||
CheckAny = func(s *State, i int) bool { return true }
|
||||
)
|
||||
|
||||
// CheckArgs validates function arguments against specifications
|
||||
func (s *State) CheckArgs(specs ...ArgSpec) error {
|
||||
for i, spec := range specs {
|
||||
argIdx := i + 1
|
||||
if argIdx > s.GetTop() {
|
||||
if spec.Required {
|
||||
return fmt.Errorf("missing argument %d: %s", argIdx, spec.Name)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if s.IsNil(argIdx) && !spec.Required {
|
||||
continue
|
||||
}
|
||||
|
||||
if !spec.Check(s, argIdx) {
|
||||
return fmt.Errorf("argument %d (%s) must be %s", argIdx, spec.Name, spec.Type)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckMinArgs checks for minimum number of arguments
|
||||
func (s *State) CheckMinArgs(min int) error {
|
||||
if s.GetTop() < min {
|
||||
return fmt.Errorf("expected at least %d arguments, got %d", min, s.GetTop())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckExactArgs checks for exact number of arguments
|
||||
func (s *State) CheckExactArgs(count int) error {
|
||||
if s.GetTop() != count {
|
||||
return fmt.Errorf("expected exactly %d arguments, got %d", count, s.GetTop())
|
||||
}
|
||||
return nil
|
||||
}
|
806
wrapper.go
806
wrapper.go
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user