278 lines
6.6 KiB
Go
278 lines
6.6 KiB
Go
package sashimi
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io/fs"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
)
|
|
|
|
type Migration struct {
|
|
Number int
|
|
Name string
|
|
Filename string
|
|
Content string
|
|
}
|
|
|
|
type Migrator struct {
|
|
db *DB
|
|
dataDir string
|
|
}
|
|
|
|
// migrationFileRegex matches files like "1_create_database.sql", "002_add_users.sql"
|
|
var migrationFileRegex = regexp.MustCompile(`^(\d+)_(.+)\.sql$`)
|
|
|
|
// NewMigrator creates a new migration manager
|
|
func NewMigrator(db *DB, dataDir string) *Migrator {
|
|
return &Migrator{
|
|
db: db,
|
|
dataDir: dataDir,
|
|
}
|
|
}
|
|
|
|
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist
|
|
func (m *Migrator) ensureMigrationsTable() error {
|
|
conn, err := m.db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
defer m.db.pool.Put(conn)
|
|
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS migrations (
|
|
number INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
filename TEXT NOT NULL,
|
|
executed_at INTEGER NOT NULL
|
|
)
|
|
`
|
|
return sqlitex.Execute(conn, query, nil)
|
|
}
|
|
|
|
// getExecutedMigrations returns a map of migration numbers that have been executed
|
|
func (m *Migrator) getExecutedMigrations() (map[int]bool, error) {
|
|
executed := make(map[int]bool)
|
|
|
|
stmt, err := m.db.Query("SELECT number FROM migrations ORDER BY number")
|
|
if err != nil {
|
|
return executed, err
|
|
}
|
|
defer stmt.Finalize()
|
|
|
|
for {
|
|
hasRow, err := stmt.Step()
|
|
if err != nil {
|
|
return executed, err
|
|
}
|
|
if !hasRow {
|
|
break
|
|
}
|
|
executed[stmt.ColumnInt(0)] = true
|
|
}
|
|
|
|
return executed, nil
|
|
}
|
|
|
|
// loadMigrations reads all migration files from the data directory
|
|
func (m *Migrator) loadMigrations() ([]Migration, error) {
|
|
var migrations []Migration
|
|
|
|
err := filepath.WalkDir(m.dataDir, func(path string, d fs.DirEntry, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if d.IsDir() || !strings.HasSuffix(d.Name(), ".sql") {
|
|
return nil
|
|
}
|
|
|
|
matches := migrationFileRegex.FindStringSubmatch(d.Name())
|
|
if len(matches) != 3 {
|
|
return nil // Skip files that don't match migration pattern
|
|
}
|
|
|
|
number, err := strconv.Atoi(matches[1])
|
|
if err != nil {
|
|
return fmt.Errorf("invalid migration number in %s: %w", d.Name(), err)
|
|
}
|
|
|
|
content, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read migration file %s: %w", path, err)
|
|
}
|
|
|
|
migrations = append(migrations, Migration{
|
|
Number: number,
|
|
Name: matches[2],
|
|
Filename: d.Name(),
|
|
Content: string(content),
|
|
})
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Sort migrations by number
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Number < migrations[j].Number
|
|
})
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
// Run executes all pending migrations
|
|
func (m *Migrator) Run() error {
|
|
if err := m.ensureMigrationsTable(); err != nil {
|
|
return fmt.Errorf("failed to ensure migrations table: %w", err)
|
|
}
|
|
|
|
migrations, err := m.loadMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load migrations: %w", err)
|
|
}
|
|
|
|
executed, err := m.getExecutedMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get executed migrations: %w", err)
|
|
}
|
|
|
|
pendingMigrations := []Migration{}
|
|
for _, migration := range migrations {
|
|
if !executed[migration.Number] {
|
|
pendingMigrations = append(pendingMigrations, migration)
|
|
}
|
|
}
|
|
|
|
if len(pendingMigrations) == 0 {
|
|
fmt.Println("No pending migrations")
|
|
return nil
|
|
}
|
|
|
|
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
|
|
|
|
return m.db.Transaction(func() error {
|
|
conn, err := m.db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
defer m.db.pool.Put(conn)
|
|
|
|
for _, migration := range pendingMigrations {
|
|
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
|
|
|
|
// Execute the migration SQL
|
|
if err := sqlitex.ExecuteScript(conn, migration.Content, nil); err != nil {
|
|
return fmt.Errorf("failed to execute migration %d (%s): %w",
|
|
migration.Number, migration.Name, err)
|
|
}
|
|
|
|
// Record the migration as executed
|
|
_, err := m.db.Insert("migrations", map[string]any{
|
|
"number": migration.Number,
|
|
"name": migration.Name,
|
|
"filename": migration.Filename,
|
|
"executed_at": time.Now().Unix(),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to record migration %d: %w", migration.Number, err)
|
|
}
|
|
|
|
fmt.Printf("✓ Migration %d completed\n", migration.Number)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// CreateNew creates a new migration file with the given name
|
|
func (m *Migrator) CreateNew(name string) error {
|
|
migrations, err := m.loadMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load existing migrations: %w", err)
|
|
}
|
|
|
|
// Find the next migration number
|
|
nextNumber := 1
|
|
if len(migrations) > 0 {
|
|
lastMigration := migrations[len(migrations)-1]
|
|
nextNumber = lastMigration.Number + 1
|
|
}
|
|
|
|
// Clean the migration name
|
|
cleanName := strings.ReplaceAll(name, " ", "_")
|
|
cleanName = regexp.MustCompile(`[^a-zA-Z0-9_]`).ReplaceAllString(cleanName, "")
|
|
cleanName = strings.ToLower(cleanName)
|
|
|
|
filename := fmt.Sprintf("%d_%s.sql", nextNumber, cleanName)
|
|
filepath := filepath.Join(m.dataDir, filename)
|
|
|
|
// Check if file already exists
|
|
if _, err := os.Stat(filepath); err == nil {
|
|
return fmt.Errorf("migration file %s already exists", filename)
|
|
}
|
|
|
|
// Create the migration file with a template
|
|
template := fmt.Sprintf(`-- Migration %d: %s
|
|
-- Created: %s
|
|
|
|
-- Add your SQL statements here
|
|
|
|
`, nextNumber, name, time.Now().Format("2006-01-02 15:04:05"))
|
|
|
|
if err := os.WriteFile(filepath, []byte(template), 0644); err != nil {
|
|
return fmt.Errorf("failed to create migration file: %w", err)
|
|
}
|
|
|
|
fmt.Printf("Created migration: %s\n", filename)
|
|
return nil
|
|
}
|
|
|
|
// Status shows the current migration status
|
|
func (m *Migrator) Status() error {
|
|
if err := m.ensureMigrationsTable(); err != nil {
|
|
return fmt.Errorf("failed to ensure migrations table: %w", err)
|
|
}
|
|
|
|
migrations, err := m.loadMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load migrations: %w", err)
|
|
}
|
|
|
|
executed, err := m.getExecutedMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get executed migrations: %w", err)
|
|
}
|
|
|
|
if len(migrations) == 0 {
|
|
fmt.Println("No migration files found")
|
|
return nil
|
|
}
|
|
|
|
fmt.Println("Migration Status:")
|
|
fmt.Println("=================")
|
|
|
|
pendingCount := 0
|
|
for _, migration := range migrations {
|
|
status := "PENDING"
|
|
if executed[migration.Number] {
|
|
status = "EXECUTED"
|
|
} else {
|
|
pendingCount++
|
|
}
|
|
fmt.Printf("%d_%s.sql - %s\n", migration.Number, migration.Name, status)
|
|
}
|
|
|
|
fmt.Printf("\nTotal: %d migrations, %d pending\n", len(migrations), pendingCount)
|
|
return nil
|
|
}
|