Implement the wrapper
This commit is contained in:
parent
2f6764aef1
commit
81501915e3
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
9
.idea/ljtg.iml
Normal file
9
.idea/ljtg.iml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="WEB_MODULE" version="4">
|
||||||
|
<component name="Go" enabled="true" />
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
10
.idea/material_theme_project_new.xml
Normal file
10
.idea/material_theme_project_new.xml
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="MaterialThemeProjectNewConfig">
|
||||||
|
<option name="metadata">
|
||||||
|
<MTProjectMetadataState>
|
||||||
|
<option name="userId" value="3fdbffe2:19499e3ad4d:-7fff" />
|
||||||
|
</MTProjectMetadataState>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/ljtg.iml" filepath="$PROJECT_DIR$/.idea/ljtg.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
8
.idea/vcs.xml
Normal file
8
.idea/vcs.xml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/luajit" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/luajit/luajit" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
98
functions.go
Normal file
98
functions.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <lua.h>
|
||||||
|
#include <lauxlib.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
extern int goFunctionWrapper(lua_State* L);
|
||||||
|
|
||||||
|
static int get_upvalue_index(int i) {
|
||||||
|
return -10002 - i; // LUA_GLOBALSINDEX - i
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GoFunction func(*State) int
|
||||||
|
|
||||||
|
var (
|
||||||
|
functionRegistry = struct {
|
||||||
|
sync.RWMutex
|
||||||
|
funcs map[unsafe.Pointer]GoFunction
|
||||||
|
}{
|
||||||
|
funcs: make(map[unsafe.Pointer]GoFunction),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
//export goFunctionWrapper
|
||||||
|
func goFunctionWrapper(L *C.lua_State) C.int {
|
||||||
|
state := &State{L: L, safeStack: true}
|
||||||
|
|
||||||
|
// Get upvalue using standard Lua 5.1 macro
|
||||||
|
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
||||||
|
if ptr == nil {
|
||||||
|
state.PushString("error: function not found")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
functionRegistry.RLock()
|
||||||
|
fn, ok := functionRegistry.funcs[ptr]
|
||||||
|
functionRegistry.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
state.PushString("error: function not found in registry")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
result := fn(state)
|
||||||
|
return C.int(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) PushGoFunction(fn GoFunction) error {
|
||||||
|
// Push lightuserdata as upvalue and create closure
|
||||||
|
ptr := C.malloc(1)
|
||||||
|
if ptr == nil {
|
||||||
|
return fmt.Errorf("failed to allocate memory for function pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
functionRegistry.Lock()
|
||||||
|
functionRegistry.funcs[ptr] = fn
|
||||||
|
functionRegistry.Unlock()
|
||||||
|
|
||||||
|
C.lua_pushlightuserdata(s.L, ptr)
|
||||||
|
C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) RegisterGoFunction(name string, fn GoFunction) error {
|
||||||
|
if err := s.PushGoFunction(fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cname := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) UnregisterGoFunction(name string) {
|
||||||
|
s.PushNil()
|
||||||
|
cname := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) Cleanup() {
|
||||||
|
functionRegistry.Lock()
|
||||||
|
defer functionRegistry.Unlock()
|
||||||
|
|
||||||
|
for ptr := range functionRegistry.funcs {
|
||||||
|
C.free(ptr)
|
||||||
|
}
|
||||||
|
functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction)
|
||||||
|
}
|
109
functions_test.go
Normal file
109
functions_test.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGoFunctions(t *testing.T) {
|
||||||
|
for _, f := range factories {
|
||||||
|
t.Run(f.name, func(t *testing.T) {
|
||||||
|
L := f.new()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
defer L.Cleanup()
|
||||||
|
|
||||||
|
addFunc := func(s *State) int {
|
||||||
|
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.RegisterGoFunction("add", addFunc); err != nil {
|
||||||
|
t.Fatalf("Failed to register function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test basic function call
|
||||||
|
if err := L.DoString("result = add(40, 2)"); err != nil {
|
||||||
|
t.Fatalf("Failed to call function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
L.GetGlobal("result")
|
||||||
|
if result := L.ToNumber(-1); result != 42 {
|
||||||
|
t.Errorf("got %v, want 42", result)
|
||||||
|
}
|
||||||
|
L.Pop(1)
|
||||||
|
|
||||||
|
// Test multiple return values
|
||||||
|
multiFunc := func(s *State) int {
|
||||||
|
s.PushString("hello")
|
||||||
|
s.PushNumber(42)
|
||||||
|
s.PushBoolean(true)
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
|
||||||
|
t.Fatalf("Failed to register multi function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := `
|
||||||
|
a, b, c = multi()
|
||||||
|
result = (a == "hello" and b == 42 and c == true)
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := L.DoString(code); err != nil {
|
||||||
|
t.Fatalf("Failed to call multi function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
L.GetGlobal("result")
|
||||||
|
if !L.ToBoolean(-1) {
|
||||||
|
t.Error("Multiple return values test failed")
|
||||||
|
}
|
||||||
|
L.Pop(1)
|
||||||
|
|
||||||
|
// Test error handling
|
||||||
|
errFunc := func(s *State) int {
|
||||||
|
s.PushString("test error")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.RegisterGoFunction("err", errFunc); err != nil {
|
||||||
|
t.Fatalf("Failed to register error function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.DoString("err()"); err == nil {
|
||||||
|
t.Error("Expected error from error function")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test unregistering
|
||||||
|
L.UnregisterGoFunction("add")
|
||||||
|
if err := L.DoString("add(1, 2)"); err == nil {
|
||||||
|
t.Error("Expected error calling unregistered function")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStackSafety(t *testing.T) {
|
||||||
|
L := NewSafe()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
defer L.Cleanup()
|
||||||
|
|
||||||
|
// Test stack overflow protection
|
||||||
|
overflowFunc := func(s *State) int {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
s.PushNumber(float64(i))
|
||||||
|
}
|
||||||
|
s.PushString("done")
|
||||||
|
return 101
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.DoString("overflow()"); err != nil {
|
||||||
|
t.Logf("Got expected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
1
luajit
Submodule
1
luajit
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit e4fd777d6ad41d338125b095abc98e4dd54c05d7
|
146
stack.go
Normal file
146
stack.go
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <lua.h>
|
||||||
|
#include <lauxlib.h>
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// LuaError represents an error from the Lua state
|
||||||
|
type LuaError struct {
|
||||||
|
Code int
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *LuaError) Error() string {
|
||||||
|
return fmt.Sprintf("lua error (code=%d): %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stack management constants from lua.h
|
||||||
|
const (
|
||||||
|
LUA_MINSTACK = 20 // Minimum Lua stack size
|
||||||
|
LUA_MAXSTACK = 1000000 // Maximum Lua stack size
|
||||||
|
LUA_REGISTRYINDEX = -10000 // Pseudo-index for the Lua registry
|
||||||
|
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
|
||||||
|
)
|
||||||
|
|
||||||
|
// checkStack ensures there is enough space on the Lua stack
|
||||||
|
func (s *State) checkStack(n int) error {
|
||||||
|
if C.lua_checkstack(s.L, C.int(n)) == 0 {
|
||||||
|
return fmt.Errorf("stack overflow (cannot allocate %d slots)", n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeCall wraps a potentially dangerous C call with stack checking
|
||||||
|
func (s *State) safeCall(f func() C.int) error {
|
||||||
|
// Save current stack size
|
||||||
|
top := s.GetTop()
|
||||||
|
|
||||||
|
// Ensure we have enough stack space (minimum 20 slots as per Lua standard)
|
||||||
|
if err := s.checkStack(LUA_MINSTACK); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the call
|
||||||
|
status := f()
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
if status != 0 {
|
||||||
|
err := &LuaError{
|
||||||
|
Code: int(status),
|
||||||
|
Message: s.ToString(-1),
|
||||||
|
}
|
||||||
|
s.Pop(1) // Remove error message
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify stack integrity
|
||||||
|
newTop := s.GetTop()
|
||||||
|
if newTop < top {
|
||||||
|
return fmt.Errorf("stack underflow: %d slots lost", top-newTop)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stackGuard wraps a function with stack checking and restoration
|
||||||
|
func stackGuard[T any](s *State, f func() (T, error)) (T, error) {
|
||||||
|
// Save current stack size
|
||||||
|
top := s.GetTop()
|
||||||
|
|
||||||
|
// Run the protected function
|
||||||
|
result, err := f()
|
||||||
|
|
||||||
|
// Restore stack size
|
||||||
|
newTop := s.GetTop()
|
||||||
|
if newTop > top {
|
||||||
|
s.Pop(newTop - top)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// stackGuardValue executes a function that returns a value and error with stack protection
|
||||||
|
func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) {
|
||||||
|
// Save current stack size
|
||||||
|
top := s.GetTop()
|
||||||
|
|
||||||
|
// Run the protected function
|
||||||
|
result, err := f()
|
||||||
|
|
||||||
|
// Restore stack size
|
||||||
|
newTop := s.GetTop()
|
||||||
|
if newTop > top {
|
||||||
|
s.Pop(newTop - top)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// stackGuardErr executes a function that only returns an error with stack protection
|
||||||
|
func stackGuardErr(s *State, f func() error) error {
|
||||||
|
// Save current stack size
|
||||||
|
top := s.GetTop()
|
||||||
|
|
||||||
|
// Run the protected function
|
||||||
|
err := f()
|
||||||
|
|
||||||
|
// Restore stack size
|
||||||
|
newTop := s.GetTop()
|
||||||
|
if newTop > top {
|
||||||
|
s.Pop(newTop - top)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStackTrace returns the current Lua stack trace
|
||||||
|
func (s *State) getStackTrace() string {
|
||||||
|
// Push debug.traceback function
|
||||||
|
s.GetGlobal("debug")
|
||||||
|
if !s.IsTable(-1) {
|
||||||
|
s.Pop(1)
|
||||||
|
return "stack trace not available (debug module not loaded)"
|
||||||
|
}
|
||||||
|
|
||||||
|
s.GetField(-1, "traceback")
|
||||||
|
if !s.IsFunction(-1) {
|
||||||
|
s.Pop(2)
|
||||||
|
return "stack trace not available (debug.traceback not found)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call debug.traceback
|
||||||
|
if err := s.safeCall(func() C.int {
|
||||||
|
return C.lua_pcall(s.L, 0, 1, 0)
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Sprintf("error getting stack trace: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the resulting string
|
||||||
|
trace := s.ToString(-1)
|
||||||
|
s.Pop(1) // Remove the trace string
|
||||||
|
|
||||||
|
return trace
|
||||||
|
}
|
177
table.go
Normal file
177
table.go
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <lua.h>
|
||||||
|
#include <lualib.h>
|
||||||
|
#include <lauxlib.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
static int get_table_length(lua_State *L, int index) {
|
||||||
|
return lua_objlen(L, index);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TableValue represents any value that can be stored in a Lua table
|
||||||
|
type TableValue interface {
|
||||||
|
~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
|
||||||
|
|
||||||
|
// ToTable converts a Lua table to a Go map
|
||||||
|
func (s *State) ToTable(index int) (map[string]interface{}, error) {
|
||||||
|
if s.safeStack {
|
||||||
|
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
|
||||||
|
if !s.IsTable(index) {
|
||||||
|
return nil, fmt.Errorf("not a table at index %d", index)
|
||||||
|
}
|
||||||
|
return s.toTableUnsafe(index)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if !s.IsTable(index) {
|
||||||
|
return nil, fmt.Errorf("not a table at index %d", index)
|
||||||
|
}
|
||||||
|
return s.toTableUnsafe(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) pushTableSafe(table map[string]interface{}) error {
|
||||||
|
size := 2
|
||||||
|
if err := s.checkStack(size); err != nil {
|
||||||
|
return fmt.Errorf("insufficient stack space: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.NewTable()
|
||||||
|
for k, v := range table {
|
||||||
|
if err := s.pushValueSafe(v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.SetField(-2, k)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
|
||||||
|
s.NewTable()
|
||||||
|
for k, v := range table {
|
||||||
|
if err := s.pushValueUnsafe(v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.SetField(-2, k)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
|
||||||
|
if err := s.checkStack(2); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.toTableUnsafe(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
||||||
|
absIdx := s.absIndex(index)
|
||||||
|
table := make(map[string]interface{})
|
||||||
|
|
||||||
|
// Check if it's an array-like table
|
||||||
|
length := s.GetTableLength(absIdx)
|
||||||
|
if length > 0 {
|
||||||
|
array := make([]float64, length)
|
||||||
|
isArray := true
|
||||||
|
|
||||||
|
// Try to convert to array
|
||||||
|
for i := 1; i <= length; i++ {
|
||||||
|
s.PushNumber(float64(i))
|
||||||
|
s.GetTable(absIdx)
|
||||||
|
if s.GetType(-1) != TypeNumber {
|
||||||
|
isArray = false
|
||||||
|
s.Pop(1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
array[i-1] = s.ToNumber(-1)
|
||||||
|
s.Pop(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isArray {
|
||||||
|
return map[string]interface{}{"": array}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular table
|
||||||
|
s.PushNil()
|
||||||
|
for C.lua_next(s.L, C.int(absIdx)) != 0 {
|
||||||
|
key := ""
|
||||||
|
valueType := C.lua_type(s.L, -2)
|
||||||
|
if valueType == C.LUA_TSTRING {
|
||||||
|
key = s.ToString(-2)
|
||||||
|
} else if valueType == C.LUA_TNUMBER {
|
||||||
|
key = fmt.Sprintf("%g", s.ToNumber(-2))
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := s.toValueUnsafe(-1)
|
||||||
|
if err != nil {
|
||||||
|
s.Pop(1)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle nested array case
|
||||||
|
if m, ok := value.(map[string]interface{}); ok {
|
||||||
|
if arr, ok := m[""]; ok {
|
||||||
|
value = arr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
table[key] = value
|
||||||
|
s.Pop(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return table, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTable creates a new table and pushes it onto the stack
|
||||||
|
func (s *State) NewTable() {
|
||||||
|
if s.safeStack {
|
||||||
|
if err := s.checkStack(1); err != nil {
|
||||||
|
// Since we can't return an error, we'll push nil instead
|
||||||
|
s.PushNil()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
C.lua_createtable(s.L, 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTable sets a table field with cached absolute index
|
||||||
|
func (s *State) SetTable(index int) {
|
||||||
|
absIdx := index
|
||||||
|
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||||
|
absIdx = s.GetTop() + index + 1
|
||||||
|
}
|
||||||
|
C.lua_settable(s.L, C.int(absIdx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTable gets a table field with cached absolute index
|
||||||
|
func (s *State) GetTable(index int) {
|
||||||
|
absIdx := index
|
||||||
|
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||||
|
absIdx = s.GetTop() + index + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.safeStack {
|
||||||
|
if err := s.checkStack(1); err != nil {
|
||||||
|
s.PushNil()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
C.lua_gettable(s.L, C.int(absIdx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
|
||||||
|
func (s *State) PushTable(table map[string]interface{}) error {
|
||||||
|
if s.safeStack {
|
||||||
|
return s.pushTableSafe(table)
|
||||||
|
}
|
||||||
|
return s.pushTableUnsafe(table)
|
||||||
|
}
|
97
table_test.go
Normal file
97
table_test.go
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTableOperations(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
data: map[string]interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "primitives",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"str": "hello",
|
||||||
|
"num": 42.0,
|
||||||
|
"bool": true,
|
||||||
|
"array": []float64{1.1, 2.2, 3.3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"nested": map[string]interface{}{
|
||||||
|
"value": 123.0,
|
||||||
|
"array": []float64{4.4, 5.5},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range factories {
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||||
|
L := f.new()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
if err := L.PushTable(tt.data); err != nil {
|
||||||
|
t.Fatalf("PushTable() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := L.ToTable(-1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ToTable() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tablesEqual(got, tt.data) {
|
||||||
|
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tablesEqual(a, b map[string]interface{}) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v1 := range a {
|
||||||
|
v2, ok := b[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v1 := v1.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
v2, ok := v2.(map[string]interface{})
|
||||||
|
if !ok || !tablesEqual(v1, v2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case []float64:
|
||||||
|
v2, ok := v2.([]float64)
|
||||||
|
if !ok || len(v1) != len(v2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range v1 {
|
||||||
|
if math.Abs(v1[i]-v2[i]) > 1e-10 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if v1 != v2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
51
types.go
Normal file
51
types.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <lua.h>
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
// LuaType represents Lua value types
|
||||||
|
type LuaType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// These constants must match lua.h's LUA_T* values
|
||||||
|
TypeNone LuaType = -1
|
||||||
|
TypeNil LuaType = 0
|
||||||
|
TypeBoolean LuaType = 1
|
||||||
|
TypeLightUserData LuaType = 2
|
||||||
|
TypeNumber LuaType = 3
|
||||||
|
TypeString LuaType = 4
|
||||||
|
TypeTable LuaType = 5
|
||||||
|
TypeFunction LuaType = 6
|
||||||
|
TypeUserData LuaType = 7
|
||||||
|
TypeThread LuaType = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of the Lua type
|
||||||
|
func (t LuaType) String() string {
|
||||||
|
switch t {
|
||||||
|
case TypeNone:
|
||||||
|
return "none"
|
||||||
|
case TypeNil:
|
||||||
|
return "nil"
|
||||||
|
case TypeBoolean:
|
||||||
|
return "boolean"
|
||||||
|
case TypeLightUserData:
|
||||||
|
return "lightuserdata"
|
||||||
|
case TypeNumber:
|
||||||
|
return "number"
|
||||||
|
case TypeString:
|
||||||
|
return "string"
|
||||||
|
case TypeTable:
|
||||||
|
return "table"
|
||||||
|
case TypeFunction:
|
||||||
|
return "function"
|
||||||
|
case TypeUserData:
|
||||||
|
return "userdata"
|
||||||
|
case TypeThread:
|
||||||
|
return "thread"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
325
wrapper.go
Normal file
325
wrapper.go
Normal file
|
@ -0,0 +1,325 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -I${SRCDIR}/luajit
|
||||||
|
#cgo windows LDFLAGS: -L${SRCDIR}/luajit -llua51
|
||||||
|
#cgo !windows LDFLAGS: -L${SRCDIR}/luajit -lluajit
|
||||||
|
|
||||||
|
#include <lua.h>
|
||||||
|
#include <lualib.h>
|
||||||
|
#include <lauxlib.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
void init_dll_paths(void);
|
||||||
|
|
||||||
|
static int do_string(lua_State *L, const char *s) {
|
||||||
|
int status = luaL_loadstring(L, s);
|
||||||
|
if (status) return status;
|
||||||
|
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static int do_file(lua_State *L, const char *filename) {
|
||||||
|
int status = luaL_loadfile(L, filename);
|
||||||
|
if (status) return status;
|
||||||
|
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State represents a Lua state with configurable stack safety
|
||||||
|
type State struct {
|
||||||
|
L *C.lua_State
|
||||||
|
safeStack bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSafe creates a new Lua state with full stack safety guarantees
|
||||||
|
func NewSafe() *State {
|
||||||
|
L := C.luaL_newstate()
|
||||||
|
if L == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
C.luaL_openlibs(L)
|
||||||
|
return &State{L: L, safeStack: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Lua state with minimal stack checking
|
||||||
|
func New() *State {
|
||||||
|
L := C.luaL_newstate()
|
||||||
|
if L == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
C.luaL_openlibs(L)
|
||||||
|
return &State{L: L, safeStack: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the Lua state
|
||||||
|
func (s *State) Close() {
|
||||||
|
if s.L != nil {
|
||||||
|
C.lua_close(s.L)
|
||||||
|
s.L = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoString executes a Lua string with appropriate stack management
|
||||||
|
func (s *State) DoString(str string) error {
|
||||||
|
cstr := C.CString(str)
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
|
||||||
|
if s.safeStack {
|
||||||
|
return stackGuardErr(s, func() error {
|
||||||
|
return s.safeCall(func() C.int {
|
||||||
|
return C.do_string(s.L, cstr)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
status := C.do_string(s.L, cstr)
|
||||||
|
if status != 0 {
|
||||||
|
return &LuaError{
|
||||||
|
Code: int(status),
|
||||||
|
Message: s.ToString(-1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushValue pushes a Go value onto the stack
|
||||||
|
func (s *State) PushValue(v interface{}) error {
|
||||||
|
if s.safeStack {
|
||||||
|
return stackGuardErr(s, func() error {
|
||||||
|
if err := s.checkStack(1); err != nil {
|
||||||
|
return fmt.Errorf("pushing value: %w", err)
|
||||||
|
}
|
||||||
|
return s.pushValueUnsafe(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return s.pushValueUnsafe(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) pushValueSafe(v interface{}) error {
|
||||||
|
if err := s.checkStack(1); err != nil {
|
||||||
|
return fmt.Errorf("pushing value: %w", err)
|
||||||
|
}
|
||||||
|
return s.pushValueUnsafe(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) pushValueUnsafe(v interface{}) error {
|
||||||
|
switch v := v.(type) {
|
||||||
|
case nil:
|
||||||
|
s.PushNil()
|
||||||
|
case bool:
|
||||||
|
s.PushBoolean(v)
|
||||||
|
case float64:
|
||||||
|
s.PushNumber(v)
|
||||||
|
case int:
|
||||||
|
s.PushNumber(float64(v))
|
||||||
|
case string:
|
||||||
|
s.PushString(v)
|
||||||
|
case map[string]interface{}:
|
||||||
|
// Special case: handle array stored in map
|
||||||
|
if arr, ok := v[""].([]float64); ok {
|
||||||
|
s.NewTable()
|
||||||
|
for i, elem := range arr {
|
||||||
|
s.PushNumber(float64(i + 1))
|
||||||
|
s.PushNumber(elem)
|
||||||
|
s.SetTable(-3)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.pushTableUnsafe(v)
|
||||||
|
case []float64:
|
||||||
|
s.NewTable()
|
||||||
|
for i, elem := range v {
|
||||||
|
s.PushNumber(float64(i + 1))
|
||||||
|
s.PushNumber(elem)
|
||||||
|
s.SetTable(-3)
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
s.NewTable()
|
||||||
|
for i, elem := range v {
|
||||||
|
s.PushNumber(float64(i + 1))
|
||||||
|
if err := s.pushValueUnsafe(elem); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.SetTable(-3)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported type: %T", v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToValue converts a Lua value to a Go value
|
||||||
|
func (s *State) ToValue(index int) (interface{}, error) {
|
||||||
|
if s.safeStack {
|
||||||
|
return stackGuardValue[interface{}](s, func() (interface{}, error) {
|
||||||
|
return s.toValueUnsafe(index)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return s.toValueUnsafe(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
||||||
|
switch s.GetType(index) {
|
||||||
|
case TypeNil:
|
||||||
|
return nil, nil
|
||||||
|
case TypeBoolean:
|
||||||
|
return s.ToBoolean(index), nil
|
||||||
|
case TypeNumber:
|
||||||
|
return s.ToNumber(index), nil
|
||||||
|
case TypeString:
|
||||||
|
return s.ToString(index), nil
|
||||||
|
case TypeTable:
|
||||||
|
if !s.IsTable(index) {
|
||||||
|
return nil, fmt.Errorf("not a table at index %d", index)
|
||||||
|
}
|
||||||
|
return s.toTableUnsafe(index)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple operations remain unchanged as they don't need stack protection
|
||||||
|
|
||||||
|
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
|
||||||
|
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
|
||||||
|
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
|
||||||
|
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
|
||||||
|
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
|
||||||
|
func (s *State) ToString(index int) string {
|
||||||
|
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
|
||||||
|
}
|
||||||
|
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
|
||||||
|
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
|
||||||
|
|
||||||
|
// Push operations
|
||||||
|
|
||||||
|
func (s *State) PushNil() { C.lua_pushnil(s.L) }
|
||||||
|
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
|
||||||
|
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
|
||||||
|
func (s *State) PushString(str string) {
|
||||||
|
cstr := C.CString(str)
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
C.lua_pushstring(s.L, cstr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
func bool2int(b bool) int {
|
||||||
|
if b {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) absIndex(index int) int {
|
||||||
|
if index > 0 || index <= LUA_REGISTRYINDEX {
|
||||||
|
return index
|
||||||
|
}
|
||||||
|
return s.GetTop() + index + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetField sets a field in a table at the given index with cached absolute index
|
||||||
|
func (s *State) SetField(index int, key string) {
|
||||||
|
absIdx := index
|
||||||
|
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||||
|
absIdx = s.GetTop() + index + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
cstr := C.CString(key)
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
C.lua_setfield(s.L, C.int(absIdx), cstr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetField gets a field from a table with cached absolute index
|
||||||
|
func (s *State) GetField(index int, key string) {
|
||||||
|
absIdx := index
|
||||||
|
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||||
|
absIdx = s.GetTop() + index + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.safeStack {
|
||||||
|
if err := s.checkStack(1); err != nil {
|
||||||
|
s.PushNil()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cstr := C.CString(key)
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
C.lua_getfield(s.L, C.int(absIdx), cstr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobal gets a global variable and pushes it onto the stack
|
||||||
|
func (s *State) GetGlobal(name string) {
|
||||||
|
if s.safeStack {
|
||||||
|
if err := s.checkStack(1); err != nil {
|
||||||
|
s.PushNil()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cstr := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGlobal sets a global variable from the value at the top of the stack
|
||||||
|
func (s *State) SetGlobal(name string) {
|
||||||
|
// SetGlobal doesn't need stack space checking as it pops the value
|
||||||
|
cstr := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes element with cached absolute index
|
||||||
|
func (s *State) Remove(index int) {
|
||||||
|
absIdx := index
|
||||||
|
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||||
|
absIdx = s.GetTop() + index + 1
|
||||||
|
}
|
||||||
|
C.lua_remove(s.L, C.int(absIdx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoFile executes a Lua file with appropriate stack management
|
||||||
|
func (s *State) DoFile(filename string) error {
|
||||||
|
cfilename := C.CString(filename)
|
||||||
|
defer C.free(unsafe.Pointer(cfilename))
|
||||||
|
|
||||||
|
if s.safeStack {
|
||||||
|
return stackGuardErr(s, func() error {
|
||||||
|
return s.safeCall(func() C.int {
|
||||||
|
return C.do_file(s.L, cfilename)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
status := C.do_file(s.L, cfilename)
|
||||||
|
if status != 0 {
|
||||||
|
return &LuaError{
|
||||||
|
Code: int(status),
|
||||||
|
Message: s.ToString(-1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) SetPackagePath(path string) error {
|
||||||
|
path = filepath.ToSlash(path)
|
||||||
|
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
|
||||||
|
return fmt.Errorf("setting package.path: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) AddPackagePath(path string) error {
|
||||||
|
path = filepath.ToSlash(path)
|
||||||
|
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
|
||||||
|
return fmt.Errorf("adding to package.path: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
276
wrapper_test.go
Normal file
276
wrapper_test.go
Normal file
|
@ -0,0 +1,276 @@
|
||||||
|
package luajit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stateFactory struct {
|
||||||
|
name string
|
||||||
|
new func() *State
|
||||||
|
}
|
||||||
|
|
||||||
|
var factories = []stateFactory{
|
||||||
|
{"unsafe", New},
|
||||||
|
{"safe", NewSafe},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
for _, f := range factories {
|
||||||
|
t.Run(f.name, func(t *testing.T) {
|
||||||
|
L := f.new()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"simple addition", "return 1 + 1", false},
|
||||||
|
{"set global", "test = 42", false},
|
||||||
|
{"syntax error", "this is not valid lua", true},
|
||||||
|
{"runtime error", "error('test error')", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range factories {
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||||
|
L := f.new()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
err := L.DoString(tt.code)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPushAndGetValues(t *testing.T) {
|
||||||
|
values := []struct {
|
||||||
|
name string
|
||||||
|
push func(*State)
|
||||||
|
check func(*State) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
push: func(L *State) { L.PushString("hello") },
|
||||||
|
check: func(L *State) error {
|
||||||
|
if got := L.ToString(-1); got != "hello" {
|
||||||
|
return fmt.Errorf("got %q, want %q", got, "hello")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number",
|
||||||
|
push: func(L *State) { L.PushNumber(42.5) },
|
||||||
|
check: func(L *State) error {
|
||||||
|
if got := L.ToNumber(-1); got != 42.5 {
|
||||||
|
return fmt.Errorf("got %f, want %f", got, 42.5)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean",
|
||||||
|
push: func(L *State) { L.PushBoolean(true) },
|
||||||
|
check: func(L *State) error {
|
||||||
|
if got := L.ToBoolean(-1); !got {
|
||||||
|
return fmt.Errorf("got %v, want true", got)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
push: func(L *State) { L.PushNil() },
|
||||||
|
check: func(L *State) error {
|
||||||
|
if typ := L.GetType(-1); typ != TypeNil {
|
||||||
|
return fmt.Errorf("got type %v, want TypeNil", typ)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range factories {
|
||||||
|
for _, v := range values {
|
||||||
|
t.Run(f.name+"/"+v.name, func(t *testing.T) {
|
||||||
|
L := f.new()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
v.push(L)
|
||||||
|
if err := v.check(L); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStackManipulation(t *testing.T) {
|
||||||
|
for _, f := range factories {
|
||||||
|
t.Run(f.name, func(t *testing.T) {
|
||||||
|
L := f.new()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// Push values
|
||||||
|
values := []string{"first", "second", "third"}
|
||||||
|
for _, v := range values {
|
||||||
|
L.PushString(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check size
|
||||||
|
if top := L.GetTop(); top != len(values) {
|
||||||
|
t.Errorf("stack size = %d, want %d", top, len(values))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pop one value
|
||||||
|
L.Pop(1)
|
||||||
|
|
||||||
|
// Check new top
|
||||||
|
if str := L.ToString(-1); str != "second" {
|
||||||
|
t.Errorf("top value = %q, want 'second'", str)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check new size
|
||||||
|
if top := L.GetTop(); top != len(values)-1 {
|
||||||
|
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobals(t *testing.T) {
|
||||||
|
for _, f := range factories {
|
||||||
|
t.Run(f.name, func(t *testing.T) {
|
||||||
|
L := f.new()
|
||||||
|
if L == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// Test via Lua
|
||||||
|
if err := L.DoString(`globalVar = "test"`); err != nil {
|
||||||
|
t.Fatalf("DoString error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the global
|
||||||
|
L.GetGlobal("globalVar")
|
||||||
|
if str := L.ToString(-1); str != "test" {
|
||||||
|
t.Errorf("global value = %q, want 'test'", str)
|
||||||
|
}
|
||||||
|
L.Pop(1)
|
||||||
|
|
||||||
|
// Set and get via API
|
||||||
|
L.PushNumber(42)
|
||||||
|
L.SetGlobal("testNum")
|
||||||
|
|
||||||
|
L.GetGlobal("testNum")
|
||||||
|
if num := L.ToNumber(-1); num != 42 {
|
||||||
|
t.Errorf("global number = %f, want 42", num)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoFile(t *testing.T) {
|
||||||
|
L := NewSafe()
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
// Create test file
|
||||||
|
content := []byte(`
|
||||||
|
function add(a, b)
|
||||||
|
return a + b
|
||||||
|
end
|
||||||
|
result = add(40, 2)
|
||||||
|
`)
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
filename := filepath.Join(tmpDir, "test.lua")
|
||||||
|
if err := os.WriteFile(filename, content, 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to create test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.DoFile(filename); err != nil {
|
||||||
|
t.Fatalf("DoFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
L.GetGlobal("result")
|
||||||
|
if result := L.ToNumber(-1); result != 42 {
|
||||||
|
t.Errorf("Expected result=42, got %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAndPackagePath(t *testing.T) {
|
||||||
|
L := NewSafe()
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create module file
|
||||||
|
moduleContent := []byte(`
|
||||||
|
local M = {}
|
||||||
|
function M.multiply(a, b)
|
||||||
|
return a * b
|
||||||
|
end
|
||||||
|
return M
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to create module file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add module path and test require
|
||||||
|
if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil {
|
||||||
|
t.Fatalf("AddPackagePath failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := L.DoString(`
|
||||||
|
local math = require("mathmod")
|
||||||
|
result = math.multiply(6, 7)
|
||||||
|
`); err != nil {
|
||||||
|
t.Fatalf("Failed to require module: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
L.GetGlobal("result")
|
||||||
|
if result := L.ToNumber(-1); result != 42 {
|
||||||
|
t.Errorf("Expected result=42, got %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetPackagePath(t *testing.T) {
|
||||||
|
L := NewSafe()
|
||||||
|
defer L.Close()
|
||||||
|
|
||||||
|
customPath := "./custom/?.lua"
|
||||||
|
if err := L.SetPackagePath(customPath); err != nil {
|
||||||
|
t.Fatalf("SetPackagePath failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
L.GetGlobal("package")
|
||||||
|
L.GetField(-1, "path")
|
||||||
|
if path := L.ToString(-1); path != customPath {
|
||||||
|
t.Errorf("Expected package.path=%q, got %q", customPath, path)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user