@@ 27,7 27,7 @@ func run() error {
dbUser := flag.String("u", "", "database user")
dbHost := flag.String("h", "127.0.0.1", "database host")
dbPort := flag.Int("p", 0, "database port")
- dbType := flag.String("t", "mysql", "type of database (mysql, postgres, sqlite)")
+ dbType := flag.String("t", "mysql", "type of database (mysql, mariadb, postgres, sqlite)")
dry := flag.Bool("d", false, "dry run")
sslKey := flag.String("ssl-key", "", "path to client key pem")
sslCert := flag.String("ssl-cert", "", "path to client cert pem")
@@ 93,7 93,7 @@ func run() error {
if *dbPort == 0 {
*dbPort = 5432
}
- case "mysql":
+ case "mysql", "mariadb":
if *dbUser == "" {
*dbUser = "root"
}
@@ 123,7 123,7 @@ func run() error {
// Prepare our database-specific configs
var db migrate.Store
switch *dbType {
- case "mysql":
+ case "mysql", "mariadb":
var err error
db, err = mysql.New(*dbUser, string(password), *dbHost,
*dbName, *dbPort, *sslKey, *sslCert, *sslCA,
@@ 146,8 146,23 @@ func run() error {
return errors.Wrap(err, "open")
}
+ var dbt migrate.DBType
+ switch *dbType {
+ case "mysql":
+ dbt = migrate.DBTypeMySQL
+ case "mariadb":
+ dbt = migrate.DBTypeMariaDB
+ case "postgres":
+ dbt = migrate.DBTypePostgres
+ case "sqlite":
+ dbt = migrate.DBTypeSQLite
+ default:
+ return fmt.Errorf("unknown db type: %s", *dbType)
+ }
+
// Prepare our database for migrations and collect the relevant files.
- m, err := migrate.New(db, migrate.StdLogger{}, *migrationDir, *skip)
+ m, err := migrate.New(db, migrate.StdLogger{}, dbt, *migrationDir,
+ *skip)
if err != nil {
return err
}
@@ 157,7 172,7 @@ func run() error {
return nil
}
for i := len(m.Migrations); i < len(m.Files); i++ {
- fmt.Println("would migrate", m.Files[i].Name())
+ fmt.Println("would migrate", m.Files[i].Info.Name())
}
return nil
}
@@ 23,36 23,51 @@ var spaces = regexp.MustCompile(`\s+`)
type Migrate struct {
Migrations []Migration
- Files []os.FileInfo
+ Files []*file
db Store
log Logger
- dir string
idx int
}
+type file struct {
+ Info os.FileInfo
+ fullpath string
+}
+
type Migration struct {
Filename string
Checksum string
Content string
+ fullpath string
}
var regexNum = regexp.MustCompile(`^\d+`)
+type DBType string
+
+const (
+ DBTypeMySQL DBType = "mysql"
+ DBTypeMariaDB DBType = "mariadb"
+ DBTypePostgres DBType = "postgres"
+ DBTypeSQLite DBType = "sqlite"
+)
+
func New(
db Store,
log Logger,
+ dbt DBType,
dir, skip string,
) (*Migrate, error) {
- m := &Migrate{db: db, log: log, dir: dir}
+ m := &Migrate{db: db, log: log}
// Get files in migration dir and sort them
var err error
- m.Files, err = readdir(dir)
+ m.Files, err = readDir(dir, dbt)
if err != nil {
return nil, errors.Wrap(err, "get migrations")
}
- if err = sortfiles(m.Files); err != nil {
+ if err = sortFiles(m.Files); err != nil {
return nil, errors.Wrap(err, "sort")
}
@@ 101,6 116,19 @@ func New(
return nil, errors.Wrap(err, "get migrations")
}
+ // Fill in migration fullpath field based on the db type.
+ overrides, err := getOverrideSet(dir, dbt)
+ if err != nil {
+ return nil, fmt.Errorf("get override set: %w", err)
+ }
+ for i, mg := range m.Migrations {
+ override, exist := overrides[mg.Filename]
+ if exist {
+ m.Migrations[i].fullpath = override.fullpath
+ } else {
+ m.Migrations[i].fullpath = filepath.Join(dir, mg.Filename)
+ }
+ }
if err = m.validHistory(); err != nil {
return nil, err
}
@@ 112,11 140,11 @@ func New(
func (m *Migrate) Migrate() (bool, error) {
var migrated bool
for i := len(m.Migrations); i < len(m.Files); i++ {
- filename := m.Files[i].Name()
- if err := m.migrateFile(filename); err != nil {
+ fi := m.Files[i]
+ if err := m.migrateFile(fi); err != nil {
return false, errors.Wrap(err, "migrate file")
}
- m.log.Println("migrated", filename)
+ m.log.Println("migrated", fi.Info.Name())
migrated = true
}
return migrated, nil
@@ 131,11 159,12 @@ func (m *Migrate) validHistory() error {
}
for i := m.idx; i < len(m.Migrations); i++ {
mg := m.Migrations[i]
- if mg.Filename != m.Files[i].Name() {
+ if mg.Filename != m.Files[i].Info.Name() {
m.log.Printf("\n%s was added to history before %s.\n",
- m.Files[i].Name(), mg.Filename)
+ m.Files[i].Info.Name(), mg.Filename)
return errors.New("failed to migrate. migrations must be appended")
}
+ fmt.Println("MIGRATION", mg.fullpath)
if err := m.checkHash(mg); err != nil {
return errors.Wrap(err, "check hash")
}
@@ 144,7 173,7 @@ func (m *Migrate) validHistory() error {
}
func (m *Migrate) checkHash(mg Migration) error {
- fi, err := os.Open(filepath.Join(m.dir, mg.Filename))
+ fi, err := os.Open(mg.fullpath)
if err != nil {
return err
}
@@ 161,9 190,8 @@ func (m *Migrate) checkHash(mg Migration) error {
return nil
}
-func (m *Migrate) migrateFile(filename string) error {
- pth := filepath.Join(m.dir, filename)
- byt, err := ioutil.ReadFile(pth)
+func (m *Migrate) migrateFile(fi *file) error {
+ byt, err := ioutil.ReadFile(fi.fullpath)
if err != nil {
return err
}
@@ 180,11 208,11 @@ func (m *Migrate) migrateFile(filename string) error {
// Ensure that commands are present
if len(filteredCmds) == 0 {
- return fmt.Errorf("no sql statements in file: %s", filename)
+ return fmt.Errorf("no sql statements in file: %s", fi.Info.Name())
}
// Get our checkpoints, if any
- checkpoints, err := m.db.GetMetaCheckpoints(filename)
+ checkpoints, err := m.db.GetMetaCheckpoints(fi.Info.Name())
if err != nil {
return errors.Wrap(err, "get checkpoints")
}
@@ 209,7 237,7 @@ func (m *Migrate) migrateFile(filename string) error {
if checksum != checkpoints[i] {
return fmt.Errorf(
"checksum does not equal checkpoint. has %s (cmd %d) changed?",
- filename, i)
+ fi.Info.Name(), i)
}
continue
}
@@ 228,7 256,7 @@ func (m *Migrate) migrateFile(filename string) error {
_, err := m.db.Exec(cmd)
if err != nil {
m.log.Println("failed on", cmd)
- return fmt.Errorf("%s: %s", filename, err)
+ return fmt.Errorf("%s: %s", fi.Info.Name(), err)
}
// Save a checkpoint
@@ 236,7 264,7 @@ func (m *Migrate) migrateFile(filename string) error {
if err != nil {
return errors.Wrap(err, "compute checksum")
}
- err = m.db.InsertMetaCheckpoint(filename, cmd, checksum, i)
+ err = m.db.InsertMetaCheckpoint(fi.Info.Name(), cmd, checksum, i)
if err != nil {
return errors.Wrap(err, "insert checkpoint")
}
@@ 252,7 280,8 @@ func (m *Migrate) migrateFile(filename string) error {
if err != nil {
return errors.Wrap(err, "compute file checksum")
}
- if err = m.db.InsertMigration(filename, string(byt), checksum); err != nil {
+ err = m.db.InsertMigration(fi.Info.Name(), string(byt), checksum)
+ if err != nil {
return errors.Wrap(err, "insert migration")
}
return nil
@@ 265,7 294,7 @@ func (m *Migrate) skip(toFile string) (int, error) {
// Ensure the file exists
index := -1
for i, fi := range m.Files {
- if fi.Name() == toFile {
+ if fi.Info.Name() == toFile {
index = i
break
}
@@ 274,8 303,7 @@ func (m *Migrate) skip(toFile string) (int, error) {
return 0, fmt.Errorf("%s does not exist", toFile)
}
for i := 0; i <= index; i++ {
- name := m.Files[i].Name()
- fi, err := os.Open(filepath.Join(m.dir, name))
+ fi, err := os.Open(m.Files[i].fullpath)
if err != nil {
return -1, err
}
@@ 284,7 312,9 @@ func (m *Migrate) skip(toFile string) (int, error) {
fi.Close()
return -1, err
}
- if err = m.db.UpsertMigration(name, content, checksum); err != nil {
+ name := m.Files[i].Info.Name()
+ err = m.db.UpsertMigration(name, content, checksum)
+ if err != nil {
fi.Close()
return -1, err
}
@@ 307,14 337,28 @@ func computeChecksum(r io.Reader) (content string, checksum string, err error) {
return string(byt), fmt.Sprintf("%x", h.Sum(nil)), nil
}
-// readdir collects file infos from the migration directory.
-func readdir(dir string) ([]os.FileInfo, error) {
- files := []os.FileInfo{}
+// readDir collects file infos from the migration directory.
+func readDir(dir string, dbt DBType) ([]*file, error) {
+ files := []*file{}
tmp, err := ioutil.ReadDir(dir)
if err != nil {
return nil, errors.Wrap(err, "read dir")
}
+
+ // Allow for DB-specific workarounds. For instance, if MySQL and
+ // MariaDB are subtly incompatible (and they are, as they name
+ // CONSTRAINTS in different ways), then it's possible later migrations
+ // will work on one database but not another, even though they should
+ // be compatible. There is no easy workaround, especially when you're
+ // on an OS with access to one database but not the other.
+ //
+ // To ease this, we crawl through secondary directories specific to the
+ // name of the DB used. If "migrate -t maria-db" then we'll look for
+ // the `maria-db` folder and prefer identical migration filenames in
+ // that folder over the other one.
for _, fi := range tmp {
+ fullpath := filepath.Join(dir, fi.Name())
+
// Skip directories and hidden files
if fi.IsDir() || strings.HasPrefix(fi.Name(), ".") {
continue
@@ 323,32 367,74 @@ func readdir(dir string) ([]os.FileInfo, error) {
if filepath.Ext(fi.Name()) != ".sql" {
continue
}
- files = append(files, fi)
+ files = append(files, &file{Info: fi, fullpath: fullpath})
}
if len(files) == 0 {
return nil, errors.New("no sql migration files found (might be the wrong -dir)")
}
+
+ // Prioritize our specific database over the set in the main migration
+ // directory.
+ overrideSet, err := getOverrideSet(dir, dbt)
+ if err != nil {
+ return nil, fmt.Errorf("get override set: %w", err)
+ }
+ for i, fi := range files {
+ if override, exist := overrideSet[fi.Info.Name()]; exist {
+ files[i] = override
+ fmt.Println("OVERRIDING", override.Info.Name())
+ }
+ }
return files, nil
}
-// sortfiles by name, ensuring that something like 1.sql, 2.sql, 10.sql is
+func getOverrideSet(dir string, dbt DBType) (map[string]*file, error) {
+ tmp, err := ioutil.ReadDir(dir)
+ if err != nil {
+ return nil, errors.Wrap(err, "read dir")
+ }
+ overrides := []*file{}
+ for _, fi := range tmp {
+ fullpath := filepath.Join(dir, fi.Name())
+ if !fi.IsDir() || fi.Name() != string(dbt) {
+ continue
+ }
+
+ // The empty DBType prevents recursive descent into structures
+ // like ./mariadb/mariadb/mariadb/...
+ overrides, err = readDir(fullpath, DBType(""))
+ if err != nil {
+ return nil, fmt.Errorf("read dir %s: %w",
+ fi.Name(), err)
+ }
+ }
+ overrideSet := make(map[string]*file, len(overrides))
+ for _, o := range overrides {
+ overrideSet[o.Info.Name()] = o
+ }
+ return overrideSet, nil
+}
+
+// sortFiles by name, ensuring that something like 1.sql, 2.sql, 10.sql is
// ordered correctly.
-func sortfiles(files []os.FileInfo) error {
+func sortFiles(files []*file) error {
var nameErr error
sort.Slice(files, func(i, j int) bool {
if nameErr != nil {
return false
}
- fiName1 := regexNum.FindString(files[i].Name())
- fiName2 := regexNum.FindString(files[j].Name())
+ fiName1 := regexNum.FindString(files[i].Info.Name())
+ fiName2 := regexNum.FindString(files[j].Info.Name())
fiNum1, err := strconv.ParseUint(fiName1, 10, 64)
if err != nil {
- nameErr = errors.Wrapf(err, "parse uint in file %s", files[i].Name())
+ nameErr = errors.Wrapf(err, "parse uint in file %s",
+ files[i].Info.Name())
return false
}
fiNum2, err := strconv.ParseUint(fiName2, 10, 64)
if err != nil {
- nameErr = errors.Wrapf(err, "parse uint in file %s", files[i].Name())
+ nameErr = errors.Wrapf(err, "parse uint in file %s",
+ files[i].Info.Name())
return false
}
if fiNum1 == fiNum2 {
@@ 362,15 448,16 @@ func sortfiles(files []os.FileInfo) error {
func migrationsFromFiles(m *Migrate) ([]Migration, error) {
ms := make([]Migration, len(m.Files))
- for i, fileInfo := range m.Files {
- filename := filepath.Join(m.dir, fileInfo.Name())
- byt, err := ioutil.ReadFile(filename)
+ for i, fi := range m.Files {
+ fmt.Println("FULLPATH", fi.fullpath)
+ byt, err := ioutil.ReadFile(fi.fullpath)
if err != nil {
return nil, errors.Wrap(err, "read file")
}
ms[i] = Migration{
- Filename: fileInfo.Name(),
+ Filename: fi.Info.Name(),
Content: string(byt),
+ fullpath: fi.fullpath,
}
}
return ms, nil