LuaJIT-to-Go/wrapper.go

323 lines
7.7 KiB
Go

package luajit
/*
#cgo CFLAGS: -I${SRCDIR}/include
#cgo windows LDFLAGS: -L${SRCDIR}/lib/windows -llua51
#cgo !windows LDFLAGS: -L${SRCDIR}/lib/linux -lluajit
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
static int do_string(lua_State *L, const char *s) {
int status = luaL_loadstring(L, s);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
}
static int do_file(lua_State *L, const char *filename) {
int status = luaL_loadfile(L, filename);
if (status) return status;
return lua_pcall(L, 0, LUA_MULTRET, 0);
}
*/
import "C"
import (
"fmt"
"path/filepath"
"unsafe"
)
// State represents a Lua state with configurable stack safety
type State struct {
L *C.lua_State
safeStack bool
}
// NewSafe creates a new Lua state with full stack safety guarantees
func NewSafe() *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: true}
}
// New creates a new Lua state with minimal stack checking
func New() *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
C.luaL_openlibs(L)
return &State{L: L, safeStack: false}
}
// Close closes the Lua state
func (s *State) Close() {
if s.L != nil {
C.lua_close(s.L)
s.L = nil
}
}
// DoString executes a Lua string with appropriate stack management
func (s *State) DoString(str string) error {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
if s.safeStack {
return stackGuardErr(s, func() error {
return s.safeCall(func() C.int {
return C.do_string(s.L, cstr)
})
})
}
status := C.do_string(s.L, cstr)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
}
return nil
}
// PushValue pushes a Go value onto the stack
func (s *State) PushValue(v interface{}) error {
if s.safeStack {
return stackGuardErr(s, func() error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
})
}
return s.pushValueUnsafe(v)
}
func (s *State) pushValueSafe(v interface{}) error {
if err := s.checkStack(1); err != nil {
return fmt.Errorf("pushing value: %w", err)
}
return s.pushValueUnsafe(v)
}
func (s *State) pushValueUnsafe(v interface{}) error {
switch v := v.(type) {
case nil:
s.PushNil()
case bool:
s.PushBoolean(v)
case float64:
s.PushNumber(v)
case int:
s.PushNumber(float64(v))
case string:
s.PushString(v)
case map[string]interface{}:
// Special case: handle array stored in map
if arr, ok := v[""].([]float64); ok {
s.NewTable()
for i, elem := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
return nil
}
return s.pushTableUnsafe(v)
case []float64:
s.NewTable()
for i, elem := range v {
s.PushNumber(float64(i + 1))
s.PushNumber(elem)
s.SetTable(-3)
}
case []interface{}:
s.NewTable()
for i, elem := range v {
s.PushNumber(float64(i + 1))
if err := s.pushValueUnsafe(elem); err != nil {
return err
}
s.SetTable(-3)
}
default:
return fmt.Errorf("unsupported type: %T", v)
}
return nil
}
// ToValue converts a Lua value to a Go value
func (s *State) ToValue(index int) (interface{}, error) {
if s.safeStack {
return stackGuardValue[interface{}](s, func() (interface{}, error) {
return s.toValueUnsafe(index)
})
}
return s.toValueUnsafe(index)
}
func (s *State) toValueUnsafe(index int) (interface{}, error) {
switch s.GetType(index) {
case TypeNil:
return nil, nil
case TypeBoolean:
return s.ToBoolean(index), nil
case TypeNumber:
return s.ToNumber(index), nil
case TypeString:
return s.ToString(index), nil
case TypeTable:
if !s.IsTable(index) {
return nil, fmt.Errorf("not a table at index %d", index)
}
return s.toTableUnsafe(index)
default:
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
}
}
// Simple operations remain unchanged as they don't need stack protection
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
func (s *State) ToString(index int) string {
return C.GoString(C.lua_tolstring(s.L, C.int(index), nil))
}
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
// Push operations
func (s *State) PushNil() { C.lua_pushnil(s.L) }
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, C.int(bool2int(b))) }
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.double(n)) }
func (s *State) PushString(str string) {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushstring(s.L, cstr)
}
// Helper functions
func bool2int(b bool) int {
if b {
return 1
}
return 0
}
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// SetField sets a field in a table at the given index with cached absolute index
func (s *State) SetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.int(absIdx), cstr)
}
// GetField gets a field from a table with cached absolute index
func (s *State) GetField(index int, key string) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(key)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.int(absIdx), cstr)
}
// GetGlobal gets a global variable and pushes it onto the stack
func (s *State) GetGlobal(name string) {
if s.safeStack {
if err := s.checkStack(1); err != nil {
s.PushNil()
return
}
}
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_getfield(s.L, C.LUA_GLOBALSINDEX, cstr)
}
// SetGlobal sets a global variable from the value at the top of the stack
func (s *State) SetGlobal(name string) {
// SetGlobal doesn't need stack space checking as it pops the value
cstr := C.CString(name)
defer C.free(unsafe.Pointer(cstr))
C.lua_setfield(s.L, C.LUA_GLOBALSINDEX, cstr)
}
// Remove removes element with cached absolute index
func (s *State) Remove(index int) {
absIdx := index
if s.safeStack && (index < 0 && index > LUA_REGISTRYINDEX) {
absIdx = s.GetTop() + index + 1
}
C.lua_remove(s.L, C.int(absIdx))
}
// DoFile executes a Lua file with appropriate stack management
func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
if s.safeStack {
return stackGuardErr(s, func() error {
return s.safeCall(func() C.int {
return C.do_file(s.L, cfilename)
})
})
}
status := C.do_file(s.L, cfilename)
if status != 0 {
return &LuaError{
Code: int(status),
Message: s.ToString(-1),
}
}
return nil
}
func (s *State) SetPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = %q`, path)); err != nil {
return fmt.Errorf("setting package.path: %w", err)
}
return nil
}
func (s *State) AddPackagePath(path string) error {
path = filepath.ToSlash(path)
if err := s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)); err != nil {
return fmt.Errorf("adding to package.path: %w", err)
}
return nil
}