You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
1.9 KiB
75 lines
1.9 KiB
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"io/fs"
|
|
"sort"
|
|
"strings"
|
|
)
|
|
|
|
func Migrate(database *sql.DB, migrationFS fs.FS) error {
|
|
if _, err := database.Exec(`CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY)`); err != nil {
|
|
return fmt.Errorf("create schema_version: %w", err)
|
|
}
|
|
|
|
subFS, err := fs.Sub(migrationFS, "migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("sub migrations: %w", err)
|
|
}
|
|
|
|
files, err := fs.ReadDir(subFS, ".")
|
|
if err != nil {
|
|
return fmt.Errorf("read migrations: %w", err)
|
|
}
|
|
|
|
var currentVersion int
|
|
if err := database.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_version").Scan(¤tVersion); err != nil {
|
|
return fmt.Errorf("get current version: %w", err)
|
|
}
|
|
|
|
type migration struct {
|
|
version int
|
|
sql string
|
|
}
|
|
var migrations []migration
|
|
for _, f := range files {
|
|
if f.IsDir() || !strings.HasSuffix(f.Name(), ".sql") {
|
|
continue
|
|
}
|
|
var ver int
|
|
if _, err := fmt.Sscanf(f.Name(), "%d_", &ver); err != nil {
|
|
continue
|
|
}
|
|
content, err := fs.ReadFile(subFS, f.Name())
|
|
if err != nil {
|
|
return fmt.Errorf("read migration %s: %w", f.Name(), err)
|
|
}
|
|
migrations = append(migrations, migration{version: ver, sql: string(content)})
|
|
}
|
|
|
|
sort.Slice(migrations, func(i, j int) bool { return migrations[i].version < migrations[j].version })
|
|
|
|
for _, m := range migrations {
|
|
if m.version <= currentVersion {
|
|
continue
|
|
}
|
|
tx, err := database.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("begin tx for migration %d: %w", m.version, err)
|
|
}
|
|
if _, err := tx.Exec(m.sql); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("run migration %d: %w", m.version, err)
|
|
}
|
|
if _, err := tx.Exec("INSERT INTO schema_version (version) VALUES (?)", m.version); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("record migration %d: %w", m.version, err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit migration %d: %w", m.version, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|