diff --git a/functions.go b/functions.go index bf550f3..9c37313 100644 --- a/functions.go +++ b/functions.go @@ -22,6 +22,12 @@ import ( // GoFunction defines the signature for Go functions callable from Lua type GoFunction func(*State) int +// LuaFunction represents a Lua function callable from Go +type LuaFunction struct { + state *State + ref int +} + // Static registry size reduces resizing operations const initialRegistrySize = 64 @@ -35,6 +41,14 @@ var ( funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize), } + // luaFunctionRegistry stores Lua function references + luaFunctionRegistry = struct { + sync.RWMutex + refs map[int]*State + }{ + refs: make(map[int]*State), + } + // statePool reuses State structs to avoid allocations statePool = sync.Pool{ New: func() any { @@ -103,14 +117,188 @@ func (s *State) UnregisterGoFunction(name string) { s.SetGlobal(name) } -// Cleanup frees all function pointers and clears the registry -func (s *State) Cleanup() { - functionRegistry.Lock() - defer functionRegistry.Unlock() +// StoreLuaFunction stores a Lua function from the stack and returns a reference +func (s *State) StoreLuaFunction(index int) (*LuaFunction, error) { + if !s.IsFunction(index) { + return nil, fmt.Errorf("value at index %d is not a function", index) + } - // Free all allocated pointers + s.PushCopy(index) + ref := int(C.luaL_ref(s.L, C.LUA_REGISTRYINDEX)) + if ref == C.LUA_REFNIL { + return nil, fmt.Errorf("failed to store function reference") + } + + luaFunc := &LuaFunction{ + state: s, + ref: ref, + } + + luaFunctionRegistry.Lock() + luaFunctionRegistry.refs[ref] = s + luaFunctionRegistry.Unlock() + + return luaFunc, nil +} + +// GetLuaFunction gets a global Lua function and stores it +func (s *State) GetLuaFunction(name string) (*LuaFunction, error) { + s.GetGlobal(name) + defer s.Pop(1) + + if !s.IsFunction(-1) { + return nil, fmt.Errorf("global '%s' is not a function", name) + } + + return s.StoreLuaFunction(-1) +} + +// Call executes the Lua function with given arguments and returns results +func (lf *LuaFunction) Call(args ...any) ([]any, error) { + s := lf.state + + // Push function from registry + C.lua_rawgeti(s.L, C.LUA_REGISTRYINDEX, C.int(lf.ref)) + + // Push arguments + for i, arg := range args { + if err := s.PushValue(arg); err != nil { + s.Pop(i + 1) // Clean up function and pushed args + return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err) + } + } + + // Call function + baseTop := s.GetTop() - len(args) - 1 + if err := s.Call(len(args), C.LUA_MULTRET); err != nil { + return nil, err + } + + // Extract results + newTop := s.GetTop() + nresults := newTop - baseTop + + results := make([]any, nresults) + for i := 0; i < nresults; i++ { + val, err := s.ToValue(baseTop + i + 1) + if err != nil { + results[i] = nil + } else { + results[i] = val + } + } + + s.SetTop(baseTop) + return results, nil +} + +// CallSingle calls the function and returns only the first result +func (lf *LuaFunction) CallSingle(args ...any) (any, error) { + results, err := lf.Call(args...) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, nil + } + return results[0], nil +} + +// CallTyped calls the function and converts the first result to the specified type +func CallTyped[T any](lf *LuaFunction, args ...any) (T, error) { + var zero T + result, err := lf.CallSingle(args...) + if err != nil { + return zero, err + } + if result == nil { + return zero, nil + } + if converted, ok := ConvertValue[T](result); ok { + return converted, nil + } + return zero, fmt.Errorf("cannot convert result to %T", zero) +} + +// Release releases the Lua function reference +func (lf *LuaFunction) Release() { + if lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL { + luaFunctionRegistry.Lock() + delete(luaFunctionRegistry.refs, lf.ref) + luaFunctionRegistry.Unlock() + + C.luaL_unref(lf.state.L, C.LUA_REGISTRYINDEX, C.int(lf.ref)) + lf.ref = C.LUA_NOREF + } +} + +// IsValid checks if the function reference is still valid +func (lf *LuaFunction) IsValid() bool { + return lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL +} + +// ToGoFunction converts to a standard Go function signature +func (lf *LuaFunction) ToGoFunction() func(...any) ([]any, error) { + return func(args ...any) ([]any, error) { + return lf.Call(args...) + } +} + +// CreateCallback creates a reusable callback function +func (s *State) CreateCallback(luaCode string) (*LuaFunction, error) { + if err := s.LoadString(luaCode); err != nil { + return nil, fmt.Errorf("failed to load callback code: %w", err) + } + + luaFunc, err := s.StoreLuaFunction(-1) + s.Pop(1) // Remove function from stack + return luaFunc, err +} + +// Cleanup frees all function pointers and clears registries +func (s *State) Cleanup() { + // Clean up Go function registry + functionRegistry.Lock() for ptr := range functionRegistry.funcs { C.free(ptr) delete(functionRegistry.funcs, ptr) } + functionRegistry.Unlock() + + // Clean up Lua function registry for this state + luaFunctionRegistry.Lock() + for ref, state := range luaFunctionRegistry.refs { + if state == s { + C.luaL_unref(s.L, C.LUA_REGISTRYINDEX, C.int(ref)) + delete(luaFunctionRegistry.refs, ref) + } + } + luaFunctionRegistry.Unlock() +} + +// BatchRegisterGoFunctions registers multiple Go functions at once +func (s *State) BatchRegisterGoFunctions(funcs map[string]GoFunction) error { + for name, fn := range funcs { + if err := s.RegisterGoFunction(name, fn); err != nil { + return fmt.Errorf("failed to register function '%s': %w", name, err) + } + } + return nil +} + +// GetAllLuaFunctions gets multiple global Lua functions by name +func (s *State) GetAllLuaFunctions(names ...string) (map[string]*LuaFunction, error) { + funcs := make(map[string]*LuaFunction, len(names)) + for _, name := range names { + if fn, err := s.GetLuaFunction(name); err == nil { + funcs[name] = fn + } else { + // Clean up any successfully created functions + for _, f := range funcs { + f.Release() + } + return nil, fmt.Errorf("failed to get function '%s': %w", name, err) + } + } + return funcs, nil }