package luajit /* #cgo !windows pkg-config: --static luajit #cgo windows CFLAGS: -I/usr/local/include/luajit-2.1 #cgo windows LDFLAGS: -lluajit-5.1 -static #include #include #include #include #include static int do_string(lua_State *L, const char *s) { int status = luaL_loadstring(L, s); if (status == 0) { status = lua_pcall(L, 0, 0, 0); } return status; } static int do_file(lua_State *L, const char *filename) { int status = luaL_loadfile(L, filename); if (status == 0) { status = lua_pcall(L, 0, 0, 0); } return status; } static int execute_with_results(lua_State *L, const char *code, int store_results) { int status = luaL_loadstring(L, code); if (status != 0) return status; return lua_pcall(L, 0, store_results ? LUA_MULTRET : 0, 0); } static size_t get_table_length(lua_State *L, int index) { return lua_objlen(L, index); } static int is_integer(lua_State *L, int index) { if (!lua_isnumber(L, index)) return 0; lua_Number n = lua_tonumber(L, index); return n == (lua_Number)(lua_Integer)n; } static int sample_array_type(lua_State *L, int index, int count) { int all_numbers = 1; int all_integers = 1; int all_strings = 1; int all_bools = 1; for (int i = 1; i <= count && i <= 5; i++) { lua_pushnumber(L, i); lua_gettable(L, index); int type = lua_type(L, -1); if (type != LUA_TNUMBER) all_numbers = all_integers = 0; if (type != LUA_TSTRING) all_strings = 0; if (type != LUA_TBOOLEAN) all_bools = 0; if (all_numbers && !is_integer(L, -1)) all_integers = 0; lua_pop(L, 1); if (!all_numbers && !all_strings && !all_bools) break; } if (all_integers) return 1; if (all_numbers) return 2; if (all_strings) return 3; if (all_bools) return 4; return 0; } */ import "C" import ( "fmt" "strconv" "strings" "unsafe" ) // Stack management constants const ( LUA_MINSTACK = 20 LUA_MAXSTACK = 1000000 LUA_REGISTRYINDEX = -10000 LUA_GLOBALSINDEX = -10002 ) type State struct { L *C.lua_State } func New(openLibs ...bool) *State { L := C.luaL_newstate() if L == nil { return nil } if len(openLibs) == 0 || openLibs[0] { C.luaL_openlibs(L) } return &State{L: L} } func (s *State) Close() { if s.L != nil { C.lua_close(s.L) s.L = nil } } // Stack operations func (s *State) GetTop() int { return int(C.lua_gettop(s.L)) } func (s *State) SetTop(index int) { C.lua_settop(s.L, C.int(index)) } func (s *State) PushCopy(index int) { C.lua_pushvalue(s.L, C.int(index)) } func (s *State) Pop(n int) { C.lua_settop(s.L, C.int(-n-1)) } func (s *State) Remove(index int) { C.lua_remove(s.L, C.int(index)) } func (s *State) absIndex(index int) int { if index > 0 || index <= LUA_REGISTRYINDEX { return index } return s.GetTop() + index + 1 } // Type checking func (s *State) GetType(index int) LuaType { return LuaType(C.lua_type(s.L, C.int(index))) } func (s *State) IsNil(index int) bool { return s.GetType(index) == TypeNil } func (s *State) IsBoolean(index int) bool { return s.GetType(index) == TypeBoolean } func (s *State) IsNumber(index int) bool { return C.lua_isnumber(s.L, C.int(index)) != 0 } func (s *State) IsString(index int) bool { return C.lua_isstring(s.L, C.int(index)) != 0 } func (s *State) IsTable(index int) bool { return s.GetType(index) == TypeTable } func (s *State) IsFunction(index int) bool { return s.GetType(index) == TypeFunction } // Value conversion func (s *State) ToBoolean(index int) bool { return C.lua_toboolean(s.L, C.int(index)) != 0 } func (s *State) ToNumber(index int) float64 { return float64(C.lua_tonumber(s.L, C.int(index))) } func (s *State) ToString(index int) string { var length C.size_t cstr := C.lua_tolstring(s.L, C.int(index), &length) if cstr == nil { return "" } return C.GoStringN(cstr, C.int(length)) } // Push methods func (s *State) PushNil() { C.lua_pushnil(s.L) } func (s *State) PushBoolean(b bool) { C.lua_pushboolean(s.L, boolToInt(b)) } func (s *State) PushNumber(n float64) { C.lua_pushnumber(s.L, C.lua_Number(n)) } func (s *State) PushString(str string) { if len(str) < 128 { cstr := C.CString(str) defer C.free(unsafe.Pointer(cstr)) C.lua_pushlstring(s.L, cstr, C.size_t(len(str))) } else { header := (*struct { p unsafe.Pointer len int cap int })(unsafe.Pointer(&str)) C.lua_pushlstring(s.L, (*C.char)(header.p), C.size_t(len(str))) } } // Table operations func (s *State) CreateTable(narr, nrec int) { C.lua_createtable(s.L, C.int(narr), C.int(nrec)) } func (s *State) NewTable() { C.lua_createtable(s.L, 0, 0) } func (s *State) GetTable(index int) { C.lua_gettable(s.L, C.int(index)) } func (s *State) SetTable(index int) { C.lua_settable(s.L, C.int(index)) } func (s *State) Next(index int) bool { return C.lua_next(s.L, C.int(index)) != 0 } func (s *State) GetField(index int, key string) { ckey := C.CString(key) defer C.free(unsafe.Pointer(ckey)) C.lua_getfield(s.L, C.int(index), ckey) } func (s *State) SetField(index int, key string) { ckey := C.CString(key) defer C.free(unsafe.Pointer(ckey)) C.lua_setfield(s.L, C.int(index), ckey) } func (s *State) GetTableLength(index int) int { return int(C.get_table_length(s.L, C.int(index))) } // Enhanced PushValue with comprehensive type support func (s *State) PushValue(v any) error { switch val := v.(type) { case nil: s.PushNil() case bool: s.PushBoolean(val) case int: s.PushNumber(float64(val)) case int64: s.PushNumber(float64(val)) case float64: s.PushNumber(val) case string: s.PushString(val) case []byte: s.PushString(string(val)) case []int: return s.pushIntSlice(val) case []string: return s.pushStringSlice(val) case []bool: return s.pushBoolSlice(val) case []float64: return s.pushFloatSlice(val) case []any: return s.pushAnySlice(val) case []map[string]any: return s.pushMapSlice(val) case map[string]string: return s.pushStringMap(val) case map[string]int: return s.pushIntMap(val) case map[int]any: return s.pushIntKeyMap(val) case map[string]any: return s.pushAnyMap(val) default: return fmt.Errorf("unsupported type: %T", v) } return nil } func (s *State) pushIntSlice(arr []int) error { s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) s.PushNumber(float64(v)) s.SetTable(-3) } return nil } func (s *State) pushStringSlice(arr []string) error { s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) s.PushString(v) s.SetTable(-3) } return nil } func (s *State) pushBoolSlice(arr []bool) error { s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) s.PushBoolean(v) s.SetTable(-3) } return nil } func (s *State) pushFloatSlice(arr []float64) error { s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) s.PushNumber(v) s.SetTable(-3) } return nil } func (s *State) pushAnySlice(arr []any) error { s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) if err := s.PushValue(v); err != nil { return err } s.SetTable(-3) } return nil } func (s *State) pushStringMap(m map[string]string) error { s.CreateTable(0, len(m)) for k, v := range m { s.PushString(k) s.PushString(v) s.SetTable(-3) } return nil } func (s *State) pushIntMap(m map[string]int) error { s.CreateTable(0, len(m)) for k, v := range m { s.PushString(k) s.PushNumber(float64(v)) s.SetTable(-3) } return nil } func (s *State) pushIntKeyMap(m map[int]any) error { s.CreateTable(0, len(m)) for k, v := range m { s.PushNumber(float64(k)) if err := s.PushValue(v); err != nil { return err } s.SetTable(-3) } return nil } func (s *State) pushAnyMap(m map[string]any) error { s.CreateTable(0, len(m)) for k, v := range m { s.PushString(k) if err := s.PushValue(v); err != nil { return err } s.SetTable(-3) } return nil } // Enhanced ToValue with automatic type detection func (s *State) ToValue(index int) (any, error) { switch s.GetType(index) { case TypeNil: return nil, nil case TypeBoolean: return s.ToBoolean(index), nil case TypeNumber: num := s.ToNumber(index) if num == float64(int(num)) && num >= -2147483648 && num <= 2147483647 { return int(num), nil } return num, nil case TypeString: return s.ToString(index), nil case TypeTable: return s.ToTable(index) default: return nil, fmt.Errorf("unsupported type: %s", s.GetType(index)) } } // ToTable converts a Lua table to optimal Go type func (s *State) ToTable(index int) (any, error) { absIdx := s.absIndex(index) if !s.IsTable(absIdx) { return nil, fmt.Errorf("value at index %d is not a table", index) } length := s.GetTableLength(absIdx) if length > 0 { arrayType := int(C.sample_array_type(s.L, C.int(absIdx), C.int(length))) switch arrayType { case 1: // int array return s.extractIntArray(absIdx, length), nil case 2: // float array return s.extractFloatArray(absIdx, length), nil case 3: // string array return s.extractStringArray(absIdx, length), nil case 4: // bool array return s.extractBoolArray(absIdx, length), nil default: // mixed array return s.extractAnyArray(absIdx, length), nil } } return s.extractAnyMap(absIdx) } func (s *State) extractIntArray(index, length int) []int { result := make([]int, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) s.GetTable(index) result[i-1] = int(s.ToNumber(-1)) s.Pop(1) } return result } func (s *State) extractFloatArray(index, length int) []float64 { result := make([]float64, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) s.GetTable(index) result[i-1] = s.ToNumber(-1) s.Pop(1) } return result } func (s *State) extractStringArray(index, length int) []string { result := make([]string, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) s.GetTable(index) result[i-1] = s.ToString(-1) s.Pop(1) } return result } func (s *State) extractBoolArray(index, length int) []bool { result := make([]bool, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) s.GetTable(index) result[i-1] = s.ToBoolean(-1) s.Pop(1) } return result } func (s *State) extractAnyArray(index, length int) []any { result := make([]any, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) s.GetTable(index) if val, err := s.ToValue(-1); err == nil { result[i-1] = val } s.Pop(1) } return result } func (s *State) extractAnyMap(index int) (map[string]any, error) { result := make(map[string]any) s.PushNil() for s.Next(index) { var key string switch s.GetType(-2) { case TypeString: key = s.ToString(-2) case TypeNumber: key = strconv.FormatFloat(s.ToNumber(-2), 'g', -1, 64) default: s.Pop(1) continue } if value, err := s.ToValue(-1); err == nil { result[key] = value } s.Pop(1) } return result, nil } // Global operations func (s *State) GetGlobal(name string) { s.GetField(LUA_GLOBALSINDEX, name) } func (s *State) SetGlobal(name string) { s.SetField(LUA_GLOBALSINDEX, name) } // Code execution func (s *State) LoadString(code string) error { ccode := C.CString(code) defer C.free(unsafe.Pointer(ccode)) status := C.luaL_loadstring(s.L, ccode) if status != 0 { err := s.CreateLuaError(int(status), "LoadString") s.Pop(1) return err } return nil } func (s *State) LoadFile(filename string) error { cfilename := C.CString(filename) defer C.free(unsafe.Pointer(cfilename)) status := C.luaL_loadfile(s.L, cfilename) if status != 0 { err := s.CreateLuaError(int(status), fmt.Sprintf("LoadFile(%s)", filename)) s.Pop(1) return err } return nil } func (s *State) Call(nargs, nresults int) error { status := C.lua_pcall(s.L, C.int(nargs), C.int(nresults), 0) if status != 0 { err := s.CreateLuaError(int(status), fmt.Sprintf("Call(%d,%d)", nargs, nresults)) s.Pop(1) return err } return nil } func (s *State) DoString(code string) error { ccode := C.CString(code) defer C.free(unsafe.Pointer(ccode)) status := C.do_string(s.L, ccode) if status != 0 { err := s.CreateLuaError(int(status), "DoString") s.Pop(1) return err } return nil } func (s *State) DoFile(filename string) error { cfilename := C.CString(filename) defer C.free(unsafe.Pointer(cfilename)) status := C.do_file(s.L, cfilename) if status != 0 { err := s.CreateLuaError(int(status), fmt.Sprintf("DoFile(%s)", filename)) s.Pop(1) return err } return nil } func (s *State) Execute(code string) (int, error) { baseTop := s.GetTop() ccode := C.CString(code) defer C.free(unsafe.Pointer(ccode)) status := C.execute_with_results(s.L, ccode, 1) if status != 0 { err := s.CreateLuaError(int(status), "Execute") s.Pop(1) return 0, err } return s.GetTop() - baseTop, nil } func (s *State) ExecuteWithResult(code string) (any, error) { top := s.GetTop() defer s.SetTop(top) nresults, err := s.Execute(code) if err != nil { return nil, err } if nresults == 0 { return nil, nil } return s.ToValue(-nresults) } func (s *State) BatchExecute(statements []string) error { return s.DoString(strings.Join(statements, "; ")) } // Package path operations func (s *State) SetPackagePath(path string) error { path = strings.ReplaceAll(path, "\\", "/") return s.DoString(fmt.Sprintf(`package.path = %q`, path)) } func (s *State) AddPackagePath(path string) error { path = strings.ReplaceAll(path, "\\", "/") return s.DoString(fmt.Sprintf(`package.path = package.path .. ";%s"`, path)) } // Metatable operations func (s *State) SetMetatable(index int) { C.lua_setmetatable(s.L, C.int(index)) } func (s *State) GetMetatable(index int) bool { return C.lua_getmetatable(s.L, C.int(index)) != 0 } // Helper functions func boolToInt(b bool) C.int { if b { return 1 } return 0 } // GetFieldString gets a string field from a table with default func (s *State) GetFieldString(index int, key string, defaultVal string) string { s.GetField(index, key) defer s.Pop(1) if s.IsString(-1) { return s.ToString(-1) } return defaultVal } // GetFieldNumber gets a number field from a table with default func (s *State) GetFieldNumber(index int, key string, defaultVal float64) float64 { s.GetField(index, key) defer s.Pop(1) if s.IsNumber(-1) { return s.ToNumber(-1) } return defaultVal } // GetFieldBool gets a boolean field from a table with default func (s *State) GetFieldBool(index int, key string, defaultVal bool) bool { s.GetField(index, key) defer s.Pop(1) if s.IsBoolean(-1) { return s.ToBoolean(-1) } return defaultVal } // GetFieldTable gets a table field from a table func (s *State) GetFieldTable(index int, key string) (any, bool) { s.GetField(index, key) defer s.Pop(1) if s.IsTable(-1) { val, err := s.ToTable(-1) return val, err == nil } return nil, false } // ForEachTableKV iterates over string key-value pairs in a table func (s *State) ForEachTableKV(index int, fn func(key, value string) bool) { absIdx := s.absIndex(index) s.PushNil() for s.Next(absIdx) { if s.IsString(-2) && s.IsString(-1) { if !fn(s.ToString(-2), s.ToString(-1)) { s.Pop(2) return } } s.Pop(1) } } // ForEachArray iterates over array elements func (s *State) ForEachArray(index int, fn func(i int, state *State) bool) { absIdx := s.absIndex(index) length := s.GetTableLength(absIdx) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) s.GetTable(absIdx) if !fn(i, s) { s.Pop(1) return } s.Pop(1) } } // SafeToString safely converts value to string with error func (s *State) SafeToString(index int) (string, error) { if !s.IsString(index) && !s.IsNumber(index) { return "", fmt.Errorf("value at index %d is not a string", index) } return s.ToString(index), nil } // SafeToNumber safely converts value to number with error func (s *State) SafeToNumber(index int) (float64, error) { if !s.IsNumber(index) { return 0, fmt.Errorf("value at index %d is not a number", index) } return s.ToNumber(index), nil } // SafeToTable safely converts value to table with error func (s *State) SafeToTable(index int) (any, error) { if !s.IsTable(index) { return nil, fmt.Errorf("value at index %d is not a table", index) } return s.ToTable(index) } // CallGlobal calls a global function with arguments func (s *State) CallGlobal(name string, args ...any) ([]any, error) { s.GetGlobal(name) if !s.IsFunction(-1) { s.Pop(1) return nil, fmt.Errorf("global '%s' is not a function", name) } for i, arg := range args { if err := s.PushValue(arg); err != nil { s.Pop(i + 1) return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err) } } baseTop := s.GetTop() - len(args) - 1 if err := s.Call(len(args), C.LUA_MULTRET); err != nil { return nil, err } newTop := s.GetTop() nresults := newTop - baseTop results := make([]any, nresults) for i := 0; i < nresults; i++ { val, err := s.ToValue(baseTop + i + 1) if err != nil { results[i] = nil } else { results[i] = val } } s.SetTop(baseTop) return results, nil } func (s *State) pushMapSlice(arr []map[string]any) error { s.CreateTable(len(arr), 0) for i, m := range arr { s.PushNumber(float64(i + 1)) if err := s.PushValue(m); err != nil { return err } s.SetTable(-3) } return nil }