LuaJIT-to-Go/wrapper.go

406 lines
10 KiB
Go
Raw Normal View History

2025-01-24 19:53:09 -06:00
package luajit
/*
#cgo CFLAGS: -I${SRCDIR}/vendor/luajit/include
#cgo windows LDFLAGS: -L${SRCDIR}/vendor/luajit/windows -lluajit -static
#cgo !windows LDFLAGS: -L${SRCDIR}/vendor/luajit/linux -lluajit -static
2025-01-24 19:53:09 -06:00
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
// Simple wrapper around luaL_loadstring
static int load_chunk(lua_State *L, const char *s) {
return luaL_loadstring(L, s);
}
// Direct wrapper around lua_pcall
static int protected_call(lua_State *L, int nargs, int nresults, int errfunc) {
return lua_pcall(L, nargs, nresults, errfunc);
}
// Combined load and execute with no results
2025-01-24 19:53:09 -06:00
static int do_string(lua_State *L, const char *s) {
return luaL_dostring(L, s);
2025-01-24 19:53:09 -06:00
}
// Combined load and execute file
2025-01-24 19:53:09 -06:00
static int do_file(lua_State *L, const char *filename) {
return luaL_dofile(L, filename);
}
// Execute string with multiple returns
static int execute_string(lua_State *L, const char *s) {
int base = lua_gettop(L); // Save stack position
int status = luaL_loadstring(L, s);
if (status) return -status; // Return negative status for load errors
status = lua_pcall(L, 0, LUA_MULTRET, 0);
if (status) return -status; // Return negative status for runtime errors
return lua_gettop(L) - base; // Return number of results
}
// Get absolute stack index (converts negative indices)
static int get_abs_index(lua_State *L, int idx) {
if (idx > 0 || idx <= LUA_REGISTRYINDEX) return idx;
return lua_gettop(L) + idx + 1;
}
// Stack manipulation helpers
static int check_stack(lua_State *L, int n) {
return lua_checkstack(L, n);
}
static void remove_stack(lua_State *L, int idx) {
lua_remove(L, idx);
}
static int get_field_helper(lua_State *L, int idx, const char *k) {
lua_getfield(L, idx, k);
return lua_type(L, -1);
}
static void set_field_helper(lua_State *L, int idx, const char *k) {
lua_setfield(L, idx, k);
2025-01-24 19:53:09 -06:00
}
*/
import "C"
import (
"fmt"
"strings"
2025-01-24 19:53:09 -06:00
"unsafe"
)
// State represents a Lua state
2025-01-24 19:53:09 -06:00
type State struct {
L *C.lua_State
2025-01-24 19:53:09 -06:00
}
// New creates a new Lua state
2025-01-24 19:53:09 -06:00
func New() *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L}
2025-01-24 19:53:09 -06:00
}
// Close closes the Lua state
func (s *State) Close() {
if s.L != nil {
C.lua_close(s.L)
s.L = nil
}
}
// DoString executes a Lua string.
2025-01-24 19:53:09 -06:00
func (s *State) DoString(str string) error {
// Save initial stack size
top := s.GetTop()
// Load the string
if err := s.LoadString(str); err != nil {
return err
2025-01-24 19:53:09 -06:00
}
// Execute and check for errors
if err := s.Call(0, 0); err != nil {
return err
2025-01-24 19:53:09 -06:00
}
// Restore stack to initial size to clean up any leftovers
s.SetTop(top)
2025-01-24 19:53:09 -06:00
return nil
}
// PushValue pushes a Go value onto the stack
func (s *State) PushValue(v interface{}) error {
switch v := v.(type) {
case nil:
s.PushNil()
case bool:
s.PushBoolean(v)
case float64:
s.PushNumber(v)
case int:
s.PushNumber(float64(v))
case string:
s.PushString(v)
case map[string]interface{}:
// Special case: handle array stored in map
if arr, ok := v[""].([]float64); ok {
s.NewTable()
for i, elem := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
return nil
}
return s.PushTable(v)
2025-01-24 19:53:09 -06:00
case []float64:
s.NewTable()
for i, elem := range v {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
case []interface{}:
s.NewTable()
for i, elem := range v {
s.PushNumber(float64(i + 1))
if err := s.PushValue(elem); err != nil {
2025-01-24 19:53:09 -06:00
return err
}
s.SetTable(-3)
}
default:
return fmt.Errorf("unsupported type: %T", v)
}
return nil
}
// ToValue converts a Lua value to a Go value
func (s *State) ToValue(index int) (interface{}, error) {
switch s.GetType(index) {
case TypeNil:
return nil, nil
case TypeBoolean:
return s.ToBoolean(index), nil
case TypeNumber:
return s.ToNumber(index), nil
case TypeString:
return s.ToString(index), nil
case TypeTable:
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.ToTable(index)
2025-01-24 19:53:09 -06:00
default:
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
}
}
// Simple operations remain unchanged as they don't need stack protection
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
2025-02-03 18:52:26 -06:00
func (s *State) IsString(index int) bool { return s.GetType(index) == TypeString }
2025-02-03 18:55:34 -06:00
func (s *State) IsNumber(index int) bool { return s.GetType(index) == TypeNumber }
2025-01-24 19:53:09 -06:00
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
2025-02-03 19:09:30 -06:00
func (s *State) IsNil(index int) bool { return s.GetType(index) == TypeNil }
2025-01-24 19:53:09 -06:00
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
func (s *State) ToString(index int) string {
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
}
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) }
2025-01-24 19:53:09 -06:00
// Push operations
func (s *State) PushNil() { C.lua_pushnil(s.L) }
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
func (s *State) PushString(str string) {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushstring(s.L, cstr)
}
2025-02-03 18:52:26 -06:00
func (s *State) Next(index int) bool {
return C.lua_next(s.L, C.int(index)) != 0
}
2025-01-24 19:53:09 -06:00
// Helper functions
func bool2int(b bool) int {
if b {
return 1
}
return 0
}
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// SetField sets a field in a table at the given index
2025-01-24 19:53:09 -06:00
func (s *State) SetField(index int, key string) {
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.int(index), cstr)
2025-01-24 19:53:09 -06:00
}
// GetField gets a field from a table
2025-01-24 19:53:09 -06:00
func (s *State) GetField(index int, key string) {
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.int(index), cstr)
2025-01-24 19:53:09 -06:00
}
// GetGlobal gets a global variable and pushes it onto the stack
func (s *State) GetGlobal(name string) {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.get_field_helper(s.L, C.LUA_GLOBALSINDEX, cname)
2025-01-24 19:53:09 -06:00
}
// SetGlobal sets a global variable from the value at the top of the stack
func (s *State) SetGlobal(name string) {
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
}
// Remove removes element with cached absolute index
func (s *State) Remove(index int) {
absIdx := index
C.lua_remove(s.L, C.int(absIdx))
}
// DoFile executes a Lua file with appropriate stack management
func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
status := C.do_file(s.L, cfilename)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
}
return nil
}
// SetPackagePath sets the Lua package.path
2025-01-24 19:53:09 -06:00
func (s *State) SetPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
cmd := fmt.Sprintf(`package.path = %q`, path)
return s.DoString(cmd)
2025-01-24 19:53:09 -06:00
}
// AddPackagePath adds a path to package.path
2025-01-24 19:53:09 -06:00
func (s *State) AddPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/") // Convert Windows paths
cmd := fmt.Sprintf(`package.path = package.path .. ";%s"`, path)
return s.DoString(cmd)
}
// Call executes a function on the stack with the given number of arguments and results.
// The function and arguments should already be on the stack in the correct order
// (function first, then args from left to right).
func (s *State) Call(nargs, nresults int) error {
if !s.IsFunction(-nargs - 1) {
return fmt.Errorf("attempt to call a non-function")
}
status := C.protected_call(s.L, C.int(nargs), C.int(nresults), 0)
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1)
return err
2025-01-24 19:53:09 -06:00
}
return nil
}
2025-02-03 19:09:30 -06:00
// LoadString loads but does not execute a string of Lua code.
// The compiled code chunk is left on the stack.
func (s *State) LoadString(str string) error {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
status := C.load_chunk(s.L, cstr)
2025-02-03 19:09:30 -06:00
if status != 0 {
err := &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
s.Pop(1)
return err
}
if !s.IsFunction(-1) {
s.Pop(1)
return fmt.Errorf("failed to load function")
}
2025-02-03 19:09:30 -06:00
return nil
}
// ExecuteString executes a string of Lua code and returns the number of results.
// The results are left on the stack.
func (s *State) ExecuteString(str string) (int, error) {
base := s.GetTop()
// First load the string
if err := s.LoadString(str); err != nil {
return 0, err
}
// Now execute it
if err := s.Call(0, C.LUA_MULTRET); err != nil {
return 0, err
}
return s.GetTop() - base, nil
}
// ExecuteStringResult executes a Lua string and returns its first result as a Go value.
// It's a convenience wrapper around ExecuteString for the common case of wanting
// a single return value. The stack is restored to its original state after execution.
func (s *State) ExecuteStringResult(code string) (interface{}, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when we're done
nresults, err := s.ExecuteString(code)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
}
if nresults == 0 {
return nil, nil
}
// Get the result
result, err := s.ToValue(-nresults) // Get first result
if err != nil {
return nil, fmt.Errorf("error converting result: %w", err)
}
return result, nil
}
// DoStringResult executes a Lua string and expects a single return value.
// Unlike ExecuteStringResult, this function specifically expects exactly one
// return value and will return an error if the code returns 0 or multiple values.
func (s *State) DoStringResult(code string) (interface{}, error) {
top := s.GetTop()
defer s.SetTop(top) // Restore stack when we're done
nresults, err := s.ExecuteString(code)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
}
if nresults != 1 {
return nil, fmt.Errorf("expected 1 return value, got %d", nresults)
}
// Get the result
result, err := s.ToValue(-1)
if err != nil {
return nil, fmt.Errorf("error converting result: %w", err)
}
return result, nil
}