package db import ( "database/sql" "fmt" "io/fs" "sort" "strings" ) func Migrate(database *sql.DB, migrationFS fs.FS) error { if _, err := database.Exec(`CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY)`); err != nil { return fmt.Errorf("create schema_version: %w", err) } subFS, err := fs.Sub(migrationFS, "migrations") if err != nil { return fmt.Errorf("sub migrations: %w", err) } files, err := fs.ReadDir(subFS, ".") if err != nil { return fmt.Errorf("read migrations: %w", err) } var currentVersion int if err := database.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_version").Scan(¤tVersion); err != nil { return fmt.Errorf("get current version: %w", err) } type migration struct { version int sql string } var migrations []migration for _, f := range files { if f.IsDir() || !strings.HasSuffix(f.Name(), ".sql") { continue } var ver int if _, err := fmt.Sscanf(f.Name(), "%d_", &ver); err != nil { continue } content, err := fs.ReadFile(subFS, f.Name()) if err != nil { return fmt.Errorf("read migration %s: %w", f.Name(), err) } migrations = append(migrations, migration{version: ver, sql: string(content)}) } sort.Slice(migrations, func(i, j int) bool { return migrations[i].version < migrations[j].version }) for _, m := range migrations { if m.version <= currentVersion { continue } tx, err := database.Begin() if err != nil { return fmt.Errorf("begin tx for migration %d: %w", m.version, err) } if _, err := tx.Exec(m.sql); err != nil { tx.Rollback() return fmt.Errorf("run migration %d: %w", m.version, err) } if _, err := tx.Exec("INSERT INTO schema_version (version) VALUES (?)", m.version); err != nil { tx.Rollback() return fmt.Errorf("record migration %d: %w", m.version, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit migration %d: %w", m.version, err) } } return nil }