diff --git a/core/config/config.go b/core/config/config.go new file mode 100644 index 0000000..a2407d7 --- /dev/null +++ b/core/config/config.go @@ -0,0 +1,396 @@ +package config + +import ( + "errors" + "fmt" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Config represents a configuration loaded from a Lua file +type Config struct { + values map[string]any +} + +// New creates a new empty configuration +func New() *Config { + return &Config{ + values: make(map[string]any), + } +} + +// Load loads configuration from a Lua file +func Load(filePath string) (*Config, error) { + // Create a new Lua state + state := luajit.New() + if state == nil { + return nil, errors.New("failed to create Lua state") + } + defer state.Close() + + // Execute the Lua file + if err := state.DoFile(filePath); err != nil { + return nil, fmt.Errorf("failed to load config file: %w", err) + } + + // Create the config instance + config := New() + + // Extract values from the Lua state + if err := extractGlobals(state, config.values); err != nil { + return nil, err + } + + return config, nil +} + +// extractGlobals extracts global variables from the Lua state +func extractGlobals(state *luajit.State, values map[string]any) error { + // Get the globals table (_G) + state.GetGlobal("_G") + if !state.IsTable(-1) { + state.Pop(1) + return errors.New("failed to get globals table") + } + + // Pre-populate with standard globals for reference checking + stdGlobals := map[string]bool{ + "_G": true, "_VERSION": true, "assert": true, "collectgarbage": true, + "coroutine": true, "debug": true, "dofile": true, "error": true, + "getmetatable": true, "io": true, "ipairs": true, "load": true, + "loadfile": true, "loadstring": true, "math": true, "next": true, + "os": true, "package": true, "pairs": true, "pcall": true, + "print": true, "rawequal": true, "rawget": true, "rawset": true, + "require": true, "select": true, "setmetatable": true, "string": true, + "table": true, "tonumber": true, "tostring": true, "type": true, + "unpack": true, "xpcall": true, + // LuaJIT specific globals + "jit": true, "bit": true, "ffi": true, "bit32": true, + } + + // First, let's get the original globals to compare with user globals + originalGlobals := make(map[string]bool) + + // Execute empty Lua state to get standard globals + emptyState := luajit.New() + if emptyState != nil { + defer emptyState.Close() + + emptyState.GetGlobal("_G") + emptyState.PushNil() // Start iteration + for emptyState.Next(-2) { + if emptyState.IsString(-2) { + key := emptyState.ToString(-2) + originalGlobals[key] = true + } + emptyState.Pop(1) // Pop value, leave key for next iteration + } + emptyState.Pop(1) // Pop _G + } + + // Iterate through the globals table + state.PushNil() // Start iteration + for state.Next(-2) { + // Stack now has key at -2 and value at -1 + + // Get key as string + if !state.IsString(-2) { + state.Pop(1) // Pop value, leave key for next iteration + continue + } + + key := state.ToString(-2) + + // Skip standard Lua globals, but only if they're not overridden by user + // (standard globals will be functions or tables, user values usually aren't) + valueType := state.GetType(-1) + + // Skip functions, userdata, and threads regardless of origin + if valueType == luajit.TypeFunction || valueType == luajit.TypeUserData || valueType == luajit.TypeThread { + state.Pop(1) + continue + } + + // For known Lua globals, we need to see if they're the original or user-defined + if stdGlobals[key] { + // For simple value types, assume user-defined + if valueType == luajit.TypeBoolean || valueType == luajit.TypeNumber || valueType == luajit.TypeString { + // These are probably user values with standard names + } else if originalGlobals[key] { + // If it's in the original globals and not a simple type, skip it + state.Pop(1) + continue + } + } + + // Handle primitive types directly + switch valueType { + case luajit.TypeBoolean: + values[key] = state.ToBoolean(-1) + state.Pop(1) + continue + case luajit.TypeNumber: + values[key] = state.ToNumber(-1) + state.Pop(1) + continue + case luajit.TypeString: + values[key] = state.ToString(-1) + state.Pop(1) + continue + case luajit.TypeTable: + // For tables, use the existing conversion logic + default: + // Skip unsupported types + state.Pop(1) + continue + } + + // Handle tables (arrays and maps) + if valueType == luajit.TypeTable { + // Check if it looks like an array first + arrLen := state.GetTableLength(-1) + if arrLen > 0 { + // Process as array + arr := make([]any, arrLen) + for i := 1; i <= arrLen; i++ { + state.PushNumber(float64(i)) + state.GetTable(-2) // Get t[i] + + switch state.GetType(-1) { + case luajit.TypeBoolean: + arr[i-1] = state.ToBoolean(-1) + case luajit.TypeNumber: + arr[i-1] = state.ToNumber(-1) + case luajit.TypeString: + arr[i-1] = state.ToString(-1) + default: + // For complex elements, try to convert + if val, err := state.ToValue(-1); err == nil { + arr[i-1] = val + } + } + + state.Pop(1) // Pop value + } + values[key] = arr + state.Pop(1) + continue + } + + // Try normal table conversion for non-array tables + if table, err := state.ToTable(-1); err == nil { + values[key] = table + } + } + + // Pop value, leave key for next iteration + state.Pop(1) + } + + // Pop the globals table + state.Pop(1) + + return nil +} + +// Get returns a configuration value by key +func (c *Config) Get(key string) any { + return c.values[key] +} + +// GetString returns a string configuration value +func (c *Config) GetString(key string, defaultValue string) string { + value, ok := c.values[key] + if !ok { + return defaultValue + } + + str, ok := value.(string) + if !ok { + return defaultValue + } + + return str +} + +// GetInt returns an integer configuration value +func (c *Config) GetInt(key string, defaultValue int) int { + value, ok := c.values[key] + if !ok { + return defaultValue + } + + // Handle both int and float64 (which is what Lua numbers become in Go) + switch v := value.(type) { + case int: + return v + case float64: + return int(v) + default: + return defaultValue + } +} + +// GetFloat returns a float configuration value +func (c *Config) GetFloat(key string, defaultValue float64) float64 { + value, ok := c.values[key] + if !ok { + return defaultValue + } + + // Handle both float64 and int + switch v := value.(type) { + case float64: + return v + case int: + return float64(v) + default: + return defaultValue + } +} + +// GetBool returns a boolean configuration value +func (c *Config) GetBool(key string, defaultValue bool) bool { + value, ok := c.values[key] + if !ok { + return defaultValue + } + + boolValue, ok := value.(bool) + if !ok { + return defaultValue + } + + return boolValue +} + +// GetMap returns a map configuration value +func (c *Config) GetMap(key string) map[string]any { + value, ok := c.values[key] + if !ok { + return nil + } + + table, ok := value.(map[string]any) + if !ok { + return nil + } + + return table +} + +// GetArray returns an array of values from a Lua array +func (c *Config) GetArray(key string) []any { + value := c.Get(key) + if value == nil { + return nil + } + + // Direct array + if arr, ok := value.([]any); ok { + return arr + } + + // Arrays in Lua might also be represented as maps with an empty string key + valueMap, ok := value.(map[string]any) + if !ok { + return nil + } + + // The array data is stored with an empty key + arr, ok := valueMap[""] + if !ok { + return nil + } + + // Check if it's a float64 array (common for Lua numeric arrays) + if floatArr, ok := arr.([]float64); ok { + // Convert to []any + result := make([]any, len(floatArr)) + for i, v := range floatArr { + result[i] = v + } + return result + } + + // Otherwise, try to return as is + anyArr, ok := arr.([]any) + if !ok { + return nil + } + + return anyArr +} + +// GetIntArray returns an array of integers from a Lua array +func (c *Config) GetIntArray(key string) []int { + value := c.Get(key) + if value == nil { + return nil + } + + // Direct array case + if arr, ok := value.([]any); ok { + result := make([]int, 0, len(arr)) + for _, v := range arr { + if num, ok := v.(float64); ok { + result = append(result, int(num)) + } + } + return result + } + + // Arrays in Lua might also be represented as maps with an empty string key + valueMap, ok := value.(map[string]any) + if !ok { + return nil + } + + // The array data is stored with an empty key + arr, ok := valueMap[""] + if !ok { + return nil + } + + // For numeric arrays, LuaJIT returns []float64 + floatArr, ok := arr.([]float64) + if !ok { + return nil + } + + // Convert to int slice + result := make([]int, len(floatArr)) + for i, v := range floatArr { + result[i] = int(v) + } + + return result +} + +// GetStringArray returns an array of strings from a Lua array +func (c *Config) GetStringArray(key string) []string { + arr := c.GetArray(key) + if arr == nil { + return nil + } + + result := make([]string, 0, len(arr)) + for _, v := range arr { + if str, ok := v.(string); ok { + result = append(result, str) + } + } + + return result +} + +// Values returns all configuration values +// Note: The returned map should not be modified +func (c *Config) Values() map[string]any { + return c.values +} + +// Set sets a configuration value +func (c *Config) Set(key string, value any) { + c.values[key] = value +} diff --git a/core/config/config_test.go b/core/config/config_test.go new file mode 100644 index 0000000..16297d7 --- /dev/null +++ b/core/config/config_test.go @@ -0,0 +1,351 @@ +package config + +import ( + "os" + "reflect" + "testing" +) + +// TestLoad verifies we can successfully load configuration values from a Lua file. +func TestLoad(t *testing.T) { + // Create a temporary config file + content := ` + -- Basic configuration values + host = "localhost" + port = 8080 + debug = true + pi = 3.14159 + ` + configFile := createTempLuaFile(t, content) + defer os.Remove(configFile) + + // Load the config + cfg, err := Load(configFile) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Verify values were loaded correctly + if host := cfg.GetString("host", ""); host != "localhost" { + t.Errorf("Expected host to be 'localhost', got '%s'", host) + } + + if port := cfg.GetInt("port", 0); port != 8080 { + t.Errorf("Expected port to be 8080, got %d", port) + } + + if debug := cfg.GetBool("debug", false); !debug { + t.Errorf("Expected debug to be true") + } + + if pi := cfg.GetFloat("pi", 0); pi != 3.14159 { + t.Errorf("Expected pi to be 3.14159, got %f", pi) + } +} + +// TestLoadErrors ensures the package properly handles loading errors. +func TestLoadErrors(t *testing.T) { + // Test with non-existent file + _, err := Load("nonexistent.lua") + if err == nil { + t.Error("Expected error when loading non-existent file, got nil") + } + + // Test with invalid Lua + content := ` + -- This is invalid Lua + host = "localhost + port = 8080) + ` + configFile := createTempLuaFile(t, content) + defer os.Remove(configFile) + + _, err = Load(configFile) + if err == nil { + t.Error("Expected error when loading invalid Lua, got nil") + } +} + +// TestLocalVsGlobal verifies only global variables are exported, not locals. +func TestLocalVsGlobal(t *testing.T) { + // Create a temporary config file with both local and global variables + content := ` + -- Local variables should not be exported + local local_var = "hidden" + + -- Global variables should be exported + global_var = "visible" + + -- A function that uses both + function test_func() + return local_var .. " " .. global_var + end + ` + configFile := createTempLuaFile(t, content) + defer os.Remove(configFile) + + // Load the config + cfg, err := Load(configFile) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Check that global_var exists + if globalVar := cfg.GetString("global_var", ""); globalVar != "visible" { + t.Errorf("Expected global_var to be 'visible', got '%s'", globalVar) + } + + // Check that local_var does not exist + if localVar := cfg.GetString("local_var", "default"); localVar != "default" { + t.Errorf("Expected local_var to use default, got '%s'", localVar) + } + + // Check that functions are not exported + if val := cfg.Get("test_func"); val != nil { + t.Errorf("Expected function to not be exported, got %v", val) + } +} + +// TestArrayHandling verifies correct handling of Lua arrays. +func TestArrayHandling(t *testing.T) { + // Create a temporary config file with arrays + content := ` + -- Numeric array + numbers = {10, 20, 30, 40, 50} + + -- String array + strings = {"apple", "banana", "cherry"} + + -- Mixed array + mixed = {1, "two", true, 4.5} + ` + configFile := createTempLuaFile(t, content) + defer os.Remove(configFile) + + // Load the config + cfg, err := Load(configFile) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Test GetIntArray + intArray := cfg.GetIntArray("numbers") + expectedInts := []int{10, 20, 30, 40, 50} + if !reflect.DeepEqual(intArray, expectedInts) { + t.Errorf("Expected int array %v, got %v", expectedInts, intArray) + } + + // Test GetStringArray + strArray := cfg.GetStringArray("strings") + expectedStrs := []string{"apple", "banana", "cherry"} + if !reflect.DeepEqual(strArray, expectedStrs) { + t.Errorf("Expected string array %v, got %v", expectedStrs, strArray) + } + + // Test GetArray with mixed types + mixedArray := cfg.GetArray("mixed") + if len(mixedArray) != 4 { + t.Errorf("Expected mixed array length 4, got %d", len(mixedArray)) + // Skip further tests if array is empty to avoid panic + return + } + + // Check types - carefully to avoid panics + if len(mixedArray) > 0 { + if num, ok := mixedArray[0].(float64); !ok || num != 1 { + t.Errorf("Expected first element to be 1, got %v", mixedArray[0]) + } + } + + if len(mixedArray) > 1 { + if str, ok := mixedArray[1].(string); !ok || str != "two" { + t.Errorf("Expected second element to be 'two', got %v", mixedArray[1]) + } + } +} + +// TestComplexTable tests handling of complex nested tables. +func TestComplexTable(t *testing.T) { + // Create a temporary config file with complex tables + content := ` + -- Nested table structure + server = { + host = "localhost", + port = 8080, + settings = { + timeout = 30, + retries = 3 + } + } + + -- Table with mixed array and map elements + mixed_table = { + list = {1, 2, 3}, + mapping = { + a = "apple", + b = "banana" + } + } + ` + configFile := createTempLuaFile(t, content) + defer os.Remove(configFile) + + // Load the config + cfg, err := Load(configFile) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Test getting nested values + serverMap := cfg.GetMap("server") + if serverMap == nil { + t.Fatal("Expected server map to exist") + } + + // Check first level values + if host, ok := serverMap["host"].(string); !ok || host != "localhost" { + t.Errorf("Expected server.host to be 'localhost', got %v", serverMap["host"]) + } + + if port, ok := serverMap["port"].(float64); !ok || port != 8080 { + t.Errorf("Expected server.port to be 8080, got %v", serverMap["port"]) + } + + // Check nested settings + settings, ok := serverMap["settings"].(map[string]any) + if !ok { + t.Fatal("Expected server.settings to be a map") + } + + if timeout, ok := settings["timeout"].(float64); !ok || timeout != 30 { + t.Errorf("Expected server.settings.timeout to be 30, got %v", settings["timeout"]) + } +} + +// TestDefaultValues verifies default values work correctly when keys don't exist. +func TestDefaultValues(t *testing.T) { + // Create a temporary config file + content := ` + -- Just one value + existing = "value" + ` + configFile := createTempLuaFile(t, content) + defer os.Remove(configFile) + + // Load the config + cfg, err := Load(configFile) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Test defaults for non-existent keys + if val := cfg.GetString("nonexistent", "default"); val != "default" { + t.Errorf("Expected default string, got '%s'", val) + } + + if val := cfg.GetInt("nonexistent", 42); val != 42 { + t.Errorf("Expected default int 42, got %d", val) + } + + if val := cfg.GetFloat("nonexistent", 3.14); val != 3.14 { + t.Errorf("Expected default float 3.14, got %f", val) + } + + if val := cfg.GetBool("nonexistent", true); !val { + t.Errorf("Expected default bool true, got false") + } +} + +// TestModifyConfig tests the ability to modify configuration values. +func TestModifyConfig(t *testing.T) { + // Create a config manually + cfg := New() + + // Set some values + cfg.Set("host", "localhost") + cfg.Set("port", 8080) + + // Verify the values were set + if host := cfg.GetString("host", ""); host != "localhost" { + t.Errorf("Expected host to be 'localhost', got '%s'", host) + } + + if port := cfg.GetInt("port", 0); port != 8080 { + t.Errorf("Expected port to be 8080, got %d", port) + } + + // Modify a value + cfg.Set("host", "127.0.0.1") + + // Verify the change + if host := cfg.GetString("host", ""); host != "127.0.0.1" { + t.Errorf("Expected modified host to be '127.0.0.1', got '%s'", host) + } +} + +// TestTypeConversion tests the type conversion in getter methods. +func TestTypeConversion(t *testing.T) { + // Create a temporary config file with values that need conversion + content := ` + -- Numbers that can be integers + int_as_float = 42.0 + + -- Floats that should remain floats + float_val = 3.14159 + ` + configFile := createTempLuaFile(t, content) + defer os.Remove(configFile) + + // Load the config + cfg, err := Load(configFile) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Test GetInt with a float value + if val := cfg.GetInt("int_as_float", 0); val != 42 { + t.Errorf("Expected int 42, got %d", val) + } + + // Test GetFloat with an int value + if val := cfg.GetFloat("int_as_float", 0); val != 42.0 { + t.Errorf("Expected float 42.0, got %f", val) + } + + // Test incorrect type handling + cfg.Set("string_val", "not a number") + + if val := cfg.GetInt("string_val", 99); val != 99 { + t.Errorf("Expected default int 99 for string value, got %d", val) + } + + if val := cfg.GetFloat("string_val", 99.9); val != 99.9 { + t.Errorf("Expected default float 99.9 for string value, got %f", val) + } + + if val := cfg.GetBool("float_val", false); val != false { + t.Errorf("Expected default false for non-bool value, got true") + } +} + +// Helper function to create a temporary Lua file with content +func createTempLuaFile(t *testing.T, content string) string { + t.Helper() + + tempFile, err := os.CreateTemp("", "config-test-*.lua") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + if _, err := tempFile.WriteString(content); err != nil { + os.Remove(tempFile.Name()) + t.Fatalf("Failed to write to temp file: %v", err) + } + + if err := tempFile.Close(); err != nil { + os.Remove(tempFile.Name()) + t.Fatalf("Failed to close temp file: %v", err) + } + + return tempFile.Name() +}