Implement the wrapper

This commit is contained in:
Sky Johnson 2025-01-24 19:53:09 -06:00
parent 2f6764aef1
commit 81501915e3
15 changed files with 1326 additions and 0 deletions

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

9
.idea/ljtg.iml Normal file
View File

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MaterialThemeProjectNewConfig">
<option name="metadata">
<MTProjectMetadataState>
<option name="userId" value="3fdbffe2:19499e3ad4d:-7fff" />
</MTProjectMetadataState>
</option>
</component>
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/ljtg.iml" filepath="$PROJECT_DIR$/.idea/ljtg.iml" />
</modules>
</component>
</project>

8
.idea/vcs.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/luajit" vcs="Git" />
<mapping directory="$PROJECT_DIR$/luajit/luajit" vcs="Git" />
</component>
</project>

98
functions.go Normal file
View File

@ -0,0 +1,98 @@
package luajit
/*
#include <lua.h>
#include <lauxlib.h>
#include <stdlib.h>
extern int goFunctionWrapper(lua_State* L);
static int get_upvalue_index(int i) {
return -10002 - i; // LUA_GLOBALSINDEX - i
}
*/
import "C"
import (
"fmt"
"sync"
"unsafe"
)
type GoFunction func(*State) int
var (
functionRegistry = struct {
sync.RWMutex
funcs map[unsafe.Pointer]GoFunction
}{
funcs: make(map[unsafe.Pointer]GoFunction),
}
)
//export goFunctionWrapper
func goFunctionWrapper(L *C.lua_State) C.int {
state := &State{L: L, safeStack: true}
// Get upvalue using standard Lua 5.1 macro
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
if ptr == nil {
state.PushString("error: function not found")
return -1
}
functionRegistry.RLock()
fn, ok := functionRegistry.funcs[ptr]
functionRegistry.RUnlock()
if !ok {
state.PushString("error: function not found in registry")
return -1
}
result := fn(state)
return C.int(result)
}
func (s *State) PushGoFunction(fn GoFunction) error {
// Push lightuserdata as upvalue and create closure
ptr := C.malloc(1)
if ptr == nil {
return fmt.Errorf("failed to allocate memory for function pointer")
}
functionRegistry.Lock()
functionRegistry.funcs[ptr] = fn
functionRegistry.Unlock()
C.lua_pushlightuserdata(s.L, ptr)
C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1)
return nil
}
func (s *State) RegisterGoFunction(name string, fn GoFunction) error {
if err := s.PushGoFunction(fn); err != nil {
return err
}
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
return nil
}
func (s *State) UnregisterGoFunction(name string) {
s.PushNil()
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cname)
}
func (s *State) Cleanup() {
functionRegistry.Lock()
defer functionRegistry.Unlock()
for ptr := range functionRegistry.funcs {
C.free(ptr)
}
functionRegistry.funcs = make(map[unsafe.Pointer]GoFunction)
}

109
functions_test.go Normal file
View File

@ -0,0 +1,109 @@
package luajit
import "testing"
func TestGoFunctions(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
addFunc := func(s *State) int {
s.PushNumber(s.ToNumber(1) + s.ToNumber(2))
return 1
}
if err := L.RegisterGoFunction("add", addFunc); err != nil {
t.Fatalf("Failed to register function: %v", err)
}
// Test basic function call
if err := L.DoString("result = add(40, 2)"); err != nil {
t.Fatalf("Failed to call function: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("got %v, want 42", result)
}
L.Pop(1)
// Test multiple return values
multiFunc := func(s *State) int {
s.PushString("hello")
s.PushNumber(42)
s.PushBoolean(true)
return 3
}
if err := L.RegisterGoFunction("multi", multiFunc); err != nil {
t.Fatalf("Failed to register multi function: %v", err)
}
code := `
a, b, c = multi()
result = (a == "hello" and b == 42 and c == true)
`
if err := L.DoString(code); err != nil {
t.Fatalf("Failed to call multi function: %v", err)
}
L.GetGlobal("result")
if !L.ToBoolean(-1) {
t.Error("Multiple return values test failed")
}
L.Pop(1)
// Test error handling
errFunc := func(s *State) int {
s.PushString("test error")
return -1
}
if err := L.RegisterGoFunction("err", errFunc); err != nil {
t.Fatalf("Failed to register error function: %v", err)
}
if err := L.DoString("err()"); err == nil {
t.Error("Expected error from error function")
}
// Test unregistering
L.UnregisterGoFunction("add")
if err := L.DoString("add(1, 2)"); err == nil {
t.Error("Expected error calling unregistered function")
}
})
}
}
func TestStackSafety(t *testing.T) {
L := NewSafe()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
defer L.Cleanup()
// Test stack overflow protection
overflowFunc := func(s *State) int {
for i := 0; i < 100; i++ {
s.PushNumber(float64(i))
}
s.PushString("done")
return 101
}
if err := L.RegisterGoFunction("overflow", overflowFunc); err != nil {
t.Fatal(err)
}
if err := L.DoString("overflow()"); err != nil {
t.Logf("Got expected error: %v", err)
}
}

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.sharkk.net/Sky/LuaJIT-to-Go
go 1.23.4

1
luajit Submodule

@ -0,0 +1 @@
Subproject commit e4fd777d6ad41d338125b095abc98e4dd54c05d7

146
stack.go Normal file
View File

@ -0,0 +1,146 @@
package luajit
/*
#include <lua.h>
#include <lauxlib.h>
*/
import "C"
import "fmt"
// LuaError represents an error from the Lua state
type LuaError struct {
Code int
Message string
}
func (e *LuaError) Error() string {
return fmt.Sprintf("lua error (code=%d): %s", e.Code, e.Message)
}
// Stack management constants from lua.h
const (
LUA_MINSTACK = 20 // Minimum Lua stack size
LUA_MAXSTACK = 1000000 // Maximum Lua stack size
LUA_REGISTRYINDEX = -10000 // Pseudo-index for the Lua registry
LUA_GLOBALSINDEX = -10002 // Pseudo-index for globals table
)
// checkStack ensures there is enough space on the Lua stack
func (s *State) checkStack(n int) error {
if C.lua_checkstack(s.L, C.int(n)) == 0 {
return fmt.Errorf("stack overflow (cannot allocate %d slots)", n)
}
return nil
}
// safeCall wraps a potentially dangerous C call with stack checking
func (s *State) safeCall(f func() C.int) error {
// Save current stack size
top := s.GetTop()
// Ensure we have enough stack space (minimum 20 slots as per Lua standard)
if err := s.checkStack(LUA_MINSTACK); err != nil {
return err
}
// Make the call
status := f()
// Check for errors
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1) // Remove error message
return err
}
// Verify stack integrity
newTop := s.GetTop()
if newTop < top {
return fmt.Errorf("stack underflow: %d slots lost", top-newTop)
}
return nil
}
// stackGuard wraps a function with stack checking and restoration
func stackGuard[T any](s *State, f func() (T, error)) (T, error) {
// Save current stack size
top := s.GetTop()
// Run the protected function
result, err := f()
// Restore stack size
newTop := s.GetTop()
if newTop > top {
s.Pop(newTop - top)
}
return result, err
}
// stackGuardValue executes a function that returns a value and error with stack protection
func stackGuardValue[T any](s *State, f func() (T, error)) (T, error) {
// Save current stack size
top := s.GetTop()
// Run the protected function
result, err := f()
// Restore stack size
newTop := s.GetTop()
if newTop > top {
s.Pop(newTop - top)
}
return result, err
}
// stackGuardErr executes a function that only returns an error with stack protection
func stackGuardErr(s *State, f func() error) error {
// Save current stack size
top := s.GetTop()
// Run the protected function
err := f()
// Restore stack size
newTop := s.GetTop()
if newTop > top {
s.Pop(newTop - top)
}
return err
}
// getStackTrace returns the current Lua stack trace
func (s *State) getStackTrace() string {
// Push debug.traceback function
s.GetGlobal("debug")
if !s.IsTable(-1) {
s.Pop(1)
return "stack trace not available (debug module not loaded)"
}
s.GetField(-1, "traceback")
if !s.IsFunction(-1) {
s.Pop(2)
return "stack trace not available (debug.traceback not found)"
}
// Call debug.traceback
if err := s.safeCall(func() C.int {
return C.lua_pcall(s.L, 0, 1, 0)
}); err != nil {
return fmt.Sprintf("error getting stack trace: %v", err)
}
// Get the resulting string
trace := s.ToString(-1)
s.Pop(1) // Remove the trace string
return trace
}

177
table.go Normal file
View File

@ -0,0 +1,177 @@
package luajit
/*
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
static int get_table_length(lua_State *L, int index) {
return lua_objlen(L, index);
}
*/
import "C"
import (
"fmt"
)
// TableValue represents any value that can be stored in a Lua table
type TableValue interface {
~string | ~float64 | ~bool | ~int | ~map[string]interface{} | ~[]float64 | ~[]interface{}
}
func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) }
// ToTable converts a Lua table to a Go map
func (s *State) ToTable(index int) (map[string]interface{}, error) {
if s.safeStack {
return stackGuardValue[map[string]interface{}](s, func() (map[string]interface{}, error) {
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
})
}
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
}
func (s *State) pushTableSafe(table map[string]interface{}) error {
size := 2
if err := s.checkStack(size); err != nil {
return fmt.Errorf("insufficient stack space: %w", err)
}
s.NewTable()
for k, v := range table {
if err := s.pushValueSafe(v); err != nil {
return err
}
s.SetField(-2, k)
}
return nil
}
func (s *State) pushTableUnsafe(table map[string]interface{}) error {
s.NewTable()
for k, v := range table {
if err := s.pushValueUnsafe(v); err != nil {
return err
}
s.SetField(-2, k)
}
return nil
}
func (s *State) toTableSafe(index int) (map[string]interface{}, error) {
if err := s.checkStack(2); err != nil {
return nil, err
}
return s.toTableUnsafe(index)
}
func (s *State) toTableUnsafe(index int) (map[string]interface{}, error) {
absIdx := s.absIndex(index)
table := make(map[string]interface{})
// Check if it's an array-like table
length := s.GetTableLength(absIdx)
if length > 0 {
array := make([]float64, length)
isArray := true
// Try to convert to array
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
if s.GetType(-1) != TypeNumber {
isArray = false
s.Pop(1)
break
}
array[i-1] = s.ToNumber(-1)
s.Pop(1)
}
if isArray {
return map[string]interface{}{"": array}, nil
}
}
// Handle regular table
s.PushNil()
for C.lua_next(s.L, C.int(absIdx)) != 0 {
key := ""
valueType := C.lua_type(s.L, -2)
if valueType == C.LUA_TSTRING {
key = s.ToString(-2)
} else if valueType == C.LUA_TNUMBER {
key = fmt.Sprintf("%g", s.ToNumber(-2))
}
value, err := s.toValueUnsafe(-1)
if err != nil {
s.Pop(1)
return nil, err
}
// Handle nested array case
if m, ok := value.(map[string]interface{}); ok {
if arr, ok := m[""]; ok {
value = arr
}
}
table[key] = value
s.Pop(1)
}
return table, nil
}
// NewTable creates a new table and pushes it onto the stack
func (s *State) NewTable() {
if s.safeStack {
if err := s.checkStack(1); err != nil {
// Since we can't return an error, we'll push nil instead
s.PushNil()
return
}
}
C.lua_createtable(s.L, 0, 0)
}
// SetTable sets a table field with cached absolute index
func (s *State) SetTable(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
C.lua_settable(s.L, C.int(absIdx))
}
// GetTable gets a table field with cached absolute index
func (s *State) GetTable(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
C.lua_gettable(s.L, C.int(absIdx))
}
// PushTable pushes a Go map onto the Lua stack as a table with stack checking
func (s *State) PushTable(table map[string]interface{}) error {
if s.safeStack {
return s.pushTableSafe(table)
}
return s.pushTableUnsafe(table)
}

97
table_test.go Normal file
View File

@ -0,0 +1,97 @@
package luajit
import (
"math"
"testing"
)
func TestTableOperations(t *testing.T) {
tests := []struct {
name string
data map[string]interface{}
}{
{
name: "empty",
data: map[string]interface{}{},
},
{
name: "primitives",
data: map[string]interface{}{
"str": "hello",
"num": 42.0,
"bool": true,
"array": []float64{1.1, 2.2, 3.3},
},
},
{
name: "nested",
data: map[string]interface{}{
"nested": map[string]interface{}{
"value": 123.0,
"array": []float64{4.4, 5.5},
},
},
},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
if err := L.PushTable(tt.data); err != nil {
t.Fatalf("PushTable() error = %v", err)
}
got, err := L.ToTable(-1)
if err != nil {
t.Fatalf("ToTable() error = %v", err)
}
if !tablesEqual(got, tt.data) {
t.Errorf("table mismatch\ngot = %v\nwant = %v", got, tt.data)
}
})
}
}
}
func tablesEqual(a, b map[string]interface{}) bool {
if len(a) != len(b) {
return false
}
for k, v1 := range a {
v2, ok := b[k]
if !ok {
return false
}
switch v1 := v1.(type) {
case map[string]interface{}:
v2, ok := v2.(map[string]interface{})
if !ok || !tablesEqual(v1, v2) {
return false
}
case []float64:
v2, ok := v2.([]float64)
if !ok || len(v1) != len(v2) {
return false
}
for i := range v1 {
if math.Abs(v1[i]-v2[i]) > 1e-10 {
return false
}
}
default:
if v1 != v2 {
return false
}
}
}
return true
}

51
types.go Normal file
View File

@ -0,0 +1,51 @@
package luajit
/*
#include <lua.h>
*/
import "C"
// LuaType represents Lua value types
type LuaType int
const (
// These constants must match lua.h's LUA_T* values
TypeNone LuaType = -1
TypeNil LuaType = 0
TypeBoolean LuaType = 1
TypeLightUserData LuaType = 2
TypeNumber LuaType = 3
TypeString LuaType = 4
TypeTable LuaType = 5
TypeFunction LuaType = 6
TypeUserData LuaType = 7
TypeThread LuaType = 8
)
// String returns the string representation of the Lua type
func (t LuaType) String() string {
switch t {
case TypeNone:
return "none"
case TypeNil:
return "nil"
case TypeBoolean:
return "boolean"
case TypeLightUserData:
return "lightuserdata"
case TypeNumber:
return "number"
case TypeString:
return "string"
case TypeTable:
return "table"
case TypeFunction:
return "function"
case TypeUserData:
return "userdata"
case TypeThread:
return "thread"
default:
return "unknown"
}
}

325
wrapper.go Normal file
View File

@ -0,0 +1,325 @@
package luajit
/*
#cgo CFLAGS: -I${SRCDIR}/luajit
#cgo windows LDFLAGS: -L${SRCDIR}/luajit -llua51
#cgo !windows LDFLAGS: -L${SRCDIR}/luajit -lluajit
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
void init_dll_paths(void);
static int do_string(lua_State *L, const char *s) {
int status = luaL_loadstring(L, s);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
}
static int do_file(lua_State *L, const char *filename) {
int status = luaL_loadfile(L, filename);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
}
*/
import "C"
import (
"fmt"
"path/filepath"
"unsafe"
)
// State represents a Lua state with configurable stack safety
type State struct {
L *C.lua_State
safeStack bool
}
// NewSafe creates a new Lua state with full stack safety guarantees
func NewSafe() *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: true}
}
// New creates a new Lua state with minimal stack checking
func New() *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: false}
}
// Close closes the Lua state
func (s *State) Close() {
if s.L != nil {
C.lua_close(s.L)
s.L = nil
}
}
// DoString executes a Lua string with appropriate stack management
func (s *State) DoString(str string) error {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
if s.safeStack {
return stackGuardErr(s, func() error {
return s.safeCall(func() C.int {
return C.do_string(s.L, cstr)
})
})
}
status := C.do_string(s.L, cstr)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
}
return nil
}
// PushValue pushes a Go value onto the stack
func (s *State) PushValue(v interface{}) error {
if s.safeStack {
return stackGuardErr(s, func() error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
})
}
return s.pushValueUnsafe(v)
}
func (s *State) pushValueSafe(v interface{}) error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
}
func (s *State) pushValueUnsafe(v interface{}) error {
switch v := v.(type) {
case nil:
s.PushNil()
case bool:
s.PushBoolean(v)
case float64:
s.PushNumber(v)
case int:
s.PushNumber(float64(v))
case string:
s.PushString(v)
case map[string]interface{}:
// Special case: handle array stored in map
if arr, ok := v[""].([]float64); ok {
s.NewTable()
for i, elem := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
return nil
}
return s.pushTableUnsafe(v)
case []float64:
s.NewTable()
for i, elem := range v {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
case []interface{}:
s.NewTable()
for i, elem := range v {
s.PushNumber(float64(i + 1))
if err := s.pushValueUnsafe(elem); err != nil {
return err
}
s.SetTable(-3)
}
default:
return fmt.Errorf("unsupported type: %T", v)
}
return nil
}
// ToValue converts a Lua value to a Go value
func (s *State) ToValue(index int) (interface{}, error) {
if s.safeStack {
return stackGuardValue[interface{}](s, func() (interface{}, error) {
return s.toValueUnsafe(index)
})
}
return s.toValueUnsafe(index)
}
func (s *State) toValueUnsafe(index int) (interface{}, error) {
switch s.GetType(index) {
case TypeNil:
return nil, nil
case TypeBoolean:
return s.ToBoolean(index), nil
case TypeNumber:
return s.ToNumber(index), nil
case TypeString:
return s.ToString(index), nil
case TypeTable:
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
default:
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
}
}
// Simple operations remain unchanged as they don't need stack protection
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
func (s *State) ToString(index int) string {
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
}
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
// Push operations
func (s *State) PushNil() { C.lua_pushnil(s.L) }
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
func (s *State) PushString(str string) {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushstring(s.L, cstr)
}
// Helper functions
func bool2int(b bool) int {
if b {
return 1
}
return 0
}
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// SetField sets a field in a table at the given index with cached absolute index
func (s *State) SetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.int(absIdx), cstr)
}
// GetField gets a field from a table with cached absolute index
func (s *State) GetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.int(absIdx), cstr)
}
// GetGlobal gets a global variable and pushes it onto the stack
func (s *State) GetGlobal(name string) {
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
}
// SetGlobal sets a global variable from the value at the top of the stack
func (s *State) SetGlobal(name string) {
// SetGlobal doesn't need stack space checking as it pops the value
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
}
// Remove removes element with cached absolute index
func (s *State) Remove(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
C.lua_remove(s.L, C.int(absIdx))
}
// DoFile executes a Lua file with appropriate stack management
func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
if s.safeStack {
return stackGuardErr(s, func() error {
return s.safeCall(func() C.int {
return C.do_file(s.L, cfilename)
})
})
}
status := C.do_file(s.L, cfilename)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
}
return nil
}
func (s *State) SetPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
return fmt.Errorf("setting package.path: %w", err)
}
return nil
}
func (s *State) AddPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
return fmt.Errorf("adding to package.path: %w", err)
}
return nil
}

276
wrapper_test.go Normal file
View File

@ -0,0 +1,276 @@
package luajit
import (
"fmt"
"os"
"path/filepath"
"testing"
)
type stateFactory struct {
name string
new func() *State
}
var factories = []stateFactory{
{"unsafe", New},
{"safe", NewSafe},
}
func TestNew(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
})
}
}
func TestDoString(t *testing.T) {
tests := []struct {
name string
code string
wantErr bool
}{
{"simple addition", "return 1 + 1", false},
{"set global", "test = 42", false},
{"syntax error", "this is not valid lua", true},
{"runtime error", "error('test error')", true},
}
for _, f := range factories {
for _, tt := range tests {
t.Run(f.name+"/"+tt.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
err := L.DoString(tt.code)
if (err != nil) != tt.wantErr {
t.Errorf("DoString() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
}
func TestPushAndGetValues(t *testing.T) {
values := []struct {
name string
push func(*State)
check func(*State) error
}{
{
name: "string",
push: func(L *State) { L.PushString("hello") },
check: func(L *State) error {
if got := L.ToString(-1); got != "hello" {
return fmt.Errorf("got %q, want %q", got, "hello")
}
return nil
},
},
{
name: "number",
push: func(L *State) { L.PushNumber(42.5) },
check: func(L *State) error {
if got := L.ToNumber(-1); got != 42.5 {
return fmt.Errorf("got %f, want %f", got, 42.5)
}
return nil
},
},
{
name: "boolean",
push: func(L *State) { L.PushBoolean(true) },
check: func(L *State) error {
if got := L.ToBoolean(-1); !got {
return fmt.Errorf("got %v, want true", got)
}
return nil
},
},
{
name: "nil",
push: func(L *State) { L.PushNil() },
check: func(L *State) error {
if typ := L.GetType(-1); typ != TypeNil {
return fmt.Errorf("got type %v, want TypeNil", typ)
}
return nil
},
},
}
for _, f := range factories {
for _, v := range values {
t.Run(f.name+"/"+v.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
v.push(L)
if err := v.check(L); err != nil {
t.Error(err)
}
})
}
}
}
func TestStackManipulation(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Push values
values := []string{"first", "second", "third"}
for _, v := range values {
L.PushString(v)
}
// Check size
if top := L.GetTop(); top != len(values) {
t.Errorf("stack size = %d, want %d", top, len(values))
}
// Pop one value
L.Pop(1)
// Check new top
if str := L.ToString(-1); str != "second" {
t.Errorf("top value = %q, want 'second'", str)
}
// Check new size
if top := L.GetTop(); top != len(values)-1 {
t.Errorf("stack size after pop = %d, want %d", top, len(values)-1)
}
})
}
}
func TestGlobals(t *testing.T) {
for _, f := range factories {
t.Run(f.name, func(t *testing.T) {
L := f.new()
if L == nil {
t.Fatal("Failed to create Lua state")
}
defer L.Close()
// Test via Lua
if err := L.DoString(`globalVar = "test"`); err != nil {
t.Fatalf("DoString error: %v", err)
}
// Get the global
L.GetGlobal("globalVar")
if str := L.ToString(-1); str != "test" {
t.Errorf("global value = %q, want 'test'", str)
}
L.Pop(1)
// Set and get via API
L.PushNumber(42)
L.SetGlobal("testNum")
L.GetGlobal("testNum")
if num := L.ToNumber(-1); num != 42 {
t.Errorf("global number = %f, want 42", num)
}
})
}
}
func TestDoFile(t *testing.T) {
L := NewSafe()
defer L.Close()
// Create test file
content := []byte(`
function add(a, b)
return a + b
end
result = add(40, 2)
`)
tmpDir := t.TempDir()
filename := filepath.Join(tmpDir, "test.lua")
if err := os.WriteFile(filename, content, 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
if err := L.DoFile(filename); err != nil {
t.Fatalf("DoFile failed: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestRequireAndPackagePath(t *testing.T) {
L := NewSafe()
defer L.Close()
tmpDir := t.TempDir()
// Create module file
moduleContent := []byte(`
local M = {}
function M.multiply(a, b)
return a * b
end
return M
`)
if err := os.WriteFile(filepath.Join(tmpDir, "mathmod.lua"), moduleContent, 0644); err != nil {
t.Fatalf("Failed to create module file: %v", err)
}
// Add module path and test require
if err := L.AddPackagePath(filepath.Join(tmpDir, "?.lua")); err != nil {
t.Fatalf("AddPackagePath failed: %v", err)
}
if err := L.DoString(`
local math = require("mathmod")
result = math.multiply(6, 7)
`); err != nil {
t.Fatalf("Failed to require module: %v", err)
}
L.GetGlobal("result")
if result := L.ToNumber(-1); result != 42 {
t.Errorf("Expected result=42, got %v", result)
}
}
func TestSetPackagePath(t *testing.T) {
L := NewSafe()
defer L.Close()
customPath := "./custom/?.lua"
if err := L.SetPackagePath(customPath); err != nil {
t.Fatalf("SetPackagePath failed: %v", err)
}
L.GetGlobal("package")
L.GetField(-1, "path")
if path := L.ToString(-1); path != customPath {
t.Errorf("Expected package.path=%q, got %q", customPath, path)
}
}