Sashimi/migrate.go

278 lines
6.6 KiB
Go

package sashimi
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"time"
"zombiezen.com/go/sqlite/sqlitex"
)
type Migration struct {
Number int
Name string
Filename string
Content string
}
type Migrator struct {
db *DB
dataDir string
}
// migrationFileRegex matches files like "1_create_database.sql", "002_add_users.sql"
var migrationFileRegex = regexp.MustCompile(`^(\d+)_(.+)\.sql$`)
// NewMigrator creates a new migration manager
func NewMigrator(db *DB, dataDir string) *Migrator {
return &Migrator{
db: db,
dataDir: dataDir,
}
}
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist
func (m *Migrator) ensureMigrationsTable() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
query := `
CREATE TABLE IF NOT EXISTS migrations (
number INTEGER PRIMARY KEY,
name TEXT NOT NULL,
filename TEXT NOT NULL,
executed_at INTEGER NOT NULL
)
`
return sqlitex.Execute(conn, query, nil)
}
// getExecutedMigrations returns a map of migration numbers that have been executed
func (m *Migrator) getExecutedMigrations() (map[int]bool, error) {
executed := make(map[int]bool)
stmt, err := m.db.Query("SELECT number FROM migrations ORDER BY number")
if err != nil {
return executed, err
}
defer stmt.Finalize()
for {
hasRow, err := stmt.Step()
if err != nil {
return executed, err
}
if !hasRow {
break
}
executed[stmt.ColumnInt(0)] = true
}
return executed, nil
}
// loadMigrations reads all migration files from the data directory
func (m *Migrator) loadMigrations() ([]Migration, error) {
var migrations []Migration
err := filepath.WalkDir(m.dataDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() || !strings.HasSuffix(d.Name(), ".sql") {
return nil
}
matches := migrationFileRegex.FindStringSubmatch(d.Name())
if len(matches) != 3 {
return nil // Skip files that don't match migration pattern
}
number, err := strconv.Atoi(matches[1])
if err != nil {
return fmt.Errorf("invalid migration number in %s: %w", d.Name(), err)
}
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read migration file %s: %w", path, err)
}
migrations = append(migrations, Migration{
Number: number,
Name: matches[2],
Filename: d.Name(),
Content: string(content),
})
return nil
})
if err != nil {
return nil, err
}
// Sort migrations by number
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Number < migrations[j].Number
})
return migrations, nil
}
// Run executes all pending migrations
func (m *Migrator) Run() error {
if err := m.ensureMigrationsTable(); err != nil {
return fmt.Errorf("failed to ensure migrations table: %w", err)
}
migrations, err := m.loadMigrations()
if err != nil {
return fmt.Errorf("failed to load migrations: %w", err)
}
executed, err := m.getExecutedMigrations()
if err != nil {
return fmt.Errorf("failed to get executed migrations: %w", err)
}
pendingMigrations := []Migration{}
for _, migration := range migrations {
if !executed[migration.Number] {
pendingMigrations = append(pendingMigrations, migration)
}
}
if len(pendingMigrations) == 0 {
fmt.Println("No pending migrations")
return nil
}
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
return m.db.Transaction(func() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
for _, migration := range pendingMigrations {
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
// Execute the migration SQL
if err := sqlitex.ExecuteScript(conn, migration.Content, nil); err != nil {
return fmt.Errorf("failed to execute migration %d (%s): %w",
migration.Number, migration.Name, err)
}
// Record the migration as executed
_, err := m.db.Insert("migrations", map[string]any{
"number": migration.Number,
"name": migration.Name,
"filename": migration.Filename,
"executed_at": time.Now().Unix(),
})
if err != nil {
return fmt.Errorf("failed to record migration %d: %w", migration.Number, err)
}
fmt.Printf("✓ Migration %d completed\n", migration.Number)
}
return nil
})
}
// CreateNew creates a new migration file with the given name
func (m *Migrator) CreateNew(name string) error {
migrations, err := m.loadMigrations()
if err != nil {
return fmt.Errorf("failed to load existing migrations: %w", err)
}
// Find the next migration number
nextNumber := 1
if len(migrations) > 0 {
lastMigration := migrations[len(migrations)-1]
nextNumber = lastMigration.Number + 1
}
// Clean the migration name
cleanName := strings.ReplaceAll(name, " ", "_")
cleanName = regexp.MustCompile(`[^a-zA-Z0-9_]`).ReplaceAllString(cleanName, "")
cleanName = strings.ToLower(cleanName)
filename := fmt.Sprintf("%d_%s.sql", nextNumber, cleanName)
filepath := filepath.Join(m.dataDir, filename)
// Check if file already exists
if _, err := os.Stat(filepath); err == nil {
return fmt.Errorf("migration file %s already exists", filename)
}
// Create the migration file with a template
template := fmt.Sprintf(`-- Migration %d: %s
-- Created: %s
-- Add your SQL statements here
`, nextNumber, name, time.Now().Format("2006-01-02 15:04:05"))
if err := os.WriteFile(filepath, []byte(template), 0644); err != nil {
return fmt.Errorf("failed to create migration file: %w", err)
}
fmt.Printf("Created migration: %s\n", filename)
return nil
}
// Status shows the current migration status
func (m *Migrator) Status() error {
if err := m.ensureMigrationsTable(); err != nil {
return fmt.Errorf("failed to ensure migrations table: %w", err)
}
migrations, err := m.loadMigrations()
if err != nil {
return fmt.Errorf("failed to load migrations: %w", err)
}
executed, err := m.getExecutedMigrations()
if err != nil {
return fmt.Errorf("failed to get executed migrations: %w", err)
}
if len(migrations) == 0 {
fmt.Println("No migration files found")
return nil
}
fmt.Println("Migration Status:")
fmt.Println("=================")
pendingCount := 0
for _, migration := range migrations {
status := "PENDING"
if executed[migration.Number] {
status = "EXECUTED"
} else {
pendingCount++
}
fmt.Printf("%d_%s.sql - %s\n", migration.Number, migration.Name, status)
}
fmt.Printf("\nTotal: %d migrations, %d pending\n", len(migrations), pendingCount)
return nil
}