package database import ( "fmt" "reflect" "regexp" "strconv" "strings" "unicode" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) var placeholderRegex = regexp.MustCompile(`%[sd]`) var db *sqlite.Conn // Init initializes the database connection with WAL mode and cache settings func Init(dbPath string) error { conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate) if err != nil { return fmt.Errorf("failed to open database: %w", err) } // Enable WAL mode if err := sqlitex.Execute(conn, "PRAGMA journal_mode=WAL", nil); err != nil { conn.Close() return fmt.Errorf("failed to enable WAL mode: %w", err) } // Set generous cache size (64MB) if err := sqlitex.Execute(conn, "PRAGMA cache_size=-65536", nil); err != nil { conn.Close() return fmt.Errorf("failed to set cache size: %w", err) } // Enable foreign keys if err := sqlitex.Execute(conn, "PRAGMA foreign_keys=ON", nil); err != nil { conn.Close() return fmt.Errorf("failed to enable foreign keys: %w", err) } db = conn return nil } // DB returns the global database connection func DB() *sqlite.Conn { if db == nil { panic("database not initialized - call Init() first") } return db } // Scan scans a SQLite statement result into a struct using field names func Scan(stmt *sqlite.Stmt, dest any) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct { return fmt.Errorf("dest must be a pointer to struct") } elem := v.Elem() typ := elem.Type() for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) columnName := toSnakeCase(field.Name) fieldValue := elem.Field(i) if !fieldValue.CanSet() { continue } // Find column index by name colIndex := -1 for j := 0; j < stmt.ColumnCount(); j++ { if stmt.ColumnName(j) == columnName { colIndex = j break } } if colIndex == -1 { continue // Column not found } switch fieldValue.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: fieldValue.SetInt(stmt.ColumnInt64(colIndex)) case reflect.String: fieldValue.SetString(stmt.ColumnText(colIndex)) case reflect.Float32, reflect.Float64: fieldValue.SetFloat(stmt.ColumnFloat(colIndex)) case reflect.Bool: fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0) } } return nil } // Query executes a query with fmt-style placeholders and automatically binds parameters func Query(query string, args ...any) (*sqlite.Stmt, error) { // Replace fmt placeholders with SQLite placeholders convertedQuery, paramTypes := convertPlaceholders(query) stmt, err := DB().Prepare(convertedQuery) if err != nil { return nil, err } // Bind parameters with correct types for i, arg := range args { if i >= len(paramTypes) { break } switch paramTypes[i] { case "s": // string if s, ok := arg.(string); ok { stmt.BindText(i+1, s) } else { stmt.BindText(i+1, fmt.Sprintf("%v", arg)) } case "d": // integer switch v := arg.(type) { case int: stmt.BindInt64(i+1, int64(v)) case int32: stmt.BindInt64(i+1, int64(v)) case int64: stmt.BindInt64(i+1, v) case float64: stmt.BindInt64(i+1, int64(v)) default: if i64, err := strconv.ParseInt(fmt.Sprintf("%v", arg), 10, 64); err == nil { stmt.BindInt64(i+1, i64) } else { stmt.BindInt64(i+1, 0) } } } } return stmt, nil } // Get executes a query and returns the first row func Get(dest any, query string, args ...any) error { stmt, err := Query(query, args...) if err != nil { return err } defer stmt.Finalize() hasRow, err := stmt.Step() if err != nil { return err } if !hasRow { return fmt.Errorf("no rows found") } return Scan(stmt, dest) } // Select executes a query and scans all rows into a slice func Select(dest any, query string, args ...any) error { destValue := reflect.ValueOf(dest) if destValue.Kind() != reflect.Ptr || destValue.Elem().Kind() != reflect.Slice { return fmt.Errorf("dest must be a pointer to slice") } sliceValue := destValue.Elem() elemType := sliceValue.Type().Elem() // Ensure element type is a pointer to struct if elemType.Kind() != reflect.Ptr || elemType.Elem().Kind() != reflect.Struct { return fmt.Errorf("slice elements must be pointers to structs") } stmt, err := Query(query, args...) if err != nil { return err } defer stmt.Finalize() for { hasRow, err := stmt.Step() if err != nil { return err } if !hasRow { break } // Create new instance of the element type newElem := reflect.New(elemType.Elem()) if err := Scan(stmt, newElem.Interface()); err != nil { return err } sliceValue.Set(reflect.Append(sliceValue, newElem)) } return nil } // Exec executes a statement with fmt-style placeholders func Exec(query string, args ...any) error { convertedQuery, paramTypes := convertPlaceholders(query) sqlArgs := make([]any, len(args)) for i, arg := range args { if i < len(paramTypes) && paramTypes[i] == "d" { // Convert to int64 for integer parameters switch v := arg.(type) { case int: sqlArgs[i] = int64(v) case int32: sqlArgs[i] = int64(v) case int64: sqlArgs[i] = v default: sqlArgs[i] = arg } } else { sqlArgs[i] = arg } } return sqlitex.Execute(DB(), convertedQuery, &sqlitex.ExecOptions{ Args: sqlArgs, }) } // Update updates specific fields in the database func Update(tableName string, fields map[string]any, whereField string, whereValue any) error { if len(fields) == 0 { return nil // No changes } // Build UPDATE query setParts := make([]string, 0, len(fields)) args := make([]any, 0, len(fields)+1) for field, value := range fields { setParts = append(setParts, field+" = ?") args = append(args, value) } args = append(args, whereValue) query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", tableName, strings.Join(setParts, ", "), whereField) return sqlitex.Execute(DB(), query, &sqlitex.ExecOptions{ Args: args, }) } // Insert inserts a struct into the database func Insert(tableName string, obj any, excludeFields ...string) (int64, error) { v := reflect.ValueOf(obj) if v.Kind() == reflect.Pointer { v = v.Elem() } t := v.Type() exclude := make(map[string]bool) for _, field := range excludeFields { exclude[toSnakeCase(field)] = true } var columns []string var placeholders []string var args []any for i := 0; i < t.NumField(); i++ { field := t.Field(i) columnName := toSnakeCase(field.Name) if exclude[columnName] { continue } columns = append(columns, columnName) placeholders = append(placeholders, "?") args = append(args, v.Field(i).Interface()) } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) stmt, err := DB().Prepare(query) if err != nil { return 0, err } defer stmt.Finalize() // Bind parameters for i, arg := range args { switch v := arg.(type) { case string: stmt.BindText(i+1, v) case int, int32, int64: stmt.BindInt64(i+1, reflect.ValueOf(v).Int()) case float32, float64: stmt.BindFloat(i+1, reflect.ValueOf(v).Float()) default: stmt.BindText(i+1, fmt.Sprintf("%v", v)) } } _, err = stmt.Step() if err != nil { return 0, err } return DB().LastInsertRowID(), nil } func convertPlaceholders(query string) (string, []string) { var paramTypes []string convertedQuery := placeholderRegex.ReplaceAllStringFunc(query, func(match string) string { paramTypes = append(paramTypes, match[1:]) // Remove % prefix return "?" }) return convertedQuery, paramTypes } // toSnakeCase converts PascalCase to snake_case func toSnakeCase(s string) string { var result strings.Builder runes := []rune(s) for i, r := range runes { if i > 0 && unicode.IsUpper(r) { // Don't add underscore if previous char was also uppercase // unless next char is lowercase (end of acronym) if !unicode.IsUpper(runes[i-1]) || (i+1 < len(runes) && unicode.IsLower(runes[i+1])) { result.WriteByte('_') } } result.WriteRune(unicode.ToLower(r)) } return result.String() }