From 0e12e1956788e4ffe43c181e7db44abe157cda67 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 2 Jul 2025 11:12:26 -0500 Subject: [PATCH] introduce batching --- builder.go | 69 +++++++++++++++++++++++++ wrapper.go | 147 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 210 insertions(+), 6 deletions(-) diff --git a/builder.go b/builder.go index ed8fa01..c877e22 100644 --- a/builder.go +++ b/builder.go @@ -66,6 +66,75 @@ func (tb *TableBuilder) SetArray(key string, values []any) *TableBuilder { return tb } +// BatchBuild builds a table from a map using batched operations +func (tb *TableBuilder) BatchBuild(data map[string]any) *TableBuilder { + stringFields := make(map[string]string) + numberFields := make(map[string]float64) + boolFields := make(map[string]bool) + + // Separate fields by type for batching + for key, value := range data { + switch v := value.(type) { + case string: + stringFields[key] = v + case int: + numberFields[key] = float64(v) + case int64: + numberFields[key] = float64(v) + case float64: + numberFields[key] = v + case float32: + numberFields[key] = float64(v) + case bool: + boolFields[key] = v + case []int: + if len(v) > 5 { + tb.state.BatchPushIntArray(v) + } else { + tb.state.PushValue(value) + } + tb.state.SetField(tb.index, key) + case []string: + if len(v) > 3 { + tb.state.BatchPushStringArray(v) + } else { + tb.state.PushValue(value) + } + tb.state.SetField(tb.index, key) + case []float64: + if len(v) > 5 { + tb.state.BatchPushFloatArray(v) + } else { + tb.state.PushValue(value) + } + tb.state.SetField(tb.index, key) + case []bool: + if len(v) > 5 { + tb.state.BatchPushBoolArray(v) + } else { + tb.state.PushValue(value) + } + tb.state.SetField(tb.index, key) + default: + tb.state.PushValue(value) + tb.state.SetField(tb.index, key) + } + } + + // Execute batched operations + if len(stringFields) > 0 { + tb.state.BatchSetStringFields(tb.index, stringFields) + } + if len(numberFields) > 0 { + tb.state.BatchSetNumberFields(tb.index, numberFields) + } + if len(boolFields) > 0 { + tb.state.BatchSetBoolFields(tb.index, boolFields) + } + + return tb +} + // Build finalizes the table (no-op, table is already on stack) func (tb *TableBuilder) Build() { // Table is already on the stack at tb.index diff --git a/wrapper.go b/wrapper.go index 12ad81d..10b46b4 100644 --- a/wrapper.go +++ b/wrapper.go @@ -71,6 +71,38 @@ static int sample_array_type(lua_State *L, int index, int count) { if (all_bools) return 4; return 0; } + +static int sample_map_type(lua_State *L, int index, int max_samples) { + int all_numbers = 1; + int all_integers = 1; + int all_strings = 1; + int all_bools = 1; + int samples = 0; + + lua_pushnil(L); + while (lua_next(L, index) != 0 && samples < max_samples) { + if (lua_type(L, -2) == LUA_TSTRING) { + int type = lua_type(L, -1); + if (type != LUA_TNUMBER) all_numbers = all_integers = 0; + if (type != LUA_TSTRING) all_strings = 0; + if (type != LUA_TBOOLEAN) all_bools = 0; + + if (all_numbers && !is_integer(L, -1)) all_integers = 0; + + samples++; + } + lua_pop(L, 1); + + if (!all_numbers && !all_strings && !all_bools) break; + } + + if (samples == 0) return 0; + if (all_integers) return 1; + if (all_numbers) return 2; + if (all_strings) return 3; + if (all_bools) return 4; + return 0; +} */ import "C" import ( @@ -233,6 +265,9 @@ func (s *State) PushValue(v any) error { } func (s *State) pushIntSlice(arr []int) error { + if len(arr) > 5 { + return s.BatchPushIntArray(arr) + } s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) @@ -243,6 +278,9 @@ func (s *State) pushIntSlice(arr []int) error { } func (s *State) pushStringSlice(arr []string) error { + if len(arr) > 3 { + return s.BatchPushStringArray(arr) + } s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) @@ -253,6 +291,9 @@ func (s *State) pushStringSlice(arr []string) error { } func (s *State) pushBoolSlice(arr []bool) error { + if len(arr) > 5 { + return s.BatchPushBoolArray(arr) + } s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) @@ -263,6 +304,9 @@ func (s *State) pushBoolSlice(arr []bool) error { } func (s *State) pushFloatSlice(arr []float64) error { + if len(arr) > 5 { + return s.BatchPushFloatArray(arr) + } s.CreateTable(len(arr), 0) for i, v := range arr { s.PushNumber(float64(i + 1)) @@ -286,6 +330,9 @@ func (s *State) pushAnySlice(arr []any) error { func (s *State) pushStringMap(m map[string]string) error { s.CreateTable(0, len(m)) + if len(m) > 3 { + return s.BatchSetStringFields(-1, m) + } for k, v := range m { s.PushString(k) s.PushString(v) @@ -296,6 +343,13 @@ func (s *State) pushStringMap(m map[string]string) error { func (s *State) pushIntMap(m map[string]int) error { s.CreateTable(0, len(m)) + if len(m) > 3 { + numberFields := make(map[string]float64, len(m)) + for k, v := range m { + numberFields[k] = float64(v) + } + return s.BatchSetNumberFields(-1, numberFields) + } for k, v := range m { s.PushString(k) s.PushNumber(float64(v)) @@ -350,6 +404,54 @@ func (s *State) ToValue(index int) (any, error) { } } +func (s *State) extractStringMap(index int) map[string]string { + result := make(map[string]string) + s.PushNil() + for s.Next(index) { + if s.IsString(-2) && s.IsString(-1) { + result[s.ToString(-2)] = s.ToString(-1) + } + s.Pop(1) + } + return result +} + +func (s *State) extractIntMap(index int) map[string]int { + result := make(map[string]int) + s.PushNil() + for s.Next(index) { + if s.IsString(-2) && s.IsNumber(-1) { + result[s.ToString(-2)] = int(s.ToNumber(-1)) + } + s.Pop(1) + } + return result +} + +func (s *State) extractFloatMap(index int) map[string]float64 { + result := make(map[string]float64) + s.PushNil() + for s.Next(index) { + if s.IsString(-2) && s.IsNumber(-1) { + result[s.ToString(-2)] = s.ToNumber(-1) + } + s.Pop(1) + } + return result +} + +func (s *State) extractBoolMap(index int) map[string]bool { + result := make(map[string]bool) + s.PushNil() + for s.Next(index) { + if s.IsString(-2) && s.IsBoolean(-1) { + result[s.ToString(-2)] = s.ToBoolean(-1) + } + s.Pop(1) + } + return result +} + // ToTable converts a Lua table to optimal Go type func (s *State) ToTable(index int) (any, error) { absIdx := s.absIndex(index) @@ -362,23 +464,41 @@ func (s *State) ToTable(index int) (any, error) { if length > 0 { arrayType := int(C.sample_array_type(s.L, C.int(absIdx), C.int(length))) switch arrayType { - case 1: // int array + case 1: return s.extractIntArray(absIdx, length), nil - case 2: // float array + case 2: return s.extractFloatArray(absIdx, length), nil - case 3: // string array + case 3: return s.extractStringArray(absIdx, length), nil - case 4: // bool array + case 4: return s.extractBoolArray(absIdx, length), nil - default: // mixed array + default: return s.extractAnyArray(absIdx, length), nil } } - return s.extractAnyMap(absIdx) + mapType := int(C.sample_map_type(s.L, C.int(absIdx), C.int(5))) + switch mapType { + case 1: + return s.extractIntMap(absIdx), nil + case 2: + return s.extractFloatMap(absIdx), nil + case 3: + return s.extractStringMap(absIdx), nil + case 4: + return s.extractBoolMap(absIdx), nil + default: + result, err := s.extractAnyMap(absIdx) + return result, err + } } func (s *State) extractIntArray(index, length int) []int { + if length > 10 { + if result, err := s.BatchExtractIntArray(index, length); err == nil { + return result + } + } result := make([]int, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) @@ -390,6 +510,11 @@ func (s *State) extractIntArray(index, length int) []int { } func (s *State) extractFloatArray(index, length int) []float64 { + if length > 10 { + if result, err := s.BatchExtractFloatArray(index, length); err == nil { + return result + } + } result := make([]float64, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) @@ -401,6 +526,11 @@ func (s *State) extractFloatArray(index, length int) []float64 { } func (s *State) extractStringArray(index, length int) []string { + if length > 5 { + if result, err := s.BatchExtractStringArray(index, length); err == nil { + return result + } + } result := make([]string, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i)) @@ -412,6 +542,11 @@ func (s *State) extractStringArray(index, length int) []string { } func (s *State) extractBoolArray(index, length int) []bool { + if length > 10 { + if result, err := s.BatchExtractBoolArray(index, length); err == nil { + return result + } + } result := make([]bool, length) for i := 1; i <= length; i++ { s.PushNumber(float64(i))