package runner import ( "context" "fmt" "path/filepath" "strings" "sync" "time" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" "Moonshark/utils/color" "Moonshark/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) var ( dbPools = make(map[string]*sqlitex.Pool) poolsMu sync.RWMutex dataDir string poolSize = 8 // Default, will be set to match runner pool size connTimeout = 5 * time.Second ) // InitSQLite initializes the SQLite subsystem func InitSQLite(dir string) { dataDir = dir logger.Infof("SQLite is g2g! %s", color.Yellow(dir)) } // SetSQLitePoolSize sets the pool size to match the runner pool size func SetSQLitePoolSize(size int) { if size > 0 { poolSize = size } } // CleanupSQLite closes all database connections func CleanupSQLite() { poolsMu.Lock() defer poolsMu.Unlock() for name, pool := range dbPools { if err := pool.Close(); err != nil { logger.Errorf("Failed to close database %s: %v", name, err) } } dbPools = make(map[string]*sqlitex.Pool) logger.Debugf("SQLite connections closed") } // getPool returns a connection pool for the database func getPool(dbName string) (*sqlitex.Pool, error) { // Validate database name dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, fmt.Errorf("invalid database name") } // Check for existing pool poolsMu.RLock() pool, exists := dbPools[dbName] if exists { poolsMu.RUnlock() return pool, nil } poolsMu.RUnlock() // Create new pool under write lock poolsMu.Lock() defer poolsMu.Unlock() // Double-check if a pool was created while waiting for lock if pool, exists = dbPools[dbName]; exists { return pool, nil } // Create new pool with proper size dbPath := filepath.Join(dataDir, dbName+".db") pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ PoolSize: poolSize, PrepareConn: func(conn *sqlite.Conn) error { // Execute PRAGMA statements individually pragmas := []string{ "PRAGMA journal_mode = WAL", "PRAGMA synchronous = NORMAL", "PRAGMA cache_size = 1000", "PRAGMA foreign_keys = ON", "PRAGMA temp_store = MEMORY", } for _, pragma := range pragmas { if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil { return err } } return nil }, }) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } dbPools[dbName] = pool logger.Debugf("Created SQLite pool for %s (size: %d)", dbName, poolSize) return pool, nil } // sqlQuery executes a SQL query and returns results func sqlQuery(state *luajit.State) int { if err := state.CheckMinArgs(2); err != nil { return state.PushError("sqlite.query: %v", err) } dbName, err := state.SafeToString(1) if err != nil { return state.PushError("sqlite.query: database name must be string") } query, err := state.SafeToString(2) if err != nil { return state.PushError("sqlite.query: query must be string") } // Get pool pool, err := getPool(dbName) if err != nil { return state.PushError("sqlite.query: %v", err) } // Get connection with timeout ctx, cancel := context.WithTimeout(context.Background(), connTimeout) defer cancel() conn, err := pool.Take(ctx) if err != nil { return state.PushError("sqlite.query: connection timeout: %v", err) } defer pool.Put(conn) // Create execution options var execOpts sqlitex.ExecOptions rows := make([]map[string]any, 0, 16) // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.query: %v", err) } } // Set up result function execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { row := make(map[string]any) colCount := stmt.ColumnCount() for i := range colCount { colName := stmt.ColumnName(i) switch stmt.ColumnType(i) { case sqlite.TypeInteger: row[colName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: row[colName] = stmt.ColumnFloat(i) case sqlite.TypeText: row[colName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) if blobSize > 0 { buf := make([]byte, blobSize) row[colName] = stmt.ColumnBytes(i, buf) } else { row[colName] = []byte{} } case sqlite.TypeNull: row[colName] = nil } } rows = append(rows, row) return nil } // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.query: %v", err) } // Create result using specific map type and PushValue if err := state.PushValue(rows); err != nil { return state.PushError("sqlite.query: %v", err) } return 1 } // sqlExec executes a SQL statement without returning results func sqlExec(state *luajit.State) int { if err := state.CheckMinArgs(2); err != nil { return state.PushError("sqlite.exec: %v", err) } dbName, err := state.SafeToString(1) if err != nil { return state.PushError("sqlite.exec: database name must be string") } query, err := state.SafeToString(2) if err != nil { return state.PushError("sqlite.exec: query must be string") } // Get pool pool, err := getPool(dbName) if err != nil { return state.PushError("sqlite.exec: %v", err) } // Get connection with timeout ctx, cancel := context.WithTimeout(context.Background(), connTimeout) defer cancel() conn, err := pool.Take(ctx) if err != nil { return state.PushError("sqlite.exec: connection timeout: %v", err) } defer pool.Put(conn) // Check if parameters are provided hasParams := state.GetTop() >= 3 && !state.IsNil(3) // Fast path for multi-statement scripts if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { return state.PushError("sqlite.exec: %v", err) } state.PushNumber(float64(conn.Changes())) return 1 } // Fast path for simple queries with no parameters if !hasParams { if err := sqlitex.Execute(conn, query, nil); err != nil { return state.PushError("sqlite.exec: %v", err) } state.PushNumber(float64(conn.Changes())) return 1 } // Create execution options for parameterized query var execOpts sqlitex.ExecOptions if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.exec: %v", err) } // Execute with parameters if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.exec: %v", err) } // Return affected rows state.PushNumber(float64(conn.Changes())) return 1 } // setupParams configures execution options with parameters from Lua func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { if state.IsTable(paramIndex) { paramsAny, err := state.SafeToTable(paramIndex) if err != nil { return fmt.Errorf("invalid parameters: %w", err) } // Type assert to map[string]any params, ok := paramsAny.(map[string]any) if !ok { return fmt.Errorf("parameters must be a table") } // Check for array-style params if arr, ok := params[""]; ok { if arrParams, ok := arr.([]any); ok { execOpts.Args = arrParams } else if floatArr, ok := arr.([]float64); ok { args := make([]any, len(floatArr)) for i, v := range floatArr { args[i] = v } execOpts.Args = args } } else { // Named parameters named := make(map[string]any, len(params)) for k, v := range params { if len(k) > 0 && k[0] != ':' { named[":"+k] = v } else { named[k] = v } } execOpts.Named = named } } else { // Positional parameters from stack count := state.GetTop() - 2 args := make([]any, count) for i := range count { idx := i + 3 val, err := state.ToValue(idx) if err != nil { return fmt.Errorf("invalid parameter %d: %w", i+1, err) } args[i] = val } execOpts.Args = args } return nil } // sqlGetOne executes a query and returns only the first row func sqlGetOne(state *luajit.State) int { if err := state.CheckMinArgs(2); err != nil { return state.PushError("sqlite.get_one: %v", err) } dbName, err := state.SafeToString(1) if err != nil { return state.PushError("sqlite.get_one: database name must be string") } query, err := state.SafeToString(2) if err != nil { return state.PushError("sqlite.get_one: query must be string") } // Get pool pool, err := getPool(dbName) if err != nil { return state.PushError("sqlite.get_one: %v", err) } // Get connection with timeout ctx, cancel := context.WithTimeout(context.Background(), connTimeout) defer cancel() conn, err := pool.Take(ctx) if err != nil { return state.PushError("sqlite.get_one: connection timeout: %v", err) } defer pool.Put(conn) // Create execution options var execOpts sqlitex.ExecOptions var result map[string]any // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.get_one: %v", err) } } // Set up result function to get only first row execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { if result != nil { return nil // Already got first row } result = make(map[string]any) colCount := stmt.ColumnCount() for i := range colCount { colName := stmt.ColumnName(i) switch stmt.ColumnType(i) { case sqlite.TypeInteger: result[colName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: result[colName] = stmt.ColumnFloat(i) case sqlite.TypeText: result[colName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) if blobSize > 0 { buf := make([]byte, blobSize) result[colName] = stmt.ColumnBytes(i, buf) } else { result[colName] = []byte{} } case sqlite.TypeNull: result[colName] = nil } } return nil } // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.get_one: %v", err) } // Return result or nil if no rows if result == nil { state.PushNil() } else { if err := state.PushValue(result); err != nil { return state.PushError("sqlite.get_one: %v", err) } } return 1 } // RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err } if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil { return err } if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil { return err } return nil }