1
0

Lua to Go function handling

This commit is contained in:
Sky Johnson 2025-07-14 16:05:34 -05:00
parent b4fe354c41
commit cba8e1b151

View File

@ -22,6 +22,12 @@ import (
// GoFunction defines the signature for Go functions callable from Lua // GoFunction defines the signature for Go functions callable from Lua
type GoFunction func(*State) int type GoFunction func(*State) int
// LuaFunction represents a Lua function callable from Go
type LuaFunction struct {
state *State
ref int
}
// Static registry size reduces resizing operations // Static registry size reduces resizing operations
const initialRegistrySize = 64 const initialRegistrySize = 64
@ -35,6 +41,14 @@ var (
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize), funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
} }
// luaFunctionRegistry stores Lua function references
luaFunctionRegistry = struct {
sync.RWMutex
refs map[int]*State
}{
refs: make(map[int]*State),
}
// statePool reuses State structs to avoid allocations // statePool reuses State structs to avoid allocations
statePool = sync.Pool{ statePool = sync.Pool{
New: func() any { New: func() any {
@ -103,14 +117,188 @@ func (s *State) UnregisterGoFunction(name string) {
s.SetGlobal(name) s.SetGlobal(name)
} }
// Cleanup frees all function pointers and clears the registry // StoreLuaFunction stores a Lua function from the stack and returns a reference
func (s *State) Cleanup() { func (s *State) StoreLuaFunction(index int) (*LuaFunction, error) {
functionRegistry.Lock() if !s.IsFunction(index) {
defer functionRegistry.Unlock() return nil, fmt.Errorf("value at index %d is not a function", index)
}
// Free all allocated pointers s.PushCopy(index)
ref := int(C.luaL_ref(s.L, C.LUA_REGISTRYINDEX))
if ref == C.LUA_REFNIL {
return nil, fmt.Errorf("failed to store function reference")
}
luaFunc := &LuaFunction{
state: s,
ref: ref,
}
luaFunctionRegistry.Lock()
luaFunctionRegistry.refs[ref] = s
luaFunctionRegistry.Unlock()
return luaFunc, nil
}
// GetLuaFunction gets a global Lua function and stores it
func (s *State) GetLuaFunction(name string) (*LuaFunction, error) {
s.GetGlobal(name)
defer s.Pop(1)
if !s.IsFunction(-1) {
return nil, fmt.Errorf("global '%s' is not a function", name)
}
return s.StoreLuaFunction(-1)
}
// Call executes the Lua function with given arguments and returns results
func (lf *LuaFunction) Call(args ...any) ([]any, error) {
s := lf.state
// Push function from registry
C.lua_rawgeti(s.L, C.LUA_REGISTRYINDEX, C.int(lf.ref))
// Push arguments
for i, arg := range args {
if err := s.PushValue(arg); err != nil {
s.Pop(i + 1) // Clean up function and pushed args
return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err)
}
}
// Call function
baseTop := s.GetTop() - len(args) - 1
if err := s.Call(len(args), C.LUA_MULTRET); err != nil {
return nil, err
}
// Extract results
newTop := s.GetTop()
nresults := newTop - baseTop
results := make([]any, nresults)
for i := 0; i < nresults; i++ {
val, err := s.ToValue(baseTop + i + 1)
if err != nil {
results[i] = nil
} else {
results[i] = val
}
}
s.SetTop(baseTop)
return results, nil
}
// CallSingle calls the function and returns only the first result
func (lf *LuaFunction) CallSingle(args ...any) (any, error) {
results, err := lf.Call(args...)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, nil
}
return results[0], nil
}
// CallTyped calls the function and converts the first result to the specified type
func CallTyped[T any](lf *LuaFunction, args ...any) (T, error) {
var zero T
result, err := lf.CallSingle(args...)
if err != nil {
return zero, err
}
if result == nil {
return zero, nil
}
if converted, ok := ConvertValue[T](result); ok {
return converted, nil
}
return zero, fmt.Errorf("cannot convert result to %T", zero)
}
// Release releases the Lua function reference
func (lf *LuaFunction) Release() {
if lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL {
luaFunctionRegistry.Lock()
delete(luaFunctionRegistry.refs, lf.ref)
luaFunctionRegistry.Unlock()
C.luaL_unref(lf.state.L, C.LUA_REGISTRYINDEX, C.int(lf.ref))
lf.ref = C.LUA_NOREF
}
}
// IsValid checks if the function reference is still valid
func (lf *LuaFunction) IsValid() bool {
return lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL
}
// ToGoFunction converts to a standard Go function signature
func (lf *LuaFunction) ToGoFunction() func(...any) ([]any, error) {
return func(args ...any) ([]any, error) {
return lf.Call(args...)
}
}
// CreateCallback creates a reusable callback function
func (s *State) CreateCallback(luaCode string) (*LuaFunction, error) {
if err := s.LoadString(luaCode); err != nil {
return nil, fmt.Errorf("failed to load callback code: %w", err)
}
luaFunc, err := s.StoreLuaFunction(-1)
s.Pop(1) // Remove function from stack
return luaFunc, err
}
// Cleanup frees all function pointers and clears registries
func (s *State) Cleanup() {
// Clean up Go function registry
functionRegistry.Lock()
for ptr := range functionRegistry.funcs { for ptr := range functionRegistry.funcs {
C.free(ptr) C.free(ptr)
delete(functionRegistry.funcs, ptr) delete(functionRegistry.funcs, ptr)
} }
functionRegistry.Unlock()
// Clean up Lua function registry for this state
luaFunctionRegistry.Lock()
for ref, state := range luaFunctionRegistry.refs {
if state == s {
C.luaL_unref(s.L, C.LUA_REGISTRYINDEX, C.int(ref))
delete(luaFunctionRegistry.refs, ref)
}
}
luaFunctionRegistry.Unlock()
}
// BatchRegisterGoFunctions registers multiple Go functions at once
func (s *State) BatchRegisterGoFunctions(funcs map[string]GoFunction) error {
for name, fn := range funcs {
if err := s.RegisterGoFunction(name, fn); err != nil {
return fmt.Errorf("failed to register function '%s': %w", name, err)
}
}
return nil
}
// GetAllLuaFunctions gets multiple global Lua functions by name
func (s *State) GetAllLuaFunctions(names ...string) (map[string]*LuaFunction, error) {
funcs := make(map[string]*LuaFunction, len(names))
for _, name := range names {
if fn, err := s.GetLuaFunction(name); err == nil {
funcs[name] = fn
} else {
// Clean up any successfully created functions
for _, f := range funcs {
f.Release()
}
return nil, fmt.Errorf("failed to get function '%s': %w", name, err)
}
}
return funcs, nil
} }