528 lines
12 KiB
Go
528 lines
12 KiB
Go
package sashimi
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
)
|
|
|
|
// Stmt is a type alias for sqlite.Stmt to avoid zombiezen being a direct dependency
|
|
type Stmt = sqlite.Stmt
|
|
|
|
var placeholderRegex = regexp.MustCompile(`%[sd]`)
|
|
|
|
type DB struct {
|
|
pool *sqlitex.Pool
|
|
}
|
|
|
|
// New creates a new database wrapper instance with connection pooling
|
|
func New(dbPath string) (*DB, error) {
|
|
// Create connection pool with 10 connections max
|
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
|
Flags: sqlite.OpenReadWrite | sqlite.OpenCreate,
|
|
PoolSize: 10,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
|
}
|
|
|
|
db := &DB{pool: pool}
|
|
|
|
// Configure database using one connection from pool
|
|
if err := db.configure(); err != nil {
|
|
pool.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// configure sets up database pragmas
|
|
func (db *DB) configure() error {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
defer db.pool.Put(conn)
|
|
|
|
configs := []string{
|
|
"PRAGMA journal_mode=WAL",
|
|
"PRAGMA cache_size=-65536", // 64MB cache
|
|
"PRAGMA foreign_keys=ON",
|
|
"PRAGMA busy_timeout=5000", // 5 second timeout
|
|
}
|
|
|
|
for _, config := range configs {
|
|
if err := sqlitex.Execute(conn, config, nil); err != nil {
|
|
return fmt.Errorf("failed to configure database: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes the database connection pool
|
|
func (db *DB) Close() error {
|
|
if db.pool != nil {
|
|
return db.pool.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Pool returns the underlying connection pool
|
|
func (db *DB) Pool() *sqlitex.Pool {
|
|
return db.pool
|
|
}
|
|
|
|
// Scan scans a SQLite statement result into a struct using field names
|
|
func (db *DB) Scan(stmt *PooledStmt, dest any) error {
|
|
v := reflect.ValueOf(dest)
|
|
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("dest must be a pointer to struct")
|
|
}
|
|
|
|
elem := v.Elem()
|
|
typ := elem.Type()
|
|
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
columnName := toSnakeCase(field.Name)
|
|
|
|
fieldValue := elem.Field(i)
|
|
if !fieldValue.CanSet() {
|
|
continue
|
|
}
|
|
|
|
// Find column index by name
|
|
colIndex := -1
|
|
for j := 0; j < stmt.ColumnCount(); j++ {
|
|
if stmt.ColumnName(j) == columnName {
|
|
colIndex = j
|
|
break
|
|
}
|
|
}
|
|
|
|
if colIndex == -1 {
|
|
continue // Column not found
|
|
}
|
|
|
|
switch fieldValue.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
|
|
case reflect.String:
|
|
fieldValue.SetString(stmt.ColumnText(colIndex))
|
|
case reflect.Float32, reflect.Float64:
|
|
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
|
|
case reflect.Bool:
|
|
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Query executes a query with fmt-style placeholders and automatically binds parameters
|
|
func (db *DB) Query(query string, args ...any) (*PooledStmt, error) {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
|
|
convertedQuery, paramTypes := convertPlaceholders(query)
|
|
|
|
stmt, err := conn.Prepare(convertedQuery)
|
|
if err != nil {
|
|
db.pool.Put(conn)
|
|
return nil, err
|
|
}
|
|
|
|
// Bind parameters with correct types
|
|
for i, arg := range args {
|
|
if i >= len(paramTypes) {
|
|
break
|
|
}
|
|
|
|
switch paramTypes[i] {
|
|
case "s": // string
|
|
if s, ok := arg.(string); ok {
|
|
stmt.BindText(i+1, s)
|
|
} else {
|
|
stmt.BindText(i+1, fmt.Sprintf("%v", arg))
|
|
}
|
|
case "d": // integer
|
|
switch v := arg.(type) {
|
|
case int:
|
|
stmt.BindInt64(i+1, int64(v))
|
|
case int32:
|
|
stmt.BindInt64(i+1, int64(v))
|
|
case int64:
|
|
stmt.BindInt64(i+1, v)
|
|
case float64:
|
|
stmt.BindInt64(i+1, int64(v))
|
|
default:
|
|
if i64, err := strconv.ParseInt(fmt.Sprintf("%v", arg), 10, 64); err == nil {
|
|
stmt.BindInt64(i+1, i64)
|
|
} else {
|
|
stmt.BindInt64(i+1, 0)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create a wrapped statement that releases the connection when finalized
|
|
return &PooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
|
|
}
|
|
|
|
// PooledStmt wraps a statement to automatically release pool connections
|
|
type PooledStmt struct {
|
|
*sqlite.Stmt
|
|
pool *sqlitex.Pool
|
|
conn *sqlite.Conn
|
|
finalized bool
|
|
}
|
|
|
|
func (ps *PooledStmt) Finalize() error {
|
|
if !ps.finalized {
|
|
err := ps.Stmt.Finalize()
|
|
ps.pool.Put(ps.conn)
|
|
ps.finalized = true
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get executes a query and returns the first row
|
|
func (db *DB) Get(dest any, query string, args ...any) error {
|
|
stmt, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Finalize()
|
|
|
|
hasRow, err := stmt.Step()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !hasRow {
|
|
return fmt.Errorf("no rows found")
|
|
}
|
|
|
|
return db.scanValue(stmt, dest)
|
|
}
|
|
|
|
// Select executes a query and scans all rows into a slice
|
|
func (db *DB) Select(dest any, query string, args ...any) error {
|
|
destValue := reflect.ValueOf(dest)
|
|
if destValue.Kind() != reflect.Ptr || destValue.Elem().Kind() != reflect.Slice {
|
|
return fmt.Errorf("dest must be a pointer to slice")
|
|
}
|
|
|
|
sliceValue := destValue.Elem()
|
|
elemType := sliceValue.Type().Elem()
|
|
|
|
// Ensure element type is a pointer to struct
|
|
if elemType.Kind() != reflect.Ptr || elemType.Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("slice elements must be pointers to structs")
|
|
}
|
|
|
|
stmt, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Finalize()
|
|
|
|
for {
|
|
hasRow, err := stmt.Step()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !hasRow {
|
|
break
|
|
}
|
|
|
|
// Create new instance of the element type
|
|
newElem := reflect.New(elemType.Elem())
|
|
if err := db.Scan(stmt, newElem.Interface()); err != nil {
|
|
return err
|
|
}
|
|
|
|
sliceValue.Set(reflect.Append(sliceValue, newElem))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Exec executes a statement with fmt-style placeholders
|
|
func (db *DB) Exec(query string, args ...any) error {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
defer db.pool.Put(conn)
|
|
|
|
convertedQuery, paramTypes := convertPlaceholders(query)
|
|
|
|
sqlArgs := make([]any, len(args))
|
|
for i, arg := range args {
|
|
if i < len(paramTypes) && paramTypes[i] == "d" {
|
|
// Convert to int64 for integer parameters
|
|
switch v := arg.(type) {
|
|
case int:
|
|
sqlArgs[i] = int64(v)
|
|
case int32:
|
|
sqlArgs[i] = int64(v)
|
|
case int64:
|
|
sqlArgs[i] = v
|
|
default:
|
|
sqlArgs[i] = arg
|
|
}
|
|
} else {
|
|
sqlArgs[i] = arg
|
|
}
|
|
}
|
|
|
|
return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{
|
|
Args: sqlArgs,
|
|
})
|
|
}
|
|
|
|
// Update updates specific fields in the database
|
|
func (db *DB) Update(tableName string, fields map[string]any, whereField string, whereValue any) error {
|
|
if len(fields) == 0 {
|
|
return nil // No changes
|
|
}
|
|
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
defer db.pool.Put(conn)
|
|
|
|
// Build UPDATE query
|
|
setParts := make([]string, 0, len(fields))
|
|
args := make([]any, 0, len(fields)+1)
|
|
|
|
for field, value := range fields {
|
|
setParts = append(setParts, field+" = ?")
|
|
args = append(args, value)
|
|
}
|
|
|
|
args = append(args, whereValue)
|
|
|
|
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
|
|
tableName, strings.Join(setParts, ", "), whereField)
|
|
|
|
return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
})
|
|
}
|
|
|
|
// Insert inserts a struct or map into the database
|
|
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
defer db.pool.Put(conn)
|
|
|
|
exclude := make(map[string]bool)
|
|
for _, field := range excludeFields {
|
|
exclude[toSnakeCase(field)] = true
|
|
}
|
|
|
|
var columns []string
|
|
var placeholders []string
|
|
var args []any
|
|
|
|
v := reflect.ValueOf(obj)
|
|
if v.Kind() == reflect.Pointer {
|
|
v = v.Elem()
|
|
}
|
|
|
|
switch v.Kind() {
|
|
case reflect.Map:
|
|
// Handle map[string]any
|
|
m := obj.(map[string]any)
|
|
for key, value := range m {
|
|
columnName := toSnakeCase(key)
|
|
if exclude[columnName] {
|
|
continue
|
|
}
|
|
columns = append(columns, columnName)
|
|
placeholders = append(placeholders, "?")
|
|
args = append(args, value)
|
|
}
|
|
|
|
case reflect.Struct:
|
|
// Handle struct
|
|
t := v.Type()
|
|
for i := 0; i < t.NumField(); i++ {
|
|
field := t.Field(i)
|
|
columnName := toSnakeCase(field.Name)
|
|
if exclude[columnName] {
|
|
continue
|
|
}
|
|
columns = append(columns, columnName)
|
|
placeholders = append(placeholders, "?")
|
|
args = append(args, v.Field(i).Interface())
|
|
}
|
|
|
|
default:
|
|
return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any")
|
|
}
|
|
|
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
|
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
|
|
|
stmt, err := conn.Prepare(query)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer stmt.Finalize()
|
|
|
|
// Bind parameters
|
|
for i, arg := range args {
|
|
switch v := arg.(type) {
|
|
case string:
|
|
stmt.BindText(i+1, v)
|
|
case int, int32, int64:
|
|
stmt.BindInt64(i+1, reflect.ValueOf(v).Int())
|
|
case float32, float64:
|
|
stmt.BindFloat(i+1, reflect.ValueOf(v).Float())
|
|
default:
|
|
stmt.BindText(i+1, fmt.Sprintf("%v", v))
|
|
}
|
|
}
|
|
|
|
_, err = stmt.Step()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return conn.LastInsertRowID(), nil
|
|
}
|
|
|
|
// Transaction executes multiple operations atomically
|
|
func (db *DB) Transaction(fn func() error) error {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
defer db.pool.Put(conn)
|
|
|
|
// Begin transaction
|
|
if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Execute operations
|
|
err = fn()
|
|
|
|
if err != nil {
|
|
// Rollback on error
|
|
sqlitex.Execute(conn, "ROLLBACK", nil)
|
|
return err
|
|
}
|
|
|
|
// Commit on success
|
|
return sqlitex.Execute(conn, "COMMIT", nil)
|
|
}
|
|
|
|
// scanValue scans a statement result into either a struct or primitive type
|
|
func (db *DB) scanValue(stmt *PooledStmt, dest any) error {
|
|
v := reflect.ValueOf(dest)
|
|
if v.Kind() != reflect.Pointer {
|
|
return fmt.Errorf("dest must be a pointer")
|
|
}
|
|
|
|
elem := v.Elem()
|
|
|
|
// Handle primitive types
|
|
if isPrimitiveType(elem.Kind()) {
|
|
if stmt.ColumnCount() == 0 {
|
|
return fmt.Errorf("no columns in result")
|
|
}
|
|
|
|
return scanPrimitive(stmt, elem, 0)
|
|
}
|
|
|
|
// Handle struct types
|
|
if elem.Kind() != reflect.Struct {
|
|
return fmt.Errorf("dest must be a pointer to struct or primitive type")
|
|
}
|
|
|
|
return db.Scan(stmt, dest)
|
|
}
|
|
|
|
// isPrimitiveType checks if a reflect.Kind represents a primitive type
|
|
func isPrimitiveType(k reflect.Kind) bool {
|
|
switch k {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
reflect.String, reflect.Float32, reflect.Float64, reflect.Bool:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// scanPrimitive scans a column value into a primitive type
|
|
func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error {
|
|
switch fieldValue.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
|
|
case reflect.String:
|
|
fieldValue.SetString(stmt.ColumnText(colIndex))
|
|
case reflect.Float32, reflect.Float64:
|
|
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
|
|
case reflect.Bool:
|
|
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
|
|
default:
|
|
return fmt.Errorf("unsupported type: %v", fieldValue.Kind())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func convertPlaceholders(query string) (string, []string) {
|
|
var paramTypes []string
|
|
|
|
convertedQuery := placeholderRegex.ReplaceAllStringFunc(query, func(match string) string {
|
|
paramTypes = append(paramTypes, match[1:]) // Remove % prefix
|
|
return "?"
|
|
})
|
|
|
|
return convertedQuery, paramTypes
|
|
}
|
|
|
|
// toSnakeCase converts PascalCase to snake_case
|
|
func toSnakeCase(s string) string {
|
|
var result strings.Builder
|
|
runes := []rune(s)
|
|
|
|
for i, r := range runes {
|
|
if i > 0 {
|
|
prev := runes[i-1]
|
|
|
|
// Add underscore before digit if previous char was letter
|
|
if unicode.IsDigit(r) && unicode.IsLetter(prev) {
|
|
result.WriteByte('_')
|
|
}
|
|
// Add underscore before uppercase letter
|
|
if unicode.IsUpper(r) {
|
|
// Don't add if previous was also uppercase (unless end of acronym)
|
|
if !unicode.IsUpper(prev) ||
|
|
(i+1 < len(runes) && unicode.IsLower(runes[i+1])) {
|
|
result.WriteByte('_')
|
|
}
|
|
}
|
|
}
|
|
result.WriteRune(unicode.ToLower(r))
|
|
}
|
|
return result.String()
|
|
}
|