package runner import ( "context" "errors" "fmt" "path/filepath" "strings" "sync" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" "Moonshark/core/utils/logger" "maps" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // SQLiteConnection tracks an active connection type SQLiteConnection struct { DbName string Conn *sqlite.Conn Pool *sqlitex.Pool } // SQLiteManager handles database connections type SQLiteManager struct { mu sync.RWMutex pools map[string]*sqlitex.Pool activeConns map[string]*SQLiteConnection dataDir string } var sqliteManager *SQLiteManager // InitSQLite initializes the SQLite manager func InitSQLite(dataDir string) { sqliteManager = &SQLiteManager{ pools: make(map[string]*sqlitex.Pool), activeConns: make(map[string]*SQLiteConnection), dataDir: dataDir, } logger.Server("SQLite initialized with data directory: %s", dataDir) } // CleanupSQLite closes all database connections func CleanupSQLite() { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Release all connections and close pools for id, conn := range sqliteManager.activeConns { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, id) } for name, pool := range sqliteManager.pools { if err := pool.Close(); err != nil { logger.Error("Failed to close database %s: %v", name, err) } } sqliteManager.pools = nil sqliteManager.activeConns = nil logger.Debug("SQLite connections closed") } // ReleaseActiveConnections returns all active connections to their pools func ReleaseActiveConnections(state *luajit.State) { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Get active connections table from Lua state.GetGlobal("__active_sqlite_connections") if !state.IsTable(-1) { state.Pop(1) return } // Iterate through active connections state.PushNil() // Start iteration for state.Next(-2) { if state.IsTable(-1) { state.GetField(-1, "id") if state.IsString(-1) { connID := state.ToString(-1) if conn, exists := sqliteManager.activeConns[connID]; exists { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, connID) } } state.Pop(1) // Pop connection id } state.Pop(1) // Pop value, leave key for next iteration } // Clear the active connections table state.PushNil() state.SetGlobal("__active_sqlite_connections") } // getPool returns a connection pool for the specified database func getPool(dbName string) (*sqlitex.Pool, error) { if sqliteManager == nil { return nil, errors.New("SQLite not initialized") } // Validate database name dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, errors.New("invalid database name") } // Check for existing pool with read lock sqliteManager.mu.RLock() pool, exists := sqliteManager.pools[dbName] sqliteManager.mu.RUnlock() if exists { return pool, nil } // Create new pool with write lock sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Double check if another goroutine created it if pool, exists = sqliteManager.pools[dbName]; exists { return pool, nil } // Create database file path and pool dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } sqliteManager.pools[dbName] = pool return pool, nil } // getConnection returns a connection from the pool func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) { // Check for existing connection first sqliteManager.mu.RLock() conn, exists := sqliteManager.activeConns[connID] sqliteManager.mu.RUnlock() if exists { return conn.Conn, conn.Pool, nil } // Get the pool pool, err := getPool(dbName) if err != nil { return nil, nil, err } // Get a connection dbConn, err := pool.Take(context.Background()) if err != nil { return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err) } // Store connection sqliteManager.mu.Lock() sqliteManager.activeConns[connID] = &SQLiteConnection{ DbName: dbName, Conn: dbConn, Pool: pool, } sqliteManager.mu.Unlock() return dbConn, pool, nil } // processParams extracts parameters and connection ID from Lua state func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) { connID = defaultConnID // Check if using positional parameters if state.GetTop() >= 3 && !state.IsTable(3) { isPositional = true paramCount := state.GetTop() - 2 // Count all args after db and query // Check if last param is a connection ID lastIdx := paramCount + 2 // db(1) + query(2) + paramCount if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { connID = state.ToString(lastIdx) paramCount-- // Exclude connID from param count } // Create array for positional parameters positionalParams = make([]any, paramCount) // Collect all parameters for i := 0; i < paramCount; i++ { paramIdx := i + 3 // Params start at index 3 switch state.GetType(paramIdx) { case luajit.TypeNumber: positionalParams[i] = state.ToNumber(paramIdx) case luajit.TypeString: positionalParams[i] = state.ToString(paramIdx) case luajit.TypeBoolean: positionalParams[i] = state.ToBoolean(paramIdx) case luajit.TypeNil: positionalParams[i] = nil default: val, errConv := state.ToValue(paramIdx) if errConv != nil { return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv) } positionalParams[i] = val } } return nil, connID, isPositional, positionalParams, nil } // Named parameter handling if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { connID = state.ToString(4) } // Get table parameters if present if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { params, err = state.ToTable(3) } return params, connID, isPositional, nil, err } // prepareExecOptions prepares SQLite execution options based on parameters func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions { execOpts := &sqlitex.ExecOptions{} if params == nil && !isPositional { return execOpts } // Prepare parameters isArray := false var namedParams map[string]any var arrParams []any // Check for array parameters if m, ok := params.(map[string]any); ok { if arr, hasArray := m[""]; hasArray { isArray = true if slice, ok := arr.([]any); ok { arrParams = slice } else if floatSlice, ok := arr.([]float64); ok { arrParams = make([]any, len(floatSlice)) for i, v := range floatSlice { arrParams[i] = v } } } else { // Process named parameters namedParams = make(map[string]any, len(m)) for k, v := range m { if len(k) > 0 && k[0] != ':' { namedParams[":"+k] = v } else { namedParams[k] = v } } } } else if slice, ok := params.([]any); ok { isArray = true arrParams = slice } else if floatSlice, ok := params.([]float64); ok { isArray = true arrParams = make([]any, len(floatSlice)) for i, v := range floatSlice { arrParams[i] = v } } // Use positional params if explicitly provided if isPositional { arrParams = positionalParams isArray = true } // Limit positional params to actual placeholders if isArray && arrParams != nil { placeholderCount := strings.Count(query, "?") if len(arrParams) > placeholderCount { arrParams = arrParams[:placeholderCount] } execOpts.Args = arrParams } else if namedParams != nil { execOpts.Named = namedParams } return execOpts } // sqlOperation handles both query and exec operations func sqlOperation(state *luajit.State, isQuery bool) int { operation := "query" if !isQuery { operation = "exec" } // Get database name if !state.IsString(1) { state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation)) return -1 } dbName := state.ToString(1) // Get query if !state.IsString(2) { state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation)) return -1 } query := state.ToString(2) // Generate a temporary connection ID if needed defaultConnID := fmt.Sprintf("temp_%p", &query) // Process parameters and get connection ID params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID) if err != nil { state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) return -1 } // Get connection conn, pool, err := getConnection(dbName, connID) if err != nil { state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) return -1 } // For temporary connections, defer release if strings.HasPrefix(connID, "temp_") { defer func() { sqliteManager.mu.Lock() delete(sqliteManager.activeConns, connID) sqliteManager.mu.Unlock() pool.Put(conn) }() } // Prepare execution options execOpts := prepareExecOptions(query, params, isPositional, positionalParams) // Define rows slice outside the closure var rows []map[string]any // For queries, add result function if isQuery { execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { row := make(map[string]any) columnCount := stmt.ColumnCount() for i := range columnCount { columnName := stmt.ColumnName(i) switch stmt.ColumnType(i) { case sqlite.TypeInteger: row[columnName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: row[columnName] = stmt.ColumnFloat(i) case sqlite.TypeText: row[columnName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) buf := make([]byte, blobSize) row[columnName] = stmt.ColumnBytes(i, buf) case sqlite.TypeNull: row[columnName] = nil } } // Add row copy to results rowCopy := make(map[string]any, len(row)) maps.Copy(rowCopy, row) rows = append(rows, rowCopy) return nil } } // Execute query var execErr error if isQuery || execOpts.Args != nil || execOpts.Named != nil { execErr = sqlitex.Execute(conn, query, execOpts) } else { // Use ExecScript for queries without parameters execErr = sqlitex.ExecScript(conn, query) } if execErr != nil { state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error())) return -1 } // Return results for query, affected rows for exec if isQuery { // Create result table with rows state.NewTable() for i, row := range rows { state.PushNumber(float64(i + 1)) if err := state.PushTable(row); err != nil { state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) return -1 } state.SetTable(-3) } } else { // Return number of affected rows state.PushNumber(float64(conn.Changes())) } return 1 } // luaSQLQuery executes a SQL query and returns results to Lua func luaSQLQuery(state *luajit.State) int { return sqlOperation(state, true) } // luaSQLExec executes a SQL statement without returning results func luaSQLExec(state *luajit.State) int { return sqlOperation(state, false) } // RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil { return err } return state.RegisterGoFunction("__sqlite_exec", luaSQLExec) }