package runner import ( "errors" "fmt" "path/filepath" "strings" "sync" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" "Moonshark/core/utils/logger" "maps" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // SQLiteConnection tracks an active connection type SQLiteConnection struct { DbName string Conn *sqlite.Conn Pool *sqlitex.Pool } // SQLiteManager handles database connections type SQLiteManager struct { mu sync.RWMutex pools map[string]*sqlitex.Pool activeConns map[string]*SQLiteConnection dataDir string } // Global manager var sqliteManager *SQLiteManager // InitSQLite initializes the SQLite manager func InitSQLite(dataDir string) { sqliteManager = &SQLiteManager{ pools: make(map[string]*sqlitex.Pool), activeConns: make(map[string]*SQLiteConnection), dataDir: dataDir, } logger.Server("SQLite initialized with data directory: %s", dataDir) } // CleanupSQLite closes all database connections func CleanupSQLite() { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Release all active connections for id, conn := range sqliteManager.activeConns { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, id) } // Close all pools for name, pool := range sqliteManager.pools { if err := pool.Close(); err != nil { logger.Error("Failed to close database %s: %v", name, err) } } sqliteManager.pools = nil sqliteManager.activeConns = nil logger.Debug("SQLite connections closed") } // ReleaseActiveConnections returns all active connections to their pools func ReleaseActiveConnections(state *luajit.State) { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Get active connections table from Lua state.GetGlobal("__active_sqlite_connections") if !state.IsTable(-1) { state.Pop(1) return } // Iterate through active connections state.PushNil() // Start iteration for state.Next(-2) { // Stack now has key at -2 and value at -1 if state.IsTable(-1) { state.GetField(-1, "id") if state.IsString(-1) { connID := state.ToString(-1) // Release connection from Go side if conn, exists := sqliteManager.activeConns[connID]; exists { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, connID) } } state.Pop(1) // Pop connection id } state.Pop(1) // Pop value, leave key for next iteration } // Clear the active connections table state.PushNil() state.SetGlobal("__active_sqlite_connections") } // getPool returns a connection pool for the specified database func getPool(dbName string) (*sqlitex.Pool, error) { if sqliteManager == nil { return nil, errors.New("SQLite not initialized") } // Validate database name dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, errors.New("invalid database name") } // Check for existing pool sqliteManager.mu.RLock() pool, exists := sqliteManager.pools[dbName] sqliteManager.mu.RUnlock() if exists { return pool, nil } // Create new pool sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Double check if another goroutine created it if pool, exists = sqliteManager.pools[dbName]; exists { return pool, nil } // Create database file path dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") // Create the pool pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } sqliteManager.pools[dbName] = pool return pool, nil } // getConnection returns a connection from the pool func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) { // Check for existing connection first sqliteManager.mu.RLock() conn, exists := sqliteManager.activeConns[connID] sqliteManager.mu.RUnlock() if exists { return conn.Conn, conn.Pool, nil } // Get the pool pool, err := getPool(dbName) if err != nil { return nil, nil, err } // Get a connection dbConn := pool.Get(nil) if dbConn == nil { return nil, nil, errors.New("failed to get connection from pool") } // Store connection sqliteManager.mu.Lock() sqliteManager.activeConns[connID] = &SQLiteConnection{ DbName: dbName, Conn: dbConn, Pool: pool, } sqliteManager.mu.Unlock() return dbConn, pool, nil } // luaSQLQuery executes a SQL query and returns results to Lua func luaSQLQuery(state *luajit.State) int { // Get database name if !state.IsString(1) { state.PushString("sqlite.query: database name must be a string") return -1 } dbName := state.ToString(1) // Get query if !state.IsString(2) { state.PushString("sqlite.query: query must be a string") return -1 } query := state.ToString(2) // Get connection ID (optional for compatibility) var connID string if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { connID = state.ToString(4) } else { // Generate a temporary connection ID connID = fmt.Sprintf("temp_%p", &query) } // Get parameters (optional) var params map[string]any if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { var err error params, err = state.ToTable(3) if err != nil { state.PushString("sqlite.query: failed to parse parameters: " + err.Error()) return -1 } } // Get connection conn, pool, err := getConnection(dbName, connID) if err != nil { state.PushString("sqlite.query: " + err.Error()) return -1 } // For temporary connections, defer release if !strings.HasPrefix(connID, "temp_") { defer pool.Put(conn) // Remove from active connections sqliteManager.mu.Lock() delete(sqliteManager.activeConns, connID) sqliteManager.mu.Unlock() } // Execute query and collect results var rows []map[string]any err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ Named: params, // Using Named for named parameters ResultFunc: func(stmt *sqlite.Stmt) error { row := make(map[string]any) columnCount := stmt.ColumnCount() for i := range columnCount { columnName := stmt.ColumnName(i) columnType := stmt.ColumnType(i) switch columnType { case sqlite.TypeInteger: row[columnName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: row[columnName] = stmt.ColumnFloat(i) case sqlite.TypeText: row[columnName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) buf := make([]byte, blobSize) blob := stmt.ColumnBytes(i, buf) row[columnName] = blob case sqlite.TypeNull: row[columnName] = nil } } // Add row copy to results rowCopy := make(map[string]any, len(row)) maps.Copy(rowCopy, row) rows = append(rows, rowCopy) return nil }, }) if err != nil { state.PushString("sqlite.query: " + err.Error()) return -1 } // Create result table state.NewTable() // Add results to the table for i, row := range rows { state.PushNumber(float64(i + 1)) if err := state.PushTable(row); err != nil { state.PushString("sqlite.query: " + err.Error()) return -1 } state.SetTable(-3) } return 1 } // luaSQLExec executes a SQL statement without returning results func luaSQLExec(state *luajit.State) int { // Get database name and query if !state.IsString(1) { state.PushString("sqlite.exec: database name must be a string") return -1 } dbName := state.ToString(1) if !state.IsString(2) { state.PushString("sqlite.exec: query must be a string") return -1 } query := state.ToString(2) // Get connection ID (optional for compatibility) var connID string if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { connID = state.ToString(4) } else { // Generate a temporary connection ID connID = fmt.Sprintf("temp_%p", &query) } // Get parameters (optional) var params map[string]any if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { var err error params, err = state.ToTable(3) if err != nil { state.PushString("sqlite.exec: failed to parse parameters: " + err.Error()) return -1 } } // Get connection conn, pool, err := getConnection(dbName, connID) if err != nil { state.PushString("sqlite.exec: " + err.Error()) return -1 } // For temporary connections, defer release if !strings.HasPrefix(connID, "temp_") { defer pool.Put(conn) // Remove from active connections sqliteManager.mu.Lock() delete(sqliteManager.activeConns, connID) sqliteManager.mu.Unlock() } // Execute statement if params != nil { err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ Named: params, // Using Named for named parameters }) } else { err = sqlitex.ExecScript(conn, query) } if err != nil { state.PushString("sqlite.exec: " + err.Error()) return -1 } // Return number of affected rows state.PushNumber(float64(conn.Changes())) return 1 } // RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil { return err } if err := state.RegisterGoFunction("__sqlite_exec", luaSQLExec); err != nil { return err } return nil }