Sashimi/db.go

528 lines
12 KiB
Go

package sashimi
import (
"context"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"unicode"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
// Stmt is a type alias for sqlite.Stmt to avoid zombiezen being a direct dependency
type Stmt = sqlite.Stmt
var placeholderRegex = regexp.MustCompile(`%[sd]`)
type DB struct {
pool *sqlitex.Pool
}
// New creates a new database wrapper instance with connection pooling
func New(dbPath string) (*DB, error) {
// Create connection pool with 10 connections max
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
Flags: sqlite.OpenReadWrite | sqlite.OpenCreate,
PoolSize: 10,
})
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
db := &DB{pool: pool}
// Configure database using one connection from pool
if err := db.configure(); err != nil {
pool.Close()
return nil, err
}
return db, nil
}
// configure sets up database pragmas
func (db *DB) configure() error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
configs := []string{
"PRAGMA journal_mode=WAL",
"PRAGMA cache_size=-65536", // 64MB cache
"PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000", // 5 second timeout
}
for _, config := range configs {
if err := sqlitex.Execute(conn, config, nil); err != nil {
return fmt.Errorf("failed to configure database: %w", err)
}
}
return nil
}
// Close closes the database connection pool
func (db *DB) Close() error {
if db.pool != nil {
return db.pool.Close()
}
return nil
}
// Pool returns the underlying connection pool
func (db *DB) Pool() *sqlitex.Pool {
return db.pool
}
// Scan scans a SQLite statement result into a struct using field names
func (db *DB) Scan(stmt *PooledStmt, dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct")
}
elem := v.Elem()
typ := elem.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
columnName := toSnakeCase(field.Name)
fieldValue := elem.Field(i)
if !fieldValue.CanSet() {
continue
}
// Find column index by name
colIndex := -1
for j := 0; j < stmt.ColumnCount(); j++ {
if stmt.ColumnName(j) == columnName {
colIndex = j
break
}
}
if colIndex == -1 {
continue // Column not found
}
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
case reflect.String:
fieldValue.SetString(stmt.ColumnText(colIndex))
case reflect.Float32, reflect.Float64:
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
case reflect.Bool:
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
}
}
return nil
}
// Query executes a query with fmt-style placeholders and automatically binds parameters
func (db *DB) Query(query string, args ...any) (*PooledStmt, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
convertedQuery, paramTypes := convertPlaceholders(query)
stmt, err := conn.Prepare(convertedQuery)
if err != nil {
db.pool.Put(conn)
return nil, err
}
// Bind parameters with correct types
for i, arg := range args {
if i >= len(paramTypes) {
break
}
switch paramTypes[i] {
case "s": // string
if s, ok := arg.(string); ok {
stmt.BindText(i+1, s)
} else {
stmt.BindText(i+1, fmt.Sprintf("%v", arg))
}
case "d": // integer
switch v := arg.(type) {
case int:
stmt.BindInt64(i+1, int64(v))
case int32:
stmt.BindInt64(i+1, int64(v))
case int64:
stmt.BindInt64(i+1, v)
case float64:
stmt.BindInt64(i+1, int64(v))
default:
if i64, err := strconv.ParseInt(fmt.Sprintf("%v", arg), 10, 64); err == nil {
stmt.BindInt64(i+1, i64)
} else {
stmt.BindInt64(i+1, 0)
}
}
}
}
// Create a wrapped statement that releases the connection when finalized
return &PooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
}
// PooledStmt wraps a statement to automatically release pool connections
type PooledStmt struct {
*sqlite.Stmt
pool *sqlitex.Pool
conn *sqlite.Conn
finalized bool
}
func (ps *PooledStmt) Finalize() error {
if !ps.finalized {
err := ps.Stmt.Finalize()
ps.pool.Put(ps.conn)
ps.finalized = true
return err
}
return nil
}
// Get executes a query and returns the first row
func (db *DB) Get(dest any, query string, args ...any) error {
stmt, err := db.Query(query, args...)
if err != nil {
return err
}
defer stmt.Finalize()
hasRow, err := stmt.Step()
if err != nil {
return err
}
if !hasRow {
return fmt.Errorf("no rows found")
}
return db.scanValue(stmt, dest)
}
// Select executes a query and scans all rows into a slice
func (db *DB) Select(dest any, query string, args ...any) error {
destValue := reflect.ValueOf(dest)
if destValue.Kind() != reflect.Ptr || destValue.Elem().Kind() != reflect.Slice {
return fmt.Errorf("dest must be a pointer to slice")
}
sliceValue := destValue.Elem()
elemType := sliceValue.Type().Elem()
// Ensure element type is a pointer to struct
if elemType.Kind() != reflect.Ptr || elemType.Elem().Kind() != reflect.Struct {
return fmt.Errorf("slice elements must be pointers to structs")
}
stmt, err := db.Query(query, args...)
if err != nil {
return err
}
defer stmt.Finalize()
for {
hasRow, err := stmt.Step()
if err != nil {
return err
}
if !hasRow {
break
}
// Create new instance of the element type
newElem := reflect.New(elemType.Elem())
if err := db.Scan(stmt, newElem.Interface()); err != nil {
return err
}
sliceValue.Set(reflect.Append(sliceValue, newElem))
}
return nil
}
// Exec executes a statement with fmt-style placeholders
func (db *DB) Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
convertedQuery, paramTypes := convertPlaceholders(query)
sqlArgs := make([]any, len(args))
for i, arg := range args {
if i < len(paramTypes) && paramTypes[i] == "d" {
// Convert to int64 for integer parameters
switch v := arg.(type) {
case int:
sqlArgs[i] = int64(v)
case int32:
sqlArgs[i] = int64(v)
case int64:
sqlArgs[i] = v
default:
sqlArgs[i] = arg
}
} else {
sqlArgs[i] = arg
}
}
return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{
Args: sqlArgs,
})
}
// Update updates specific fields in the database
func (db *DB) Update(tableName string, fields map[string]any, whereField string, whereValue any) error {
if len(fields) == 0 {
return nil // No changes
}
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Build UPDATE query
setParts := make([]string, 0, len(fields))
args := make([]any, 0, len(fields)+1)
for field, value := range fields {
setParts = append(setParts, field+" = ?")
args = append(args, value)
}
args = append(args, whereValue)
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
tableName, strings.Join(setParts, ", "), whereField)
return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: args,
})
}
// Insert inserts a struct or map into the database
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
exclude := make(map[string]bool)
for _, field := range excludeFields {
exclude[toSnakeCase(field)] = true
}
var columns []string
var placeholders []string
var args []any
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
switch v.Kind() {
case reflect.Map:
// Handle map[string]any
m := obj.(map[string]any)
for key, value := range m {
columnName := toSnakeCase(key)
if exclude[columnName] {
continue
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, value)
}
case reflect.Struct:
// Handle struct
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
columnName := toSnakeCase(field.Name)
if exclude[columnName] {
continue
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, v.Field(i).Interface())
}
default:
return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any")
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
stmt, err := conn.Prepare(query)
if err != nil {
return 0, err
}
defer stmt.Finalize()
// Bind parameters
for i, arg := range args {
switch v := arg.(type) {
case string:
stmt.BindText(i+1, v)
case int, int32, int64:
stmt.BindInt64(i+1, reflect.ValueOf(v).Int())
case float32, float64:
stmt.BindFloat(i+1, reflect.ValueOf(v).Float())
default:
stmt.BindText(i+1, fmt.Sprintf("%v", v))
}
}
_, err = stmt.Step()
if err != nil {
return 0, err
}
return conn.LastInsertRowID(), nil
}
// Transaction executes multiple operations atomically
func (db *DB) Transaction(fn func() error) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Begin transaction
if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
return err
}
// Execute operations
err = fn()
if err != nil {
// Rollback on error
sqlitex.Execute(conn, "ROLLBACK", nil)
return err
}
// Commit on success
return sqlitex.Execute(conn, "COMMIT", nil)
}
// scanValue scans a statement result into either a struct or primitive type
func (db *DB) scanValue(stmt *PooledStmt, dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer {
return fmt.Errorf("dest must be a pointer")
}
elem := v.Elem()
// Handle primitive types
if isPrimitiveType(elem.Kind()) {
if stmt.ColumnCount() == 0 {
return fmt.Errorf("no columns in result")
}
return scanPrimitive(stmt, elem, 0)
}
// Handle struct types
if elem.Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct or primitive type")
}
return db.Scan(stmt, dest)
}
// isPrimitiveType checks if a reflect.Kind represents a primitive type
func isPrimitiveType(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.String, reflect.Float32, reflect.Float64, reflect.Bool:
return true
}
return false
}
// scanPrimitive scans a column value into a primitive type
func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error {
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
case reflect.String:
fieldValue.SetString(stmt.ColumnText(colIndex))
case reflect.Float32, reflect.Float64:
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
case reflect.Bool:
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
default:
return fmt.Errorf("unsupported type: %v", fieldValue.Kind())
}
return nil
}
func convertPlaceholders(query string) (string, []string) {
var paramTypes []string
convertedQuery := placeholderRegex.ReplaceAllStringFunc(query, func(match string) string {
paramTypes = append(paramTypes, match[1:]) // Remove % prefix
return "?"
})
return convertedQuery, paramTypes
}
// toSnakeCase converts PascalCase to snake_case
func toSnakeCase(s string) string {
var result strings.Builder
runes := []rune(s)
for i, r := range runes {
if i > 0 {
prev := runes[i-1]
// Add underscore before digit if previous char was letter
if unicode.IsDigit(r) && unicode.IsLetter(prev) {
result.WriteByte('_')
}
// Add underscore before uppercase letter
if unicode.IsUpper(r) {
// Don't add if previous was also uppercase (unless end of acronym)
if !unicode.IsUpper(prev) ||
(i+1 < len(runes) && unicode.IsLower(runes[i+1])) {
result.WriteByte('_')
}
}
}
result.WriteRune(unicode.ToLower(r))
}
return result.String()
}