305 lines
7.3 KiB
Go
305 lines
7.3 KiB
Go
package luajit
|
|
|
|
/*
|
|
#include <lua.h>
|
|
#include <lauxlib.h>
|
|
#include <stdlib.h>
|
|
|
|
extern int goFunctionWrapper(lua_State* L);
|
|
|
|
// Helper function to access upvalues
|
|
static int get_upvalue_index(int i) {
|
|
return lua_upvalueindex(i);
|
|
}
|
|
*/
|
|
import "C"
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"unsafe"
|
|
)
|
|
|
|
// GoFunction defines the signature for Go functions callable from Lua
|
|
type GoFunction func(*State) int
|
|
|
|
// LuaFunction represents a Lua function callable from Go
|
|
type LuaFunction struct {
|
|
state *State
|
|
ref int
|
|
}
|
|
|
|
// Static registry size reduces resizing operations
|
|
const initialRegistrySize = 64
|
|
|
|
var (
|
|
// functionRegistry stores all registered Go functions
|
|
functionRegistry = struct {
|
|
sync.RWMutex
|
|
funcs map[unsafe.Pointer]GoFunction
|
|
initOnce sync.Once
|
|
}{
|
|
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
|
|
}
|
|
|
|
// luaFunctionRegistry stores Lua function references
|
|
luaFunctionRegistry = struct {
|
|
sync.RWMutex
|
|
refs map[int]*State
|
|
}{
|
|
refs: make(map[int]*State),
|
|
}
|
|
|
|
// statePool reuses State structs to avoid allocations
|
|
statePool = sync.Pool{
|
|
New: func() any {
|
|
return &State{}
|
|
},
|
|
}
|
|
)
|
|
|
|
//export goFunctionWrapper
|
|
func goFunctionWrapper(L *C.lua_State) C.int {
|
|
state := statePool.Get().(*State)
|
|
state.L = L
|
|
defer statePool.Put(state)
|
|
|
|
ptr := C.lua_touserdata(L, C.get_upvalue_index(1))
|
|
if ptr == nil {
|
|
state.PushString("error: function pointer not found")
|
|
return -1
|
|
}
|
|
|
|
functionRegistry.RLock()
|
|
fn, ok := functionRegistry.funcs[ptr]
|
|
functionRegistry.RUnlock()
|
|
|
|
if !ok {
|
|
state.PushString("error: function not found in registry")
|
|
return -1
|
|
}
|
|
|
|
return C.int(fn(state))
|
|
}
|
|
|
|
// PushGoFunction wraps a Go function and pushes it onto the Lua stack
|
|
func (s *State) PushGoFunction(fn GoFunction) error {
|
|
// Allocate unique memory for each function
|
|
ptr := C.malloc(C.size_t(unsafe.Sizeof(uintptr(0))))
|
|
if ptr == nil {
|
|
return fmt.Errorf("failed to allocate memory for function pointer")
|
|
}
|
|
|
|
// Register the function
|
|
functionRegistry.Lock()
|
|
functionRegistry.funcs[ptr] = fn
|
|
functionRegistry.Unlock()
|
|
|
|
// Push the pointer as lightuserdata
|
|
C.lua_pushlightuserdata(s.L, ptr)
|
|
C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1)
|
|
|
|
return nil
|
|
}
|
|
|
|
// RegisterGoFunction registers a Go function as a global Lua function
|
|
func (s *State) RegisterGoFunction(name string, fn GoFunction) error {
|
|
if err := s.PushGoFunction(fn); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.SetGlobal(name)
|
|
return nil
|
|
}
|
|
|
|
// UnregisterGoFunction removes a global function
|
|
func (s *State) UnregisterGoFunction(name string) {
|
|
s.PushNil()
|
|
s.SetGlobal(name)
|
|
}
|
|
|
|
// StoreLuaFunction stores a Lua function from the stack and returns a reference
|
|
func (s *State) StoreLuaFunction(index int) (*LuaFunction, error) {
|
|
if !s.IsFunction(index) {
|
|
return nil, fmt.Errorf("value at index %d is not a function", index)
|
|
}
|
|
|
|
s.PushCopy(index)
|
|
ref := int(C.luaL_ref(s.L, C.LUA_REGISTRYINDEX))
|
|
if ref == C.LUA_REFNIL {
|
|
return nil, fmt.Errorf("failed to store function reference")
|
|
}
|
|
|
|
luaFunc := &LuaFunction{
|
|
state: s,
|
|
ref: ref,
|
|
}
|
|
|
|
luaFunctionRegistry.Lock()
|
|
luaFunctionRegistry.refs[ref] = s
|
|
luaFunctionRegistry.Unlock()
|
|
|
|
return luaFunc, nil
|
|
}
|
|
|
|
// GetLuaFunction gets a global Lua function and stores it
|
|
func (s *State) GetLuaFunction(name string) (*LuaFunction, error) {
|
|
s.GetGlobal(name)
|
|
defer s.Pop(1)
|
|
|
|
if !s.IsFunction(-1) {
|
|
return nil, fmt.Errorf("global '%s' is not a function", name)
|
|
}
|
|
|
|
return s.StoreLuaFunction(-1)
|
|
}
|
|
|
|
// Call executes the Lua function with given arguments and returns results
|
|
func (lf *LuaFunction) Call(args ...any) ([]any, error) {
|
|
s := lf.state
|
|
|
|
// Push function from registry
|
|
C.lua_rawgeti(s.L, C.LUA_REGISTRYINDEX, C.int(lf.ref))
|
|
|
|
// Push arguments
|
|
for i, arg := range args {
|
|
if err := s.PushValue(arg); err != nil {
|
|
s.Pop(i + 1) // Clean up function and pushed args
|
|
return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err)
|
|
}
|
|
}
|
|
|
|
// Call function
|
|
baseTop := s.GetTop() - len(args) - 1
|
|
if err := s.Call(len(args), C.LUA_MULTRET); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Extract results
|
|
newTop := s.GetTop()
|
|
nresults := newTop - baseTop
|
|
|
|
results := make([]any, nresults)
|
|
for i := 0; i < nresults; i++ {
|
|
val, err := s.ToValue(baseTop + i + 1)
|
|
if err != nil {
|
|
results[i] = nil
|
|
} else {
|
|
results[i] = val
|
|
}
|
|
}
|
|
|
|
s.SetTop(baseTop)
|
|
return results, nil
|
|
}
|
|
|
|
// CallSingle calls the function and returns only the first result
|
|
func (lf *LuaFunction) CallSingle(args ...any) (any, error) {
|
|
results, err := lf.Call(args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(results) == 0 {
|
|
return nil, nil
|
|
}
|
|
return results[0], nil
|
|
}
|
|
|
|
// CallTyped calls the function and converts the first result to the specified type
|
|
func CallTyped[T any](lf *LuaFunction, args ...any) (T, error) {
|
|
var zero T
|
|
result, err := lf.CallSingle(args...)
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
if result == nil {
|
|
return zero, nil
|
|
}
|
|
if converted, ok := ConvertValue[T](result); ok {
|
|
return converted, nil
|
|
}
|
|
return zero, fmt.Errorf("cannot convert result to %T", zero)
|
|
}
|
|
|
|
// Release releases the Lua function reference
|
|
func (lf *LuaFunction) Release() {
|
|
if lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL {
|
|
luaFunctionRegistry.Lock()
|
|
delete(luaFunctionRegistry.refs, lf.ref)
|
|
luaFunctionRegistry.Unlock()
|
|
|
|
C.luaL_unref(lf.state.L, C.LUA_REGISTRYINDEX, C.int(lf.ref))
|
|
lf.ref = C.LUA_NOREF
|
|
}
|
|
}
|
|
|
|
// IsValid checks if the function reference is still valid
|
|
func (lf *LuaFunction) IsValid() bool {
|
|
return lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL
|
|
}
|
|
|
|
// ToGoFunction converts to a standard Go function signature
|
|
func (lf *LuaFunction) ToGoFunction() func(...any) ([]any, error) {
|
|
return func(args ...any) ([]any, error) {
|
|
return lf.Call(args...)
|
|
}
|
|
}
|
|
|
|
// CreateCallback creates a reusable callback function
|
|
func (s *State) CreateCallback(luaCode string) (*LuaFunction, error) {
|
|
if err := s.LoadString(luaCode); err != nil {
|
|
return nil, fmt.Errorf("failed to load callback code: %w", err)
|
|
}
|
|
|
|
luaFunc, err := s.StoreLuaFunction(-1)
|
|
s.Pop(1) // Remove function from stack
|
|
return luaFunc, err
|
|
}
|
|
|
|
// Cleanup frees all function pointers and clears registries
|
|
func (s *State) Cleanup() {
|
|
// Clean up Go function registry
|
|
functionRegistry.Lock()
|
|
for ptr := range functionRegistry.funcs {
|
|
C.free(ptr)
|
|
delete(functionRegistry.funcs, ptr)
|
|
}
|
|
functionRegistry.Unlock()
|
|
|
|
// Clean up Lua function registry for this state
|
|
luaFunctionRegistry.Lock()
|
|
for ref, state := range luaFunctionRegistry.refs {
|
|
if state == s {
|
|
C.luaL_unref(s.L, C.LUA_REGISTRYINDEX, C.int(ref))
|
|
delete(luaFunctionRegistry.refs, ref)
|
|
}
|
|
}
|
|
luaFunctionRegistry.Unlock()
|
|
}
|
|
|
|
// BatchRegisterGoFunctions registers multiple Go functions at once
|
|
func (s *State) BatchRegisterGoFunctions(funcs map[string]GoFunction) error {
|
|
for name, fn := range funcs {
|
|
if err := s.RegisterGoFunction(name, fn); err != nil {
|
|
return fmt.Errorf("failed to register function '%s': %w", name, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetAllLuaFunctions gets multiple global Lua functions by name
|
|
func (s *State) GetAllLuaFunctions(names ...string) (map[string]*LuaFunction, error) {
|
|
funcs := make(map[string]*LuaFunction, len(names))
|
|
for _, name := range names {
|
|
if fn, err := s.GetLuaFunction(name); err == nil {
|
|
funcs[name] = fn
|
|
} else {
|
|
// Clean up any successfully created functions
|
|
for _, f := range funcs {
|
|
f.Release()
|
|
}
|
|
return nil, fmt.Errorf("failed to get function '%s': %w", name, err)
|
|
}
|
|
}
|
|
return funcs, nil
|
|
}
|