package runner import ( "context" "errors" "fmt" "path/filepath" "strings" "sync" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" "Moonshark/core/utils/logger" "maps" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // SQLiteConnection tracks an active connection type SQLiteConnection struct { DbName string Conn *sqlite.Conn Pool *sqlitex.Pool } // SQLiteManager handles database connections type SQLiteManager struct { mu sync.RWMutex pools map[string]*sqlitex.Pool activeConns map[string]*SQLiteConnection dataDir string } // Global manager var sqliteManager *SQLiteManager // InitSQLite initializes the SQLite manager func InitSQLite(dataDir string) { sqliteManager = &SQLiteManager{ pools: make(map[string]*sqlitex.Pool), activeConns: make(map[string]*SQLiteConnection), dataDir: dataDir, } logger.Server("SQLite initialized with data directory: %s", dataDir) } // CleanupSQLite closes all database connections func CleanupSQLite() { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Release all active connections for id, conn := range sqliteManager.activeConns { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, id) } // Close all pools for name, pool := range sqliteManager.pools { if err := pool.Close(); err != nil { logger.Error("Failed to close database %s: %v", name, err) } } sqliteManager.pools = nil sqliteManager.activeConns = nil logger.Debug("SQLite connections closed") } // ReleaseActiveConnections returns all active connections to their pools func ReleaseActiveConnections(state *luajit.State) { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Get active connections table from Lua state.GetGlobal("__active_sqlite_connections") if !state.IsTable(-1) { state.Pop(1) return } // Iterate through active connections state.PushNil() // Start iteration for state.Next(-2) { // Stack now has key at -2 and value at -1 if state.IsTable(-1) { state.GetField(-1, "id") if state.IsString(-1) { connID := state.ToString(-1) // Release connection from Go side if conn, exists := sqliteManager.activeConns[connID]; exists { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, connID) } } state.Pop(1) // Pop connection id } state.Pop(1) // Pop value, leave key for next iteration } // Clear the active connections table state.PushNil() state.SetGlobal("__active_sqlite_connections") } // getPool returns a connection pool for the specified database func getPool(dbName string) (*sqlitex.Pool, error) { if sqliteManager == nil { return nil, errors.New("SQLite not initialized") } // Validate database name dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, errors.New("invalid database name") } // Check for existing pool sqliteManager.mu.RLock() pool, exists := sqliteManager.pools[dbName] sqliteManager.mu.RUnlock() if exists { return pool, nil } // Create new pool sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Double check if another goroutine created it if pool, exists = sqliteManager.pools[dbName]; exists { return pool, nil } // Create database file path dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") // Create the pool pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } sqliteManager.pools[dbName] = pool return pool, nil } // getConnection returns a connection from the pool func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) { // Check for existing connection first sqliteManager.mu.RLock() conn, exists := sqliteManager.activeConns[connID] sqliteManager.mu.RUnlock() if exists { return conn.Conn, conn.Pool, nil } // Get the pool pool, err := getPool(dbName) if err != nil { return nil, nil, err } // Get a connection using the newer Take API dbConn, err := pool.Take(context.Background()) if err != nil { return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err) } // Store connection sqliteManager.mu.Lock() sqliteManager.activeConns[connID] = &SQLiteConnection{ DbName: dbName, Conn: dbConn, Pool: pool, } sqliteManager.mu.Unlock() return dbConn, pool, nil } // detectParamType determines if parameters are positional or named func detectParamType(params any) (isArray bool) { if params == nil { return false } // Check if it's a map[string]any if paramsMap, ok := params.(map[string]any); ok { // Check for the empty string key which indicates an array if array, hasArray := paramsMap[""]; hasArray { // Verify it's actually an array if _, isSlice := array.([]any); isSlice { return true } if _, isFloatSlice := array.([]float64); isFloatSlice { return true } } return false } // If it's already a slice type if _, ok := params.([]any); ok { return true } if _, ok := params.([]float64); ok { return true } return false } // prepareParams processes parameters for SQLite queries func prepareParams(params any) (map[string]any, []any) { if params == nil { return nil, nil } // Handle positional parameters (array-like) if detectParamType(params) { var positional []any // Extract array from special map format if paramsMap, ok := params.(map[string]any); ok { if array, hasArray := paramsMap[""]; hasArray { if slice, ok := array.([]any); ok { positional = slice } else if floatSlice, ok := array.([]float64); ok { // Convert []float64 to []any positional = make([]any, len(floatSlice)) for i, v := range floatSlice { positional[i] = v } } } } else if slice, ok := params.([]any); ok { positional = slice } else if floatSlice, ok := params.([]float64); ok { // Convert []float64 to []any positional = make([]any, len(floatSlice)) for i, v := range floatSlice { positional[i] = v } } return nil, positional } // Handle named parameters (map-like) if paramsMap, ok := params.(map[string]any); ok { modified := make(map[string]any, len(paramsMap)) for key, value := range paramsMap { if len(key) > 0 && key[0] != ':' { modified[":"+key] = value } else { modified[key] = value } } return modified, nil } return nil, nil } // luaSQLQuery executes a SQL query and returns results to Lua func luaSQLQuery(state *luajit.State) int { // Get database name if !state.IsString(1) { state.PushString("sqlite.query: database name must be a string") return -1 } dbName := state.ToString(1) // Get query if !state.IsString(2) { state.PushString("sqlite.query: query must be a string") return -1 } query := state.ToString(2) // Check if using positional parameters isPositional := false var positionalParams []any // Get connection ID (optional) var connID string // Check if we have positional parameters instead of a params table if state.GetTop() >= 3 && !state.IsTable(3) { isPositional = true paramCount := state.GetTop() - 2 // Count all args after db and query // Adjust connection ID index if we have positional params if paramCount > 0 { // Last parameter might be connID if it's a string lastIdx := paramCount + 2 // db(1) + query(2) + paramCount if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { connID = state.ToString(lastIdx) paramCount-- // Exclude connID from param count } } // Create array for positional parameters positionalParams = make([]any, paramCount) // Collect all parameters for i := 0; i < paramCount; i++ { paramIdx := i + 3 // Params start at index 3 // Convert to appropriate Go value var value any switch state.GetType(paramIdx) { case luajit.TypeNumber: value = state.ToNumber(paramIdx) case luajit.TypeString: value = state.ToString(paramIdx) case luajit.TypeBoolean: value = state.ToBoolean(paramIdx) case luajit.TypeNil: value = nil default: // Try to convert as generic value var err error value, err = state.ToValue(paramIdx) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error())) return -1 } } positionalParams[i] = value } } else { // Original named parameter table handling if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { connID = state.ToString(4) } else { // Generate a temporary connection ID connID = fmt.Sprintf("temp_%p", &query) } } // Get parameters (optional for named parameters) var params any if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { var err error params, err = state.ToTable(3) if err != nil { state.PushString("sqlite.query: failed to parse parameters: " + err.Error()) return -1 } } // Get connection conn, pool, err := getConnection(dbName, connID) if err != nil { state.PushString("sqlite.query: " + err.Error()) return -1 } // For temporary connections, defer release if strings.HasPrefix(connID, "temp_") { defer func() { // Release the connection sqliteManager.mu.Lock() delete(sqliteManager.activeConns, connID) sqliteManager.mu.Unlock() pool.Put(conn) }() } // Execute query and collect results var rows []map[string]any // Prepare params based on type namedParams, positional := prepareParams(params) // If we have direct positional params from function args, use those if isPositional { positional = positionalParams } // Count actual placeholders in the query placeholderCount := strings.Count(query, "?") // Execute with appropriate parameter type execOpts := &sqlitex.ExecOptions{ ResultFunc: func(stmt *sqlite.Stmt) error { row := make(map[string]any) columnCount := stmt.ColumnCount() for i := range columnCount { columnName := stmt.ColumnName(i) columnType := stmt.ColumnType(i) switch columnType { case sqlite.TypeInteger: row[columnName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: row[columnName] = stmt.ColumnFloat(i) case sqlite.TypeText: row[columnName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) buf := make([]byte, blobSize) blob := stmt.ColumnBytes(i, buf) row[columnName] = blob case sqlite.TypeNull: row[columnName] = nil } } // Add row copy to results rowCopy := make(map[string]any, len(row)) maps.Copy(rowCopy, row) rows = append(rows, rowCopy) return nil }, } // Set appropriate parameter type if namedParams != nil { execOpts.Named = namedParams } else if positional != nil { // Make sure we're not passing more positional parameters than placeholders if len(positional) > placeholderCount { positional = positional[:placeholderCount] } execOpts.Args = positional } err = sqlitex.Execute(conn, query, execOpts) if err != nil { state.PushString("sqlite.query: " + err.Error()) return -1 } // Create result table state.NewTable() // Add results to the table for i, row := range rows { state.PushNumber(float64(i + 1)) if err := state.PushTable(row); err != nil { state.PushString("sqlite.query: " + err.Error()) return -1 } state.SetTable(-3) } return 1 } // luaSQLExec executes a SQL statement without returning results func luaSQLExec(state *luajit.State) int { // Get database name and query if !state.IsString(1) { state.PushString("sqlite.exec: database name must be a string") return -1 } dbName := state.ToString(1) if !state.IsString(2) { state.PushString("sqlite.exec: query must be a string") return -1 } query := state.ToString(2) // Check if using positional parameters isPositional := false var positionalParams []any // Get connection ID (optional) var connID string // Check if we have positional parameters instead of a params table if state.GetTop() >= 3 && !state.IsTable(3) { isPositional = true paramCount := state.GetTop() - 2 // Count all args after db and query // Adjust connection ID index if we have positional params if paramCount > 0 { // Last parameter might be connID if it's a string lastIdx := paramCount + 2 // db(1) + query(2) + paramCount if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { connID = state.ToString(lastIdx) paramCount-- // Exclude connID from param count } } // Create array for positional parameters positionalParams = make([]any, paramCount) // Collect all parameters for i := 0; i < paramCount; i++ { paramIdx := i + 3 // Params start at index 3 // Convert to appropriate Go value var value any switch state.GetType(paramIdx) { case luajit.TypeNumber: value = state.ToNumber(paramIdx) case luajit.TypeString: value = state.ToString(paramIdx) case luajit.TypeBoolean: value = state.ToBoolean(paramIdx) case luajit.TypeNil: value = nil default: // Try to convert as generic value var err error value, err = state.ToValue(paramIdx) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error())) return -1 } } positionalParams[i] = value } } else { // Original named parameter table handling if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { connID = state.ToString(4) } else { // Generate a temporary connection ID connID = fmt.Sprintf("temp_%p", &query) } } // Get parameters (optional for named parameters) var params any if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { var err error params, err = state.ToTable(3) if err != nil { state.PushString("sqlite.exec: failed to parse parameters: " + err.Error()) return -1 } } // Get connection conn, pool, err := getConnection(dbName, connID) if err != nil { state.PushString("sqlite.exec: " + err.Error()) return -1 } // For temporary connections, defer release if strings.HasPrefix(connID, "temp_") { defer func() { // Release the connection sqliteManager.mu.Lock() delete(sqliteManager.activeConns, connID) sqliteManager.mu.Unlock() pool.Put(conn) }() } // Count actual placeholders in the query placeholderCount := strings.Count(query, "?") // Prepare params based on type namedParams, positional := prepareParams(params) // If we have direct positional params from function args, use those if isPositional { positional = positionalParams } // Ensure we don't pass more parameters than placeholders if positional != nil && len(positional) > placeholderCount { positional = positional[:placeholderCount] } // Execute with appropriate parameter type var execErr error if isPositional || positional != nil { // Execute with positional parameters execOpts := &sqlitex.ExecOptions{ Args: positional, } execErr = sqlitex.Execute(conn, query, execOpts) } else if namedParams != nil { // Execute with named parameters execOpts := &sqlitex.ExecOptions{ Named: namedParams, } execErr = sqlitex.Execute(conn, query, execOpts) } else { // Execute without parameters execErr = sqlitex.ExecScript(conn, query) } if execErr != nil { state.PushString("sqlite.exec: " + execErr.Error()) return -1 } // Return number of affected rows state.PushNumber(float64(conn.Changes())) return 1 } // RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil { return err } if err := state.RegisterGoFunction("__sqlite_exec", luaSQLExec); err != nil { return err } return nil }