Implement the wrapper
This commit is contained in:
parent
2f6764aef1
commit
81501915e3
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
9
.idea/ljtg.iml
Normal file
9
.idea/ljtg.iml
Normal file
|
@ -0,0 +1,9 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="WEB_MODULE" version="4">
|
||||
<component name="Go" enabled="true" />
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
10
.idea/material_theme_project_new.xml
Normal file
10
.idea/material_theme_project_new.xml
Normal file
|
@ -0,0 +1,10 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="MaterialThemeProjectNewConfig">
|
||||
<option name="metadata">
|
||||
<MTProjectMetadataState>
|
||||
<option name="userId" value="3fdbffe2:19499e3ad4d:-7fff" />
|
||||
</MTProjectMetadataState>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/ljtg.iml" filepath="$PROJECT_DIR$/.idea/ljtg.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
8
.idea/vcs.xml
Normal file
8
.idea/vcs.xml
Normal file
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/luajit" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/luajit/luajit" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
98
functions.go
Normal file
98
functions.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package luajit
|
||||
|
||||
/*
|
||||
#include <lua.h>
|
||||
#include <lauxlib.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern int goFunctionWrapper(lua_State* L);
|
||||
|
||||
static int get_upvalue_index(int i) {
|
||||
return -10002 - i; // LUA_GLOBALSINDEX - i
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type GoFunction func(*State) int
|
||||
|
||||
var (
|
||||
functionRegistry = struct {
|
||||
sync.RWMutex
|
||||
funcs map[unsafe.Pointer]GoFunction
|
||||
}{
|
||||
funcs: make(map[unsafe.Pointer]GoFunction),
|
||||
}
|
||||
)
|
||||
|
||||
//export goFunctionWrapper
|
||||
func goFunctionWrapper(L *C.lua_State) C.int {
|
||||
state := &State{L: L, safeStack: true}
|
||||
|
||||
// Get upvalue using standard Lua 5.1 macro
|
||||
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
||||
if ptr == nil {
|
||||
state.PushString("error: function not found")
|
||||
return -1
|
||||
}
|
||||
|
||||
functionRegistry.RLock()
|
||||
fn, ok := functionRegistry.funcs[ptr]
|
||||
functionRegistry.RUnlock()
|
||||
|
||||
if !ok {
|
||||
state.PushString("error: function not found in registry")
|
||||
return -1
|
||||
}
|
||||
|
||||
result := fn(state)
|
||||
return C.int(result)
|
||||
}
|
||||
|
||||
func (s *State) PushGoFunction(fn GoFunction) error {
|
||||
// Push lightuserdata as upvalue and create closure
|
||||
ptr := C.malloc(1)
|
||||
if ptr == nil {
|
||||
return fmt.Errorf("failed to allocate memory for function pointer")
|
||||
}
|
||||
|
||||
functionRegistry.Lock()
|
||||
functionRegistry.funcs[ptr] = fn
|
||||
functionRegistry.Unlock()
|
||||
|
||||
C.lua_pushlightuserdata(s.L, ptr)
|
||||
C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) RegisterGoFunction(name string, fn GoFunction) error {
|
||||
if err := s.PushGoFunction(fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) UnregisterGoFunction(name string) {
|
||||
s.PushNil()
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
|
||||
}
|
||||
|
||||
func (s *State) Cleanup() {
|
||||
functionRegistry.Lock()
|
||||
defer functionRegistry.Unlock()
|
||||
|
||||
for ptr := range functionRegistry.funcs {
|
||||
C.free(ptr)
|
||||
}
|
||||
functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction)
|
||||
}
|
109
functions_test.go
Normal file
109
functions_test.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package luajit
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGoFunctions(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
addFunc := func(s *State) int {
|
||||
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
|
||||
return 1
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("add", addFunc); err != nil {
|
||||
t.Fatalf("Failed to register function: %v", err)
|
||||
}
|
||||
|
||||
// Test basic function call
|
||||
if err := L.DoString("result = add(40, 2)"); err != nil {
|
||||
t.Fatalf("Failed to call function: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("got %v, want 42", result)
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
// Test multiple return values
|
||||
multiFunc := func(s *State) int {
|
||||
s.PushString("hello")
|
||||
s.PushNumber(42)
|
||||
s.PushBoolean(true)
|
||||
return 3
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
|
||||
t.Fatalf("Failed to register multi function: %v", err)
|
||||
}
|
||||
|
||||
code := `
|
||||
a, b, c = multi()
|
||||
result = (a == "hello" and b == 42 and c == true)
|
||||
`
|
||||
|
||||
if err := L.DoString(code); err != nil {
|
||||
t.Fatalf("Failed to call multi function: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if !L.ToBoolean(-1) {
|
||||
t.Error("Multiple return values test failed")
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
// Test error handling
|
||||
errFunc := func(s *State) int {
|
||||
s.PushString("test error")
|
||||
return -1
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("err", errFunc); err != nil {
|
||||
t.Fatalf("Failed to register error function: %v", err)
|
||||
}
|
||||
|
||||
if err := L.DoString("err()"); err == nil {
|
||||
t.Error("Expected error from error function")
|
||||
}
|
||||
|
||||
// Test unregistering
|
||||
L.UnregisterGoFunction("add")
|
||||
if err := L.DoString("add(1, 2)"); err == nil {
|
||||
t.Error("Expected error calling unregistered function")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStackSafety(t *testing.T) {
|
||||
L := NewSafe()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
defer L.Cleanup()
|
||||
|
||||
// Test stack overflow protection
|
||||
overflowFunc := func(s *State) int {
|
||||
for i := 0; i < 100; i++ {
|
||||
s.PushNumber(float64(i))
|
||||
}
|
||||
s.PushString("done")
|
||||
return 101
|
||||
}
|
||||
|
||||
if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := L.DoString("overflow()"); err != nil {
|
||||
t.Logf("Got expected error: %v", err)
|
||||
}
|
||||
}
|
1
luajit
Submodule
1
luajit
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit e4fd777d6ad41d338125b095abc98e4dd54c05d7
|
146
stack.go
Normal file
146
stack.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package luajit
|
||||
|
||||
/*
|
||||
#include <lua.h>
|
||||
#include <lauxlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import "fmt"
|
||||
|
||||
// LuaError represents an error from the Lua state
|
||||
type LuaError struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *LuaError) Error() string {
|
||||
return fmt.Sprintf("lua error (code=%d): %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Stack management constants from lua.h
|
||||
const (
|
||||
LUA_MINSTACK = 20 // Minimum Lua stack size
|
||||
LUA_MAXSTACK = 1000000 // Maximum Lua stack size
|
||||
LUA_REGISTRYINDEX = -10000 // Pseudo-index for the Lua registry
|
||||
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
|
||||
)
|
||||
|
||||
// checkStack ensures there is enough space on the Lua stack
|
||||
func (s *State) checkStack(n int) error {
|
||||
if C.lua_checkstack(s.L, C.int(n)) == 0 {
|
||||
return fmt.Errorf("stack overflow (cannot allocate %d slots)", n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// safeCall wraps a potentially dangerous C call with stack checking
|
||||
func (s *State) safeCall(f func() C.int) error {
|
||||
// Save current stack size
|
||||
top := s.GetTop()
|
||||
|
||||
// Ensure we have enough stack space (minimum 20 slots as per Lua standard)
|
||||
if err := s.checkStack(LUA_MINSTACK); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make the call
|
||||
status := f()
|
||||
|
||||
// Check for errors
|
||||
if status != 0 {
|
||||
err := &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
s.Pop(1) // Remove error message
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify stack integrity
|
||||
newTop := s.GetTop()
|
||||
if newTop < top {
|
||||
return fmt.Errorf("stack underflow: %d slots lost", top-newTop)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stackGuard wraps a function with stack checking and restoration
|
||||
func stackGuard[T any](s *State, f func() (T, error)) (T, error) {
|
||||
// Save current stack size
|
||||
top := s.GetTop()
|
||||
|
||||
// Run the protected function
|
||||
result, err := f()
|
||||
|
||||
// Restore stack size
|
||||
newTop := s.GetTop()
|
||||
if newTop > top {
|
||||
s.Pop(newTop - top)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// stackGuardValue executes a function that returns a value and error with stack protection
|
||||
func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) {
|
||||
// Save current stack size
|
||||
top := s.GetTop()
|
||||
|
||||
// Run the protected function
|
||||
result, err := f()
|
||||
|
||||
// Restore stack size
|
||||
newTop := s.GetTop()
|
||||
if newTop > top {
|
||||
s.Pop(newTop - top)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// stackGuardErr executes a function that only returns an error with stack protection
|
||||
func stackGuardErr(s *State, f func() error) error {
|
||||
// Save current stack size
|
||||
top := s.GetTop()
|
||||
|
||||
// Run the protected function
|
||||
err := f()
|
||||
|
||||
// Restore stack size
|
||||
newTop := s.GetTop()
|
||||
if newTop > top {
|
||||
s.Pop(newTop - top)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// getStackTrace returns the current Lua stack trace
|
||||
func (s *State) getStackTrace() string {
|
||||
// Push debug.traceback function
|
||||
s.GetGlobal("debug")
|
||||
if !s.IsTable(-1) {
|
||||
s.Pop(1)
|
||||
return "stack trace not available (debug module not loaded)"
|
||||
}
|
||||
|
||||
s.GetField(-1, "traceback")
|
||||
if !s.IsFunction(-1) {
|
||||
s.Pop(2)
|
||||
return "stack trace not available (debug.traceback not found)"
|
||||
}
|
||||
|
||||
// Call debug.traceback
|
||||
if err := s.safeCall(func() C.int {
|
||||
return C.lua_pcall(s.L, 0, 1, 0)
|
||||
}); err != nil {
|
||||
return fmt.Sprintf("error getting stack trace: %v", err)
|
||||
}
|
||||
|
||||
// Get the resulting string
|
||||
trace := s.ToString(-1)
|
||||
s.Pop(1) // Remove the trace string
|
||||
|
||||
return trace
|
||||
}
|
177
table.go
Normal file
177
table.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package luajit
|
||||
|
||||
/*
|
||||
#include <lua.h>
|
||||
#include <lualib.h>
|
||||
#include <lauxlib.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static int get_table_length(lua_State *L, int index) {
|
||||
return lua_objlen(L, index);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// TableValue represents any value that can be stored in a Lua table
|
||||
type TableValue interface {
|
||||
~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{}
|
||||
}
|
||||
|
||||
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
|
||||
|
||||
// ToTable converts a Lua table to a Go map
|
||||
func (s *State) ToTable(index int) (map[string]interface{}, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
})
|
||||
}
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) pushTableSafe(table map[string]interface{}) error {
|
||||
size := 2
|
||||
if err := s.checkStack(size); err != nil {
|
||||
return fmt.Errorf("insufficient stack space: %w", err)
|
||||
}
|
||||
|
||||
s.NewTable()
|
||||
for k, v := range table {
|
||||
if err := s.pushValueSafe(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetField(-2, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
|
||||
s.NewTable()
|
||||
for k, v := range table {
|
||||
if err := s.pushValueUnsafe(v); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetField(-2, k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
|
||||
if err := s.checkStack(2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
|
||||
absIdx := s.absIndex(index)
|
||||
table := make(map[string]interface{})
|
||||
|
||||
// Check if it's an array-like table
|
||||
length := s.GetTableLength(absIdx)
|
||||
if length > 0 {
|
||||
array := make([]float64, length)
|
||||
isArray := true
|
||||
|
||||
// Try to convert to array
|
||||
for i := 1; i <= length; i++ {
|
||||
s.PushNumber(float64(i))
|
||||
s.GetTable(absIdx)
|
||||
if s.GetType(-1) != TypeNumber {
|
||||
isArray = false
|
||||
s.Pop(1)
|
||||
break
|
||||
}
|
||||
array[i-1] = s.ToNumber(-1)
|
||||
s.Pop(1)
|
||||
}
|
||||
|
||||
if isArray {
|
||||
return map[string]interface{}{"": array}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle regular table
|
||||
s.PushNil()
|
||||
for C.lua_next(s.L, C.int(absIdx)) != 0 {
|
||||
key := ""
|
||||
valueType := C.lua_type(s.L, -2)
|
||||
if valueType == C.LUA_TSTRING {
|
||||
key = s.ToString(-2)
|
||||
} else if valueType == C.LUA_TNUMBER {
|
||||
key = fmt.Sprintf("%g", s.ToNumber(-2))
|
||||
}
|
||||
|
||||
value, err := s.toValueUnsafe(-1)
|
||||
if err != nil {
|
||||
s.Pop(1)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle nested array case
|
||||
if m, ok := value.(map[string]interface{}); ok {
|
||||
if arr, ok := m[""]; ok {
|
||||
value = arr
|
||||
}
|
||||
}
|
||||
|
||||
table[key] = value
|
||||
s.Pop(1)
|
||||
}
|
||||
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// NewTable creates a new table and pushes it onto the stack
|
||||
func (s *State) NewTable() {
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
// Since we can't return an error, we'll push nil instead
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
C.lua_createtable(s.L, 0, 0)
|
||||
}
|
||||
|
||||
// SetTable sets a table field with cached absolute index
|
||||
func (s *State) SetTable(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
C.lua_settable(s.L, C.int(absIdx))
|
||||
}
|
||||
|
||||
// GetTable gets a table field with cached absolute index
|
||||
func (s *State) GetTable(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
C.lua_gettable(s.L, C.int(absIdx))
|
||||
}
|
||||
|
||||
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
|
||||
func (s *State) PushTable(table map[string]interface{}) error {
|
||||
if s.safeStack {
|
||||
return s.pushTableSafe(table)
|
||||
}
|
||||
return s.pushTableUnsafe(table)
|
||||
}
|
97
table_test.go
Normal file
97
table_test.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package luajit
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTableOperations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
data: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "primitives",
|
||||
data: map[string]interface{}{
|
||||
"str": "hello",
|
||||
"num": 42.0,
|
||||
"bool": true,
|
||||
"array": []float64{1.1, 2.2, 3.3},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested",
|
||||
data: map[string]interface{}{
|
||||
"nested": map[string]interface{}{
|
||||
"value": 123.0,
|
||||
"array": []float64{4.4, 5.5},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
if err := L.PushTable(tt.data); err != nil {
|
||||
t.Fatalf("PushTable() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := L.ToTable(-1)
|
||||
if err != nil {
|
||||
t.Fatalf("ToTable() error = %v", err)
|
||||
}
|
||||
|
||||
if !tablesEqual(got, tt.data) {
|
||||
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tablesEqual(a, b map[string]interface{}) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k, v1 := range a {
|
||||
v2, ok := b[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch v1 := v1.(type) {
|
||||
case map[string]interface{}:
|
||||
v2, ok := v2.(map[string]interface{})
|
||||
if !ok || !tablesEqual(v1, v2) {
|
||||
return false
|
||||
}
|
||||
case []float64:
|
||||
v2, ok := v2.([]float64)
|
||||
if !ok || len(v1) != len(v2) {
|
||||
return false
|
||||
}
|
||||
for i := range v1 {
|
||||
if math.Abs(v1[i]-v2[i]) > 1e-10 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
default:
|
||||
if v1 != v2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
51
types.go
Normal file
51
types.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package luajit
|
||||
|
||||
/*
|
||||
#include <lua.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// LuaType represents Lua value types
|
||||
type LuaType int
|
||||
|
||||
const (
|
||||
// These constants must match lua.h's LUA_T* values
|
||||
TypeNone LuaType = -1
|
||||
TypeNil LuaType = 0
|
||||
TypeBoolean LuaType = 1
|
||||
TypeLightUserData LuaType = 2
|
||||
TypeNumber LuaType = 3
|
||||
TypeString LuaType = 4
|
||||
TypeTable LuaType = 5
|
||||
TypeFunction LuaType = 6
|
||||
TypeUserData LuaType = 7
|
||||
TypeThread LuaType = 8
|
||||
)
|
||||
|
||||
// String returns the string representation of the Lua type
|
||||
func (t LuaType) String() string {
|
||||
switch t {
|
||||
case TypeNone:
|
||||
return "none"
|
||||
case TypeNil:
|
||||
return "nil"
|
||||
case TypeBoolean:
|
||||
return "boolean"
|
||||
case TypeLightUserData:
|
||||
return "lightuserdata"
|
||||
case TypeNumber:
|
||||
return "number"
|
||||
case TypeString:
|
||||
return "string"
|
||||
case TypeTable:
|
||||
return "table"
|
||||
case TypeFunction:
|
||||
return "function"
|
||||
case TypeUserData:
|
||||
return "userdata"
|
||||
case TypeThread:
|
||||
return "thread"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
325
wrapper.go
Normal file
325
wrapper.go
Normal file
|
@ -0,0 +1,325 @@
|
|||
package luajit
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/luajit
|
||||
#cgo windows LDFLAGS: -L${SRCDIR}/luajit -llua51
|
||||
#cgo !windows LDFLAGS: -L${SRCDIR}/luajit -lluajit
|
||||
|
||||
#include <lua.h>
|
||||
#include <lualib.h>
|
||||
#include <lauxlib.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
void init_dll_paths(void);
|
||||
|
||||
static int do_string(lua_State *L, const char *s) {
|
||||
int status = luaL_loadstring(L, s);
|
||||
if (status) return status;
|
||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||
}
|
||||
|
||||
|
||||
static int do_file(lua_State *L, const char *filename) {
|
||||
int status = luaL_loadfile(L, filename);
|
||||
if (status) return status;
|
||||
return lua_pcall(L, 0, LUA_MULTRET, 0);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// State represents a Lua state with configurable stack safety
|
||||
type State struct {
|
||||
L *C.lua_State
|
||||
safeStack bool
|
||||
}
|
||||
|
||||
// NewSafe creates a new Lua state with full stack safety guarantees
|
||||
func NewSafe() *State {
|
||||
L := C.luaL_newstate()
|
||||
if L == nil {
|
||||
return nil
|
||||
}
|
||||
C.luaL_openlibs(L)
|
||||
return &State{L: L, safeStack: true}
|
||||
}
|
||||
|
||||
// New creates a new Lua state with minimal stack checking
|
||||
func New() *State {
|
||||
L := C.luaL_newstate()
|
||||
if L == nil {
|
||||
return nil
|
||||
}
|
||||
C.luaL_openlibs(L)
|
||||
return &State{L: L, safeStack: false}
|
||||
}
|
||||
|
||||
// Close closes the Lua state
|
||||
func (s *State) Close() {
|
||||
if s.L != nil {
|
||||
C.lua_close(s.L)
|
||||
s.L = nil
|
||||
}
|
||||
}
|
||||
|
||||
// DoString executes a Lua string with appropriate stack management
|
||||
func (s *State) DoString(str string) error {
|
||||
cstr := C.CString(str)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
return s.safeCall(func() C.int {
|
||||
return C.do_string(s.L, cstr)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
status := C.do_string(s.L, cstr)
|
||||
if status != 0 {
|
||||
return &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushValue pushes a Go value onto the stack
|
||||
func (s *State) PushValue(v interface{}) error {
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
return fmt.Errorf("pushing value: %w", err)
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
})
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
}
|
||||
|
||||
func (s *State) pushValueSafe(v interface{}) error {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
return fmt.Errorf("pushing value: %w", err)
|
||||
}
|
||||
return s.pushValueUnsafe(v)
|
||||
}
|
||||
|
||||
func (s *State) pushValueUnsafe(v interface{}) error {
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
s.PushNil()
|
||||
case bool:
|
||||
s.PushBoolean(v)
|
||||
case float64:
|
||||
s.PushNumber(v)
|
||||
case int:
|
||||
s.PushNumber(float64(v))
|
||||
case string:
|
||||
s.PushString(v)
|
||||
case map[string]interface{}:
|
||||
// Special case: handle array stored in map
|
||||
if arr, ok := v[""].([]float64); ok {
|
||||
s.NewTable()
|
||||
for i, elem := range arr {
|
||||
s.PushNumber(float64(i + 1))
|
||||
s.PushNumber(elem)
|
||||
s.SetTable(-3)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return s.pushTableUnsafe(v)
|
||||
case []float64:
|
||||
s.NewTable()
|
||||
for i, elem := range v {
|
||||
s.PushNumber(float64(i + 1))
|
||||
s.PushNumber(elem)
|
||||
s.SetTable(-3)
|
||||
}
|
||||
case []interface{}:
|
||||
s.NewTable()
|
||||
for i, elem := range v {
|
||||
s.PushNumber(float64(i + 1))
|
||||
if err := s.pushValueUnsafe(elem); err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetTable(-3)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %T", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToValue converts a Lua value to a Go value
|
||||
func (s *State) ToValue(index int) (interface{}, error) {
|
||||
if s.safeStack {
|
||||
return stackGuardValue[interface{}](s, func() (interface{}, error) {
|
||||
return s.toValueUnsafe(index)
|
||||
})
|
||||
}
|
||||
return s.toValueUnsafe(index)
|
||||
}
|
||||
|
||||
func (s *State) toValueUnsafe(index int) (interface{}, error) {
|
||||
switch s.GetType(index) {
|
||||
case TypeNil:
|
||||
return nil, nil
|
||||
case TypeBoolean:
|
||||
return s.ToBoolean(index), nil
|
||||
case TypeNumber:
|
||||
return s.ToNumber(index), nil
|
||||
case TypeString:
|
||||
return s.ToString(index), nil
|
||||
case TypeTable:
|
||||
if !s.IsTable(index) {
|
||||
return nil, fmt.Errorf("not a table at index %d", index)
|
||||
}
|
||||
return s.toTableUnsafe(index)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
|
||||
}
|
||||
}
|
||||
|
||||
// Simple operations remain unchanged as they don't need stack protection
|
||||
|
||||
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
|
||||
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
|
||||
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
|
||||
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
|
||||
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
|
||||
func (s *State) ToString(index int) string {
|
||||
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
|
||||
}
|
||||
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
|
||||
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
|
||||
|
||||
// Push operations
|
||||
|
||||
func (s *State) PushNil() { C.lua_pushnil(s.L) }
|
||||
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
|
||||
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
|
||||
func (s *State) PushString(str string) {
|
||||
cstr := C.CString(str)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_pushstring(s.L, cstr)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func bool2int(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *State) absIndex(index int) int {
|
||||
if index > 0 || index <= LUA_REGISTRYINDEX {
|
||||
return index
|
||||
}
|
||||
return s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
// SetField sets a field in a table at the given index with cached absolute index
|
||||
func (s *State) SetField(index int, key string) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
cstr := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_setfield(s.L, C.int(absIdx), cstr)
|
||||
}
|
||||
|
||||
// GetField gets a field from a table with cached absolute index
|
||||
func (s *State) GetField(index int, key string) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cstr := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_getfield(s.L, C.int(absIdx), cstr)
|
||||
}
|
||||
|
||||
// GetGlobal gets a global variable and pushes it onto the stack
|
||||
func (s *State) GetGlobal(name string) {
|
||||
if s.safeStack {
|
||||
if err := s.checkStack(1); err != nil {
|
||||
s.PushNil()
|
||||
return
|
||||
}
|
||||
}
|
||||
cstr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||
}
|
||||
|
||||
// SetGlobal sets a global variable from the value at the top of the stack
|
||||
func (s *State) SetGlobal(name string) {
|
||||
// SetGlobal doesn't need stack space checking as it pops the value
|
||||
cstr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
|
||||
}
|
||||
|
||||
// Remove removes element with cached absolute index
|
||||
func (s *State) Remove(index int) {
|
||||
absIdx := index
|
||||
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
|
||||
absIdx = s.GetTop() + index + 1
|
||||
}
|
||||
C.lua_remove(s.L, C.int(absIdx))
|
||||
}
|
||||
|
||||
// DoFile executes a Lua file with appropriate stack management
|
||||
func (s *State) DoFile(filename string) error {
|
||||
cfilename := C.CString(filename)
|
||||
defer C.free(unsafe.Pointer(cfilename))
|
||||
|
||||
if s.safeStack {
|
||||
return stackGuardErr(s, func() error {
|
||||
return s.safeCall(func() C.int {
|
||||
return C.do_file(s.L, cfilename)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
status := C.do_file(s.L, cfilename)
|
||||
if status != 0 {
|
||||
return &LuaError{
|
||||
Code: int(status),
|
||||
Message: s.ToString(-1),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) SetPackagePath(path string) error {
|
||||
path = filepath.ToSlash(path)
|
||||
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
|
||||
return fmt.Errorf("setting package.path: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) AddPackagePath(path string) error {
|
||||
path = filepath.ToSlash(path)
|
||||
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
|
||||
return fmt.Errorf("adding to package.path: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
276
wrapper_test.go
Normal file
276
wrapper_test.go
Normal file
|
@ -0,0 +1,276 @@
|
|||
package luajit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type stateFactory struct {
|
||||
name string
|
||||
new func() *State
|
||||
}
|
||||
|
||||
var factories = []stateFactory{
|
||||
{"unsafe", New},
|
||||
{"safe", NewSafe},
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple addition", "return 1 + 1", false},
|
||||
{"set global", "test = 42", false},
|
||||
{"syntax error", "this is not valid lua", true},
|
||||
{"runtime error", "error('test error')", true},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, tt := range tests {
|
||||
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
err := L.DoString(tt.code)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushAndGetValues(t *testing.T) {
|
||||
values := []struct {
|
||||
name string
|
||||
push func(*State)
|
||||
check func(*State) error
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
push: func(L *State) { L.PushString("hello") },
|
||||
check: func(L *State) error {
|
||||
if got := L.ToString(-1); got != "hello" {
|
||||
return fmt.Errorf("got %q, want %q", got, "hello")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
push: func(L *State) { L.PushNumber(42.5) },
|
||||
check: func(L *State) error {
|
||||
if got := L.ToNumber(-1); got != 42.5 {
|
||||
return fmt.Errorf("got %f, want %f", got, 42.5)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
push: func(L *State) { L.PushBoolean(true) },
|
||||
check: func(L *State) error {
|
||||
if got := L.ToBoolean(-1); !got {
|
||||
return fmt.Errorf("got %v, want true", got)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
push: func(L *State) { L.PushNil() },
|
||||
check: func(L *State) error {
|
||||
if typ := L.GetType(-1); typ != TypeNil {
|
||||
return fmt.Errorf("got type %v, want TypeNil", typ)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range factories {
|
||||
for _, v := range values {
|
||||
t.Run(f.name+"/"+v.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
v.push(L)
|
||||
if err := v.check(L); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStackManipulation(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Push values
|
||||
values := []string{"first", "second", "third"}
|
||||
for _, v := range values {
|
||||
L.PushString(v)
|
||||
}
|
||||
|
||||
// Check size
|
||||
if top := L.GetTop(); top != len(values) {
|
||||
t.Errorf("stack size = %d, want %d", top, len(values))
|
||||
}
|
||||
|
||||
// Pop one value
|
||||
L.Pop(1)
|
||||
|
||||
// Check new top
|
||||
if str := L.ToString(-1); str != "second" {
|
||||
t.Errorf("top value = %q, want 'second'", str)
|
||||
}
|
||||
|
||||
// Check new size
|
||||
if top := L.GetTop(); top != len(values)-1 {
|
||||
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobals(t *testing.T) {
|
||||
for _, f := range factories {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
L := f.new()
|
||||
if L == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer L.Close()
|
||||
|
||||
// Test via Lua
|
||||
if err := L.DoString(`globalVar = "test"`); err != nil {
|
||||
t.Fatalf("DoString error: %v", err)
|
||||
}
|
||||
|
||||
// Get the global
|
||||
L.GetGlobal("globalVar")
|
||||
if str := L.ToString(-1); str != "test" {
|
||||
t.Errorf("global value = %q, want 'test'", str)
|
||||
}
|
||||
L.Pop(1)
|
||||
|
||||
// Set and get via API
|
||||
L.PushNumber(42)
|
||||
L.SetGlobal("testNum")
|
||||
|
||||
L.GetGlobal("testNum")
|
||||
if num := L.ToNumber(-1); num != 42 {
|
||||
t.Errorf("global number = %f, want 42", num)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoFile(t *testing.T) {
|
||||
L := NewSafe()
|
||||
defer L.Close()
|
||||
|
||||
// Create test file
|
||||
content := []byte(`
|
||||
function add(a, b)
|
||||
return a + b
|
||||
end
|
||||
result = add(40, 2)
|
||||
`)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
filename := filepath.Join(tmpDir, "test.lua")
|
||||
if err := os.WriteFile(filename, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
if err := L.DoFile(filename); err != nil {
|
||||
t.Fatalf("DoFile failed: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("Expected result=42, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAndPackagePath(t *testing.T) {
|
||||
L := NewSafe()
|
||||
defer L.Close()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create module file
|
||||
moduleContent := []byte(`
|
||||
local M = {}
|
||||
function M.multiply(a, b)
|
||||
return a * b
|
||||
end
|
||||
return M
|
||||
`)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil {
|
||||
t.Fatalf("Failed to create module file: %v", err)
|
||||
}
|
||||
|
||||
// Add module path and test require
|
||||
if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil {
|
||||
t.Fatalf("AddPackagePath failed: %v", err)
|
||||
}
|
||||
|
||||
if err := L.DoString(`
|
||||
local math = require("mathmod")
|
||||
result = math.multiply(6, 7)
|
||||
`); err != nil {
|
||||
t.Fatalf("Failed to require module: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("result")
|
||||
if result := L.ToNumber(-1); result != 42 {
|
||||
t.Errorf("Expected result=42, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetPackagePath(t *testing.T) {
|
||||
L := NewSafe()
|
||||
defer L.Close()
|
||||
|
||||
customPath := "./custom/?.lua"
|
||||
if err := L.SetPackagePath(customPath); err != nil {
|
||||
t.Fatalf("SetPackagePath failed: %v", err)
|
||||
}
|
||||
|
||||
L.GetGlobal("package")
|
||||
L.GetField(-1, "path")
|
||||
if path := L.ToString(-1); path != customPath {
|
||||
t.Errorf("Expected package.path=%q, got %q", customPath, path)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user