package sashimi import ( "context" "fmt" "io/fs" "os" "path/filepath" "regexp" "sort" "strconv" "strings" "time" "zombiezen.com/go/sqlite/sqlitex" ) type Migration struct { Number int Name string Filename string Content string } type Migrator struct { db *DB dataDir string } // migrationFileRegex matches files like "1_create_database.sql", "002_add_users.sql" var migrationFileRegex = regexp.MustCompile(`^(\d+)_(.+)\.sql$`) // NewMigrator creates a new migration manager func NewMigrator(db *DB, dataDir string) *Migrator { return &Migrator{ db: db, dataDir: dataDir, } } // ensureMigrationsTable creates the migrations tracking table if it doesn't exist func (m *Migrator) ensureMigrationsTable() error { conn, err := m.db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection: %w", err) } defer m.db.pool.Put(conn) query := ` CREATE TABLE IF NOT EXISTS migrations ( number INTEGER PRIMARY KEY, name TEXT NOT NULL, filename TEXT NOT NULL, executed_at INTEGER NOT NULL ) ` return sqlitex.Execute(conn, query, nil) } // getExecutedMigrations returns a map of migration numbers that have been executed func (m *Migrator) getExecutedMigrations() (map[int]bool, error) { executed := make(map[int]bool) stmt, err := m.db.Query("SELECT number FROM migrations ORDER BY number") if err != nil { return executed, err } defer stmt.Finalize() for { hasRow, err := stmt.Step() if err != nil { return executed, err } if !hasRow { break } executed[stmt.ColumnInt(0)] = true } return executed, nil } // loadMigrations reads all migration files from the data directory func (m *Migrator) loadMigrations() ([]Migration, error) { var migrations []Migration err := filepath.WalkDir(m.dataDir, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } if d.IsDir() || !strings.HasSuffix(d.Name(), ".sql") { return nil } matches := migrationFileRegex.FindStringSubmatch(d.Name()) if len(matches) != 3 { return nil // Skip files that don't match migration pattern } number, err := strconv.Atoi(matches[1]) if err != nil { return fmt.Errorf("invalid migration number in %s: %w", d.Name(), err) } content, err := os.ReadFile(path) if err != nil { return fmt.Errorf("failed to read migration file %s: %w", path, err) } migrations = append(migrations, Migration{ Number: number, Name: matches[2], Filename: d.Name(), Content: string(content), }) return nil }) if err != nil { return nil, err } // Sort migrations by number sort.Slice(migrations, func(i, j int) bool { return migrations[i].Number < migrations[j].Number }) return migrations, nil } // Run executes all pending migrations func (m *Migrator) Run() error { if err := m.ensureMigrationsTable(); err != nil { return fmt.Errorf("failed to ensure migrations table: %w", err) } migrations, err := m.loadMigrations() if err != nil { return fmt.Errorf("failed to load migrations: %w", err) } executed, err := m.getExecutedMigrations() if err != nil { return fmt.Errorf("failed to get executed migrations: %w", err) } pendingMigrations := []Migration{} for _, migration := range migrations { if !executed[migration.Number] { pendingMigrations = append(pendingMigrations, migration) } } if len(pendingMigrations) == 0 { fmt.Println("No pending migrations") return nil } fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations)) return m.db.Transaction(func() error { conn, err := m.db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection: %w", err) } defer m.db.pool.Put(conn) for _, migration := range pendingMigrations { fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name) // Execute the migration SQL if err := sqlitex.ExecuteScript(conn, migration.Content, nil); err != nil { return fmt.Errorf("failed to execute migration %d (%s): %w", migration.Number, migration.Name, err) } // Record the migration as executed _, err := m.db.Insert("migrations", map[string]any{ "number": migration.Number, "name": migration.Name, "filename": migration.Filename, "executed_at": time.Now().Unix(), }) if err != nil { return fmt.Errorf("failed to record migration %d: %w", migration.Number, err) } fmt.Printf("✓ Migration %d completed\n", migration.Number) } return nil }) } // CreateNew creates a new migration file with the given name func (m *Migrator) CreateNew(name string) error { migrations, err := m.loadMigrations() if err != nil { return fmt.Errorf("failed to load existing migrations: %w", err) } // Find the next migration number nextNumber := 1 if len(migrations) > 0 { lastMigration := migrations[len(migrations)-1] nextNumber = lastMigration.Number + 1 } // Clean the migration name cleanName := strings.ReplaceAll(name, " ", "_") cleanName = regexp.MustCompile(`[^a-zA-Z0-9_]`).ReplaceAllString(cleanName, "") cleanName = strings.ToLower(cleanName) filename := fmt.Sprintf("%d_%s.sql", nextNumber, cleanName) filepath := filepath.Join(m.dataDir, filename) // Check if file already exists if _, err := os.Stat(filepath); err == nil { return fmt.Errorf("migration file %s already exists", filename) } // Create the migration file with a template template := fmt.Sprintf(`-- Migration %d: %s -- Created: %s -- Add your SQL statements here `, nextNumber, name, time.Now().Format("2006-01-02 15:04:05")) if err := os.WriteFile(filepath, []byte(template), 0644); err != nil { return fmt.Errorf("failed to create migration file: %w", err) } fmt.Printf("Created migration: %s\n", filename) return nil } // Status shows the current migration status func (m *Migrator) Status() error { if err := m.ensureMigrationsTable(); err != nil { return fmt.Errorf("failed to ensure migrations table: %w", err) } migrations, err := m.loadMigrations() if err != nil { return fmt.Errorf("failed to load migrations: %w", err) } executed, err := m.getExecutedMigrations() if err != nil { return fmt.Errorf("failed to get executed migrations: %w", err) } if len(migrations) == 0 { fmt.Println("No migration files found") return nil } fmt.Println("Migration Status:") fmt.Println("=================") pendingCount := 0 for _, migration := range migrations { status := "PENDING" if executed[migration.Number] { status = "EXECUTED" } else { pendingCount++ } fmt.Printf("%d_%s.sql - %s\n", migration.Number, migration.Name, status) } fmt.Printf("\nTotal: %d migrations, %d pending\n", len(migrations), pendingCount) return nil }