package sashimi import ( "context" "fmt" "reflect" "regexp" "strconv" "strings" "unicode" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) // Stmt is a type alias for sqlite.Stmt to avoid zombiezen being a direct dependency type Stmt = sqlite.Stmt var placeholderRegex = regexp.MustCompile(`%[sd]`) type DB struct { pool *sqlitex.Pool } // New creates a new database wrapper instance with connection pooling func New(dbPath string) (*DB, error) { // Create connection pool with 10 connections max pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ Flags: sqlite.OpenReadWrite | sqlite.OpenCreate, PoolSize: 10, }) if err != nil { return nil, fmt.Errorf("failed to create connection pool: %w", err) } db := &DB{pool: pool} // Configure database using one connection from pool if err := db.configure(); err != nil { pool.Close() return nil, err } return db, nil } // configure sets up database pragmas func (db *DB) configure() error { conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection: %w", err) } defer db.pool.Put(conn) configs := []string{ "PRAGMA journal_mode=WAL", "PRAGMA cache_size=-65536", // 64MB cache "PRAGMA foreign_keys=ON", "PRAGMA busy_timeout=5000", // 5 second timeout } for _, config := range configs { if err := sqlitex.Execute(conn, config, nil); err != nil { return fmt.Errorf("failed to configure database: %w", err) } } return nil } // Close closes the database connection pool func (db *DB) Close() error { if db.pool != nil { return db.pool.Close() } return nil } // Pool returns the underlying connection pool func (db *DB) Pool() *sqlitex.Pool { return db.pool } // Scan scans a SQLite statement result into a struct using field names func (db *DB) Scan(stmt *PooledStmt, dest any) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct { return fmt.Errorf("dest must be a pointer to struct") } elem := v.Elem() typ := elem.Type() for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) columnName := toSnakeCase(field.Name) fieldValue := elem.Field(i) if !fieldValue.CanSet() { continue } // Find column index by name colIndex := -1 for j := 0; j < stmt.ColumnCount(); j++ { if stmt.ColumnName(j) == columnName { colIndex = j break } } if colIndex == -1 { continue // Column not found } switch fieldValue.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: fieldValue.SetInt(stmt.ColumnInt64(colIndex)) case reflect.String: fieldValue.SetString(stmt.ColumnText(colIndex)) case reflect.Float32, reflect.Float64: fieldValue.SetFloat(stmt.ColumnFloat(colIndex)) case reflect.Bool: fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0) } } return nil } // Query executes a query with fmt-style placeholders and automatically binds parameters func (db *DB) Query(query string, args ...any) (*PooledStmt, error) { conn, err := db.pool.Take(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get connection: %w", err) } convertedQuery, paramTypes := convertPlaceholders(query) stmt, err := conn.Prepare(convertedQuery) if err != nil { db.pool.Put(conn) return nil, err } // Bind parameters with correct types for i, arg := range args { if i >= len(paramTypes) { break } switch paramTypes[i] { case "s": // string if s, ok := arg.(string); ok { stmt.BindText(i+1, s) } else { stmt.BindText(i+1, fmt.Sprintf("%v", arg)) } case "d": // integer switch v := arg.(type) { case int: stmt.BindInt64(i+1, int64(v)) case int32: stmt.BindInt64(i+1, int64(v)) case int64: stmt.BindInt64(i+1, v) case float64: stmt.BindInt64(i+1, int64(v)) default: if i64, err := strconv.ParseInt(fmt.Sprintf("%v", arg), 10, 64); err == nil { stmt.BindInt64(i+1, i64) } else { stmt.BindInt64(i+1, 0) } } } } // Create a wrapped statement that releases the connection when finalized return &PooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil } // PooledStmt wraps a statement to automatically release pool connections type PooledStmt struct { *sqlite.Stmt pool *sqlitex.Pool conn *sqlite.Conn finalized bool } func (ps *PooledStmt) Finalize() error { if !ps.finalized { err := ps.Stmt.Finalize() ps.pool.Put(ps.conn) ps.finalized = true return err } return nil } // Get executes a query and returns the first row func (db *DB) Get(dest any, query string, args ...any) error { stmt, err := db.Query(query, args...) if err != nil { return err } defer stmt.Finalize() hasRow, err := stmt.Step() if err != nil { return err } if !hasRow { return fmt.Errorf("no rows found") } return db.Scan(stmt, dest) } // Select executes a query and scans all rows into a slice func (db *DB) Select(dest any, query string, args ...any) error { destValue := reflect.ValueOf(dest) if destValue.Kind() != reflect.Ptr || destValue.Elem().Kind() != reflect.Slice { return fmt.Errorf("dest must be a pointer to slice") } sliceValue := destValue.Elem() elemType := sliceValue.Type().Elem() // Ensure element type is a pointer to struct if elemType.Kind() != reflect.Ptr || elemType.Elem().Kind() != reflect.Struct { return fmt.Errorf("slice elements must be pointers to structs") } stmt, err := db.Query(query, args...) if err != nil { return err } defer stmt.Finalize() for { hasRow, err := stmt.Step() if err != nil { return err } if !hasRow { break } // Create new instance of the element type newElem := reflect.New(elemType.Elem()) if err := db.Scan(stmt, newElem.Interface()); err != nil { return err } sliceValue.Set(reflect.Append(sliceValue, newElem)) } return nil } // Exec executes a statement with fmt-style placeholders func (db *DB) Exec(query string, args ...any) error { conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection: %w", err) } defer db.pool.Put(conn) convertedQuery, paramTypes := convertPlaceholders(query) sqlArgs := make([]any, len(args)) for i, arg := range args { if i < len(paramTypes) && paramTypes[i] == "d" { // Convert to int64 for integer parameters switch v := arg.(type) { case int: sqlArgs[i] = int64(v) case int32: sqlArgs[i] = int64(v) case int64: sqlArgs[i] = v default: sqlArgs[i] = arg } } else { sqlArgs[i] = arg } } return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{ Args: sqlArgs, }) } // Update updates specific fields in the database func (db *DB) Update(tableName string, fields map[string]any, whereField string, whereValue any) error { if len(fields) == 0 { return nil // No changes } conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection: %w", err) } defer db.pool.Put(conn) // Build UPDATE query setParts := make([]string, 0, len(fields)) args := make([]any, 0, len(fields)+1) for field, value := range fields { setParts = append(setParts, field+" = ?") args = append(args, value) } args = append(args, whereValue) query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", tableName, strings.Join(setParts, ", "), whereField) return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Insert inserts a struct into the database func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) { conn, err := db.pool.Take(context.Background()) if err != nil { return 0, fmt.Errorf("failed to get connection: %w", err) } defer db.pool.Put(conn) v := reflect.ValueOf(obj) if v.Kind() == reflect.Pointer { v = v.Elem() } t := v.Type() exclude := make(map[string]bool) for _, field := range excludeFields { exclude[toSnakeCase(field)] = true } var columns []string var placeholders []string var args []any for i := 0; i < t.NumField(); i++ { field := t.Field(i) columnName := toSnakeCase(field.Name) if exclude[columnName] { continue } columns = append(columns, columnName) placeholders = append(placeholders, "?") args = append(args, v.Field(i).Interface()) } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) stmt, err := conn.Prepare(query) if err != nil { return 0, err } defer stmt.Finalize() // Bind parameters for i, arg := range args { switch v := arg.(type) { case string: stmt.BindText(i+1, v) case int, int32, int64: stmt.BindInt64(i+1, reflect.ValueOf(v).Int()) case float32, float64: stmt.BindFloat(i+1, reflect.ValueOf(v).Float()) default: stmt.BindText(i+1, fmt.Sprintf("%v", v)) } } _, err = stmt.Step() if err != nil { return 0, err } return conn.LastInsertRowID(), nil } // Transaction executes multiple operations atomically func (db *DB) Transaction(fn func() error) error { conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection: %w", err) } defer db.pool.Put(conn) // Begin transaction if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil { return err } // Execute operations err = fn() if err != nil { // Rollback on error sqlitex.Execute(conn, "ROLLBACK", nil) return err } // Commit on success return sqlitex.Execute(conn, "COMMIT", nil) } func convertPlaceholders(query string) (string, []string) { var paramTypes []string convertedQuery := placeholderRegex.ReplaceAllStringFunc(query, func(match string) string { paramTypes = append(paramTypes, match[1:]) // Remove % prefix return "?" }) return convertedQuery, paramTypes } // toSnakeCase converts PascalCase to snake_case func toSnakeCase(s string) string { var result strings.Builder runes := []rune(s) for i, r := range runes { if i > 0 { prev := runes[i-1] // Add underscore before digit if previous char was letter if unicode.IsDigit(r) && unicode.IsLetter(prev) { result.WriteByte('_') } // Add underscore before uppercase letter if unicode.IsUpper(r) { // Don't add if previous was also uppercase (unless end of acronym) if !unicode.IsUpper(prev) || (i+1 < len(runes) && unicode.IsLower(runes[i+1])) { result.WriteByte('_') } } } result.WriteRune(unicode.ToLower(r)) } return result.String() }