LuaJIT-to-Go/wrapper.go

827 lines
19 KiB
Go

package luajit
/*
#cgo !windows pkg-config: --static luajit
#cgo windows CFLAGS: -I/usr/local/include/luajit-2.1
#cgo windows LDFLAGS: -lluajit-5.1 -static
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <stdlib.h>
#include <string.h>
static int do_string(lua_State *L, const char *s) {
int status = luaL_loadstring(L, s);
if (status == 0) {
status = lua_pcall(L, 0, 0, 0);
}
return status;
}
static int do_file(lua_State *L, const char *filename) {
int status = luaL_loadfile(L, filename);
if (status == 0) {
status = lua_pcall(L, 0, 0, 0);
}
return status;
}
static int execute_with_results(lua_State *L, const char *code, int store_results) {
int status = luaL_loadstring(L, code);
if (status != 0) return status;
return lua_pcall(L, 0, store_results ? LUA_MULTRET : 0, 0);
}
static size_t get_table_length(lua_State *L, int index) {
return lua_objlen(L, index);
}
static int is_integer(lua_State *L, int index) {
if (!lua_isnumber(L, index)) return 0;
lua_Number n = lua_tonumber(L, index);
return n == (lua_Number)(lua_Integer)n;
}
static int sample_array_type(lua_State *L, int index, int count) {
int all_numbers = 1;
int all_integers = 1;
int all_strings = 1;
int all_bools = 1;
for (int i = 1; i <= count && i <= 5; i++) {
lua_pushnumber(L, i);
lua_gettable(L, index);
int type = lua_type(L, -1);
if (type != LUA_TNUMBER) all_numbers = all_integers = 0;
if (type != LUA_TSTRING) all_strings = 0;
if (type != LUA_TBOOLEAN) all_bools = 0;
if (all_numbers && !is_integer(L, -1)) all_integers = 0;
lua_pop(L, 1);
if (!all_numbers && !all_strings && !all_bools) break;
}
if (all_integers) return 1;
if (all_numbers) return 2;
if (all_strings) return 3;
if (all_bools) return 4;
return 0;
}
static int sample_map_type(lua_State *L, int index) {
int all_string_vals = 1;
int all_int_vals = 1;
int all_int_keys = 1;
int count = 0;
lua_pushnil(L);
while (lua_next(L, index) && count < 5) {
if (lua_type(L, -2) != LUA_TSTRING) {
all_int_keys = 0;
} else {
const char *key = lua_tostring(L, -2);
char *endptr;
strtol(key, &endptr, 10);
if (*endptr != '\0') all_int_keys = 0;
}
int val_type = lua_type(L, -1);
if (val_type != LUA_TSTRING) all_string_vals = 0;
if (val_type != LUA_TNUMBER || !is_integer(L, -1)) all_int_vals = 0;
lua_pop(L, 1);
count++;
if (!all_string_vals && !all_int_vals && !all_int_keys) break;
}
if (all_int_keys) return 4;
if (all_string_vals) return 1;
if (all_int_vals) return 2;
return 3;
}
*/
import "C"
import (
"fmt"
"strconv"
"strings"
"unsafe"
)
// Stack management constants
const (
LUA_MINSTACK = 20
LUA_MAXSTACK = 1000000
LUA_REGISTRYINDEX = -10000
LUA_GLOBALSINDEX = -10002
)
type State struct {
L *C.lua_State
}
func New(openLibs ...bool) *State {
L := C.luaL_newstate()
if L == nil {
return nil
}
if len(openLibs) == 0 || openLibs[0] {
C.luaL_openlibs(L)
}
return &State{L: L}
}
func (s *State) Close() {
if s.L != nil {
C.lua_close(s.L)
s.L = nil
}
}
// Stack operations
func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) }
func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) }
func (s *State) PushCopy(index int) { C.lua_pushvalue(s.L, C.int(index)) }
func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) }
func (s *State) Remove(index int) { C.lua_remove(s.L, C.int(index)) }
func (s *State) absIndex(index int) int {
if index > 0 || index <= LUA_REGISTRYINDEX {
return index
}
return s.GetTop() + index + 1
}
// Type checking
func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) }
func (s *State) IsNil(index int) bool { return s.GetType(index) == TypeNil }
func (s *State) IsBoolean(index int) bool { return s.GetType(index) == TypeBoolean }
func (s *State) IsNumber(index int) bool { return C.lua_isnumber(s.L, C.int(index)) != 0 }
func (s *State) IsString(index int) bool { return C.lua_isstring(s.L, C.int(index)) != 0 }
func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable }
func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction }
// Value conversion
func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 }
func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) }
func (s *State) ToString(index int) string {
var length C.size_t
cstr := C.lua_tolstring(s.L, C.int(index), &length)
if cstr == nil {
return ""
}
return C.GoStringN(cstr, C.int(length))
}
// Push methods
func (s *State) PushNil() { C.lua_pushnil(s.L) }
func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, boolToInt(b)) }
func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.lua_Number(n)) }
func (s *State) PushString(str string) {
if len(str) < 128 {
cstr := C.CString(str)
defer C.free(unsafe.Pointer(cstr))
C.lua_pushlstring(s.L, cstr, C.size_t(len(str)))
} else {
header := (*struct {
p unsafe.Pointer
len int
cap int
})(unsafe.Pointer(&str))
C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str)))
}
}
// Table operations
func (s *State) CreateTable(narr, nrec int) { C.lua_createtable(s.L, C.int(narr), C.int(nrec)) }
func (s *State) NewTable() { C.lua_createtable(s.L, 0, 0) }
func (s *State) GetTable(index int) { C.lua_gettable(s.L, C.int(index)) }
func (s *State) SetTable(index int) { C.lua_settable(s.L, C.int(index)) }
func (s *State) Next(index int) bool { return C.lua_next(s.L, C.int(index)) != 0 }
func (s *State) GetField(index int, key string) {
ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey))
C.lua_getfield(s.L, C.int(index), ckey)
}
func (s *State) SetField(index int, key string) {
ckey := C.CString(key)
defer C.free(unsafe.Pointer(ckey))
C.lua_setfield(s.L, C.int(index), ckey)
}
func (s *State) GetTableLength(index int) int {
return int(C.get_table_length(s.L, C.int(index)))
}
// Enhanced PushValue with comprehensive type support
func (s *State) PushValue(v any) error {
switch val := v.(type) {
case nil:
s.PushNil()
case bool:
s.PushBoolean(val)
case int:
s.PushNumber(float64(val))
case int64:
s.PushNumber(float64(val))
case float64:
s.PushNumber(val)
case string:
s.PushString(val)
case []int:
return s.pushIntSlice(val)
case []string:
return s.pushStringSlice(val)
case []bool:
return s.pushBoolSlice(val)
case []float64:
return s.pushFloatSlice(val)
case []any:
return s.pushAnySlice(val)
case []map[string]any:
return s.pushMapSlice(val)
case map[string]string:
return s.pushStringMap(val)
case map[string]int:
return s.pushIntMap(val)
case map[int]any:
return s.pushIntKeyMap(val)
case map[string]any:
return s.pushAnyMap(val)
default:
return fmt.Errorf("unsupported type: %T", v)
}
return nil
}
func (s *State) pushIntSlice(arr []int) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(float64(v))
s.SetTable(-3)
}
return nil
}
func (s *State) pushStringSlice(arr []string) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushString(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushBoolSlice(arr []bool) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushBoolean(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushFloatSlice(arr []float64) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
s.PushNumber(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushAnySlice(arr []any) error {
s.CreateTable(len(arr), 0)
for i, v := range arr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
func (s *State) pushStringMap(m map[string]string) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushString(k)
s.PushString(v)
s.SetTable(-3)
}
return nil
}
func (s *State) pushIntMap(m map[string]int) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushString(k)
s.PushNumber(float64(v))
s.SetTable(-3)
}
return nil
}
func (s *State) pushIntKeyMap(m map[int]any) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushNumber(float64(k))
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
func (s *State) pushAnyMap(m map[string]any) error {
s.CreateTable(0, len(m))
for k, v := range m {
s.PushString(k)
if err := s.PushValue(v); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
// Enhanced ToValue with automatic type detection
func (s *State) ToValue(index int) (any, error) {
switch s.GetType(index) {
case TypeNil:
return nil, nil
case TypeBoolean:
return s.ToBoolean(index), nil
case TypeNumber:
num := s.ToNumber(index)
if num == float64(int(num)) && num >= -2147483648 && num <= 2147483647 {
return int(num), nil
}
return num, nil
case TypeString:
return s.ToString(index), nil
case TypeTable:
return s.ToTable(index)
default:
return nil, fmt.Errorf("unsupported type: %s", s.GetType(index))
}
}
// ToTable converts a Lua table to optimal Go type
func (s *State) ToTable(index int) (any, error) {
absIdx := s.absIndex(index)
if !s.IsTable(absIdx) {
return nil, fmt.Errorf("value at index %d is not a table", index)
}
length := s.GetTableLength(absIdx)
if length > 0 {
arrayType := int(C.sample_array_type(s.L, C.int(absIdx), C.int(length)))
switch arrayType {
case 1: // int array
return s.extractIntArray(absIdx, length), nil
case 2: // float array
return s.extractFloatArray(absIdx, length), nil
case 3: // string array
return s.extractStringArray(absIdx, length), nil
case 4: // bool array
return s.extractBoolArray(absIdx, length), nil
default: // mixed array
return s.extractAnyArray(absIdx, length), nil
}
}
mapType := int(C.sample_map_type(s.L, C.int(absIdx)))
switch mapType {
case 1: // map[string]string
return s.extractStringMap(absIdx)
case 2: // map[string]int
return s.extractIntMap(absIdx)
case 4: // map[int]any
return s.extractIntKeyMap(absIdx)
default: // map[string]any
return s.extractAnyMap(absIdx)
}
}
func (s *State) extractIntArray(index, length int) []int {
result := make([]int, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = int(s.ToNumber(-1))
s.Pop(1)
}
return result
}
func (s *State) extractFloatArray(index, length int) []float64 {
result := make([]float64, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = s.ToNumber(-1)
s.Pop(1)
}
return result
}
func (s *State) extractStringArray(index, length int) []string {
result := make([]string, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = s.ToString(-1)
s.Pop(1)
}
return result
}
func (s *State) extractBoolArray(index, length int) []bool {
result := make([]bool, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
result[i-1] = s.ToBoolean(-1)
s.Pop(1)
}
return result
}
func (s *State) extractAnyArray(index, length int) []any {
result := make([]any, length)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(index)
if val, err := s.ToValue(-1); err == nil {
result[i-1] = val
}
s.Pop(1)
}
return result
}
func (s *State) extractStringMap(index int) (map[string]string, error) {
result := make(map[string]string)
s.PushNil()
for s.Next(index) {
if s.GetType(-2) == TypeString {
key := s.ToString(-2)
value := s.ToString(-1)
result[key] = value
}
s.Pop(1)
}
return result, nil
}
func (s *State) extractIntMap(index int) (map[string]int, error) {
result := make(map[string]int)
s.PushNil()
for s.Next(index) {
if s.GetType(-2) == TypeString {
key := s.ToString(-2)
value := int(s.ToNumber(-1))
result[key] = value
}
s.Pop(1)
}
return result, nil
}
func (s *State) extractIntKeyMap(index int) (map[int]any, error) {
result := make(map[int]any)
s.PushNil()
for s.Next(index) {
var key int
switch s.GetType(-2) {
case TypeString:
if k, err := strconv.Atoi(s.ToString(-2)); err == nil {
key = k
} else {
s.Pop(1)
continue
}
case TypeNumber:
key = int(s.ToNumber(-2))
default:
s.Pop(1)
continue
}
if value, err := s.ToValue(-1); err == nil {
result[key] = value
}
s.Pop(1)
}
return result, nil
}
func (s *State) extractAnyMap(index int) (map[string]any, error) {
result := make(map[string]any)
s.PushNil()
for s.Next(index) {
var key string
switch s.GetType(-2) {
case TypeString:
key = s.ToString(-2)
case TypeNumber:
key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64)
default:
s.Pop(1)
continue
}
if value, err := s.ToValue(-1); err == nil {
result[key] = value
}
s.Pop(1)
}
return result, nil
}
// Global operations
func (s *State) GetGlobal(name string) { s.GetField(LUA_GLOBALSINDEX, name) }
func (s *State) SetGlobal(name string) { s.SetField(LUA_GLOBALSINDEX, name) }
// Code execution
func (s *State) LoadString(code string) error {
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
status := C.luaL_loadstring(s.L, ccode)
if status != 0 {
err := s.CreateLuaError(int(status), "LoadString")
s.Pop(1)
return err
}
return nil
}
func (s *State) LoadFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
status := C.luaL_loadfile(s.L, cfilename)
if status != 0 {
err := s.CreateLuaError(int(status), fmt.Sprintf("LoadFile(%s)", filename))
s.Pop(1)
return err
}
return nil
}
func (s *State) Call(nargs, nresults int) error {
status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0)
if status != 0 {
err := s.CreateLuaError(int(status), fmt.Sprintf("Call(%d,%d)", nargs, nresults))
s.Pop(1)
return err
}
return nil
}
func (s *State) DoString(code string) error {
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
status := C.do_string(s.L, ccode)
if status != 0 {
err := s.CreateLuaError(int(status), "DoString")
s.Pop(1)
return err
}
return nil
}
func (s *State) DoFile(filename string) error {
cfilename := C.CString(filename)
defer C.free(unsafe.Pointer(cfilename))
status := C.do_file(s.L, cfilename)
if status != 0 {
err := s.CreateLuaError(int(status), fmt.Sprintf("DoFile(%s)", filename))
s.Pop(1)
return err
}
return nil
}
func (s *State) Execute(code string) (int, error) {
baseTop := s.GetTop()
ccode := C.CString(code)
defer C.free(unsafe.Pointer(ccode))
status := C.execute_with_results(s.L, ccode, 1)
if status != 0 {
err := s.CreateLuaError(int(status), "Execute")
s.Pop(1)
return 0, err
}
return s.GetTop() - baseTop, nil
}
func (s *State) ExecuteWithResult(code string) (any, error) {
top := s.GetTop()
defer s.SetTop(top)
nresults, err := s.Execute(code)
if err != nil {
return nil, err
}
if nresults == 0 {
return nil, nil
}
return s.ToValue(-nresults)
}
func (s *State) BatchExecute(statements []string) error {
return s.DoString(strings.Join(statements, "; "))
}
// Package path operations
func (s *State) SetPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/")
return s.DoString(fmt.Sprintf(`package.path = %q`, path))
}
func (s *State) AddPackagePath(path string) error {
path = strings.ReplaceAll(path, "\\", "/")
return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path))
}
// Metatable operations
func (s *State) SetMetatable(index int) { C.lua_setmetatable(s.L, C.int(index)) }
func (s *State) GetMetatable(index int) bool { return C.lua_getmetatable(s.L, C.int(index)) != 0 }
// Helper functions
func boolToInt(b bool) C.int {
if b {
return 1
}
return 0
}
// GetFieldString gets a string field from a table with default
func (s *State) GetFieldString(index int, key string, defaultVal string) string {
s.GetField(index, key)
defer s.Pop(1)
if s.IsString(-1) {
return s.ToString(-1)
}
return defaultVal
}
// GetFieldNumber gets a number field from a table with default
func (s *State) GetFieldNumber(index int, key string, defaultVal float64) float64 {
s.GetField(index, key)
defer s.Pop(1)
if s.IsNumber(-1) {
return s.ToNumber(-1)
}
return defaultVal
}
// GetFieldBool gets a boolean field from a table with default
func (s *State) GetFieldBool(index int, key string, defaultVal bool) bool {
s.GetField(index, key)
defer s.Pop(1)
if s.IsBoolean(-1) {
return s.ToBoolean(-1)
}
return defaultVal
}
// GetFieldTable gets a table field from a table
func (s *State) GetFieldTable(index int, key string) (any, bool) {
s.GetField(index, key)
defer s.Pop(1)
if s.IsTable(-1) {
val, err := s.ToTable(-1)
return val, err == nil
}
return nil, false
}
// ForEachTableKV iterates over string key-value pairs in a table
func (s *State) ForEachTableKV(index int, fn func(key, value string) bool) {
absIdx := s.absIndex(index)
s.PushNil()
for s.Next(absIdx) {
if s.IsString(-2) && s.IsString(-1) {
if !fn(s.ToString(-2), s.ToString(-1)) {
s.Pop(2)
return
}
}
s.Pop(1)
}
}
// ForEachArray iterates over array elements
func (s *State) ForEachArray(index int, fn func(i int, state *State) bool) {
absIdx := s.absIndex(index)
length := s.GetTableLength(absIdx)
for i := 1; i <= length; i++ {
s.PushNumber(float64(i))
s.GetTable(absIdx)
if !fn(i, s) {
s.Pop(1)
return
}
s.Pop(1)
}
}
// SafeToString safely converts value to string with error
func (s *State) SafeToString(index int) (string, error) {
if !s.IsString(index) && !s.IsNumber(index) {
return "", fmt.Errorf("value at index %d is not a string", index)
}
return s.ToString(index), nil
}
// SafeToNumber safely converts value to number with error
func (s *State) SafeToNumber(index int) (float64, error) {
if !s.IsNumber(index) {
return 0, fmt.Errorf("value at index %d is not a number", index)
}
return s.ToNumber(index), nil
}
// SafeToTable safely converts value to table with error
func (s *State) SafeToTable(index int) (any, error) {
if !s.IsTable(index) {
return nil, fmt.Errorf("value at index %d is not a table", index)
}
return s.ToTable(index)
}
// CallGlobal calls a global function with arguments
func (s *State) CallGlobal(name string, args ...any) ([]any, error) {
s.GetGlobal(name)
if !s.IsFunction(-1) {
s.Pop(1)
return nil, fmt.Errorf("global '%s' is not a function", name)
}
for i, arg := range args {
if err := s.PushValue(arg); err != nil {
s.Pop(i + 1)
return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err)
}
}
baseTop := s.GetTop() - len(args) - 1
if err := s.Call(len(args), C.LUA_MULTRET); err != nil {
return nil, err
}
newTop := s.GetTop()
nresults := newTop - baseTop
results := make([]any, nresults)
for i := 0; i < nresults; i++ {
val, err := s.ToValue(baseTop + i + 1)
if err != nil {
results[i] = nil
} else {
results[i] = val
}
}
s.SetTop(baseTop)
return results, nil
}
func (s *State) pushMapSlice(arr []map[string]any) error {
s.CreateTable(len(arr), 0)
for i, m := range arr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(m); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}