diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/ljtg.iml b/.idea/ljtg.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/ljtg.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml new file mode 100644 index 0000000..c84a72c --- /dev/null +++ b/.idea/material_theme_project_new.xml @@ -0,0 +1,10 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..185b2cc --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..b56ca89 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/functions.go b/functions.go new file mode 100644 index 0000000..c9e297e --- /dev/null +++ b/functions.go @@ -0,0 +1,98 @@ +package luajit + +/* +#include +#include +#include + +extern int goFunctionWrapper(lua_State* L); + +static int get_upvalue_index(int i) { + return -10002 - i; // LUA_GLOBALSINDEX - i +} +*/ +import "C" +import ( + "fmt" + "sync" + "unsafe" +) + +type GoFunction func(*State) int + +var ( + functionRegistry = struct { + sync.RWMutex + funcs map[unsafe.Pointer]GoFunction + }{ + funcs: make(map[unsafe.Pointer]GoFunction), + } +) + +//export goFunctionWrapper +func goFunctionWrapper(L *C.lua_State) C.int { + state := &State{L: L, safeStack: true} + + // Get upvalue using standard Lua 5.1 macro + ptr := C.lua_touserdata(L, C.get_upvalue_index(1)) + if ptr == nil { + state.PushString("error: function not found") + return -1 + } + + functionRegistry.RLock() + fn, ok := functionRegistry.funcs[ptr] + functionRegistry.RUnlock() + + if !ok { + state.PushString("error: function not found in registry") + return -1 + } + + result := fn(state) + return C.int(result) +} + +func (s *State) PushGoFunction(fn GoFunction) error { + // Push lightuserdata as upvalue and create closure + ptr := C.malloc(1) + if ptr == nil { + return fmt.Errorf("failed to allocate memory for function pointer") + } + + functionRegistry.Lock() + functionRegistry.funcs[ptr] = fn + functionRegistry.Unlock() + + C.lua_pushlightuserdata(s.L, ptr) + C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1) + return nil +} + +func (s *State) RegisterGoFunction(name string, fn GoFunction) error { + if err := s.PushGoFunction(fn); err != nil { + return err + } + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname) + return nil +} + +func (s *State) UnregisterGoFunction(name string) { + s.PushNil() + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname) +} + +func (s *State) Cleanup() { + functionRegistry.Lock() + defer functionRegistry.Unlock() + + for ptr := range functionRegistry.funcs { + C.free(ptr) + } + functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction) +} diff --git a/functions_test.go b/functions_test.go new file mode 100644 index 0000000..1a18bad --- /dev/null +++ b/functions_test.go @@ -0,0 +1,109 @@ +package luajit + +import "testing" + +func TestGoFunctions(t *testing.T) { + for _, f := range factories { + t.Run(f.name, func(t *testing.T) { + L := f.new() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + defer L.Cleanup() + + addFunc := func(s *State) int { + s.PushNumber(s.ToNumber(1) + s.ToNumber(2)) + return 1 + } + + if err := L.RegisterGoFunction("add", addFunc); err != nil { + t.Fatalf("Failed to register function: %v", err) + } + + // Test basic function call + if err := L.DoString("result = add(40, 2)"); err != nil { + t.Fatalf("Failed to call function: %v", err) + } + + L.GetGlobal("result") + if result := L.ToNumber(-1); result != 42 { + t.Errorf("got %v, want 42", result) + } + L.Pop(1) + + // Test multiple return values + multiFunc := func(s *State) int { + s.PushString("hello") + s.PushNumber(42) + s.PushBoolean(true) + return 3 + } + + if err := L.RegisterGoFunction("multi", multiFunc); err != nil { + t.Fatalf("Failed to register multi function: %v", err) + } + + code := ` + a, b, c = multi() + result = (a == "hello" and b == 42 and c == true) + ` + + if err := L.DoString(code); err != nil { + t.Fatalf("Failed to call multi function: %v", err) + } + + L.GetGlobal("result") + if !L.ToBoolean(-1) { + t.Error("Multiple return values test failed") + } + L.Pop(1) + + // Test error handling + errFunc := func(s *State) int { + s.PushString("test error") + return -1 + } + + if err := L.RegisterGoFunction("err", errFunc); err != nil { + t.Fatalf("Failed to register error function: %v", err) + } + + if err := L.DoString("err()"); err == nil { + t.Error("Expected error from error function") + } + + // Test unregistering + L.UnregisterGoFunction("add") + if err := L.DoString("add(1, 2)"); err == nil { + t.Error("Expected error calling unregistered function") + } + }) + } +} + +func TestStackSafety(t *testing.T) { + L := NewSafe() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + defer L.Cleanup() + + // Test stack overflow protection + overflowFunc := func(s *State) int { + for i := 0; i < 100; i++ { + s.PushNumber(float64(i)) + } + s.PushString("done") + return 101 + } + + if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil { + t.Fatal(err) + } + + if err := L.DoString("overflow()"); err != nil { + t.Logf("Got expected error: %v", err) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b557b7f --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.sharkk.net/Sky/LuaJIT-to-Go + +go 1.23.4 diff --git a/luajit b/luajit new file mode 160000 index 0000000..e4fd777 --- /dev/null +++ b/luajit @@ -0,0 +1 @@ +Subproject commit e4fd777d6ad41d338125b095abc98e4dd54c05d7 diff --git a/stack.go b/stack.go new file mode 100644 index 0000000..5aa3cb8 --- /dev/null +++ b/stack.go @@ -0,0 +1,146 @@ +package luajit + +/* +#include +#include +*/ +import "C" +import "fmt" + +// LuaError represents an error from the Lua state +type LuaError struct { + Code int + Message string +} + +func (e *LuaError) Error() string { + return fmt.Sprintf("lua error (code=%d): %s", e.Code, e.Message) +} + +// Stack management constants from lua.h +const ( + LUA_MINSTACK = 20 // Minimum Lua stack size + LUA_MAXSTACK = 1000000 // Maximum Lua stack size + LUA_REGISTRYINDEX = -10000 // Pseudo-index for the Lua registry + LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table +) + +// checkStack ensures there is enough space on the Lua stack +func (s *State) checkStack(n int) error { + if C.lua_checkstack(s.L, C.int(n)) == 0 { + return fmt.Errorf("stack overflow (cannot allocate %d slots)", n) + } + return nil +} + +// safeCall wraps a potentially dangerous C call with stack checking +func (s *State) safeCall(f func() C.int) error { + // Save current stack size + top := s.GetTop() + + // Ensure we have enough stack space (minimum 20 slots as per Lua standard) + if err := s.checkStack(LUA_MINSTACK); err != nil { + return err + } + + // Make the call + status := f() + + // Check for errors + if status != 0 { + err := &LuaError{ + Code: int(status), + Message: s.ToString(-1), + } + s.Pop(1) // Remove error message + return err + } + + // Verify stack integrity + newTop := s.GetTop() + if newTop < top { + return fmt.Errorf("stack underflow: %d slots lost", top-newTop) + } + + return nil +} + +// stackGuard wraps a function with stack checking and restoration +func stackGuard[T any](s *State, f func() (T, error)) (T, error) { + // Save current stack size + top := s.GetTop() + + // Run the protected function + result, err := f() + + // Restore stack size + newTop := s.GetTop() + if newTop > top { + s.Pop(newTop - top) + } + + return result, err +} + +// stackGuardValue executes a function that returns a value and error with stack protection +func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) { + // Save current stack size + top := s.GetTop() + + // Run the protected function + result, err := f() + + // Restore stack size + newTop := s.GetTop() + if newTop > top { + s.Pop(newTop - top) + } + + return result, err +} + +// stackGuardErr executes a function that only returns an error with stack protection +func stackGuardErr(s *State, f func() error) error { + // Save current stack size + top := s.GetTop() + + // Run the protected function + err := f() + + // Restore stack size + newTop := s.GetTop() + if newTop > top { + s.Pop(newTop - top) + } + + return err +} + +// getStackTrace returns the current Lua stack trace +func (s *State) getStackTrace() string { + // Push debug.traceback function + s.GetGlobal("debug") + if !s.IsTable(-1) { + s.Pop(1) + return "stack trace not available (debug module not loaded)" + } + + s.GetField(-1, "traceback") + if !s.IsFunction(-1) { + s.Pop(2) + return "stack trace not available (debug.traceback not found)" + } + + // Call debug.traceback + if err := s.safeCall(func() C.int { + return C.lua_pcall(s.L, 0, 1, 0) + }); err != nil { + return fmt.Sprintf("error getting stack trace: %v", err) + } + + // Get the resulting string + trace := s.ToString(-1) + s.Pop(1) // Remove the trace string + + return trace +} diff --git a/table.go b/table.go new file mode 100644 index 0000000..9f74265 --- /dev/null +++ b/table.go @@ -0,0 +1,177 @@ +package luajit + +/* +#include +#include +#include +#include + +static int get_table_length(lua_State *L, int index) { + return lua_objlen(L, index); +} +*/ +import "C" +import ( + "fmt" +) + +// TableValue represents any value that can be stored in a Lua table +type TableValue interface { + ~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{} +} + +func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) } + +// ToTable converts a Lua table to a Go map +func (s *State) ToTable(index int) (map[string]interface{}, error) { + if s.safeStack { + return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) { + if !s.IsTable(index) { + return nil, fmt.Errorf("not a table at index %d", index) + } + return s.toTableUnsafe(index) + }) + } + if !s.IsTable(index) { + return nil, fmt.Errorf("not a table at index %d", index) + } + return s.toTableUnsafe(index) +} + +func (s *State) pushTableSafe(table map[string]interface{}) error { + size := 2 + if err := s.checkStack(size); err != nil { + return fmt.Errorf("insufficient stack space: %w", err) + } + + s.NewTable() + for k, v := range table { + if err := s.pushValueSafe(v); err != nil { + return err + } + s.SetField(-2, k) + } + return nil +} + +func (s *State) pushTableUnsafe(table map[string]interface{}) error { + s.NewTable() + for k, v := range table { + if err := s.pushValueUnsafe(v); err != nil { + return err + } + s.SetField(-2, k) + } + return nil +} + +func (s *State) toTableSafe(index int) (map[string]interface{}, error) { + if err := s.checkStack(2); err != nil { + return nil, err + } + return s.toTableUnsafe(index) +} + +func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) { + absIdx := s.absIndex(index) + table := make(map[string]interface{}) + + // Check if it's an array-like table + length := s.GetTableLength(absIdx) + if length > 0 { + array := make([]float64, length) + isArray := true + + // Try to convert to array + for i := 1; i <= length; i++ { + s.PushNumber(float64(i)) + s.GetTable(absIdx) + if s.GetType(-1) != TypeNumber { + isArray = false + s.Pop(1) + break + } + array[i-1] = s.ToNumber(-1) + s.Pop(1) + } + + if isArray { + return map[string]interface{}{"": array}, nil + } + } + + // Handle regular table + s.PushNil() + for C.lua_next(s.L, C.int(absIdx)) != 0 { + key := "" + valueType := C.lua_type(s.L, -2) + if valueType == C.LUA_TSTRING { + key = s.ToString(-2) + } else if valueType == C.LUA_TNUMBER { + key = fmt.Sprintf("%g", s.ToNumber(-2)) + } + + value, err := s.toValueUnsafe(-1) + if err != nil { + s.Pop(1) + return nil, err + } + + // Handle nested array case + if m, ok := value.(map[string]interface{}); ok { + if arr, ok := m[""]; ok { + value = arr + } + } + + table[key] = value + s.Pop(1) + } + + return table, nil +} + +// NewTable creates a new table and pushes it onto the stack +func (s *State) NewTable() { + if s.safeStack { + if err := s.checkStack(1); err != nil { + // Since we can't return an error, we'll push nil instead + s.PushNil() + return + } + } + C.lua_createtable(s.L, 0, 0) +} + +// SetTable sets a table field with cached absolute index +func (s *State) SetTable(index int) { + absIdx := index + if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { + absIdx = s.GetTop() + index + 1 + } + C.lua_settable(s.L, C.int(absIdx)) +} + +// GetTable gets a table field with cached absolute index +func (s *State) GetTable(index int) { + absIdx := index + if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { + absIdx = s.GetTop() + index + 1 + } + + if s.safeStack { + if err := s.checkStack(1); err != nil { + s.PushNil() + return + } + } + C.lua_gettable(s.L, C.int(absIdx)) +} + +// PushTable pushes a Go map onto the Lua stack as a table with stack checking +func (s *State) PushTable(table map[string]interface{}) error { + if s.safeStack { + return s.pushTableSafe(table) + } + return s.pushTableUnsafe(table) +} diff --git a/table_test.go b/table_test.go new file mode 100644 index 0000000..70995fc --- /dev/null +++ b/table_test.go @@ -0,0 +1,97 @@ +package luajit + +import ( + "math" + "testing" +) + +func TestTableOperations(t *testing.T) { + tests := []struct { + name string + data map[string]interface{} + }{ + { + name: "empty", + data: map[string]interface{}{}, + }, + { + name: "primitives", + data: map[string]interface{}{ + "str": "hello", + "num": 42.0, + "bool": true, + "array": []float64{1.1, 2.2, 3.3}, + }, + }, + { + name: "nested", + data: map[string]interface{}{ + "nested": map[string]interface{}{ + "value": 123.0, + "array": []float64{4.4, 5.5}, + }, + }, + }, + } + + for _, f := range factories { + for _, tt := range tests { + t.Run(f.name+"/"+tt.name, func(t *testing.T) { + L := f.new() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + if err := L.PushTable(tt.data); err != nil { + t.Fatalf("PushTable() error = %v", err) + } + + got, err := L.ToTable(-1) + if err != nil { + t.Fatalf("ToTable() error = %v", err) + } + + if !tablesEqual(got, tt.data) { + t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data) + } + }) + } + } +} + +func tablesEqual(a, b map[string]interface{}) bool { + if len(a) != len(b) { + return false + } + + for k, v1 := range a { + v2, ok := b[k] + if !ok { + return false + } + + switch v1 := v1.(type) { + case map[string]interface{}: + v2, ok := v2.(map[string]interface{}) + if !ok || !tablesEqual(v1, v2) { + return false + } + case []float64: + v2, ok := v2.([]float64) + if !ok || len(v1) != len(v2) { + return false + } + for i := range v1 { + if math.Abs(v1[i]-v2[i]) > 1e-10 { + return false + } + } + default: + if v1 != v2 { + return false + } + } + } + return true +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..a778dba --- /dev/null +++ b/types.go @@ -0,0 +1,51 @@ +package luajit + +/* +#include +*/ +import "C" + +// LuaType represents Lua value types +type LuaType int + +const ( + // These constants must match lua.h's LUA_T* values + TypeNone LuaType = -1 + TypeNil LuaType = 0 + TypeBoolean LuaType = 1 + TypeLightUserData LuaType = 2 + TypeNumber LuaType = 3 + TypeString LuaType = 4 + TypeTable LuaType = 5 + TypeFunction LuaType = 6 + TypeUserData LuaType = 7 + TypeThread LuaType = 8 +) + +// String returns the string representation of the Lua type +func (t LuaType) String() string { + switch t { + case TypeNone: + return "none" + case TypeNil: + return "nil" + case TypeBoolean: + return "boolean" + case TypeLightUserData: + return "lightuserdata" + case TypeNumber: + return "number" + case TypeString: + return "string" + case TypeTable: + return "table" + case TypeFunction: + return "function" + case TypeUserData: + return "userdata" + case TypeThread: + return "thread" + default: + return "unknown" + } +} diff --git a/wrapper.go b/wrapper.go new file mode 100644 index 0000000..2440847 --- /dev/null +++ b/wrapper.go @@ -0,0 +1,325 @@ +package luajit + +/* +#cgo CFLAGS: -I${SRCDIR}/luajit +#cgo windows LDFLAGS: -L${SRCDIR}/luajit -llua51 +#cgo !windows LDFLAGS: -L${SRCDIR}/luajit -lluajit + +#include +#include +#include +#include + +void init_dll_paths(void); + +static int do_string(lua_State *L, const char *s) { + int status = luaL_loadstring(L, s); + if (status) return status; + return lua_pcall(L, 0, LUA_MULTRET, 0); +} + + +static int do_file(lua_State *L, const char *filename) { + int status = luaL_loadfile(L, filename); + if (status) return status; + return lua_pcall(L, 0, LUA_MULTRET, 0); +} +*/ +import "C" +import ( + "fmt" + "path/filepath" + "unsafe" +) + +// State represents a Lua state with configurable stack safety +type State struct { + L *C.lua_State + safeStack bool +} + +// NewSafe creates a new Lua state with full stack safety guarantees +func NewSafe() *State { + L := C.luaL_newstate() + if L == nil { + return nil + } + C.luaL_openlibs(L) + return &State{L: L, safeStack: true} +} + +// New creates a new Lua state with minimal stack checking +func New() *State { + L := C.luaL_newstate() + if L == nil { + return nil + } + C.luaL_openlibs(L) + return &State{L: L, safeStack: false} +} + +// Close closes the Lua state +func (s *State) Close() { + if s.L != nil { + C.lua_close(s.L) + s.L = nil + } +} + +// DoString executes a Lua string with appropriate stack management +func (s *State) DoString(str string) error { + cstr := C.CString(str) + defer C.free(unsafe.Pointer(cstr)) + + if s.safeStack { + return stackGuardErr(s, func() error { + return s.safeCall(func() C.int { + return C.do_string(s.L, cstr) + }) + }) + } + + status := C.do_string(s.L, cstr) + if status != 0 { + return &LuaError{ + Code: int(status), + Message: s.ToString(-1), + } + } + return nil +} + +// PushValue pushes a Go value onto the stack +func (s *State) PushValue(v interface{}) error { + if s.safeStack { + return stackGuardErr(s, func() error { + if err := s.checkStack(1); err != nil { + return fmt.Errorf("pushing value: %w", err) + } + return s.pushValueUnsafe(v) + }) + } + return s.pushValueUnsafe(v) +} + +func (s *State) pushValueSafe(v interface{}) error { + if err := s.checkStack(1); err != nil { + return fmt.Errorf("pushing value: %w", err) + } + return s.pushValueUnsafe(v) +} + +func (s *State) pushValueUnsafe(v interface{}) error { + switch v := v.(type) { + case nil: + s.PushNil() + case bool: + s.PushBoolean(v) + case float64: + s.PushNumber(v) + case int: + s.PushNumber(float64(v)) + case string: + s.PushString(v) + case map[string]interface{}: + // Special case: handle array stored in map + if arr, ok := v[""].([]float64); ok { + s.NewTable() + for i, elem := range arr { + s.PushNumber(float64(i + 1)) + s.PushNumber(elem) + s.SetTable(-3) + } + return nil + } + return s.pushTableUnsafe(v) + case []float64: + s.NewTable() + for i, elem := range v { + s.PushNumber(float64(i + 1)) + s.PushNumber(elem) + s.SetTable(-3) + } + case []interface{}: + s.NewTable() + for i, elem := range v { + s.PushNumber(float64(i + 1)) + if err := s.pushValueUnsafe(elem); err != nil { + return err + } + s.SetTable(-3) + } + default: + return fmt.Errorf("unsupported type: %T", v) + } + return nil +} + +// ToValue converts a Lua value to a Go value +func (s *State) ToValue(index int) (interface{}, error) { + if s.safeStack { + return stackGuardValue[interface{}](s, func() (interface{}, error) { + return s.toValueUnsafe(index) + }) + } + return s.toValueUnsafe(index) +} + +func (s *State) toValueUnsafe(index int) (interface{}, error) { + switch s.GetType(index) { + case TypeNil: + return nil, nil + case TypeBoolean: + return s.ToBoolean(index), nil + case TypeNumber: + return s.ToNumber(index), nil + case TypeString: + return s.ToString(index), nil + case TypeTable: + if !s.IsTable(index) { + return nil, fmt.Errorf("not a table at index %d", index) + } + return s.toTableUnsafe(index) + default: + return nil, fmt.Errorf("unsupported type: %s", s.GetType(index)) + } +} + +// Simple operations remain unchanged as they don't need stack protection + +func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) } +func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction } +func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable } +func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 } +func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) } +func (s *State) ToString(index int) string { + return C.GoString(C.lua_tolstring(s.L, C.int(index), nil)) +} +func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) } +func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) } + +// Push operations + +func (s *State) PushNil() { C.lua_pushnil(s.L) } +func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) } +func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) } +func (s *State) PushString(str string) { + cstr := C.CString(str) + defer C.free(unsafe.Pointer(cstr)) + C.lua_pushstring(s.L, cstr) +} + +// Helper functions +func bool2int(b bool) int { + if b { + return 1 + } + return 0 +} + +func (s *State) absIndex(index int) int { + if index > 0 || index <= LUA_REGISTRYINDEX { + return index + } + return s.GetTop() + index + 1 +} + +// SetField sets a field in a table at the given index with cached absolute index +func (s *State) SetField(index int, key string) { + absIdx := index + if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { + absIdx = s.GetTop() + index + 1 + } + + cstr := C.CString(key) + defer C.free(unsafe.Pointer(cstr)) + C.lua_setfield(s.L, C.int(absIdx), cstr) +} + +// GetField gets a field from a table with cached absolute index +func (s *State) GetField(index int, key string) { + absIdx := index + if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { + absIdx = s.GetTop() + index + 1 + } + + if s.safeStack { + if err := s.checkStack(1); err != nil { + s.PushNil() + return + } + } + + cstr := C.CString(key) + defer C.free(unsafe.Pointer(cstr)) + C.lua_getfield(s.L, C.int(absIdx), cstr) +} + +// GetGlobal gets a global variable and pushes it onto the stack +func (s *State) GetGlobal(name string) { + if s.safeStack { + if err := s.checkStack(1); err != nil { + s.PushNil() + return + } + } + cstr := C.CString(name) + defer C.free(unsafe.Pointer(cstr)) + C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr) +} + +// SetGlobal sets a global variable from the value at the top of the stack +func (s *State) SetGlobal(name string) { + // SetGlobal doesn't need stack space checking as it pops the value + cstr := C.CString(name) + defer C.free(unsafe.Pointer(cstr)) + C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr) +} + +// Remove removes element with cached absolute index +func (s *State) Remove(index int) { + absIdx := index + if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) { + absIdx = s.GetTop() + index + 1 + } + C.lua_remove(s.L, C.int(absIdx)) +} + +// DoFile executes a Lua file with appropriate stack management +func (s *State) DoFile(filename string) error { + cfilename := C.CString(filename) + defer C.free(unsafe.Pointer(cfilename)) + + if s.safeStack { + return stackGuardErr(s, func() error { + return s.safeCall(func() C.int { + return C.do_file(s.L, cfilename) + }) + }) + } + + status := C.do_file(s.L, cfilename) + if status != 0 { + return &LuaError{ + Code: int(status), + Message: s.ToString(-1), + } + } + return nil +} + +func (s *State) SetPackagePath(path string) error { + path = filepath.ToSlash(path) + if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil { + return fmt.Errorf("setting package.path: %w", err) + } + return nil +} + +func (s *State) AddPackagePath(path string) error { + path = filepath.ToSlash(path) + if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil { + return fmt.Errorf("adding to package.path: %w", err) + } + return nil +} diff --git a/wrapper_test.go b/wrapper_test.go new file mode 100644 index 0000000..d0f888c --- /dev/null +++ b/wrapper_test.go @@ -0,0 +1,276 @@ +package luajit + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +type stateFactory struct { + name string + new func() *State +} + +var factories = []stateFactory{ + {"unsafe", New}, + {"safe", NewSafe}, +} + +func TestNew(t *testing.T) { + for _, f := range factories { + t.Run(f.name, func(t *testing.T) { + L := f.new() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + }) + } +} + +func TestDoString(t *testing.T) { + tests := []struct { + name string + code string + wantErr bool + }{ + {"simple addition", "return 1 + 1", false}, + {"set global", "test = 42", false}, + {"syntax error", "this is not valid lua", true}, + {"runtime error", "error('test error')", true}, + } + + for _, f := range factories { + for _, tt := range tests { + t.Run(f.name+"/"+tt.name, func(t *testing.T) { + L := f.new() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + err := L.DoString(tt.code) + if (err != nil) != tt.wantErr { + t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } + } +} + +func TestPushAndGetValues(t *testing.T) { + values := []struct { + name string + push func(*State) + check func(*State) error + }{ + { + name: "string", + push: func(L *State) { L.PushString("hello") }, + check: func(L *State) error { + if got := L.ToString(-1); got != "hello" { + return fmt.Errorf("got %q, want %q", got, "hello") + } + return nil + }, + }, + { + name: "number", + push: func(L *State) { L.PushNumber(42.5) }, + check: func(L *State) error { + if got := L.ToNumber(-1); got != 42.5 { + return fmt.Errorf("got %f, want %f", got, 42.5) + } + return nil + }, + }, + { + name: "boolean", + push: func(L *State) { L.PushBoolean(true) }, + check: func(L *State) error { + if got := L.ToBoolean(-1); !got { + return fmt.Errorf("got %v, want true", got) + } + return nil + }, + }, + { + name: "nil", + push: func(L *State) { L.PushNil() }, + check: func(L *State) error { + if typ := L.GetType(-1); typ != TypeNil { + return fmt.Errorf("got type %v, want TypeNil", typ) + } + return nil + }, + }, + } + + for _, f := range factories { + for _, v := range values { + t.Run(f.name+"/"+v.name, func(t *testing.T) { + L := f.new() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + v.push(L) + if err := v.check(L); err != nil { + t.Error(err) + } + }) + } + } +} + +func TestStackManipulation(t *testing.T) { + for _, f := range factories { + t.Run(f.name, func(t *testing.T) { + L := f.new() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + // Push values + values := []string{"first", "second", "third"} + for _, v := range values { + L.PushString(v) + } + + // Check size + if top := L.GetTop(); top != len(values) { + t.Errorf("stack size = %d, want %d", top, len(values)) + } + + // Pop one value + L.Pop(1) + + // Check new top + if str := L.ToString(-1); str != "second" { + t.Errorf("top value = %q, want 'second'", str) + } + + // Check new size + if top := L.GetTop(); top != len(values)-1 { + t.Errorf("stack size after pop = %d, want %d", top, len(values)-1) + } + }) + } +} + +func TestGlobals(t *testing.T) { + for _, f := range factories { + t.Run(f.name, func(t *testing.T) { + L := f.new() + if L == nil { + t.Fatal("Failed to create Lua state") + } + defer L.Close() + + // Test via Lua + if err := L.DoString(`globalVar = "test"`); err != nil { + t.Fatalf("DoString error: %v", err) + } + + // Get the global + L.GetGlobal("globalVar") + if str := L.ToString(-1); str != "test" { + t.Errorf("global value = %q, want 'test'", str) + } + L.Pop(1) + + // Set and get via API + L.PushNumber(42) + L.SetGlobal("testNum") + + L.GetGlobal("testNum") + if num := L.ToNumber(-1); num != 42 { + t.Errorf("global number = %f, want 42", num) + } + }) + } +} + +func TestDoFile(t *testing.T) { + L := NewSafe() + defer L.Close() + + // Create test file + content := []byte(` + function add(a, b) + return a + b + end + result = add(40, 2) + `) + + tmpDir := t.TempDir() + filename := filepath.Join(tmpDir, "test.lua") + if err := os.WriteFile(filename, content, 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + if err := L.DoFile(filename); err != nil { + t.Fatalf("DoFile failed: %v", err) + } + + L.GetGlobal("result") + if result := L.ToNumber(-1); result != 42 { + t.Errorf("Expected result=42, got %v", result) + } +} + +func TestRequireAndPackagePath(t *testing.T) { + L := NewSafe() + defer L.Close() + + tmpDir := t.TempDir() + + // Create module file + moduleContent := []byte(` + local M = {} + function M.multiply(a, b) + return a * b + end + return M + `) + + if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil { + t.Fatalf("Failed to create module file: %v", err) + } + + // Add module path and test require + if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil { + t.Fatalf("AddPackagePath failed: %v", err) + } + + if err := L.DoString(` + local math = require("mathmod") + result = math.multiply(6, 7) + `); err != nil { + t.Fatalf("Failed to require module: %v", err) + } + + L.GetGlobal("result") + if result := L.ToNumber(-1); result != 42 { + t.Errorf("Expected result=42, got %v", result) + } +} + +func TestSetPackagePath(t *testing.T) { + L := NewSafe() + defer L.Close() + + customPath := "./custom/?.lua" + if err := L.SetPackagePath(customPath); err != nil { + t.Fatalf("SetPackagePath failed: %v", err) + } + + L.GetGlobal("package") + L.GetField(-1, "path") + if path := L.ToString(-1); path != customPath { + t.Errorf("Expected package.path=%q, got %q", customPath, path) + } +}