package luajit /* #include #include #include extern int goFunctionWrapper(lua_State* L); // Helper function to access upvalues static int get_upvalue_index(int i) { return lua_upvalueindex(i); } */ import "C" import ( "fmt" "sync" "unsafe" ) // GoFunction defines the signature for Go functions callable from Lua type GoFunction func(*State) int // LuaFunction represents a Lua function callable from Go type LuaFunction struct { state *State ref int } // Static registry size reduces resizing operations const initialRegistrySize = 64 var ( // functionRegistry stores all registered Go functions functionRegistry = struct { sync.RWMutex funcs map[unsafe.Pointer]GoFunction initOnce sync.Once }{ funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize), } // luaFunctionRegistry stores Lua function references luaFunctionRegistry = struct { sync.RWMutex refs map[int]*State }{ refs: make(map[int]*State), } // statePool reuses State structs to avoid allocations statePool = sync.Pool{ New: func() any { return &State{} }, } ) //export goFunctionWrapper func goFunctionWrapper(L *C.lua_State) C.int { state := statePool.Get().(*State) state.L = L defer statePool.Put(state) ptr := C.lua_touserdata(L, C.get_upvalue_index(1)) if ptr == nil { state.PushString("error: function pointer not found") return -1 } functionRegistry.RLock() fn, ok := functionRegistry.funcs[ptr] functionRegistry.RUnlock() if !ok { state.PushString("error: function not found in registry") return -1 } return C.int(fn(state)) } // PushGoFunction wraps a Go function and pushes it onto the Lua stack func (s *State) PushGoFunction(fn GoFunction) error { // Allocate unique memory for each function ptr := C.malloc(C.size_t(unsafe.Sizeof(uintptr(0)))) if ptr == nil { return fmt.Errorf("failed to allocate memory for function pointer") } // Register the function functionRegistry.Lock() functionRegistry.funcs[ptr] = fn functionRegistry.Unlock() // Push the pointer as lightuserdata C.lua_pushlightuserdata(s.L, ptr) C.lua_pushcclosure(s.L, (*[0]byte)(C.goFunctionWrapper), 1) return nil } // RegisterGoFunction registers a Go function as a global Lua function func (s *State) RegisterGoFunction(name string, fn GoFunction) error { if err := s.PushGoFunction(fn); err != nil { return err } s.SetGlobal(name) return nil } // UnregisterGoFunction removes a global function func (s *State) UnregisterGoFunction(name string) { s.PushNil() s.SetGlobal(name) } // StoreLuaFunction stores a Lua function from the stack and returns a reference func (s *State) StoreLuaFunction(index int) (*LuaFunction, error) { if !s.IsFunction(index) { return nil, fmt.Errorf("value at index %d is not a function", index) } s.PushCopy(index) ref := int(C.luaL_ref(s.L, C.LUA_REGISTRYINDEX)) if ref == C.LUA_REFNIL { return nil, fmt.Errorf("failed to store function reference") } luaFunc := &LuaFunction{ state: s, ref: ref, } luaFunctionRegistry.Lock() luaFunctionRegistry.refs[ref] = s luaFunctionRegistry.Unlock() return luaFunc, nil } // GetLuaFunction gets a global Lua function and stores it func (s *State) GetLuaFunction(name string) (*LuaFunction, error) { s.GetGlobal(name) defer s.Pop(1) if !s.IsFunction(-1) { return nil, fmt.Errorf("global '%s' is not a function", name) } return s.StoreLuaFunction(-1) } // Call executes the Lua function with given arguments and returns results func (lf *LuaFunction) Call(args ...any) ([]any, error) { s := lf.state // Push function from registry C.lua_rawgeti(s.L, C.LUA_REGISTRYINDEX, C.int(lf.ref)) // Push arguments for i, arg := range args { if err := s.PushValue(arg); err != nil { s.Pop(i + 1) // Clean up function and pushed args return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err) } } // Call function baseTop := s.GetTop() - len(args) - 1 if err := s.Call(len(args), C.LUA_MULTRET); err != nil { return nil, err } // Extract results newTop := s.GetTop() nresults := newTop - baseTop results := make([]any, nresults) for i := 0; i < nresults; i++ { val, err := s.ToValue(baseTop + i + 1) if err != nil { results[i] = nil } else { results[i] = val } } s.SetTop(baseTop) return results, nil } // CallSingle calls the function and returns only the first result func (lf *LuaFunction) CallSingle(args ...any) (any, error) { results, err := lf.Call(args...) if err != nil { return nil, err } if len(results) == 0 { return nil, nil } return results[0], nil } // CallTyped calls the function and converts the first result to the specified type func CallTyped[T any](lf *LuaFunction, args ...any) (T, error) { var zero T result, err := lf.CallSingle(args...) if err != nil { return zero, err } if result == nil { return zero, nil } if converted, ok := ConvertValue[T](result); ok { return converted, nil } return zero, fmt.Errorf("cannot convert result to %T", zero) } // Release releases the Lua function reference func (lf *LuaFunction) Release() { if lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL { luaFunctionRegistry.Lock() delete(luaFunctionRegistry.refs, lf.ref) luaFunctionRegistry.Unlock() C.luaL_unref(lf.state.L, C.LUA_REGISTRYINDEX, C.int(lf.ref)) lf.ref = C.LUA_NOREF } } // IsValid checks if the function reference is still valid func (lf *LuaFunction) IsValid() bool { return lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL } // ToGoFunction converts to a standard Go function signature func (lf *LuaFunction) ToGoFunction() func(...any) ([]any, error) { return func(args ...any) ([]any, error) { return lf.Call(args...) } } // CreateCallback creates a reusable callback function func (s *State) CreateCallback(luaCode string) (*LuaFunction, error) { if err := s.LoadString(luaCode); err != nil { return nil, fmt.Errorf("failed to load callback code: %w", err) } luaFunc, err := s.StoreLuaFunction(-1) s.Pop(1) // Remove function from stack return luaFunc, err } // Cleanup frees all function pointers and clears registries func (s *State) Cleanup() { // Clean up Go function registry functionRegistry.Lock() for ptr := range functionRegistry.funcs { C.free(ptr) delete(functionRegistry.funcs, ptr) } functionRegistry.Unlock() // Clean up Lua function registry for this state luaFunctionRegistry.Lock() for ref, state := range luaFunctionRegistry.refs { if state == s { C.luaL_unref(s.L, C.LUA_REGISTRYINDEX, C.int(ref)) delete(luaFunctionRegistry.refs, ref) } } luaFunctionRegistry.Unlock() } // BatchRegisterGoFunctions registers multiple Go functions at once func (s *State) BatchRegisterGoFunctions(funcs map[string]GoFunction) error { for name, fn := range funcs { if err := s.RegisterGoFunction(name, fn); err != nil { return fmt.Errorf("failed to register function '%s': %w", name, err) } } return nil } // GetAllLuaFunctions gets multiple global Lua functions by name func (s *State) GetAllLuaFunctions(names ...string) (map[string]*LuaFunction, error) { funcs := make(map[string]*LuaFunction, len(names)) for _, name := range names { if fn, err := s.GetLuaFunction(name); err == nil { funcs[name] = fn } else { // Clean up any successfully created functions for _, f := range funcs { f.Release() } return nil, fmt.Errorf("failed to get function '%s': %w", name, err) } } return funcs, nil }