first commit
This commit is contained in:
commit
0cf0c37c00
136
README.md
Normal file
136
README.md
Normal file
@ -0,0 +1,136 @@
|
||||
# Sashimi 🍣
|
||||
|
||||
A raw, tasty SQLite wrapper for Go built on top of [zombiezen.com/go/sqlite](https://zombiezen.com/go/sqlite).
|
||||
|
||||
## Features
|
||||
|
||||
- **Simple API** - fmt-style placeholders (`%s`, `%d`)
|
||||
- **Struct scanning** - Automatic field mapping with snake_case conversion
|
||||
- **Built-in migrations** - Numbered SQL file execution with tracking
|
||||
- **Type safety** - Reflection-based struct operations
|
||||
- **Transactions** - Easy atomic operations
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get git.sharkk.net/Sharkk/Sashimi
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"git.sharkk.net/Sharkk/Sashimi"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Connect
|
||||
db, err := sashimi.New("app.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Insert
|
||||
id, err := db.Insert("users", User{
|
||||
Name: "Alice",
|
||||
Email: "alice@example.com",
|
||||
}, "ID") // exclude ID field
|
||||
|
||||
// Get single record
|
||||
var user User
|
||||
err = db.Get(&user, "SELECT * FROM users WHERE id = %d", id)
|
||||
|
||||
// Get multiple records
|
||||
var users []*User
|
||||
err = db.Select(&users, "SELECT * FROM users WHERE name LIKE %s", "A%")
|
||||
|
||||
// Update
|
||||
err = db.Update("users", map[string]any{
|
||||
"email": "alice.new@example.com",
|
||||
}, "id", id)
|
||||
}
|
||||
```
|
||||
|
||||
## Migrations
|
||||
|
||||
### Setup
|
||||
```go
|
||||
migrator := sashimi.NewMigrator(db, "./migrations")
|
||||
```
|
||||
|
||||
### Commands
|
||||
```bash
|
||||
# Run pending migrations
|
||||
go run main.go migrate
|
||||
|
||||
# Create new migration
|
||||
go run main.go migrate new "create users table"
|
||||
|
||||
# Check status
|
||||
go run main.go migrate status
|
||||
```
|
||||
|
||||
### Migration Files
|
||||
Create numbered SQL files in your migrations directory:
|
||||
|
||||
```sql
|
||||
-- 1_create_users.sql
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
created_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
);
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Database Operations
|
||||
|
||||
```go
|
||||
// Execute query with result
|
||||
stmt, err := db.Query("SELECT * FROM users WHERE active = %d", 1)
|
||||
|
||||
// Single row
|
||||
err := db.Get(&user, "SELECT * FROM users WHERE id = %d", 123)
|
||||
|
||||
// Multiple rows
|
||||
err := db.Select(&users, "SELECT * FROM users")
|
||||
|
||||
// Execute without result
|
||||
err := db.Exec("DELETE FROM users WHERE id = %d", 123)
|
||||
|
||||
// Transactions
|
||||
err := db.Transaction(func() error {
|
||||
_, err := db.Insert("users", user)
|
||||
return err
|
||||
})
|
||||
```
|
||||
|
||||
### Struct Conventions
|
||||
|
||||
Go struct fields automatically map to snake_case columns:
|
||||
- `UserID` → `user_id`
|
||||
- `FirstName` → `first_name`
|
||||
- `CreatedAt` → `created_at`
|
||||
|
||||
## Configuration
|
||||
|
||||
The database connection is configured with:
|
||||
- WAL mode for better concurrency
|
||||
- 64MB cache size
|
||||
- Foreign keys enabled
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
391
db.go
Normal file
391
db.go
Normal file
@ -0,0 +1,391 @@
|
||||
package sashimi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
var placeholderRegex = regexp.MustCompile(`%[sd]`)
|
||||
|
||||
type DB struct {
|
||||
conn *sqlite.Conn
|
||||
}
|
||||
|
||||
// New creates a new database wrapper instance
|
||||
func New(dbPath string) (*DB, error) {
|
||||
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
db := &DB{conn: conn}
|
||||
|
||||
// Configure database
|
||||
if err := db.configure(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// configure sets up database pragmas
|
||||
func (db *DB) configure() error {
|
||||
configs := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA cache_size=-65536", // 64MB cache
|
||||
"PRAGMA foreign_keys=ON",
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err := sqlitex.Execute(db.conn, config, nil); err != nil {
|
||||
return fmt.Errorf("failed to configure database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
if db.conn != nil {
|
||||
return db.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conn returns the underlying sqlite connection
|
||||
func (db *DB) Conn() *sqlite.Conn {
|
||||
return db.conn
|
||||
}
|
||||
|
||||
// Scan scans a SQLite statement result into a struct using field names
|
||||
func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to struct")
|
||||
}
|
||||
|
||||
elem := v.Elem()
|
||||
typ := elem.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
|
||||
fieldValue := elem.Field(i)
|
||||
if !fieldValue.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find column index by name
|
||||
colIndex := -1
|
||||
for j := 0; j < stmt.ColumnCount(); j++ {
|
||||
if stmt.ColumnName(j) == columnName {
|
||||
colIndex = j
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if colIndex == -1 {
|
||||
continue // Column not found
|
||||
}
|
||||
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
|
||||
case reflect.String:
|
||||
fieldValue.SetString(stmt.ColumnText(colIndex))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
|
||||
case reflect.Bool:
|
||||
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query executes a query with fmt-style placeholders and automatically binds parameters
|
||||
func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
|
||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||
|
||||
stmt, err := db.conn.Prepare(convertedQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Bind parameters with correct types
|
||||
for i, arg := range args {
|
||||
if i >= len(paramTypes) {
|
||||
break
|
||||
}
|
||||
|
||||
switch paramTypes[i] {
|
||||
case "s": // string
|
||||
if s, ok := arg.(string); ok {
|
||||
stmt.BindText(i+1, s)
|
||||
} else {
|
||||
stmt.BindText(i+1, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
case "d": // integer
|
||||
switch v := arg.(type) {
|
||||
case int:
|
||||
stmt.BindInt64(i+1, int64(v))
|
||||
case int32:
|
||||
stmt.BindInt64(i+1, int64(v))
|
||||
case int64:
|
||||
stmt.BindInt64(i+1, v)
|
||||
case float64:
|
||||
stmt.BindInt64(i+1, int64(v))
|
||||
default:
|
||||
if i64, err := strconv.ParseInt(fmt.Sprintf("%v", arg), 10, 64); err == nil {
|
||||
stmt.BindInt64(i+1, i64)
|
||||
} else {
|
||||
stmt.BindInt64(i+1, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
// Get executes a query and returns the first row
|
||||
func (db *DB) Get(dest any, query string, args ...any) error {
|
||||
stmt, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
hasRow, err := stmt.Step()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasRow {
|
||||
return fmt.Errorf("no rows found")
|
||||
}
|
||||
|
||||
return db.Scan(stmt, dest)
|
||||
}
|
||||
|
||||
// Select executes a query and scans all rows into a slice
|
||||
func (db *DB) Select(dest any, query string, args ...any) error {
|
||||
destValue := reflect.ValueOf(dest)
|
||||
if destValue.Kind() != reflect.Ptr || destValue.Elem().Kind() != reflect.Slice {
|
||||
return fmt.Errorf("dest must be a pointer to slice")
|
||||
}
|
||||
|
||||
sliceValue := destValue.Elem()
|
||||
elemType := sliceValue.Type().Elem()
|
||||
|
||||
// Ensure element type is a pointer to struct
|
||||
if elemType.Kind() != reflect.Ptr || elemType.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("slice elements must be pointers to structs")
|
||||
}
|
||||
|
||||
stmt, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
for {
|
||||
hasRow, err := stmt.Step()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasRow {
|
||||
break
|
||||
}
|
||||
|
||||
// Create new instance of the element type
|
||||
newElem := reflect.New(elemType.Elem())
|
||||
if err := db.Scan(stmt, newElem.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sliceValue.Set(reflect.Append(sliceValue, newElem))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec executes a statement with fmt-style placeholders
|
||||
func (db *DB) Exec(query string, args ...any) error {
|
||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||
|
||||
sqlArgs := make([]any, len(args))
|
||||
for i, arg := range args {
|
||||
if i < len(paramTypes) && paramTypes[i] == "d" {
|
||||
// Convert to int64 for integer parameters
|
||||
switch v := arg.(type) {
|
||||
case int:
|
||||
sqlArgs[i] = int64(v)
|
||||
case int32:
|
||||
sqlArgs[i] = int64(v)
|
||||
case int64:
|
||||
sqlArgs[i] = v
|
||||
default:
|
||||
sqlArgs[i] = arg
|
||||
}
|
||||
} else {
|
||||
sqlArgs[i] = arg
|
||||
}
|
||||
}
|
||||
|
||||
return sqlitex.Execute(db.conn, convertedQuery, &sqlitex.ExecOptions{
|
||||
Args: sqlArgs,
|
||||
})
|
||||
}
|
||||
|
||||
// Update updates specific fields in the database
|
||||
func (db *DB) Update(tableName string, fields map[string]any, whereField string, whereValue any) error {
|
||||
if len(fields) == 0 {
|
||||
return nil // No changes
|
||||
}
|
||||
|
||||
// Build UPDATE query
|
||||
setParts := make([]string, 0, len(fields))
|
||||
args := make([]any, 0, len(fields)+1)
|
||||
|
||||
for field, value := range fields {
|
||||
setParts = append(setParts, field+" = ?")
|
||||
args = append(args, value)
|
||||
}
|
||||
|
||||
args = append(args, whereValue)
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
|
||||
tableName, strings.Join(setParts, ", "), whereField)
|
||||
|
||||
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Insert inserts a struct into the database
|
||||
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
}
|
||||
t := v.Type()
|
||||
|
||||
exclude := make(map[string]bool)
|
||||
for _, field := range excludeFields {
|
||||
exclude[toSnakeCase(field)] = true
|
||||
}
|
||||
|
||||
var columns []string
|
||||
var placeholders []string
|
||||
var args []any
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
}
|
||||
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, v.Field(i).Interface())
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||
|
||||
stmt, err := db.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
// Bind parameters
|
||||
for i, arg := range args {
|
||||
switch v := arg.(type) {
|
||||
case string:
|
||||
stmt.BindText(i+1, v)
|
||||
case int, int32, int64:
|
||||
stmt.BindInt64(i+1, reflect.ValueOf(v).Int())
|
||||
case float32, float64:
|
||||
stmt.BindFloat(i+1, reflect.ValueOf(v).Float())
|
||||
default:
|
||||
stmt.BindText(i+1, fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stmt.Step()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return db.conn.LastInsertRowID(), nil
|
||||
}
|
||||
|
||||
// Transaction executes multiple operations atomically
|
||||
func (db *DB) Transaction(fn func() error) error {
|
||||
// Begin transaction
|
||||
if err := sqlitex.Execute(db.conn, "BEGIN", nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute operations
|
||||
err := fn()
|
||||
|
||||
if err != nil {
|
||||
// Rollback on error
|
||||
sqlitex.Execute(db.conn, "ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit on success
|
||||
return sqlitex.Execute(db.conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
func convertPlaceholders(query string) (string, []string) {
|
||||
var paramTypes []string
|
||||
|
||||
convertedQuery := placeholderRegex.ReplaceAllStringFunc(query, func(match string) string {
|
||||
paramTypes = append(paramTypes, match[1:]) // Remove % prefix
|
||||
return "?"
|
||||
})
|
||||
|
||||
return convertedQuery, paramTypes
|
||||
}
|
||||
|
||||
// toSnakeCase converts PascalCase to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
runes := []rune(s)
|
||||
|
||||
for i, r := range runes {
|
||||
if i > 0 {
|
||||
prev := runes[i-1]
|
||||
|
||||
// Add underscore before digit if previous char was letter
|
||||
if unicode.IsDigit(r) && unicode.IsLetter(prev) {
|
||||
result.WriteByte('_')
|
||||
}
|
||||
// Add underscore before uppercase letter
|
||||
if unicode.IsUpper(r) {
|
||||
// Don't add if previous was also uppercase (unless end of acronym)
|
||||
if !unicode.IsUpper(prev) ||
|
||||
(i+1 < len(runes) && unicode.IsLower(runes[i+1])) {
|
||||
result.WriteByte('_')
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
}
|
||||
return result.String()
|
||||
}
|
19
go.mod
Normal file
19
go.mod
Normal file
@ -0,0 +1,19 @@
|
||||
module git.sharkk.net/Sharkk/Sashimi
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require zombiezen.com/go/sqlite v1.4.2
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
modernc.org/libc v1.65.7 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.37.1 // indirect
|
||||
)
|
25
go.sum
Normal file
25
go.sum
Normal file
@ -0,0 +1,25 @@
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
modernc.org/libc v1.65.7 h1:Ia9Z4yzZtWNtUIuiPuQ7Qf7kxYrxP1/jeHZzG8bFu00=
|
||||
modernc.org/libc v1.65.7/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
|
||||
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
|
||||
zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo=
|
||||
zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc=
|
264
migrate.go
Normal file
264
migrate.go
Normal file
@ -0,0 +1,264 @@
|
||||
package sashimi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
Number int
|
||||
Name string
|
||||
Filename string
|
||||
Content string
|
||||
}
|
||||
|
||||
type Migrator struct {
|
||||
db *DB
|
||||
dataDir string
|
||||
}
|
||||
|
||||
// migrationFileRegex matches files like "1_create_database.sql", "002_add_users.sql"
|
||||
var migrationFileRegex = regexp.MustCompile(`^(\d+)_(.+)\.sql$`)
|
||||
|
||||
// NewMigrator creates a new migration manager
|
||||
func NewMigrator(db *DB, dataDir string) *Migrator {
|
||||
return &Migrator{
|
||||
db: db,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
}
|
||||
|
||||
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist
|
||||
func (m *Migrator) ensureMigrationsTable() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
number INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
executed_at INTEGER NOT NULL
|
||||
)
|
||||
`
|
||||
return sqlitex.Execute(m.db.conn, query, nil)
|
||||
}
|
||||
|
||||
// getExecutedMigrations returns a map of migration numbers that have been executed
|
||||
func (m *Migrator) getExecutedMigrations() (map[int]bool, error) {
|
||||
executed := make(map[int]bool)
|
||||
|
||||
stmt, err := m.db.Query("SELECT number FROM migrations ORDER BY number")
|
||||
if err != nil {
|
||||
return executed, err
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
for {
|
||||
hasRow, err := stmt.Step()
|
||||
if err != nil {
|
||||
return executed, err
|
||||
}
|
||||
if !hasRow {
|
||||
break
|
||||
}
|
||||
executed[stmt.ColumnInt(0)] = true
|
||||
}
|
||||
|
||||
return executed, nil
|
||||
}
|
||||
|
||||
// loadMigrations reads all migration files from the data directory
|
||||
func (m *Migrator) loadMigrations() ([]Migration, error) {
|
||||
var migrations []Migration
|
||||
|
||||
err := filepath.WalkDir(m.dataDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.IsDir() || !strings.HasSuffix(d.Name(), ".sql") {
|
||||
return nil
|
||||
}
|
||||
|
||||
matches := migrationFileRegex.FindStringSubmatch(d.Name())
|
||||
if len(matches) != 3 {
|
||||
return nil // Skip files that don't match migration pattern
|
||||
}
|
||||
|
||||
number, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid migration number in %s: %w", d.Name(), err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migration file %s: %w", path, err)
|
||||
}
|
||||
|
||||
migrations = append(migrations, Migration{
|
||||
Number: number,
|
||||
Name: matches[2],
|
||||
Filename: d.Name(),
|
||||
Content: string(content),
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sort migrations by number
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Number < migrations[j].Number
|
||||
})
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// Run executes all pending migrations
|
||||
func (m *Migrator) Run() error {
|
||||
if err := m.ensureMigrationsTable(); err != nil {
|
||||
return fmt.Errorf("failed to ensure migrations table: %w", err)
|
||||
}
|
||||
|
||||
migrations, err := m.loadMigrations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load migrations: %w", err)
|
||||
}
|
||||
|
||||
executed, err := m.getExecutedMigrations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executed migrations: %w", err)
|
||||
}
|
||||
|
||||
pendingMigrations := []Migration{}
|
||||
for _, migration := range migrations {
|
||||
if !executed[migration.Number] {
|
||||
pendingMigrations = append(pendingMigrations, migration)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pendingMigrations) == 0 {
|
||||
fmt.Println("No pending migrations")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
|
||||
|
||||
return m.db.Transaction(func() error {
|
||||
for _, migration := range pendingMigrations {
|
||||
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
|
||||
|
||||
// Execute the migration SQL
|
||||
if err := sqlitex.Execute(m.db.conn, migration.Content, nil); err != nil {
|
||||
return fmt.Errorf("failed to execute migration %d (%s): %w",
|
||||
migration.Number, migration.Name, err)
|
||||
}
|
||||
|
||||
// Record the migration as executed
|
||||
_, err := m.db.Insert("migrations", map[string]any{
|
||||
"number": migration.Number,
|
||||
"name": migration.Name,
|
||||
"filename": migration.Filename,
|
||||
"executed_at": time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration %d: %w", migration.Number, err)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Migration %d completed\n", migration.Number)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CreateNew creates a new migration file with the given name
|
||||
func (m *Migrator) CreateNew(name string) error {
|
||||
migrations, err := m.loadMigrations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load existing migrations: %w", err)
|
||||
}
|
||||
|
||||
// Find the next migration number
|
||||
nextNumber := 1
|
||||
if len(migrations) > 0 {
|
||||
lastMigration := migrations[len(migrations)-1]
|
||||
nextNumber = lastMigration.Number + 1
|
||||
}
|
||||
|
||||
// Clean the migration name
|
||||
cleanName := strings.ReplaceAll(name, " ", "_")
|
||||
cleanName = regexp.MustCompile(`[^a-zA-Z0-9_]`).ReplaceAllString(cleanName, "")
|
||||
cleanName = strings.ToLower(cleanName)
|
||||
|
||||
filename := fmt.Sprintf("%d_%s.sql", nextNumber, cleanName)
|
||||
filepath := filepath.Join(m.dataDir, filename)
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(filepath); err == nil {
|
||||
return fmt.Errorf("migration file %s already exists", filename)
|
||||
}
|
||||
|
||||
// Create the migration file with a template
|
||||
template := fmt.Sprintf(`-- Migration %d: %s
|
||||
-- Created: %s
|
||||
|
||||
-- Add your SQL statements here
|
||||
|
||||
`, nextNumber, name, time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
if err := os.WriteFile(filepath, []byte(template), 0644); err != nil {
|
||||
return fmt.Errorf("failed to create migration file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Created migration: %s\n", filename)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Status shows the current migration status
|
||||
func (m *Migrator) Status() error {
|
||||
if err := m.ensureMigrationsTable(); err != nil {
|
||||
return fmt.Errorf("failed to ensure migrations table: %w", err)
|
||||
}
|
||||
|
||||
migrations, err := m.loadMigrations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load migrations: %w", err)
|
||||
}
|
||||
|
||||
executed, err := m.getExecutedMigrations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executed migrations: %w", err)
|
||||
}
|
||||
|
||||
if len(migrations) == 0 {
|
||||
fmt.Println("No migration files found")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("Migration Status:")
|
||||
fmt.Println("=================")
|
||||
|
||||
pendingCount := 0
|
||||
for _, migration := range migrations {
|
||||
status := "PENDING"
|
||||
if executed[migration.Number] {
|
||||
status = "EXECUTED"
|
||||
} else {
|
||||
pendingCount++
|
||||
}
|
||||
fmt.Printf("%d_%s.sql - %s\n", migration.Number, migration.Name, status)
|
||||
}
|
||||
|
||||
fmt.Printf("\nTotal: %d migrations, %d pending\n", len(migrations), pendingCount)
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user