package sqlite import ( "context" "fmt" "path/filepath" "strings" "sync" "time" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" "Moonshark/logger" "git.sharkk.net/Go/Color" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) var ( dbPools = make(map[string]*sqlitex.Pool) poolsMu sync.RWMutex dataDir string poolSize = 8 connTimeout = 5 * time.Second // Per-state connection cache stateConns = make(map[string]*stateConn) stateConnsMu sync.RWMutex ) // stateConn tracks a connection and its origin pool type stateConn struct { conn *sqlite.Conn pool *sqlitex.Pool } func InitSQLite(dir string) { dataDir = dir logger.Infof("SQLite is g2g! %s", color.Yellow(dir)) } func SetSQLitePoolSize(size int) { if size > 0 { poolSize = size } } func CleanupSQLite() { poolsMu.Lock() defer poolsMu.Unlock() // Return all cached connections to their pools stateConnsMu.Lock() for _, sc := range stateConns { if sc.pool != nil && sc.conn != nil { sc.pool.Put(sc.conn) } } stateConns = make(map[string]*stateConn) stateConnsMu.Unlock() for name, pool := range dbPools { if err := pool.Close(); err != nil { logger.Errorf("Failed to close database %s: %v", name, err) } } dbPools = make(map[string]*sqlitex.Pool) logger.Debugf("SQLite connections closed") } func getPool(dbName string) (*sqlitex.Pool, error) { dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, fmt.Errorf("invalid database name") } poolsMu.RLock() pool, exists := dbPools[dbName] if exists { poolsMu.RUnlock() return pool, nil } poolsMu.RUnlock() poolsMu.Lock() defer poolsMu.Unlock() if pool, exists = dbPools[dbName]; exists { return pool, nil } dbPath := filepath.Join(dataDir, dbName+".db") pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ PoolSize: poolSize, PrepareConn: func(conn *sqlite.Conn) error { pragmas := []string{ "PRAGMA journal_mode = WAL", "PRAGMA synchronous = NORMAL", "PRAGMA cache_size = 1000", "PRAGMA foreign_keys = ON", "PRAGMA temp_store = MEMORY", } for _, pragma := range pragmas { if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil { return err } } return nil }, }) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } dbPools[dbName] = pool logger.Debugf("Created SQLite pool for %s (size: %d)", dbName, poolSize) return pool, nil } // getStateConnection gets or creates a reusable connection for the state+db func getStateConnection(stateIndex int, dbName string) (*sqlite.Conn, error) { connKey := fmt.Sprintf("%d-%s", stateIndex, dbName) stateConnsMu.RLock() sc, exists := stateConns[connKey] stateConnsMu.RUnlock() if exists && sc.conn != nil { return sc.conn, nil } // Get new connection from pool pool, err := getPool(dbName) if err != nil { return nil, err } ctx, cancel := context.WithTimeout(context.Background(), connTimeout) defer cancel() conn, err := pool.Take(ctx) if err != nil { return nil, fmt.Errorf("connection timeout: %w", err) } // Cache it with pool reference stateConnsMu.Lock() stateConns[connKey] = &stateConn{ conn: conn, pool: pool, } stateConnsMu.Unlock() return conn, nil } func sqlQuery(state *luajit.State) int { if err := state.CheckMinArgs(3); err != nil { return state.PushError("sqlite.query: %v", err) } dbName, err := state.SafeToString(1) if err != nil { return state.PushError("sqlite.query: database name must be string") } query, err := state.SafeToString(2) if err != nil { return state.PushError("sqlite.query: query must be string") } stateIndex := int(state.ToNumber(-1)) conn, err := getStateConnection(stateIndex, dbName) if err != nil { return state.PushError("sqlite.query: %v", err) } var execOpts sqlitex.ExecOptions rows := make([]any, 0, 16) if state.GetTop() >= 4 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.query: %v", err) } } execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { row := make(map[string]any) colCount := stmt.ColumnCount() for i := range colCount { colName := stmt.ColumnName(i) switch stmt.ColumnType(i) { case sqlite.TypeInteger: row[colName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: row[colName] = stmt.ColumnFloat(i) case sqlite.TypeText: row[colName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) if blobSize > 0 { buf := make([]byte, blobSize) row[colName] = stmt.ColumnBytes(i, buf) } else { row[colName] = []byte{} } case sqlite.TypeNull: row[colName] = nil } } rows = append(rows, row) return nil } if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.query: %v", err) } if err := state.PushValue(rows); err != nil { return state.PushError("sqlite.query: %v", err) } return 1 } func sqlExec(state *luajit.State) int { if err := state.CheckMinArgs(3); err != nil { return state.PushError("sqlite.exec: %v", err) } dbName, err := state.SafeToString(1) if err != nil { return state.PushError("sqlite.exec: database name must be string") } query, err := state.SafeToString(2) if err != nil { return state.PushError("sqlite.exec: query must be string") } stateIndex := int(state.ToNumber(-1)) conn, err := getStateConnection(stateIndex, dbName) if err != nil { return state.PushError("sqlite.exec: %v", err) } hasParams := state.GetTop() >= 4 && !state.IsNil(3) if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { return state.PushError("sqlite.exec: %v", err) } state.PushNumber(float64(conn.Changes())) return 1 } if !hasParams { if err := sqlitex.Execute(conn, query, nil); err != nil { return state.PushError("sqlite.exec: %v", err) } state.PushNumber(float64(conn.Changes())) return 1 } var execOpts sqlitex.ExecOptions if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.exec: %v", err) } if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.exec: %v", err) } state.PushNumber(float64(conn.Changes())) return 1 } func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { if state.IsTable(paramIndex) { paramsAny, err := state.ToTable(paramIndex) if err != nil { return fmt.Errorf("invalid parameters: %w", err) } // Handle direct array types if arrParams, ok := paramsAny.([]any); ok { execOpts.Args = arrParams return nil } if strArr, ok := paramsAny.([]string); ok { args := make([]any, len(strArr)) for i, v := range strArr { args[i] = v } execOpts.Args = args return nil } if floatArr, ok := paramsAny.([]float64); ok { args := make([]any, len(floatArr)) for i, v := range floatArr { args[i] = v } execOpts.Args = args return nil } params, ok := paramsAny.(map[string]any) if !ok { return fmt.Errorf("unsupported parameter type: %T", paramsAny) } // Check for array-style parameters (empty string key indicates array) if arr, ok := params[""]; ok { if arrParams, ok := arr.([]any); ok { execOpts.Args = arrParams } else if floatArr, ok := arr.([]float64); ok { args := make([]any, len(floatArr)) for i, v := range floatArr { args[i] = v } execOpts.Args = args } } else { // Named parameters named := make(map[string]any, len(params)) for k, v := range params { if len(k) > 0 && k[0] != ':' { named[":"+k] = v } else { named[k] = v } } execOpts.Named = named } } else { // Multiple individual parameters count := state.GetTop() - 2 args := make([]any, count) for i := range count { idx := i + 3 val, err := state.ToValue(idx) if err != nil { return fmt.Errorf("invalid parameter %d: %w", i+1, err) } args[i] = val } execOpts.Args = args } return nil } func sqlGetOne(state *luajit.State) int { if err := state.CheckMinArgs(3); err != nil { return state.PushError("sqlite.get_one: %v", err) } dbName, err := state.SafeToString(1) if err != nil { return state.PushError("sqlite.get_one: database name must be string") } query, err := state.SafeToString(2) if err != nil { return state.PushError("sqlite.get_one: query must be string") } stateIndex := int(state.ToNumber(-1)) conn, err := getStateConnection(stateIndex, dbName) if err != nil { return state.PushError("sqlite.get_one: %v", err) } var execOpts sqlitex.ExecOptions var result map[string]any // Check if params provided (before state index) if state.GetTop() >= 4 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.get_one: %v", err) } } execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { if result != nil { return nil } result = make(map[string]any) colCount := stmt.ColumnCount() for i := range colCount { colName := stmt.ColumnName(i) switch stmt.ColumnType(i) { case sqlite.TypeInteger: result[colName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: result[colName] = stmt.ColumnFloat(i) case sqlite.TypeText: result[colName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) if blobSize > 0 { buf := make([]byte, blobSize) result[colName] = stmt.ColumnBytes(i, buf) } else { result[colName] = []byte{} } case sqlite.TypeNull: result[colName] = nil } } return nil } if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.get_one: %v", err) } if result == nil { state.PushNil() } else { if err := state.PushValue(result); err != nil { return state.PushError("sqlite.get_one: %v", err) } } return 1 } // CleanupStateConnection releases all connections for a specific state func CleanupStateConnection(stateIndex int) { stateConnsMu.Lock() defer stateConnsMu.Unlock() statePrefix := fmt.Sprintf("%d-", stateIndex) for key, sc := range stateConns { if strings.HasPrefix(key, statePrefix) { if sc.pool != nil && sc.conn != nil { sc.pool.Put(sc.conn) } delete(stateConns, key) } } } func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err } if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil { return err } if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil { return err } return nil }