package runner import ( "context" "errors" "fmt" "path/filepath" "strings" "sync" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" "Moonshark/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // SQLiteConnection tracks an active connection type SQLiteConnection struct { Conn *sqlite.Conn Pool *sqlitex.Pool } // SQLiteManager handles database connections type SQLiteManager struct { mu sync.RWMutex pools map[string]*sqlitex.Pool activeConns map[string]*SQLiteConnection dataDir string } var sqliteManager *SQLiteManager // InitSQLite initializes the SQLite manager func InitSQLite(dataDir string) { sqliteManager = &SQLiteManager{ pools: make(map[string]*sqlitex.Pool), activeConns: make(map[string]*SQLiteConnection), dataDir: dataDir, } logger.Server("SQLite initialized with data directory: %s", dataDir) } // CleanupSQLite closes all database connections func CleanupSQLite() { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() for id, conn := range sqliteManager.activeConns { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, id) } for name, pool := range sqliteManager.pools { if err := pool.Close(); err != nil { logger.Error("Failed to close database %s: %v", name, err) } } sqliteManager.pools = nil sqliteManager.activeConns = nil logger.Debug("SQLite connections closed") } // ReleaseActiveConnections returns all active connections to their pools func ReleaseActiveConnections(state *luajit.State) { if sqliteManager == nil { return } // Get active connections table from Lua state.GetGlobal("__active_sqlite_connections") if !state.IsTable(-1) { state.Pop(1) return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Iterate through active connections state.PushNil() // Start iteration for state.Next(-2) { if state.IsTable(-1) { state.GetField(-1, "id") if state.IsString(-1) { connID := state.ToString(-1) if conn, exists := sqliteManager.activeConns[connID]; exists { if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, connID) } } state.Pop(1) // Pop connection id } state.Pop(1) // Pop value, leave key for next iteration } // Clear the active connections table state.PushNil() state.SetGlobal("__active_sqlite_connections") } // getConnection returns a connection for the database func getConnection(dbName, connID string) (*sqlite.Conn, error) { if sqliteManager == nil { return nil, errors.New("SQLite not initialized") } // Validate database name dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, errors.New("invalid database name") } // Check for existing connection sqliteManager.mu.RLock() conn, exists := sqliteManager.activeConns[connID] if exists { sqliteManager.mu.RUnlock() return conn.Conn, nil } sqliteManager.mu.RUnlock() // Get or create pool under write lock sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() // Double-check if a connection was created while waiting for lock if conn, exists = sqliteManager.activeConns[connID]; exists { return conn.Conn, nil } // Get or create pool pool, exists := sqliteManager.pools[dbName] if !exists { dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") var err error pool, err = sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } sqliteManager.pools[dbName] = pool } // Get a connection dbConn, err := pool.Take(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get connection from pool: %w", err) } // Store connection sqliteManager.activeConns[connID] = &SQLiteConnection{ Conn: dbConn, Pool: pool, } return dbConn, nil } // releaseConnection returns a connection to its pool func releaseConnection(connID string) { if sqliteManager == nil { return } sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() conn, exists := sqliteManager.activeConns[connID] if !exists { return } if conn.Pool != nil { conn.Pool.Put(conn.Conn) } delete(sqliteManager.activeConns, connID) } // sqlQuery executes a SQL query and returns results func sqlQuery(state *luajit.State) int { // Get required parameters if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { state.PushString("sqlite.query: requires database name and query") return -1 } dbName := state.ToString(1) query := state.ToString(2) connID := fmt.Sprintf("temp_%p", &query) // Get connection conn, err := getConnection(dbName, connID) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } defer releaseConnection(connID) // Create execution options var execOpts sqlitex.ExecOptions rows := make([]map[string]any, 0, 16) // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if state.IsTable(3) { params, err := state.ToTable(3) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error())) return -1 } // Check for array-style params if arr, ok := params[""]; ok { if arrParams, ok := arr.([]any); ok { execOpts.Args = arrParams } else if floatArr, ok := arr.([]float64); ok { args := make([]any, len(floatArr)) for i, v := range floatArr { args[i] = v } execOpts.Args = args } } else { // Named parameters named := make(map[string]any, len(params)) for k, v := range params { if len(k) > 0 && k[0] != ':' { named[":"+k] = v } else { named[k] = v } } execOpts.Named = named } } else { // Positional parameters count := state.GetTop() - 2 args := make([]any, count) for i := range count { idx := i + 3 switch state.GetType(idx) { case luajit.TypeNumber: args[i] = state.ToNumber(idx) case luajit.TypeString: args[i] = state.ToString(idx) case luajit.TypeBoolean: args[i] = state.ToBoolean(idx) case luajit.TypeNil: args[i] = nil default: val, err := state.ToValue(idx) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error())) return -1 } args[i] = val } } execOpts.Args = args } } // Set up result function execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { row := make(map[string]any) colCount := stmt.ColumnCount() for i := range colCount { colName := stmt.ColumnName(i) switch stmt.ColumnType(i) { case sqlite.TypeInteger: row[colName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: row[colName] = stmt.ColumnFloat(i) case sqlite.TypeText: row[colName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) buf := make([]byte, blobSize) row[colName] = stmt.ColumnBytes(i, buf) case sqlite.TypeNull: row[colName] = nil } } rows = append(rows, row) return nil } // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } // Create result table state.NewTable() for i, row := range rows { state.PushNumber(float64(i + 1)) if err := state.PushTable(row); err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } state.SetTable(-3) } return 1 } // sqlExec executes a SQL statement without returning results func sqlExec(state *luajit.State) int { // Get required parameters if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { state.PushString("sqlite.exec: requires database name and query") return -1 } dbName := state.ToString(1) query := state.ToString(2) connID := fmt.Sprintf("temp_%p", &query) // Get connection conn, err := getConnection(dbName, connID) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } defer releaseConnection(connID) // Check if parameters are provided hasParams := state.GetTop() >= 3 && !state.IsNil(3) // Fast path for multi-statement scripts - use ExecScript if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } state.PushNumber(float64(conn.Changes())) return 1 } // Fast path for simple queries with no parameters if !hasParams { // Use Execute for simple statements without parameters if err := sqlitex.Execute(conn, query, nil); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } state.PushNumber(float64(conn.Changes())) return 1 } // Create execution options for parameterized query var execOpts sqlitex.ExecOptions // Set up parameters if state.IsTable(3) { params, err := state.ToTable(3) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error())) return -1 } // Check for array-style params if arr, ok := params[""]; ok { if arrParams, ok := arr.([]any); ok { execOpts.Args = arrParams } else if floatArr, ok := arr.([]float64); ok { args := make([]any, len(floatArr)) for i, v := range floatArr { args[i] = v } execOpts.Args = args } } else { // Named parameters named := make(map[string]any, len(params)) for k, v := range params { if len(k) > 0 && k[0] != ':' { named[":"+k] = v } else { named[k] = v } } execOpts.Named = named } } else { // Positional parameters count := state.GetTop() - 2 args := make([]any, count) for i := range count { idx := i + 3 switch state.GetType(idx) { case luajit.TypeNumber: args[i] = state.ToNumber(idx) case luajit.TypeString: args[i] = state.ToString(idx) case luajit.TypeBoolean: args[i] = state.ToBoolean(idx) case luajit.TypeNil: args[i] = nil default: val, err := state.ToValue(idx) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error())) return -1 } args[i] = val } } execOpts.Args = args } // Execute with parameters if err := sqlitex.Execute(conn, query, &execOpts); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } // Return affected rows state.PushNumber(float64(conn.Changes())) return 1 } // RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err } return state.RegisterGoFunction("__sqlite_exec", sqlExec) }