simplify Config

This commit is contained in:
Sky Johnson 2025-03-29 09:14:32 -05:00
parent 6fadf26bea
commit 8a3515ed06
2 changed files with 76 additions and 253 deletions

View File

@ -22,14 +22,13 @@ type Config struct {
LibDirs []string LibDirs []string
// Performance settings // Performance settings
BufferSize int PoolSize int // Number of Lua states in the pool
PoolSize int // Number of Lua states in the pool
// Feature flags // Feature flags
HTTPLoggingEnabled bool HTTPLoggingEnabled bool
Watchers map[string]bool Watchers map[string]bool
// Raw values map for backward compatibility and custom values // Raw values map for all values including custom ones
values map[string]any values map[string]any
} }
@ -65,8 +64,8 @@ func New() *Config {
// Load loads configuration from a Lua file // Load loads configuration from a Lua file
func Load(filePath string) (*Config, error) { func Load(filePath string) (*Config, error) {
// Create a new Lua state // Create a new Lua state without standard libraries
state := luajit.New() state := luajit.New(false)
if state == nil { if state == nil {
return nil, errors.New("failed to create Lua state") return nil, errors.New("failed to create Lua state")
} }
@ -80,7 +79,7 @@ func Load(filePath string) (*Config, error) {
return nil, fmt.Errorf("failed to load config file: %w", err) return nil, fmt.Errorf("failed to load config file: %w", err)
} }
// Extract values from the Lua state // Extract all globals from the Lua state
if err := extractGlobals(state, config); err != nil { if err := extractGlobals(state, config); err != nil {
return nil, err return nil, err
} }
@ -90,85 +89,38 @@ func Load(filePath string) (*Config, error) {
// extractGlobals extracts global variables from the Lua state // extractGlobals extracts global variables from the Lua state
func extractGlobals(state *luajit.State, config *Config) error { func extractGlobals(state *luajit.State, config *Config) error {
// Get the globals table (_G) // Get the globals table
state.GetGlobal("_G") state.GetGlobal("_G")
if !state.IsTable(-1) { if !state.IsTable(-1) {
state.Pop(1) state.Pop(1)
return errors.New("failed to get globals table") return errors.New("failed to get globals table")
} }
// Pre-populate with standard globals for reference checking
stdGlobals := map[string]bool{
"_G": true, "_VERSION": true, "assert": true, "collectgarbage": true,
"coroutine": true, "debug": true, "dofile": true, "error": true,
"getmetatable": true, "io": true, "ipairs": true, "load": true,
"loadfile": true, "loadstring": true, "math": true, "next": true,
"os": true, "package": true, "pairs": true, "pcall": true,
"print": true, "rawequal": true, "rawget": true, "rawset": true,
"require": true, "select": true, "setmetatable": true, "string": true,
"table": true, "tonumber": true, "tostring": true, "type": true,
"unpack": true, "xpcall": true,
// LuaJIT specific globals
"jit": true, "bit": true, "ffi": true, "bit32": true,
}
// First, let's get the original globals to compare with user globals
originalGlobals := make(map[string]bool)
// Execute empty Lua state to get standard globals
emptyState := luajit.New()
if emptyState != nil {
defer emptyState.Close()
emptyState.GetGlobal("_G")
emptyState.PushNil() // Start iteration
for emptyState.Next(-2) {
if emptyState.IsString(-2) {
key := emptyState.ToString(-2)
originalGlobals[key] = true
}
emptyState.Pop(1) // Pop value, leave key for next iteration
}
emptyState.Pop(1) // Pop _G
}
// Iterate through the globals table // Iterate through the globals table
state.PushNil() // Start iteration state.PushNil() // Start iteration
for state.Next(-2) { for state.Next(-2) {
// Stack now has key at -2 and value at -1 // Stack now has key at -2 and value at -1
// Get key as string // Skip non-string keys
if !state.IsString(-2) { if !state.IsString(-2) {
state.Pop(1) // Pop value, leave key for next iteration state.Pop(1) // Pop value, leave key for next iteration
continue continue
} }
// Get key and value type
key := state.ToString(-2) key := state.ToString(-2)
// Skip standard Lua globals, but only if they're not overridden by user
// (standard globals will be functions or tables, user values usually aren't)
valueType := state.GetType(-1) valueType := state.GetType(-1)
// Skip functions, userdata, and threads regardless of origin // Skip functions, userdata, and threads - we don't need these
if valueType == luajit.TypeFunction || valueType == luajit.TypeUserData || valueType == luajit.TypeThread { if valueType == luajit.TypeFunction ||
valueType == luajit.TypeUserData ||
valueType == luajit.TypeThread {
state.Pop(1) state.Pop(1)
continue continue
} }
// For known Lua globals, we need to see if they're the original or user-defined // Process the value
if stdGlobals[key] { processConfigValue(state, config, key)
// For simple value types, assume user-defined
if valueType == luajit.TypeBoolean || valueType == luajit.TypeNumber || valueType == luajit.TypeString {
// These are probably user values with standard names
} else if originalGlobals[key] {
// If it's in the original globals and not a simple type, skip it
state.Pop(1)
continue
}
}
// Process based on key and type
processConfigValue(state, config, key, valueType)
} }
// Pop the globals table // Pop the globals table
@ -178,30 +130,10 @@ func extractGlobals(state *luajit.State, config *Config) error {
} }
// processConfigValue processes a specific config value from Lua // processConfigValue processes a specific config value from Lua
func processConfigValue(state *luajit.State, config *Config, key string, valueType luajit.LuaType) { func processConfigValue(state *luajit.State, config *Config, key string) {
// Store in the values map first (for backward compatibility) // Get the value as its natural type
var value any value, err := state.ToValue(-1)
if err != nil {
// Extract the value based on its type
switch valueType {
case luajit.TypeBoolean:
value = state.ToBoolean(-1)
case luajit.TypeNumber:
value = state.ToNumber(-1)
case luajit.TypeString:
value = state.ToString(-1)
case luajit.TypeTable:
// For tables, use the existing conversion logic
if table, err := state.ToTable(-1); err == nil {
value = table
// Special case for watchers table
if key == "watchers" {
processWatchersTable(config, table)
}
}
default:
// Skip unsupported types
state.Pop(1) state.Pop(1)
return return
} }
@ -209,112 +141,101 @@ func processConfigValue(state *luajit.State, config *Config, key string, valueTy
// Store in the values map // Store in the values map
config.values[key] = value config.values[key] = value
// Now set specific struct fields based on key // Process special cases and config fields
switch key { switch key {
case "log_level": case "log_level":
if strVal, ok := value.(string); ok { if strVal, ok := luajit.ConvertValue[string](value); ok {
config.LogLevel = strVal config.LogLevel = strVal
} }
case "port": case "port":
if numVal, ok := value.(float64); ok { if intVal, ok := luajit.ConvertValue[int](value); ok {
config.Port = int(numVal) config.Port = intVal
} }
case "debug": case "debug":
if boolVal, ok := value.(bool); ok { if boolVal, ok := luajit.ConvertValue[bool](value); ok {
config.Debug = boolVal config.Debug = boolVal
} }
case "routes_dir": case "routes_dir":
if strVal, ok := value.(string); ok { if strVal, ok := luajit.ConvertValue[string](value); ok {
config.RoutesDir = strVal config.RoutesDir = strVal
} }
case "static_dir": case "static_dir":
if strVal, ok := value.(string); ok { if strVal, ok := luajit.ConvertValue[string](value); ok {
config.StaticDir = strVal config.StaticDir = strVal
} }
case "override_dir": case "override_dir":
if strVal, ok := value.(string); ok { if strVal, ok := luajit.ConvertValue[string](value); ok {
config.OverrideDir = strVal config.OverrideDir = strVal
} }
case "pool_size": case "pool_size":
if numVal, ok := value.(float64); ok { if intVal, ok := luajit.ConvertValue[int](value); ok {
config.PoolSize = int(numVal) config.PoolSize = intVal
} }
case "http_logging_enabled": case "http_logging_enabled":
if boolVal, ok := value.(bool); ok { if boolVal, ok := luajit.ConvertValue[bool](value); ok {
config.HTTPLoggingEnabled = boolVal config.HTTPLoggingEnabled = boolVal
} }
case "watchers":
if table, ok := value.(map[string]any); ok {
for k, v := range table {
if boolVal, ok := luajit.ConvertValue[bool](v); ok {
config.Watchers[k] = boolVal
}
}
}
case "lib_dirs": case "lib_dirs":
// Handle lib_dirs array if dirs := extractStringArray(value); len(dirs) > 0 {
processLibDirs(config, value) config.LibDirs = dirs
}
} }
state.Pop(1) // Pop value, leave key for next iteration state.Pop(1) // Pop value, leave key for next iteration
} }
// processWatchersTable processes the watchers table configuration // extractStringArray extracts a string array from various possible formats
func processWatchersTable(config *Config, watchersTable map[string]any) { func extractStringArray(value any) []string {
for key, value := range watchersTable {
if boolVal, ok := value.(bool); ok {
config.Watchers[key] = boolVal
}
}
}
// processLibDirs processes the lib_dirs array configuration
func processLibDirs(config *Config, value any) {
// Check if it's a direct array // Check if it's a direct array
if arr, ok := value.([]any); ok { if arr, ok := value.([]any); ok {
result := make([]string, 0, len(arr)) result := make([]string, 0, len(arr))
for _, v := range arr { for _, v := range arr {
if str, ok := v.(string); ok { if str, ok := luajit.ConvertValue[string](v); ok {
result = append(result, str) result = append(result, str)
} }
} }
if len(result) > 0 { return result
config.LibDirs = result
}
return
} }
// Check if it's in our special array format (map with empty key) // Check if it's in special array format (map with empty key)
valueMap, ok := value.(map[string]any) valueMap, ok := value.(map[string]any)
if !ok { if !ok {
return return nil
} }
arr, ok := valueMap[""] arr, ok := valueMap[""]
if !ok { if !ok {
return return nil
} }
// Handle array format // Handle different array types
if strArray := extractStringArray(arr); len(strArray) > 0 {
config.LibDirs = strArray
}
}
// extractStringArray extracts a string array from various possible formats
func extractStringArray(arr any) []string {
// Check different possible array formats
switch arr := arr.(type) { switch arr := arr.(type) {
case []string: case []string:
return arr return arr
case []any: case []any:
result := make([]string, 0, len(arr)) result := make([]string, 0, len(arr))
for _, v := range arr { for _, v := range arr {
if str, ok := v.(string); ok { if str, ok := luajit.ConvertValue[string](v); ok {
result = append(result, str) result = append(result, str)
} }
} }
return result return result
case []float64: case []float64:
// Unlikely but handle numeric arrays too
result := make([]string, 0, len(arr)) result := make([]string, 0, len(arr))
for _, v := range arr { for _, v := range arr {
result = append(result, fmt.Sprintf("%g", v)) result = append(result, fmt.Sprintf("%g", v))
} }
return result return result
} }
return nil return nil
} }
@ -325,7 +246,7 @@ func (c *Config) Get(key string) any {
// GetString returns a string configuration value // GetString returns a string configuration value
func (c *Config) GetString(key string, defaultValue string) string { func (c *Config) GetString(key string, defaultValue string) string {
// Check for specific struct fields first // Check for specific struct fields first for better performance
switch key { switch key {
case "log_level": case "log_level":
return c.LogLevel return c.LogLevel
@ -343,22 +264,20 @@ func (c *Config) GetString(key string, defaultValue string) string {
return defaultValue return defaultValue
} }
str, ok := value.(string) result, ok := luajit.ConvertValue[string](value)
if !ok { if !ok {
return defaultValue return defaultValue
} }
return str return result
} }
// GetInt returns an integer configuration value // GetInt returns an integer configuration value
func (c *Config) GetInt(key string, defaultValue int) int { func (c *Config) GetInt(key string, defaultValue int) int {
// Check for specific struct fields first // Check for specific struct fields first for better performance
switch key { switch key {
case "port": case "port":
return c.Port return c.Port
case "buffer_size":
return c.BufferSize
case "pool_size": case "pool_size":
return c.PoolSize return c.PoolSize
} }
@ -369,15 +288,12 @@ func (c *Config) GetInt(key string, defaultValue int) int {
return defaultValue return defaultValue
} }
// Handle both int and float64 (which is what Lua numbers become in Go) result, ok := luajit.ConvertValue[int](value)
switch v := value.(type) { if !ok {
case int:
return v
case float64:
return int(v)
default:
return defaultValue return defaultValue
} }
return result
} }
// GetFloat returns a float configuration value // GetFloat returns a float configuration value
@ -387,20 +303,17 @@ func (c *Config) GetFloat(key string, defaultValue float64) float64 {
return defaultValue return defaultValue
} }
// Handle both float64 and int result, ok := luajit.ConvertValue[float64](value)
switch v := value.(type) { if !ok {
case float64:
return v
case int:
return float64(v)
default:
return defaultValue return defaultValue
} }
return result
} }
// GetBool returns a boolean configuration value // GetBool returns a boolean configuration value
func (c *Config) GetBool(key string, defaultValue bool) bool { func (c *Config) GetBool(key string, defaultValue bool) bool {
// Check for specific struct fields first // Check for specific struct fields first for better performance
switch key { switch key {
case "debug": case "debug":
return c.Debug return c.Debug
@ -409,12 +322,13 @@ func (c *Config) GetBool(key string, defaultValue bool) bool {
} }
// Special case for watcher settings // Special case for watcher settings
if key == "watchers.routes" { if len(key) > 9 && key[:9] == "watchers." {
return c.Watchers["routes"] watcherKey := key[9:]
} else if key == "watchers.static" { val, ok := c.Watchers[watcherKey]
return c.Watchers["static"] if ok {
} else if key == "watchers.modules" { return val
return c.Watchers["modules"] }
return defaultValue
} }
// Fall back to values map for other keys // Fall back to values map for other keys
@ -423,19 +337,19 @@ func (c *Config) GetBool(key string, defaultValue bool) bool {
return defaultValue return defaultValue
} }
boolValue, ok := value.(bool) result, ok := luajit.ConvertValue[bool](value)
if !ok { if !ok {
return defaultValue return defaultValue
} }
return boolValue return result
} }
// GetMap returns a map configuration value // GetMap returns a map configuration value
func (c *Config) GetMap(key string) map[string]any { func (c *Config) GetMap(key string) map[string]any {
// Special case for watchers // Special case for watchers
if key == "watchers" { if key == "watchers" {
result := make(map[string]any) result := make(map[string]any, len(c.Watchers))
for k, v := range c.Watchers { for k, v := range c.Watchers {
result[k] = v result[k] = v
} }
@ -476,7 +390,7 @@ func (c *Config) GetArray(key string) []any {
return arr return arr
} }
// Arrays in Lua might also be represented as maps with an empty string key // Arrays in Lua might be represented as maps with an empty string key
valueMap, ok := value.(map[string]any) valueMap, ok := value.(map[string]any)
if !ok { if !ok {
return nil return nil
@ -507,51 +421,6 @@ func (c *Config) GetArray(key string) []any {
return anyArr return anyArr
} }
// GetIntArray returns an array of integers from a Lua array
func (c *Config) GetIntArray(key string) []int {
value := c.Get(key)
if value == nil {
return nil
}
// Direct array case
if arr, ok := value.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
if num, ok := v.(float64); ok {
result = append(result, int(num))
}
}
return result
}
// Arrays in Lua might also be represented as maps with an empty string key
valueMap, ok := value.(map[string]any)
if !ok {
return nil
}
// The array data is stored with an empty key
arr, ok := valueMap[""]
if !ok {
return nil
}
// For numeric arrays, LuaJIT returns []float64
floatArr, ok := arr.([]float64)
if !ok {
return nil
}
// Convert to int slice
result := make([]int, len(floatArr))
for i, v := range floatArr {
result[i] = int(v)
}
return result
}
// GetStringArray returns an array of strings from a Lua array // GetStringArray returns an array of strings from a Lua array
func (c *Config) GetStringArray(key string) []string { func (c *Config) GetStringArray(key string) []string {
// Special case for lib_dirs // Special case for lib_dirs
@ -566,7 +435,7 @@ func (c *Config) GetStringArray(key string) []string {
result := make([]string, 0, len(arr)) result := make([]string, 0, len(arr))
for _, v := range arr { for _, v := range arr {
if str, ok := v.(string); ok { if str, ok := luajit.ConvertValue[string](v); ok {
result = append(result, str) result = append(result, str)
} }
} }
@ -575,52 +444,6 @@ func (c *Config) GetStringArray(key string) []string {
} }
// Values returns all configuration values // Values returns all configuration values
// Note: The returned map should not be modified
func (c *Config) Values() map[string]any { func (c *Config) Values() map[string]any {
return c.values return c.values
} }
// Set sets a configuration value
func (c *Config) Set(key string, value any) {
c.values[key] = value
// Also update the struct field if applicable
switch key {
case "log_level":
if strVal, ok := value.(string); ok {
c.LogLevel = strVal
}
case "port":
if numVal, ok := value.(float64); ok {
c.Port = int(numVal)
} else if intVal, ok := value.(int); ok {
c.Port = intVal
}
case "debug":
if boolVal, ok := value.(bool); ok {
c.Debug = boolVal
}
case "routes_dir":
if strVal, ok := value.(string); ok {
c.RoutesDir = strVal
}
case "static_dir":
if strVal, ok := value.(string); ok {
c.StaticDir = strVal
}
case "override_dir":
if strVal, ok := value.(string); ok {
c.OverrideDir = strVal
}
case "pool_size":
if numVal, ok := value.(float64); ok {
c.PoolSize = int(numVal)
} else if intVal, ok := value.(int); ok {
c.PoolSize = intVal
}
case "http_logging_enabled":
if boolVal, ok := value.(bool); ok {
c.HTTPLoggingEnabled = boolVal
}
}
}

2
luajit

@ -1 +1 @@
Subproject commit 0756cabcaaf1e33f2b8eb535e5a24e448c2501a9 Subproject commit a2b4b1c9272f849d9c1c913366f822e0be904ba2