You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

75 lines
1.9 KiB

package db
import (
"database/sql"
"fmt"
"io/fs"
"sort"
"strings"
)
func Migrate(database *sql.DB, migrationFS fs.FS) error {
if _, err := database.Exec(`CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY)`); err != nil {
return fmt.Errorf("create schema_version: %w", err)
}
subFS, err := fs.Sub(migrationFS, "migrations")
if err != nil {
return fmt.Errorf("sub migrations: %w", err)
}
files, err := fs.ReadDir(subFS, ".")
if err != nil {
return fmt.Errorf("read migrations: %w", err)
}
var currentVersion int
if err := database.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_version").Scan(&currentVersion); err != nil {
return fmt.Errorf("get current version: %w", err)
}
type migration struct {
version int
sql string
}
var migrations []migration
for _, f := range files {
if f.IsDir() || !strings.HasSuffix(f.Name(), ".sql") {
continue
}
var ver int
if _, err := fmt.Sscanf(f.Name(), "%d_", &ver); err != nil {
continue
}
content, err := fs.ReadFile(subFS, f.Name())
if err != nil {
return fmt.Errorf("read migration %s: %w", f.Name(), err)
}
migrations = append(migrations, migration{version: ver, sql: string(content)})
}
sort.Slice(migrations, func(i, j int) bool { return migrations[i].version < migrations[j].version })
for _, m := range migrations {
if m.version <= currentVersion {
continue
}
tx, err := database.Begin()
if err != nil {
return fmt.Errorf("begin tx for migration %d: %w", m.version, err)
}
if _, err := tx.Exec(m.sql); err != nil {
tx.Rollback()
return fmt.Errorf("run migration %d: %w", m.version, err)
}
if _, err := tx.Exec("INSERT INTO schema_version (version) VALUES (?)", m.version); err != nil {
tx.Rollback()
return fmt.Errorf("record migration %d: %w", m.version, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit migration %d: %w", m.version, err)
}
}
return nil
}