1
0
LuaJIT-to-Go/batch.go

523 lines
11 KiB
Go

package luajit
/*
#include <lua.h>
#include <lauxlib.h>
#include <stdlib.h>
#include <string.h>
// Fast array creation using lua_rawseti
void batch_push_int_array_fast(lua_State *L, int *values, int count) {
lua_createtable(L, count, 0);
for (int i = 0; i < count; i++) {
lua_pushnumber(L, values[i]);
lua_rawseti(L, -2, i + 1);
}
}
void batch_push_float_array_fast(lua_State *L, double *values, int count) {
lua_createtable(L, count, 0);
for (int i = 0; i < count; i++) {
lua_pushnumber(L, values[i]);
lua_rawseti(L, -2, i + 1);
}
}
void batch_push_string_array_fast(lua_State *L, char **values, int count) {
lua_createtable(L, count, 0);
for (int i = 0; i < count; i++) {
lua_pushstring(L, values[i]);
lua_rawseti(L, -2, i + 1);
}
}
// Fast field setting
void batch_set_string_fields_fast(lua_State *L, int table_idx,
char **keys, char **values, int count) {
for (int i = 0; i < count; i++) {
lua_pushstring(L, values[i]);
lua_setfield(L, table_idx, keys[i]);
}
}
void batch_set_number_fields_fast(lua_State *L, int table_idx,
char **keys, double *values, int count) {
for (int i = 0; i < count; i++) {
lua_pushnumber(L, values[i]);
lua_setfield(L, table_idx, keys[i]);
}
}
*/
import "C"
import (
"fmt"
"sync"
"unsafe"
)
// Buffer pools for reuse
var (
intBufferPool = sync.Pool{
New: func() any { return make([]C.int, 0, 128) },
}
floatBufferPool = sync.Pool{
New: func() any { return make([]C.double, 0, 128) },
}
charPtrPool = sync.Pool{
New: func() any { return make([]*C.char, 0, 32) },
}
)
func (s *State) BatchSetStringFields(tableIndex int, fields map[string]string) error {
count := len(fields)
if count == 0 {
return nil
}
if count < 8 {
for k, v := range fields {
s.PushString(v)
s.SetField(tableIndex, k)
}
return nil
}
keys := make([]*C.char, count)
values := make([]*C.char, count)
i := 0
for k, v := range fields {
keys[i] = C.CString(k)
values[i] = C.CString(v)
i++
}
C.batch_set_string_fields_fast(s.L, C.int(tableIndex), &keys[0], &values[0], C.int(count))
for j := 0; j < count; j++ {
C.free(unsafe.Pointer(keys[j]))
C.free(unsafe.Pointer(values[j]))
}
return nil
}
func (s *State) BatchSetNumberFields(tableIndex int, fields map[string]float64) error {
count := len(fields)
if count == 0 {
return nil
}
if count < 8 {
for k, v := range fields {
s.PushNumber(v)
s.SetField(tableIndex, k)
}
return nil
}
keys := make([]*C.char, count)
values := make([]C.double, count)
i := 0
for k, v := range fields {
keys[i] = C.CString(k)
values[i] = C.double(v)
i++
}
C.batch_set_number_fields_fast(s.L, C.int(tableIndex), &keys[0], &values[0], C.int(count))
for j := 0; j < count; j++ {
C.free(unsafe.Pointer(keys[j]))
}
return nil
}
func (s *State) BatchSetBoolFields(tableIndex int, fields map[string]bool) error {
for k, v := range fields {
s.PushBoolean(v)
s.SetField(tableIndex, k)
}
return nil
}
func (s *State) BatchPushIntArray(values []int) error {
if len(values) == 0 {
s.CreateTable(0, 0)
return nil
}
if len(values) < 10 {
s.CreateTable(len(values), 0)
for i, v := range values {
s.PushNumber(float64(v))
C.lua_rawseti(s.L, -2, C.int(i+1))
}
return nil
}
// Use pooled buffer
intBuf := intBufferPool.Get().([]C.int)
defer intBufferPool.Put(intBuf[:0])
if cap(intBuf) < len(values) {
intBuf = make([]C.int, len(values))
} else {
intBuf = intBuf[:len(values)]
}
for i, v := range values {
intBuf[i] = C.int(v)
}
C.batch_push_int_array_fast(s.L, &intBuf[0], C.int(len(values)))
return nil
}
func (s *State) BatchPushFloatArray(values []float64) error {
if len(values) == 0 {
s.CreateTable(0, 0)
return nil
}
if len(values) < 10 {
s.CreateTable(len(values), 0)
for i, v := range values {
s.PushNumber(v)
C.lua_rawseti(s.L, -2, C.int(i+1))
}
return nil
}
floatBuf := floatBufferPool.Get().([]C.double)
defer floatBufferPool.Put(floatBuf[:0])
if cap(floatBuf) < len(values) {
floatBuf = make([]C.double, len(values))
} else {
floatBuf = floatBuf[:len(values)]
}
for i, v := range values {
floatBuf[i] = C.double(v)
}
C.batch_push_float_array_fast(s.L, &floatBuf[0], C.int(len(values)))
return nil
}
func (s *State) BatchPushStringArray(values []string) error {
if len(values) == 0 {
s.CreateTable(0, 0)
return nil
}
if len(values) < 6 {
s.CreateTable(len(values), 0)
for i, v := range values {
s.PushString(v)
C.lua_rawseti(s.L, -2, C.int(i+1))
}
return nil
}
cValues := make([]*C.char, len(values))
for i, v := range values {
cValues[i] = C.CString(v)
}
C.batch_push_string_array_fast(s.L, &cValues[0], C.int(len(values)))
for i := range values {
C.free(unsafe.Pointer(cValues[i]))
}
return nil
}
func (s *State) BatchPushBoolArray(values []bool) error {
if len(values) == 0 {
s.CreateTable(0, 0)
return nil
}
s.CreateTable(len(values), 0)
for i, v := range values {
s.PushBoolean(v)
C.lua_rawseti(s.L, -2, C.int(i+1))
}
return nil
}
func (s *State) BatchExtractIntArray(index, length int) ([]int, error) {
if length <= 0 {
return []int{}, nil
}
absIdx := s.absIndex(index)
result := make([]int, length)
for i := 0; i < length; i++ {
C.lua_rawgeti(s.L, C.int(absIdx), C.int(i+1))
result[i] = int(s.ToNumber(-1))
s.Pop(1)
}
return result, nil
}
func (s *State) BatchExtractFloatArray(index, length int) ([]float64, error) {
if length <= 0 {
return []float64{}, nil
}
absIdx := s.absIndex(index)
result := make([]float64, length)
for i := 0; i < length; i++ {
C.lua_rawgeti(s.L, C.int(absIdx), C.int(i+1))
result[i] = s.ToNumber(-1)
s.Pop(1)
}
return result, nil
}
func (s *State) BatchExtractStringArray(index, length int) ([]string, error) {
if length <= 0 {
return []string{}, nil
}
absIdx := s.absIndex(index)
result := make([]string, length)
for i := 0; i < length; i++ {
C.lua_rawgeti(s.L, C.int(absIdx), C.int(i+1))
result[i] = s.ToString(-1)
s.Pop(1)
}
return result, nil
}
func (s *State) BatchExtractBoolArray(index, length int) ([]bool, error) {
if length <= 0 {
return []bool{}, nil
}
absIdx := s.absIndex(index)
result := make([]bool, length)
for i := 0; i < length; i++ {
C.lua_rawgeti(s.L, C.int(absIdx), C.int(i+1))
result[i] = s.ToBoolean(-1)
s.Pop(1)
}
return result, nil
}
func (s *State) BatchSetGlobals(globals map[string]string) error {
for k, v := range globals {
s.PushString(v)
s.SetGlobal(k)
}
return nil
}
func (s *State) BatchGetGlobals(names []string) error {
for _, name := range names {
s.GetGlobal(name)
}
return nil
}
func (s *State) BatchGetTableFields(tableIndex int, keys []string) error {
absIdx := s.absIndex(tableIndex)
for _, key := range keys {
s.GetField(absIdx, key)
}
return nil
}
type TypeCheck struct {
Index int
ExpectedType LuaType
}
func (s *State) BatchCheckTypes(checks []TypeCheck) error {
for _, check := range checks {
if s.GetType(check.Index) != check.ExpectedType {
return fmt.Errorf("type mismatch at index %d: expected %s, got %s",
check.Index, check.ExpectedType, s.GetType(check.Index))
}
}
return nil
}
type BatchTableBuilder struct {
state *State
index int
stringFields map[string]string
numberFields map[string]float64
boolFields map[string]bool
otherFields map[string]any
}
func (s *State) NewBatchTableBuilder() *BatchTableBuilder {
s.NewTable()
return &BatchTableBuilder{
state: s,
index: s.GetTop(),
stringFields: make(map[string]string),
numberFields: make(map[string]float64),
boolFields: make(map[string]bool),
otherFields: make(map[string]any),
}
}
func (tb *BatchTableBuilder) SetString(key, value string) *BatchTableBuilder {
tb.stringFields[key] = value
return tb
}
func (tb *BatchTableBuilder) SetNumber(key string, value float64) *BatchTableBuilder {
tb.numberFields[key] = value
return tb
}
func (tb *BatchTableBuilder) SetBool(key string, value bool) *BatchTableBuilder {
tb.boolFields[key] = value
return tb
}
func (tb *BatchTableBuilder) SetNil(key string) *BatchTableBuilder {
tb.state.PushNil()
tb.state.SetField(tb.index, key)
return tb
}
func (tb *BatchTableBuilder) SetTable(key string, value any) *BatchTableBuilder {
tb.otherFields[key] = value
return tb
}
func (tb *BatchTableBuilder) SetArray(key string, values []any) *BatchTableBuilder {
tb.otherFields[key] = values
return tb
}
func (tb *BatchTableBuilder) Build() error {
if len(tb.stringFields) > 0 {
tb.state.BatchSetStringFields(tb.index, tb.stringFields)
}
if len(tb.numberFields) > 0 {
tb.state.BatchSetNumberFields(tb.index, tb.numberFields)
}
if len(tb.boolFields) > 0 {
tb.state.BatchSetBoolFields(tb.index, tb.boolFields)
}
for key, value := range tb.otherFields {
tb.state.PushValue(value)
tb.state.SetField(tb.index, key)
}
return nil
}
type BatchTableReader struct {
state *State
index int
}
func (s *State) NewBatchTableReader(index int) *BatchTableReader {
return &BatchTableReader{
state: s,
index: s.absIndex(index),
}
}
func (btr *BatchTableReader) ReadFields(keys []string) (map[string]any, error) {
result := make(map[string]any, len(keys))
for _, key := range keys {
btr.state.GetField(btr.index, key)
if value, err := btr.state.ToValue(-1); err == nil {
result[key] = value
}
btr.state.Pop(1)
}
return result, nil
}
type BatchGlobalManager struct {
state *State
pendingSets map[string]string
pendingGets []string
}
func (s *State) NewBatchGlobalManager() *BatchGlobalManager {
return &BatchGlobalManager{
state: s,
pendingSets: make(map[string]string),
pendingGets: make([]string, 0),
}
}
func (bgm *BatchGlobalManager) QueueSet(name, value string) *BatchGlobalManager {
bgm.pendingSets[name] = value
return bgm
}
func (bgm *BatchGlobalManager) QueueGet(name string) *BatchGlobalManager {
bgm.pendingGets = append(bgm.pendingGets, name)
return bgm
}
func (bgm *BatchGlobalManager) Execute() (map[string]any, error) {
if len(bgm.pendingSets) > 0 {
bgm.state.BatchSetGlobals(bgm.pendingSets)
}
var result map[string]any
if len(bgm.pendingGets) > 0 {
startTop := bgm.state.GetTop()
bgm.state.BatchGetGlobals(bgm.pendingGets)
result = make(map[string]any, len(bgm.pendingGets))
for i, name := range bgm.pendingGets {
if value, err := bgm.state.ToValue(startTop + i + 1); err == nil {
result[name] = value
}
}
bgm.state.SetTop(startTop)
}
return result, nil
}
type BatchValuePusher struct {
state *State
values []any
}
func (s *State) NewBatchValuePusher() *BatchValuePusher {
return &BatchValuePusher{
state: s,
values: make([]any, 0),
}
}
func (bvp *BatchValuePusher) Add(value any) *BatchValuePusher {
bvp.values = append(bvp.values, value)
return bvp
}
func (bvp *BatchValuePusher) Push() error {
for _, value := range bvp.values {
if err := bvp.state.PushValue(value); err != nil {
return err
}
}
return nil
}